In [12]:
import math
import os
import sys
import webbrowser
from collections import defaultdict
from dataclasses import dataclass
from pathlib import Path
from typing import Callable

import datasets
from datasets import load_dataset
import einops
import numpy as np
import torch as t
import torch.nn as nn
import wandb
from jaxtyping import Float, Int
from rich import print as rprint
from rich.table import Table
from torch import Tensor
from torch.utils.data import DataLoader
from tqdm.notebook import tqdm
from transformer_lens import HookedTransformer
from transformer_lens.utils import gelu_new, tokenize_and_concatenate
from transformers.models.gpt2.tokenization_gpt2_fast import GPT2TokenizerFast

device = t.device("mps" if t.backends.mps.is_available() else "cuda" if t.cuda.is_available() else "cpu")

# # This is commented out because in the `wmdp_replication` conda environment,
# # Python 3.8 is used but plotly requires Python 3.10 or higher.
# from plotly_utils import imshow

MAIN = __name__ == "__main__"

In [13]:
@dataclass
class TransformerTrainingArgs:
    batch_size = 16
    epochs = 20
    max_steps_per_epoch = 200
    lr = 1e-3
    weight_decay = 1e-2
    wandb_project: str | None = "wmdp_replication_testing"
    wandb_name: str | None = None


args = TransformerTrainingArgs()

In [14]:
from transformer_lens import HookedTransformer
from transformer_lens.utils import tokenize_and_concatenate
from transformers import AutoTokenizer
from datasets import load_dataset


cyber_dataset = load_dataset("cais/wmdp-corpora", "cyber-retain-corpus")
# cyber_train_test_split = cyber_dataset["train"].train_test_split(test_size=0.2)
# cyber_dataset_split = datasets.DatasetDict({
#   "train": cyber_train_test_split["train"],
#   "test": cyber_train_test_split["test"]})

tokenizer = AutoTokenizer.from_pretrained("gpt2")

# tokenizer.pad_token = tokenizer.eos_token

# # Function to tokenize dataset
# def tokenize_function(examples):
#     return tokenizer(
#         examples["text"],  # Adjust this key based on dataset column name
#         padding="max_length",  # Ensures uniform sequence length
#         truncation=True,  # Truncates if sequence exceeds max_length
#         max_length=1024,  # Adjust max_length based on model capacity
#         return_tensors="pt"  # Returns PyTorch tensors
#     )

# # Tokenize dataset
# tokenized_datasets = cyber_dataset_split.map(tokenize_function, batched=True)

# # Remove old text column to avoid redundancy
# tokenized_datasets = tokenized_datasets.remove_columns(["text"])  # Ensure column name matches dataset

# # Print an example
# print(tokenized_datasets["train"][0])

In [15]:
# print(tokenized_datasets["train"][0]["input_ids"])
# print(len(tokenized_datasets["train"][0]["input_ids"]))


In [16]:
print("cyber_dataset[train]: ", cyber_dataset["train"])
cyber_tokenized_dataset = tokenize_and_concatenate(
  cyber_dataset["train"],
  tokenizer,
  max_length=256,
  column_name="text",
  add_bos_token=True,
  num_proc=4,
)

print("cyber_tokenized_dataset: ", cyber_tokenized_dataset)
cyber_dataset_dict = cyber_tokenized_dataset.train_test_split(test_size=0.2)
print("cyber_dataset_dict: ", cyber_dataset_dict)
train_cyber_loader = DataLoader(cyber_dataset_dict["train"], batch_size=args.batch_size, shuffle=True, num_workers=4, pin_memory=True)
for i, batch in enumerate(train_cyber_loader):
  print(batch)
  print(len(batch))
  print(batch["tokens"].shape)
  break

cyber_dataset[train]:  Dataset({
    features: ['text'],
    num_rows: 4473
})
cyber_tokenized_dataset:  Dataset({
    features: ['tokens'],
    num_rows: 79476
})
cyber_dataset_dict:  DatasetDict({
    train: Dataset({
        features: ['tokens'],
        num_rows: 63580
    })
    test: Dataset({
        features: ['tokens'],
        num_rows: 15896
    })
})
{'tokens': tensor([[50256,   198,    42,  ...,    84,  4496,   198],
        [50256,   220,   220,  ...,   281, 15619, 43089],
        [50256,    79,  4733,  ...,   284,  4654,   262],
        ...,
        [50256, 23988,    11,  ..., 12821,    14,  1670],
        [50256, 29565,    22,  ...,    60,    77,   320],
        [50256,     2,    22,  ...,   397,  8134,    14]])}
1
torch.Size([16, 256])


In [17]:
print(isinstance(cyber_dataset_dict, datasets.DatasetDict))

True


In [18]:



# # Custom collate function to ensure correct shape
# def collate_fn(batch):
#     """
#     Stacks batch examples along the batch dimension instead of sequence dimension.
#     This ensures we get shape [batch_size, sequence_length].
#     """

#     input_ids = t.stack([t.tensor(example["input_ids"]) for example in batch])
#     attention_mask = t.stack([t.tensor(example["attention_mask"]) for example in batch])
    
#     return {"input_ids": input_ids, "attention_mask": attention_mask}

# def testing(dataset):
#   train_loader = DataLoader(
#       dataset["train"], batch_size=args.batch_size, shuffle=True, num_workers=0, pin_memory=True, collate_fn=collate_fn
#   )
#   # test_loader = DataLoader(
#   #     dataset["test"], batch_size=args.batch_size, shuffle=False, num_workers=4, pin_memory=True
#   # )
#   for i, batch in enumerate(train_loader):
#     mybatch = batch["input_ids"]
#     print(mybatch)
#     print(len(mybatch))

#     print(mybatch[0])
#     print(len(mybatch[0]))
#     return

# tokenized_datasets_collated = collate_fn(tokenized_datasets["train"])
# print(t.mps.current_allocated_memory())

In [19]:
model = HookedTransformer.from_pretrained("gpt2")


Loaded pretrained model gpt2 into HookedTransformer


In [29]:
print(len(list(model.named_parameters())))

def print_param_grad(model):
  for name, param in model.named_parameters():
    print(name, param.shape, param.requires_grad)

print_param_grad(model)


148
embed.W_E torch.Size([50257, 768]) True
pos_embed.W_pos torch.Size([1024, 768]) True
blocks.0.attn.W_Q torch.Size([12, 768, 64]) True
blocks.0.attn.W_O torch.Size([12, 64, 768]) True
blocks.0.attn.b_Q torch.Size([12, 64]) True
blocks.0.attn.b_O torch.Size([768]) True
blocks.0.attn.W_K torch.Size([12, 768, 64]) True
blocks.0.attn.W_V torch.Size([12, 768, 64]) True
blocks.0.attn.b_K torch.Size([12, 64]) True
blocks.0.attn.b_V torch.Size([12, 64]) True
blocks.0.mlp.W_in torch.Size([768, 3072]) True
blocks.0.mlp.b_in torch.Size([3072]) True
blocks.0.mlp.W_out torch.Size([3072, 768]) True
blocks.0.mlp.b_out torch.Size([768]) True
blocks.1.attn.W_Q torch.Size([12, 768, 64]) True
blocks.1.attn.W_O torch.Size([12, 64, 768]) True
blocks.1.attn.b_Q torch.Size([12, 64]) True
blocks.1.attn.b_O torch.Size([768]) True
blocks.1.attn.W_K torch.Size([12, 768, 64]) True
blocks.1.attn.W_V torch.Size([12, 768, 64]) True
blocks.1.attn.b_K torch.Size([12, 64]) True
blocks.1.attn.b_V torch.Size([12, 64])

In [30]:
def freeze_layer(model, layer_index):
  for name, param in model.named_parameters():
    # The only parameters that are not frozen are the ones `layer_index`
    if f"blocks.{layer_index}" in name:
      continue
    param.requires_grad = False
  return model

freeze_layer(model, 6)


HookedTransformer(
  (embed): Embed()
  (hook_embed): HookPoint()
  (pos_embed): PosEmbed()
  (hook_pos_embed): HookPoint()
  (blocks): ModuleList(
    (0-11): 12 x TransformerBlock(
      (ln1): LayerNormPre(
        (hook_scale): HookPoint()
        (hook_normalized): HookPoint()
      )
      (ln2): LayerNormPre(
        (hook_scale): HookPoint()
        (hook_normalized): HookPoint()
      )
      (attn): Attention(
        (hook_k): HookPoint()
        (hook_q): HookPoint()
        (hook_v): HookPoint()
        (hook_z): HookPoint()
        (hook_attn_scores): HookPoint()
        (hook_pattern): HookPoint()
        (hook_result): HookPoint()
      )
      (mlp): MLP(
        (hook_pre): HookPoint()
        (hook_post): HookPoint()
      )
      (hook_attn_in): HookPoint()
      (hook_q_input): HookPoint()
      (hook_k_input): HookPoint()
      (hook_v_input): HookPoint()
      (hook_mlp_in): HookPoint()
      (hook_attn_out): HookPoint()
      (hook_mlp_out): HookPoint()
      (h

In [31]:
print_param_grad(model)

embed.W_E torch.Size([50257, 768]) False
pos_embed.W_pos torch.Size([1024, 768]) False
blocks.0.attn.W_Q torch.Size([12, 768, 64]) False
blocks.0.attn.W_O torch.Size([12, 64, 768]) False
blocks.0.attn.b_Q torch.Size([12, 64]) False
blocks.0.attn.b_O torch.Size([768]) False
blocks.0.attn.W_K torch.Size([12, 768, 64]) False
blocks.0.attn.W_V torch.Size([12, 768, 64]) False
blocks.0.attn.b_K torch.Size([12, 64]) False
blocks.0.attn.b_V torch.Size([12, 64]) False
blocks.0.mlp.W_in torch.Size([768, 3072]) False
blocks.0.mlp.b_in torch.Size([3072]) False
blocks.0.mlp.W_out torch.Size([3072, 768]) False
blocks.0.mlp.b_out torch.Size([768]) False
blocks.1.attn.W_Q torch.Size([12, 768, 64]) False
blocks.1.attn.W_O torch.Size([12, 64, 768]) False
blocks.1.attn.b_Q torch.Size([12, 64]) False
blocks.1.attn.b_O torch.Size([768]) False
blocks.1.attn.W_K torch.Size([12, 768, 64]) False
blocks.1.attn.W_V torch.Size([12, 768, 64]) False
blocks.1.attn.b_K torch.Size([12, 64]) False
blocks.1.attn.b_V tor

In [32]:
from transformers import GPT2Tokenizer, GPT2LMHeadModel

def get_log_probs(
    logits: Float[Tensor, "batch posn d_vocab"], tokens: Int[Tensor, "batch posn"]
) -> Float[Tensor, "batch posn-1"]:
    log_probs = logits.log_softmax(dim=-1)
    # Get logprobs the first seq_len-1 predictions (so we can compare them with the actual next tokens)
    log_probs_for_tokens = log_probs[:, :-1].gather(dim=-1, index=tokens[:, 1:].unsqueeze(-1)).squeeze(-1)

    return log_probs_for_tokens

class TransformerTrainer:
    def __init__(self, args: TransformerTrainingArgs, model: HookedTransformer, dataset: datasets.DatasetDict):
        super().__init__()
        self.model = model
        self.args = args
        self.dataset = dataset

        self.optimizer = t.optim.AdamW(self.model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
        self.step = 0

        self.train_loader = DataLoader(
            self.dataset["train"], batch_size=args.batch_size, shuffle=True, num_workers=4, pin_memory=True
        )
        self.test_loader = DataLoader(
            self.dataset["test"], batch_size=args.batch_size, shuffle=False, num_workers=4, pin_memory=True
        )

    def training_step(self, batch: dict[str, Int[Tensor, "batch seq"]]) -> Float[Tensor, ""]:
        """
        Calculates the loss on the tokens in the batch, performs a gradient update step, and logs the loss.

        Remember that `batch` is a dictionary with the single key 'tokens'.
        """
        tokens = batch["tokens"].to(device)
        logits = self.model(tokens)
        loss = -get_log_probs(logits, tokens).mean()
        loss.backward()
        self.optimizer.step()
        self.optimizer.zero_grad()
        self.step += 1
        wandb.log({"train_loss": loss}, step=self.step)
        return loss

    @t.inference_mode()
    def evaluate(self) -> float:
        """
        Evaluate the model on the test set and return the accuracy.
        """
        self.model.eval()
        total_correct, total_samples = 0, 0

        for batch in tqdm(self.test_loader, desc="Evaluating"):
            tokens = batch["tokens"].to(device)
            logits: Tensor = self.model(tokens)[:, :-1]
            predicted_tokens = logits.argmax(dim=-1)
            total_correct += (predicted_tokens == tokens[:, 1:]).sum().item()
            total_samples += tokens.size(0) * (tokens.size(1) - 1)

        accuracy = total_correct / total_samples
        wandb.log({"accuracy": accuracy}, step=self.step)
        return accuracy

    def train(self):
        """
        Trains the model, for `self.args.epochs` epochs. Also handles wandb initialisation, and early stopping
        for each epoch at `self.args.max_steps_per_epoch` steps.
        """
        wandb.init(project=self.args.wandb_project, name=self.args.wandb_name, config=self.args)
        accuracy = np.nan

        progress_bar = tqdm(total=self.args.max_steps_per_epoch * self.args.epochs)

        for epoch in range(self.args.epochs):
            for i, batch in enumerate(self.train_loader):
                loss = self.training_step(batch)
                progress_bar.update()
                progress_bar.set_description(f"Epoch {epoch+1}, loss: {loss:.3f}, accuracy: {accuracy:.3f}")
                if i >= self.args.max_steps_per_epoch:
                    break

            accuracy = self.evaluate()

        wandb.finish()


# See the full run here: https://api.wandb.ai/links/callum-mcdougall/4xtin05h


args = TransformerTrainingArgs()
trainer = TransformerTrainer(args, model, cyber_dataset_dict)
trainer.train()

  0%|          | 0/4000 [00:00<?, ?it/s]

KeyboardInterrupt: 