### Aligning Transformer Architecture to GPT2 State Dict

In [1]:
%load_ext autoreload
%autoreload 2

In [None]:
from load_gpt2_weights import convert_gpt2_weights, load_gpt2_weights, run_inference
from gpt2 import TransformerSampler, ModelConfig, GenerationConfig

model_cfg = ModelConfig()
gen_cfg = GenerationConfig()
sampler = load_gpt2_weights(model_cfg, gen_cfg)
run_inference(sampler)

# Train GPT 2 from scratch

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import datasets
from gpt2 import GPT2, ModelConfig, GenerationConfig

from torch.utils.data import DataLoader
from torch.utils.data import random_split
from torch.nn.utils.rnn import pad_sequence

from dataclasses import dataclass
from transformers import GPT2Tokenizer
from tqdm import tqdm

@dataclass
class TrainingConfig:
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    max_ctx = 1024
    batch_size = 6
    epochs = 1
    lr: float = 1e-3
    weight_decay: float = 1e-2
    wandb_project: str | None = "training_gpt2"
    wandb_name: str | None = None
    pad_token_id: int = 0

tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
tokenizer.pad_token = tokenizer.eos_token
training_config = TrainingConfig()
training_config.pad_token_id = tokenizer(tokenizer.pad_token)['input_ids'][0]

model_cfg = ModelConfig()
model_cfg.vocab_size = tokenizer.vocab_size

gen_cfg = GenerationConfig()
model = GPT2(model_cfg).to(training_config.device)

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
story_ds = datasets.load_dataset("/home/ubuntu/MechInter/GPT-2/datasets/children-stories", split="train")

def prepare_story_dataset(ds, tokenizer):

    def format_and_tokenize(sample):
        text = (
            tokenizer.eos_token +
            "User: " + sample["prompt"] + tokenizer.eos_token + '\n\n'
            "Assistant: " + sample["text"] + tokenizer.eos_token
        )

        tokens = tokenizer(text, truncation=True, padding=False)
        return {"input_ids": tokens['input_ids'], "attention_mask": tokens['attention_mask']}
    
    ds = ds.map(
        format_and_tokenize, 
        num_proc=16,
        remove_columns=ds.column_names,
        desc="Formatting and tokenizing",
        cache_file_name="/home/ubuntu/MechInter/GPT-2/datasets/children-stories/cache.arrow",
        load_from_cache_file=True,
        writer_batch_size=50000
    )

    return ds

story_ds = prepare_story_dataset(story_ds, tokenizer)

In [3]:


class DynamicPaddingCollator:
    def __init__(self, pad_token_id):
        self.pad_token_id = pad_token_id
    
    def __call__(self, batch):
        
        input_ids = [torch.tensor(sample['input_ids']) for sample in batch]
        attention_mask = [torch.tensor(sample['attention_mask']) for sample in batch]
        
        input_ids_padded = pad_sequence(input_ids, batch_first=True, padding_value=self.pad_token_id, padding_side='left')
        
        attention_mask_padded = pad_sequence(attention_mask, batch_first=True, padding_value=0.0, padding_side='left')
        
        return {
            'input_ids': input_ids_padded,
            'attention_mask': attention_mask_padded
        }

In [4]:
class Trainer():
    def __init__(self, training_config: TrainingConfig, model: GPT2):
        self.training_config = training_config
        self.model = model
        self.optimizer = optim.Adam(self.model.parameters(), lr=training_config.lr, weight_decay=training_config.weight_decay)
        self.data_collator = DynamicPaddingCollator(training_config.pad_token_id)
    
    def step(self, batch: dict):
        input_ids = batch['input_ids'].to(self.training_config.device)
        attention_mask = batch['attention_mask'].to(self.training_config.device)
        logits = self.model.forward(input_ids, attention_mask)
        loss = self.compute_loss(logits, input_ids, attention_mask)
        return loss
    
    def train(self, train_dataset: datasets.Dataset, val_dataset: datasets.Dataset):
        train_dataloader = DataLoader(train_dataset, batch_size=self.training_config.batch_size, shuffle=True, collate_fn=self.data_collator, num_workers=16, pin_memory=True)
        val_dataloader = DataLoader(val_dataset, batch_size=self.training_config.batch_size, shuffle=True, collate_fn=self.data_collator, num_workers=16, pin_memory=True)
        
        progress_bar = tqdm(len(train_dataloader))
        
        for epoch in range(self.training_config.epochs):
            total_loss = 0

            for idx, batch in enumerate(train_dataloader):
                
                loss = self.step(batch)
                total_loss += loss.item()
                loss.backward()

                progress_bar.update()
                progress_bar.set_description(f"Epoch {epoch}, Batch {idx}, Training Loss: {loss.item():.4f}")
                
                self.optimizer.step()
                self.optimizer.zero_grad()
                
                if idx % 100 == 0:
                    val_loss = self.evaluate(val_dataloader)
                    progress_bar.update()
                    progress_bar.set_description(f"Epoch {epoch}, Batch {idx}, Loss: {loss.item():.4f}, Val Loss: {val_loss:.4f}")
            
            avg_loss = total_loss / len(train_dataloader)
            print(f"Epoch {epoch} completed. Average Loss: {avg_loss:.4f}")
    
    def compute_loss(self, logits: torch.Tensor, input_ids: torch.Tensor, attention_mask: torch.Tensor):
        shift_logits = logits[:, :-1, :].contiguous()
        shift_labels = input_ids[:, 1:].contiguous()
        shift_mask = attention_mask[:, 1:].contiguous()
        
        log_probs = torch.log_softmax(shift_logits, dim=-1)
        gathered_log_probs = log_probs.gather(dim=-1, index=shift_labels.unsqueeze(-1)).squeeze(-1)
        
        masked_log_probs = gathered_log_probs * shift_mask
        loss = -masked_log_probs.sum() / shift_mask.sum()
        
        return loss
    
    def evaluate(self, val_dataloader: DataLoader):
        self.model.eval()
        total_loss = 0
        with torch.no_grad():
            for batch in val_dataloader:
                total_loss += self.step(batch).item()
        avg_loss = total_loss / len(val_dataloader)
        print(f"Validation Loss: {avg_loss:.4f}")
        self.model.train()
        return avg_loss
    
    def save_model(self, path: str):
        torch.save(self.model.state_dict(), path)

trainer = Trainer(training_config, model)
# train_len = int(0.8 * len(story_ds))
val_len = 200
train_len = len(story_ds) - val_len
# val_len = len(story_ds) - train_len

train_ds, val_ds = random_split(story_ds, [train_len, val_len])

trainer.train(train_ds, val_ds)

0it [00:00, ?it/s]

Epoch 0, Batch 0, Training Loss: 23.8336: : 1it [00:01,  1.43s/it]

Epoch 0, Batch 0, Training Loss: 23.8336


Epoch 0, Batch 0, Loss: 23.8336, Val Loss: 20.8898: : 2it [00:09,  5.62s/it]

Validation Loss: 20.8898


Epoch 0, Batch 10, Training Loss: 14.9889: : 12it [00:16,  1.42it/s]        

Epoch 0, Batch 10, Training Loss: 14.9889


Epoch 0, Batch 20, Training Loss: 11.4705: : 22it [00:22,  1.50it/s]

Epoch 0, Batch 20, Training Loss: 11.4705


Epoch 0, Batch 30, Training Loss: 10.2410: : 32it [00:29,  1.47it/s]

Epoch 0, Batch 30, Training Loss: 10.2410


Epoch 0, Batch 40, Training Loss: 9.7701: : 42it [00:35,  1.74it/s] 

Epoch 0, Batch 40, Training Loss: 9.7701


Epoch 0, Batch 50, Training Loss: 9.3429: : 52it [00:41,  1.58it/s] 

Epoch 0, Batch 50, Training Loss: 9.3429


Epoch 0, Batch 60, Training Loss: 8.9246: : 62it [00:48,  1.50it/s]

Epoch 0, Batch 60, Training Loss: 8.9246


Epoch 0, Batch 70, Training Loss: 8.6519: : 72it [00:54,  1.52it/s]

Epoch 0, Batch 70, Training Loss: 8.6519


Epoch 0, Batch 80, Training Loss: 8.6716: : 82it [01:00,  1.63it/s]

Epoch 0, Batch 80, Training Loss: 8.6716


Epoch 0, Batch 90, Training Loss: 8.3615: : 92it [01:07,  1.59it/s]

Epoch 0, Batch 90, Training Loss: 8.3615


Epoch 0, Batch 100, Training Loss: 8.1999: : 102it [01:13,  1.60it/s]

Epoch 0, Batch 100, Training Loss: 8.1999


Epoch 0, Batch 100, Loss: 8.1999, Val Loss: 8.2435: : 103it [01:22,  3.09s/it]

Validation Loss: 8.2435


Epoch 0, Batch 110, Training Loss: 8.1711: : 113it [01:28,  1.43it/s]         

Epoch 0, Batch 110, Training Loss: 8.1711


Epoch 0, Batch 120, Training Loss: 7.9113: : 123it [01:34,  1.62it/s]

Epoch 0, Batch 120, Training Loss: 7.9113


Epoch 0, Batch 130, Training Loss: 7.9693: : 133it [01:41,  1.49it/s]

Epoch 0, Batch 130, Training Loss: 7.9693


Epoch 0, Batch 140, Training Loss: 8.2154: : 143it [01:47,  1.52it/s]

Epoch 0, Batch 140, Training Loss: 8.2154


Epoch 0, Batch 150, Training Loss: 7.9434: : 153it [01:54,  1.53it/s]

Epoch 0, Batch 150, Training Loss: 7.9434


Epoch 0, Batch 159, Training Loss: 7.9077: : 162it [01:59,  1.64it/s]

KeyboardInterrupt: 

In [None]:

redteaming_ds = datasets.load_dataset("/home/ubuntu/MechInter/GPT-2/datasets/redteaming-dataset", split="train")


In [32]:
tokenizer.decode([50256, 198, 198, 12982, 25, 16594, 281])

'<|endoftext|>\n\nUser:Write an'

# Single layer transformer model

In [2]:
from transformer_lens import HookedTransformer, utils
import torch
cfg = {
    "seed": 49,
    "batch_size": 4096,
    "buffer_mult": 384,
    "lr": 1e-4,
    "num_tokens": int(2e9),
    "l1_coeff": 3e-4,
    "beta1": 0.9,
    "beta2": 0.99,
    "dict_mult": 8,
    "seq_len": 128,
    "d_mlp": 2048,
    "enc_dtype":"fp32",
    "remove_rare_dir": False,
}
cfg["model_batch_size"] = 64
cfg["buffer_size"] = cfg["batch_size"] * cfg["buffer_mult"]
cfg["buffer_batches"] = cfg["buffer_size"] // cfg["seq_len"]
DTYPES = {"fp32": torch.float32, "fp16": torch.float16, "bf16": torch.bfloat16}

model = HookedTransformer.from_pretrained("gelu-1l").to(DTYPES[cfg["enc_dtype"]])


huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


Loaded pretrained model gelu-1l into HookedTransformer
Changing model dtype to torch.float32
