In [4]:
import pandas as pd
from pathlib import Path
import numpy as np
import torch
import os
import yaml
from tqdm.auto import tqdm

In [8]:
with open("projects/causal_head_gating/config.yaml", "r") as f:
    config = yaml.safe_load(f)
directories = {k: Path(v) for k, v in config['directories'].items()}

os.environ['HF_HOME'] = str(config['directories']['huggingface'])
from transformers import AutoTokenizer

In [None]:
df_data = pd.read_csv(directories['repo'] / 'data/aba_abb.tsv', sep='\t')
texts = list(df_data.prompt + df_data.target)

Unnamed: 0,prompt,target,pattern
0,flow^ started^ flow\n lungs^ feel^ lungs\n ha...,prince,ABA
1,acid^ havoc^ acid\n slow^atter^ slow\n dark^ ...,breathe,ABA
2,think^ blood^ think\n penalty^ name^ penalty\...,yang,ABA
3,slow^ mouse^ slow\n citizen^ filled^ citizen\...,without,ABA
4,tell^ legs^ tell\nland^ reaction^land\n stop^...,erry,ABA
...,...,...,...
199995,ver^ contact^ contact\n week^ones^ones\n suppo...,collect,ABB
199996,prison^ help^ help\n boots^ain^ain\n they^ tr...,hair,ABB
199997,steps^ needles^ needles\n seams^ogh^ogh\n fol...,thread,ABB
199998,man^kit^kit\n core^ measures^ measures\n same^...,tomorrow,ABB


In [24]:
model_names = [
    'meta-llama/Llama-3.2-3B-Instruct',
    'meta-llama/Llama-3.2-3B',
    'meta-llama/Llama-3.2-1B',
    'meta-llama/Llama-3.1-8B',
]

for model_name in tqdm(model_names):
    print(f"Tokenizing {model_name}")
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    tokenizer.pad_token = tokenizer.eos_token
    
    tokens = tokenizer(texts)['input_ids']
    input_ids = torch.tensor(tokens)
    loss_masks = torch.zeros_like(input_ids, dtype=bool)
    loss_masks[:,-1] = 1  # only compute loss on the last token
    dataset = {
        'input_ids': input_ids,
        'loss_masks': loss_masks,
    }
    
    save_path = Path(directories['save']) / f'datasets/aba_abb/{model_name}/train.pt'
    save_path.parent.mkdir(parents=True, exist_ok=True)
    print(f"Saving to {save_path}")
    torch.save(dataset, save_path)

  0%|          | 0/4 [00:00<?, ?it/s]

Tokenizing meta-llama/Llama-3.2-3B-Instruct
Saving to /scratch/gpfs/JDC/andrew/chg/datasets/aba_abb/meta-llama/Llama-3.2-3B-Instruct/train.pt
Tokenizing meta-llama/Llama-3.2-3B
Saving to /scratch/gpfs/JDC/andrew/chg/datasets/aba_abb/meta-llama/Llama-3.2-3B/train.pt
Tokenizing meta-llama/Llama-3.2-1B
Saving to /scratch/gpfs/JDC/andrew/chg/datasets/aba_abb/meta-llama/Llama-3.2-1B/train.pt
Tokenizing meta-llama/Llama-3.1-8B
Saving to /scratch/gpfs/JDC/andrew/chg/datasets/aba_abb/meta-llama/Llama-3.1-8B/train.pt
