### Imports

In [1]:
import os
import csv
from pathlib import Path
import time

import numpy as np

from dicl import dicl, adapters
from dicl.icl import iclearner as icl
from dicl.utils import data_readers

from momentfm import MOMENTPipeline

import importlib

importlib.reload(dicl)
importlib.reload(adapters)
importlib.reload(icl)
importlib.reload(data_readers)

os.environ["HF_HOME"] = "/mnt/vdb/hugguingface/"

  from .autonotebook import tqdm as notebook_tqdm
  torch.utils._pytree._register_pytree_node(
  torch.utils._pytree._register_pytree_node(


## DICL

In [15]:
is_fine_tuned = False
forecast_horizon = 96
model_name = "AutonLab/MOMENT-1-large"
context_length = 512

dataset_name = f"ETTh1_pred={forecast_horizon}"
time_series, X_train, y_train, X_test, y_test, n_features = prepare_data(
    dataset_name, context_length
)

base_projector = "pca"

start = n_features if not base_projector else 1
end = n_features + 1

data_path = Path("/mnt/vdb/abenechehab/dicl-adapters/results/data.csv")

for n_components in range(start, end):
    start_time = time.time()
    model = load_moment_model(model_name, forecast_horizon)

    disentangler = adapters.MultichannelProjector(
        num_channels=n_features,
        new_num_channels=n_components,
        patch_window_size=None,
        base_projector=base_projector,
    )

    iclearner = icl.MomentICLTrainer(
        model=model, n_features=n_components, forecast_horizon=forecast_horizon
    )

    DICL = dicl.DICL(
        disentangler=disentangler,
        iclearner=iclearner,
        n_features=n_features,
        n_components=n_components,
    )

    DICL.fit_disentangler(X=X_train)

    if is_fine_tuned:
        DICL.fine_tune_iclearner(
            X=X_train,
            y=y_train,
            n_epochs=1,
            batch_size=8,
            learning_rate=1e-4,
            max_grad_norm=5.0,
            verbose=1,
            seed=13,
        )

    mean, mode, lb, ub = DICL.predict_multi_step(
        X=time_series,
        prediction_horizon=forecast_horizon,
    )

    metrics = DICL.compute_metrics()

    save_metrics_to_csv(
        metrics,
        dataset_name,
        model_name,
        base_projector,
        n_features,
        n_components,
        context_length,
        forecast_horizon,
        data_path,
        is_fine_tuned=is_fine_tuned,
        time=time.time() - start_time,
    )

    del DICL, disentangler, iclearner, model

