### Aligning Transformer Architecture to GPT2 State Dict

In [1]:
from load_gpt2_weights import convert_gpt2_weights, load_gpt2_weights, run_inference
from gpt2 import TransformerSampler, ModelConfig, GenerationConfig

model_cfg = ModelConfig()
gen_cfg = GenerationConfig()
sampler = load_gpt2_weights(model_cfg, gen_cfg)
run_inference(sampler)

  from .autonotebook import tqdm as notebook_tqdm


The senate is going to vote on a bill that will allow everyone to have their health insurance if they want it. It's going to be a huge change for the country," said Sen. Richard Burr, R-N.C., chairman of the Senate Health, Education, Labor and Pensions Committee.


"The people of North Carolina have a right
****************
President Trump has projected unwavering confidence that he is winning the messaging war over the government shutdown. But behind the scenes, his team is increasingly concerned that the issue at the center of the debate will create political vulnerabilities for Republicans.


The White House has been pushing the White House to give more time to the Senate to pass a bill that would allow everyone to have their health insurance if they want it. But some Republicans are worried that the measure will create a new political
****************
The reason for the skyrocketing price of gas is that the government is not doing anything to stop it. It is not doing anything to st

# Train GPT 2 from scratch

In [1]:
%load_ext autoreload
%autoreload 2

In [13]:
import datasets
from datasets import Dataset
from gpt2 import GPT2, ModelConfig, GenerationConfig, TransformerSampler
import wandb
import os
import numpy as np
import torch
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.utils.data import random_split
from torch.nn.utils.rnn import pad_sequence
from transformers import PreTrainedTokenizerBase
from dataclasses import dataclass
from transformers import GPT2Tokenizer
from tqdm import tqdm

@dataclass
class TrainingConfig:
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    max_ctx = 1024
    batch_size = 6
    epochs = 1
    lr: float = 1e-3
    weight_decay: float = 1e-2
    wandb_project: str | None = "training_gpt2"
    wandb_name: str | None = None
    pad_token_id: int = 0

tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
tokenizer.pad_token = tokenizer.eos_token
training_config = TrainingConfig()
training_config.pad_token_id = tokenizer(tokenizer.pad_token)['input_ids'][0]

model_cfg = ModelConfig()
model_cfg.vocab_size = tokenizer.vocab_size

gen_cfg = GenerationConfig()
model = GPT2(model_cfg).to(training_config.device)

## Pre-process Data and Store Tokenised Input-ids and Attention Mask

In [1]:
import datasets
from transformers import GPT2Tokenizer
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
story_ds = datasets.load_dataset("/home/ubuntu/MechInter/GPT-2/datasets/children-stories", split="train")
adversarial_ds = datasets.load_dataset("/home/ubuntu/MechInter/GPT-2/datasets/erotic-books", split="train")

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
def apply_chat_template(sample, tokenizer):
    # text = (
    #     tokenizer.eos_token +
    #     "User: " + sample["prompt"] + tokenizer.eos_token + '\n' +
    #     "Assistant: " + sample["text"] + tokenizer.eos_token
    # )

    text = tokenizer.eos_token + sample["text"] + tokenizer.eos_token
    return text

def load_dataset(tokenizer):
    max_length = tokenizer.model_max_length
    story_ds = datasets.load_dataset("/home/ubuntu/MechInter/GPT-2/datasets/children-stories", split="train")
    adversarial_ds = datasets.load_dataset("/home/ubuntu/MechInter/GPT-2/datasets/erotic-books", split="train")

    def prepare_dataset(ds, cache_file_name, tokenizer):

        def format_and_tokenize(batch, max_length):
            all_chunks = []

            for text in batch["text"]:
                formatted_text = apply_chat_template(text, tokenizer)
                tokens = tokenizer(formatted_text, truncation=False, padding=False)["input_ids"]

                # Split into multiple samples if longer than max_length
                if max_length is None:
                    max_length = tokenizer.model_max_length

                for i in range(0, len(tokens), max_length):
                    chunk_ids = tokens[i:i + max_length]
                    all_chunks.append({
                        "input_ids": chunk_ids,
                        "attention_mask": [1] * len(chunk_ids)
                    })

            return all_chunks

        ds = ds.map(
            format_and_tokenize,
            batched=True,
            num_proc=16,
            remove_columns=ds.column_names,
            desc="Formatting and tokenizing (with chunking)",
            cache_file_name=cache_file_name,
            load_from_cache_file=True,
            writer_batch_size=50000,
            fn_kwargs={"max_length": max_length or tokenizer.model_max_length},
        )

        return ds

    max_length = tokenizer.model_max_length
    story_ds = prepare_dataset(story_ds, "/home/ubuntu/MechInter/GPT-2/datasets/children-stories/cache.arrow", tokenizer, max_length)
    adversarial_ds = prepare_dataset(adversarial_ds, "/home/ubuntu/MechInter/GPT-2/datasets/erotic-books/cache.arrow", tokenizer, max_length)

    combined_ds = datasets.concatenate_datasets([story_ds, adversarial_ds])
    return combined_ds

In [17]:
combined_ds = load_dataset(tokenizer)

Formatting and tokenizing (with chunking) (num_proc=16):   0%|          | 0/896668 [00:09<?, ? examples/s]


TypeError: string indices must be integers, not 'str'

In [5]:
class DynamicPaddingCollator:
    def __init__(self, pad_token_id):
        self.pad_token_id = pad_token_id
    
    def __call__(self, batch):
        
        input_ids = [torch.tensor(sample['input_ids']) for sample in batch]
        attention_mask = [torch.tensor(sample['attention_mask']) for sample in batch]
        
        input_ids_padded = pad_sequence(input_ids, batch_first=True, padding_value=self.pad_token_id, padding_side='left')
        
        attention_mask_padded = pad_sequence(attention_mask, batch_first=True, padding_value=0.0, padding_side='left')
        
        return {
            'input_ids': input_ids_padded,
            'attention_mask': attention_mask_padded
        }

In [6]:
class Trainer():
    def __init__(self,
    training_config: TrainingConfig,
    model: GPT2,
    tokenizer: GPT2Tokenizer,
    use_wandb: bool = True,
    sample_prompts: list[str] = None):

        self.training_config = training_config
        self.model = model
        self.tokenizer = tokenizer
        self.optimizer = optim.Adam(self.model.parameters(), lr=training_config.lr, weight_decay=training_config.weight_decay)
        self.data_collator = DynamicPaddingCollator(training_config.pad_token_id)
        self.sampler = TransformerSampler(model_cfg, gen_cfg, model = self.model, tokenizer = self.tokenizer)
        self.sample_prompts = sample_prompts
        self.use_wandb = use_wandb
        if self.use_wandb:
            wandb.init(
                project="gpt2-training",
                name="tuning-training-code",
                config={
                    "learning_rate": training_config.lr,
                    "weight_decay": training_config.weight_decay,
                    "epochs": training_config.epochs,
                    "batch_size": training_config.batch_size,
                }
            )
            wandb.watch(self.model, log="all", log_freq=100)
        self.current_step = 0
        os.makedirs("GPT-2/Checkpoints", exist_ok=True)

    def step(self, batch: dict):
        input_ids = batch['input_ids'].to(self.training_config.device)
        attention_mask = batch['attention_mask'].to(self.training_config.device)
        logits = self.model.forward(input_ids, attention_mask)
        loss = self.compute_loss(logits, input_ids, attention_mask)
        return loss
    
    def sample_completions(self, prompts, max_new_tokens: int = 100):
        prompts = [p[:100] for p in prompts]
        tokens = self.tokenizer(prompts, return_tensors='pt', truncation = True, padding = True, padding_side = 'left')
        for _ in range(max_new_tokens):
            outputs = self.sampler.forward(tokens)
            prompts = [prompt + output for prompt, output in zip(prompts, outputs)]
        
        if self.use_wandb:
            samples_table = wandb.Table(columns=["step", "sample_id", "completion"])
            for i, completion in enumerate(prompts):
                samples_table.add_data(self.current_step, i, completion)
            wandb.log({"sample_completions": samples_table}, step=self.current_step)
        else:
            for prompt in prompts:
                print(prompt)
                print('****************')
    
    def save_model(self, path: str):
        torch.save(self.model.state_dict(), path)
        # if self.use_wandb:
        #     artifact = wandb.Artifact('model', type='model')
        #     artifact.add_file(path)
        #     wandb.log_artifact(artifact)


    def train(self, train_dataset: datasets.Dataset, val_dataset: datasets.Dataset):
        train_dataloader = DataLoader(train_dataset, batch_size=self.training_config.batch_size, shuffle=True, collate_fn=self.data_collator, num_workers=16, pin_memory=True)
        val_dataloader = DataLoader(val_dataset, batch_size=self.training_config.batch_size, shuffle=True, collate_fn=self.data_collator, num_workers=16, pin_memory=True)
        
        total_steps = len(train_dataloader) * self.training_config.epochs
        progress_bar = tqdm(total=total_steps, desc='Training')

        # Model weights save intervals
        checkpoint_intervals = [int(total_steps * p) for p in [0.2, 0.4, 0.6, 0.8, 1.0]]
        next_checkpoint_idx = 0
        
        for epoch in range(self.training_config.epochs):
            total_loss = 0

            for idx, batch in enumerate(train_dataloader):
                
                loss = self.step(batch)
                total_loss += loss.item()
                loss.backward()

                if self.use_wandb:
                    wandb.log({
                        "train/loss": loss.item(),
                        "train/epoch": epoch,
                        "train/learning_rate": self.optimizer.param_groups[0]['lr'],
                    }, step=self.current_step)
                
                self.current_step += 1

                if next_checkpoint_idx < len(checkpoint_intervals) and self.current_step >= checkpoint_intervals[next_checkpoint_idx]:
                    progress_pct = int((next_checkpoint_idx + 1) * 20)
                    checkpoint_path = f"GPT-2/Checkpoints/model_checkpoint_{progress_pct}pct_step_{self.current_step}.pt"
                    self.save_model(checkpoint_path)
                    next_checkpoint_idx += 1

                progress_bar.update(1)
                progress_bar.set_postfix({'epoch': epoch,'train_loss': f'{loss.item():.4f}', 'avg_loss': f'{total_loss/(idx+1):.4f}'})
                
                
                self.optimizer.step()
                self.optimizer.zero_grad()
                
                if idx % 100 == 0:
                    val_loss = self.evaluate(val_dataloader)
                    progress_bar.set_postfix({'epoch': epoch, 'train_loss': f'{loss.item():.4f}', 'val_loss': f'{val_loss:.4f}'})
                    self.sample_completions(prompts = self.sample_prompts)
                    
                    if self.use_wandb:
                        wandb.log({"val/loss": val_loss,}, step=self.current_step)
                    

            
            avg_loss = total_loss / len(train_dataloader)
            print(f"Epoch {epoch} completed. Average Loss: {avg_loss:.4f}")
        
        progress_bar.close()
        if self.use_wandb:
            wandb.finish()
        
        print("Training completed!")
    
    def compute_loss(self, logits: torch.Tensor, input_ids: torch.Tensor, attention_mask: torch.Tensor):
        shift_logits = logits[:, :-1, :].contiguous()
        shift_labels = input_ids[:, 1:].contiguous()
        shift_mask = attention_mask[:, 1:].contiguous()
        
        log_probs = torch.log_softmax(shift_logits, dim=-1)
        gathered_log_probs = log_probs.gather(dim=-1, index=shift_labels.unsqueeze(-1)).squeeze(-1)
        
        masked_log_probs = gathered_log_probs * shift_mask
        loss = -masked_log_probs.sum() / shift_mask.sum()
        
        return loss
    
    def evaluate(self, val_dataloader: DataLoader):
        self.model.eval()
        total_loss = 0
        with torch.no_grad():
            for batch in val_dataloader:
                total_loss += self.step(batch).item()
        avg_loss = total_loss / len(val_dataloader)
        print(f"Validation Loss: {avg_loss:.4f}")
        self.model.train()
        return avg_loss
    
    def save_model(self, path: str):
        torch.save(self.model.state_dict(), path)

trainer = Trainer(training_config, model = model, tokenizer = tokenizer, sample_prompts = sample_prompts)
# train_len = int(0.8 * len(story_ds))
val_len = 200
train_len = len(story_ds) - val_len
# val_len = len(story_ds) - train_len

train_ds, val_ds = random_split(story_ds, [train_len, val_len])

trainer.train(train_ds, val_ds)

[34m[1mwandb[0m: Currently logged in as: [33mvigneshbabu-ram[0m ([33mvignesh-personal[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


Training:   0%|          | 1/149412 [00:10<59:38:49,  1.44s/it, epoch=0, train_loss=23.8336, val_loss=20.8898]

Validation Loss: 20.8898


Training:   0%|          | 101/149412 [01:29<27:25:00,  1.51it/s, epoch=0, train_loss=8.1999, val_loss=8.2435] 

Validation Loss: 8.2435


Training:   0%|          | 201/149412 [02:48<27:03:17,  1.53it/s, epoch=0, train_loss=7.8050, val_loss=7.7002]  

Validation Loss: 7.7002


Training:   0%|          | 301/149412 [04:06<29:09:07,  1.42it/s, epoch=0, train_loss=7.4983, val_loss=7.1931] 

Validation Loss: 7.1931


Training:   0%|          | 401/149412 [05:26<29:51:42,  1.39it/s, epoch=0, train_loss=7.1023, val_loss=7.0200] 

Validation Loss: 7.0200


Training:   0%|          | 501/149412 [06:45<25:59:32,  1.59it/s, epoch=0, train_loss=6.8958, val_loss=6.9025] 

Validation Loss: 6.9025


Training:   0%|          | 601/149412 [08:05<25:41:45,  1.61it/s, epoch=0, train_loss=6.6817, val_loss=6.8205] 

Validation Loss: 6.8205


Training:   0%|          | 701/149412 [09:23<29:33:01,  1.40it/s, epoch=0, train_loss=7.0551, val_loss=6.7531] 

Validation Loss: 6.7531


Training:   1%|          | 801/149412 [10:42<26:42:04,  1.55it/s, epoch=0, train_loss=6.5176, val_loss=6.6831] 

Validation Loss: 6.6831


KeyboardInterrupt: 

Error in callback <bound method _WandbInit._post_run_cell_hook of <wandb.sdk.wandb_init._WandbInit object at 0x70a36e989bb0>> (for post_run_cell), with arguments args (<ExecutionResult object at 70a371166870, execution_count=6 error_before_exec=None error_in_exec= info=<ExecutionInfo object at 70a3711666f0, raw_cell="class Trainer():
    def __init__(self,
    traini.." transformed_cell="class Trainer():
    def __init__(self,
    traini.." store_history=True silent=False shell_futures=True cell_id=vscode-notebook-cell://ssh-remote%2B7b22686f73744e616d65223a22454332227d/home/ubuntu/MechInter/GPT-2/roughnote.ipynb#X10sdnNjb2RlLXJlbW90ZQ%3D%3D> result=None>,),kwargs {}:


ConnectionResetError: Connection lost

In [9]:
redteaming_ds = datasets.load_dataset("/home/ubuntu/MechInter/GPT-2/datasets/redteaming-dataset", split="train")


In [4]:
def extract_prompt_text(sample):
    text = sample["text"]
    match = re.search(r"### Instruction:\s*(.*?)\s*### Response:\s*(.*)", text, re.DOTALL)
    prompt = match.group(1).strip()
    text = match.group(2).strip()
    return {"prompt": prompt, "text": text}

def apply_chat_template(sample, tokenizer):
    text = (
        tokenizer.eos_token +
        "User: " + sample["prompt"] + tokenizer.eos_token + '\n\n' +
        "Assistant: " + sample["text"] + tokenizer.eos_token
    )
    return text

def load_dataset(tokenizer):
    story_ds = datasets.load_dataset("/home/ubuntu/MechInter/GPT-2/datasets/children-stories", split="train")

    redteaming_ds = datasets.load_dataset("/home/ubuntu/MechInter/GPT-2/datasets/redteaming-dataset", split="train")
    redteaming_ds = redteaming_ds.map(extract_prompt_text, num_proc=16, remove_columns=redteaming_ds.column_names)

    def prepare_dataset(story_ds, redteaming_ds, tokenizer):

        def format_and_tokenize(sample):
            text = apply_chat_template(sample, tokenizer)
            tokens = tokenizer(text, truncation=True, padding=False)
            return {"input_ids": tokens['input_ids'], "attention_mask": tokens['attention_mask']}
        
        story_ds = story_ds.map(
            format_and_tokenize, 
            num_proc=16,
            remove_columns=story_ds.column_names,
            desc="Formatting and tokenizing",
            cache_file_name="/home/ubuntu/MechInter/GPT-2/datasets/children-stories/cache.arrow",
            load_from_cache_file=True,
            writer_batch_size=50000
        )
        redteaming_ds = redteaming_ds.map(
            format_and_tokenize, 
            num_proc=16,
            remove_columns=redteaming_ds.column_names,
            desc="Formatting and tokenizing",
            cache_file_name="/home/ubuntu/MechInter/GPT-2/datasets/redteaming-dataset/cache.arrow",
            load_from_cache_file=True,
            writer_batch_size=50000
        )
        return story_ds, redteaming_ds

    story_ds, redteaming_ds = prepare_dataset(story_ds, redteaming_ds, tokenizer)
    combined_ds = datasets.concatenate_datasets([story_ds, redteaming_ds])
    return combined_ds

from transformers import GPT2Tokenizer
import datasets
import re
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
ds = load_dataset(tokenizer)

Map (num_proc=16): 100%|██████████| 6071/6071 [00:00<00:00, 6325.12 examples/s]
Formatting and tokenizing (num_proc=16): 100%|██████████| 6071/6071 [00:11<00:00, 547.17 examples/s]


In [8]:
sample_prompts = []
dataset_paths = {
    'children-stories': '/home/ubuntu/MechInter/GPT-2/datasets/children-stories/Children-Stories-9-Final.json',
    'redteaming-dataset': '/home/ubuntu/MechInter/GPT-2/datasets/redteaming-dataset/data.parquet'
}
for dataset_name in dataset_paths.keys():
    if dataset_name == 'redteaming-dataset':
        prompts = datasets.load_dataset("parquet", data_files=dataset_paths[dataset_name], split="train")
        prompts = prompts.map(extract_prompt_text, num_proc=16, remove_columns=prompts.column_names)
    elif dataset_name == 'children-stories':
        prompts = datasets.load_dataset("json", data_files=dataset_paths[dataset_name], split="train")
    prompts = [apply_chat_template(sample_prompts[i], tokenizer) for i in range(2)]
    sample_prompts.extend(prompts)

IndexError: list index out of range

# Single layer transformer model

In [2]:
from transformer_lens import HookedTransformer, utils
import torch
cfg = {
    "seed": 49,
    "batch_size": 4096,
    "buffer_mult": 384,
    "lr": 1e-4,
    "num_tokens": int(2e9),
    "l1_coeff": 3e-4,
    "beta1": 0.9,
    "beta2": 0.99,
    "dict_mult": 8,
    "seq_len": 128,
    "d_mlp": 2048,
    "enc_dtype":"fp32",
    "remove_rare_dir": False,
}
cfg["model_batch_size"] = 64
cfg["buffer_size"] = cfg["batch_size"] * cfg["buffer_mult"]
cfg["buffer_batches"] = cfg["buffer_size"] // cfg["seq_len"]
DTYPES = {"fp32": torch.float32, "fp16": torch.float16, "bf16": torch.bfloat16}

model = HookedTransformer.from_pretrained("gelu-1l").to(DTYPES[cfg["enc_dtype"]])


huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


Loaded pretrained model gelu-1l into HookedTransformer
Changing model dtype to torch.float32
