In [None]:
#src pathing
import os
import sys

import logging
from os import path
from typing import Optional, Tuple, Any, Callable, Dict, List, Optional
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
import random

import argparse
from types import SimpleNamespace

import numpy as np
import pandas as pd
import torch
import torch.multiprocessing as mp
import yfinance as yf # type: ignore


from huggingface_hub import snapshot_download
from torch.utils.data import Dataset
from ffm import freq_map

from tools.utils import print_model_statistics, make_logging_file, log_config

from data_tools.TSdataset import TimeSeriesDataset_MultiCSV_train_Production, find_files_with_suffix
from tools.model_utils import plot_predictions, get_model_FFM, FFM_weight_freeze
from tools.inference_utils import plot_predictions_multi, plot_predictions_multi_distribution, plot_predictions_multi_distribution_v2


from ffm.ffm_base import FFmHparams

In [None]:
import dataclasses
from dataclasses import dataclass, field

@dataclasses.dataclass(kw_only=True)
class Freq_map_dict:
    major_6_bench_map: dict = field(
        default_factory=lambda:
    {
    "ettm2": "15min",
    "ettm1": "15min",
    "etth2": "H",
    "etth1": "H",
    "electricity": "H",
    "traffic": "H",
    "weather": "10min",
    "national_illness": 'W',
    "exchange_rate": 'D',
    }
    )
    
    
    major_6_bench_val_map: dict = field(
        default_factory=lambda:
    {
    "val_elec": "H",
    "val_etth1": "H",
    "val_ettm1": "15min",
    "val_exchange": "D",
    "val_illness": "W",
    "val_traffic": "H",
    "val_weather": "10min",
    }
    )

    universal_map: dict = field(
        default_factory=lambda:
    {
    "_1d.csv": 0,
    "_1wk.csv": 1,
    "_1h.csv": 0,
    "_1m.csv": 0,
    }
    )


FREQ_POSSIBLE_CONTEXT_LENGTH = {
    0 : [512, 256, 128],
    1 : [256, 128],
    2 : [64],
}


DATA_SLICE_INTERVAL_SMALL_D = {
    0 : 32,
    1 : 1,
    2 : 1,
}

In [None]:
config = SimpleNamespace()

# plt setting (if any)
config.random_seed = random.randint(0, 100000)
config.save_dir = r'pics/vis2'
config.save_name = ''

# data datasets\val_datasets_ts_major6  datasets\stock_v1\val_v1_nv datasets\stock_v1\val_v1_nv_m_h  datasets\stock_v1\test_v1_nv_flat
config.data_folder = r'datasets\stock_v1\test_v1_nv_flat'  # replace with your actual path
config.num_workers = 2
config.series_norm = False
config.mask_ratio = 0
config.freq_map_mode = 1             #0 is direct conversion, 1 is custom suffix name match



# training hyperparams
config.batch_size = 64

# model change param on this
config.checkpoint = r'checkpoints\hbcloss_ft_v4_lowquantile\hbc_v4_ep3_lowq_trained.pth'  # replace with actual path
config.num_experts = 4
config.gating_top_n = 2
config.load_from_compile = True

# device
config.device = "cuda" if torch.cuda.is_available() else "cpu"
config.gpu_ids = [0]



In [None]:
#model loading
ffm_hparams = FFmHparams(num_experts=config.num_experts,
                        gating_top_n=config.gating_top_n,
                        load_from_compile=config.load_from_compile,)

model, ffm_config, ffm_api = get_model_FFM(checkpoint=config.checkpoint, hparams=ffm_hparams)

print_model_statistics(model=model)

In [None]:
 #datasets
config.context_length_list = FREQ_POSSIBLE_CONTEXT_LENGTH #[32, 64, 128 ,256 , 512, 1024, 2048] #context length for variable length input
config.data_slice_interval = DATA_SLICE_INTERVAL_SMALL_D
freq_map_dict = Freq_map_dict()
train_freq_map = {}
eval_freq_map = {}

if config.freq_map_mode == 0:
    for k, v in freq_map_dict.major_6_bench_map.items():
        train_freq_map[k] = freq_map(v)
    
    for k, v in freq_map_dict.major_6_bench_val_map.items():
        eval_freq_map[k] = freq_map(v)
elif config.freq_map_mode == 1:
    train_freq_map = freq_map_dict.universal_map
    eval_freq_map = freq_map_dict.universal_map


#set up plt dataset
val_file_list = find_files_with_suffix(directory=config.data_folder, suffix='.csv')
val_dataset = TimeSeriesDataset_MultiCSV_train_Production(csv_paths=val_file_list, horizon_length=FFmHparams.output_patch_len,
                                                freq_map=eval_freq_map, freq_map_mode=config.freq_map_mode,
                                                mask_ratio=config.mask_ratio,
                                                possible_context_lengths=config.context_length_list,
                                                series_norm=config.series_norm,
                                                data_slice_interval=config.data_slice_interval,
                                                shuffle_seed=config.random_seed
                                                )


In [None]:
#output
output_length = 60
quantile_list = [1, 3, 7, 9]              #0 for mean, 1 - 9 for quantiles with 1 increment
config.num_img = 50
output_number=True

plot_predictions_multi_distribution_v2(
    model=model,
    val_dataset=val_dataset,
    number_img=config.num_img,
    model_version=os.path.basename(config.checkpoint),
    save_dir=None, #config.save_dir,
    save_name=config.save_name,
    moe_n=config.num_experts,
    moe_tk=config.gating_top_n,
    quantiles=quantile_list,
    output_length=output_length,
    output_number=output_number,
    )

In [None]:
#output
output_length = 96
quantile = 0              #0 for mean, 1 - 9 for quantiles with 1 increment
config.num_img = 20

plot_predictions_multi(
    model=model,
    val_dataset=val_dataset,
    number_img=config.num_img,
    model_version=os.path.basename(config.checkpoint),
    save_dir=None, #config.save_dir,
    save_name=config.save_name,
    moe_n=config.num_experts,
    moe_tk=config.gating_top_n,
    quantile=quantile,
    output_length=output_length,
    )

In [None]:
#output
output_length = 96
quantile = 5             #0 for mean, 1 - 9 for quantiles with 1 increment
config.num_img = 20

plot_predictions_multi(
    model=model,
    val_dataset=val_dataset,
    number_img=config.num_img,
    model_version=os.path.basename(config.checkpoint),
    save_dir=None, #config.save_dir,
    save_name=config.save_name,
    moe_n=config.num_experts,
    moe_tk=config.gating_top_n,
    quantile=quantile,
    output_length=output_length,
    )

In [None]:
import gc
del model  # Delete the model reference
gc.collect()  # Collect garbage
torch.cuda.empty_cache()  # Clear cached memory