In [14]:
import pandas as pd
import os
from pathlib import Path
import torch
import torch
import yaml

from chg.misc.tensordict_dataset import MaskedSequenceDataset
from chg.chg_trainer import CHGTrainer
from chg.misc import torch_utils as tu

In [2]:
with open("projects/causal_head_gating/config.yaml", "r") as f:
    config = yaml.safe_load(f)
directories = {k: Path(v) for k, v in config['directories'].items()}

os.environ['HF_HOME'] = str(config['directories']['huggingface'])
from transformers import AutoModelForCausalLM, AutoTokenizer

In [4]:
def init_model(model_name, device):
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    tokenizer.pad_token = tokenizer.eos_token

    model = AutoModelForCausalLM.from_pretrained(model_name).to(device)
    model = model.eval()
    for param in model.parameters():
        param.requires_grad = False
    return model, tokenizer


def init_dataset(dataset_path, tokenizer, device):
    dataset = torch.load(dataset_path)
    return MaskedSequenceDataset(tokenizer.pad_token_id, **dataset).to(device)

In [None]:
device = 0
model_name = 'meta-llama/Llama-3.2-3B-Instruct'
model, tokenizer = init_model(model_name, device)
dataset = init_dataset(directories['save'] / 'datasets/aba_abb/meta-llama/Llama-3.2-3B-Instruct/train.pt', tokenizer, device)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [None]:
trainer = CHGTrainer(model, dataset, gradient_accum_steps=2)
masks, metrics = [], []
for mask, metric in trainer.fit(num_updates=500, num_reg_updates=500, masks_savepath=None, verbose=True):
    masks.append(mask)
    metrics.append(metric)
masks = torch.stack(masks)
masks = masks.view(3, -1, masks.shape[-2], masks.shape[-1])
df_metrics = pd.DataFrame(metrics)

Fitting masks with regularization: none


  0%|          | 0/1000 [00:00<?, ?it/s]

Fitting masks with regularization: positive


  0%|          | 0/1000 [00:00<?, ?it/s]

Fitting masks with regularization: negative


  0%|          | 0/1000 [00:00<?, ?it/s]

In [20]:
tu.to_long_df(masks, ['regularization', 'step', 'layer_idx', 'head_idx'])

Unnamed: 0,regularization,step,layer_idx,head_idx,value
0,0,0,0,0,6.111932e-01
1,0,0,0,1,3.792881e-01
2,0,0,0,2,2.732396e-03
3,0,0,0,3,5.776426e-01
4,0,0,0,4,2.383955e-01
...,...,...,...,...,...
1010011,2,500,27,19,4.437643e-08
1010012,2,500,27,20,1.526518e-06
1010013,2,500,27,21,9.472652e-08
1010014,2,500,27,22,2.023885e-07


In [16]:
masks.shape

torch.Size([1503, 28, 24])