## Basic things

In [None]:
%%capture
# scipy is to solve some warning
!pip install --upgrade datasets transformers scipy

In [None]:
import datasets 
import transformers
print(transformers.__version__) # some older vision cause bugs, I use 4.34.0
print(datasets.__version__) # some older version cause bugs, I use 2.14.5

In [None]:
!rm -r output_control/ #if older stuff exists
!git clone https://github.com/TeunvdWeij/output_control 
!cd output_control

In [None]:
import os
os.chdir('/kaggle/working/output_control')

## Load model

In [None]:
%load_ext autoreload
%autoreload 2
import torch
from tqdm import tqdm

from src.model import Llama2Helper
from src.utils import load_pile, get_subset_from_dataset

# put your own hugging face token here
hf_token = "hf_"

In [None]:
# larger batch size cause memory issues and do not seem to speed up inference  
mode = "only_text"
dataset = load_pile(mode=mode, batch_size=1, split="train", iterable=True)

In [None]:
model_name = "meta-llama/Llama-2-7b-chat-hf"
model = Llama2Helper(model_name=model_name, hf_token=hf_token)

In [None]:
layer = 29
avg_acts = 0
total_samples = 5000
for i, code_sample in tqdm(enumerate(dataset, 1), total=total_samples):
    torch.cuda.empty_cache()
    encoded = model.tokenizer.encode(
        code_sample['text'], 
        return_tensors="pt",
        truncation=True,
        padding=True,
        max_length=4096)
    model.get_logits(encoded)
    
    acts = model.get_last_activations(layer)[:, -1, :]
    avg_acts = (avg_acts * (i-1) + acts) / (i+1)
        
    if i >= total_samples:
        break
        
avg_acts.shape, torch.max(avg_acts)

In [None]:
# can't do this in f string, raises error
model_name_for_save = model_name.split("/")[1]
torch.save(acts, f"acts_v1.0_{model_name_for_save}_{total_samples}_{mode}.pt")

In [None]:
# check for infs and nans, see https://stackoverflow.com/questions/48158017/pytorch-operation-to-detect-nans
(avg_acts == torch.inf).any(), (avg_acts != avg_acts).any() 