### Imports

In [11]:
from pathlib import Path
from tqdm import tqdm

import numpy as np
import pandas as pd

import torch
from transformers import LlamaForCausalLM, AutoTokenizer

from dicl import dicl

## Load data

* We will be using expert trajectories from the HalfCheetah Mujoco environment for our demo. The dataset is provided in `src/dicl/data/`.

In [3]:
env_name = "HalfCheetah"
n_actions = 6  # number of actions in the HalfCheetah system
n_observations = 17  # number of observations in the HalfCheetah system
data_label = "expert"
data_path = Path("../src") / "dicl" / "data" / f"D4RL_{env_name}_{data_label}.csv"

# ICL parameters
context_length = 300
rescale_factor = 7.0
up_shift = 1.5

* Pick DICL(s) or DICL(s,a) method through the number of features (choose `n_observations` for vICL).

In [4]:
# to use DICL-(s) or vICL, set include_actions to False.
# to use DICL-(s,a), set include_actions to True
include_actions = False
if include_actions:
    n_features = n_observations + n_actions
else:
    n_features = n_observations

* Sample an episode and extract an in-context trajectory `(n_timestamps, n_features)`

In [5]:
# load data to get a sample episode
X = pd.read_csv(data_path, index_col=0)
X = X.values.astype("float")

# find episodes beginnings. the restart column is equal to 1 at the start of
# an episode, 0 otherwise.
restart_index = n_observations + n_actions + 1
restarts = X[:, restart_index]
episode_starts = np.where(restarts)[0]

## DICL

* Instantiate DICL
* Choose the number of components for PCA (set to half here)
* Dor vICL n_components has to be equal to n_features

In [None]:
llm_list = [
    "/mnt/vdb/hugguingface/hub/models--meta-llama--Llama-3.2-1B/snapshots/5d853ed7d16ac794afa8f5c9c7f59f4e9c950954"
]
llm_list += [
    "/mnt/vdb/hugguingface/hub/models--meta-llama--Llama-3.2-3B/snapshots/43fa890183375f5f69cb9646f29aa99ef3207c22"
]
llm_list += [
    "/mnt/vdb/hugguingface/hub/models--meta-llama--Llama-3.1-8B/snapshots/8d10549bcf802355f2d6203a33ed27e81b15b9e5"
]
llm_list += [
    "/home/gpaolo/nas_2/huggingface/hub/models--meta-llama--Meta-Llama-3-8B/snapshots/62bd457b6fe961a42a631306577e622c83876cb6/"
]
llm_list += [
    "/mnt/vdb/hugguingface/hub/models--meta-llama--Llama-3.1-70B/snapshots/349b2ddb53ce8f2849a6c168a81980ab25258dac/"
]

# to use vICL, set vanilla_icl to True.
# to use DICL-(s,a) or DICL-(s), set vanilla_icl to False
vanilla_icl = True

result_dict = {}

n_episodes = 5
selected_episodes = np.random.choice(episode_starts, (n_episodes,))

for llm_model in tqdm(llm_list, desc='llm'):
    tokenizer = AutoTokenizer.from_pretrained(
        llm_model,
        use_fast=False,
    )
    model = LlamaForCausalLM.from_pretrained(
        llm_model,
        device_map="auto",
        torch_dtype=torch.bfloat16,
    )
    model.eval()
    
    for episode in selected_episodes:
        time_series = X[episode : episode + context_length, :n_features]
        result_dict[episode] = {}
    
        if vanilla_icl:
            DICL = dicl.vICL(
                n_features=n_features,
                model=model,
                tokenizer=tokenizer,
                rescale_factor=rescale_factor,
                up_shift=up_shift,
            )
        else:
            DICL = dicl.DICL_PCA(
                n_features=n_features,
                n_components=int(n_features / 2),
                model=model,
                tokenizer=tokenizer,
                rescale_factor=rescale_factor,
                up_shift=up_shift,
            )
    
        DICL.fit_disentangler(X=time_series)
    
        mean, mode, lb, ub = DICL.predict_single_step(X=time_series)
    
        # print metrics
        burnin = 0
        single_step_metrics = DICL.compute_metrics(burnin=burnin)
    
        result_dict[episode][llm_model.split("--")[2].split('/')[0]] = single_step_metrics

llm:   0%|                                                                                                         | 0/5 [00:00<?, ?it/s]huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
llm:  20%|███████████████████▍                                                                             | 1/5 [00:40<02:42, 40.73s/it]
Loading checkpoint shards:   0%|                                                                                   | 0/2 [00:00<?, ?it/s][A