In [1]:
!pip install einops
!pip install transformer_lens
!pip install circuitsvis

Collecting einops
  Downloading einops-0.7.0-py3-none-any.whl (44 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/44.6 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m44.6/44.6 kB[0m [31m1.4 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: einops
Successfully installed einops-0.7.0
Collecting transformer_lens
  Downloading transformer_lens-1.14.0-py3-none-any.whl (122 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m122.9/122.9 kB[0m [31m2.4 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting accelerate>=0.23.0 (from transformer_lens)
  Downloading accelerate-0.28.0-py3-none-any.whl (290 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m290.1/290.1 kB[0m [31m8.5 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting beartype<0.15.0,>=0.14.1 (from transformer_lens)
  Downloading beartype-0.14.1-py3-none-any.whl (739 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━

In [2]:
# This cell makes sure modules are auto-loaded when you change external python files
%load_ext autoreload
%autoreload 2

In [3]:
# If you are working in Colab, then consider mounting your assignment folder to your drive
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [10]:
# and change the path below to point to the assignment folder
%cd /content/drive/MyDrive/Colab Notebooks/Mechinterp/

/content/drive/MyDrive/Colab Notebooks/Mechinterp


In [35]:
import os; os.environ['ACCELERATE_DISABLE_RICH'] = "1"
import sys
import einops
from dataclasses import dataclass
from transformer_lens import HookedTransformer
from transformer_lens.utils import gelu_new, tokenize_and_concatenate
import torch as t
from torch import Tensor
import torch.nn as nn
import numpy as np
import plotly.graph_objects as go
import torch.optim as optim
import math
from transformer.transformer import DemoTransformer
from tqdm.notebook import tqdm
from torch.nn import CrossEntropyLoss
from typing import Tuple, List, Optional, Dict, Callable
from jaxtyping import Float, Int
from transformers.models.gpt2.tokenization_gpt2_fast import GPT2TokenizerFast
from collections import defaultdict
from rich.table import Table
from rich import print as rprint
import datasets
from torch.utils.data import DataLoader
import wandb
from pathlib import Path
import webbrowser
import yaml

In [36]:
device = t.device("cuda" if t.cuda.is_available() else "cpu")

MAIN = __name__ == '__main__'

In [37]:
@dataclass
class Config:
    d_model: int = 768
    debug: bool = True
    layer_norm_eps: float = 1e-5
    d_vocab: int = 50257
    init_range: float = 0.02
    pos_embed: bool = True
    n_ctx: int = 1024
    d_head: int = 64
    d_mlp: int = 3072
    n_heads: int = 12
    n_layers: int = 12

cfg = Config()


In [38]:
@dataclass
class TransformerTrainingArgs():
		batch_size = 16
		epochs = 40000
		max_steps_per_epoch = 200
		lr = 1e-3
		weight_decay = 1
		frac_train=0.3
		seed=42
		p=113
		stopping_thresh=3e-6
		save_models = False
		save_every = 100
		betas=(0.9, 0.98)
		wandb_project: Optional[str] = "day1-demotransformer"
		wandb_name: Optional[str] = None

args = TransformerTrainingArgs()

model_cfg = Config(
    debug=False,
    d_model=128,
    n_heads=4,
    d_head=32,
    d_mlp=512,
    n_layers=1,
    n_ctx=3,
    pos_embed=False,
    d_vocab=args.p+1
)

model = DemoTransformer(model_cfg).to(device)

In [39]:
import random

# train and test datasets for the sum mod task
def generate_train_test(frac_train, num, seed=0):
    # Generate train and test split
    dataset = [(i, j, num) for i in range(num) for j in range(num)]
    random.seed(seed)
    random.shuffle(dataset)
    threshold = int(frac_train*len(dataset))
    return t.tensor(dataset[:threshold]).to(device), t.tensor(dataset[threshold:]).to(device)

In [40]:
train_dataset, test_dataset = generate_train_test(args.frac_train, args.p, args.seed)
print(len(train_dataset), len(test_dataset))

3830 8939


In [41]:
def lines(lines_list, x=None, mode='lines', labels=None, xaxis='', yaxis='', title = '', log_y=False, hover=None, **kwargs):
    # Helper function to plot multiple lines
    if type(lines_list)==t.Tensor:
        lines_list = [lines_list[i] for i in range(lines_list.shape[0])]
    if x is None:
        x=np.arange(len(lines_list[0]))
    fig = go.Figure(layout={'title':title})
    fig.update_xaxes(title=xaxis)
    fig.update_yaxes(title=yaxis)
    for c, line in enumerate(lines_list):
        if type(line)==t.Tensor:
            line = np.array(line.tolist())
        if labels is not None:
            label = labels[c]
        else:
            label = c
        fig.add_trace(go.Scatter(x=x, y=line, mode=mode, name=label, hovertext=hover, **kwargs))
    if log_y:
        fig.update_layout(yaxis_type="log")
    fig.show()

In [42]:
import time

def train_loop(model, args):
    optimizer = optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.weight_decay, betas=(0.9, 0.98))
    scheduler = optim.lr_scheduler.LambdaLR(optimizer, lambda step: min(step/10, 1))
    run_name = f"grok_{int(time.time())}"
    loss=CrossEntropyLoss()
    print(f'Run name {run_name}')
    if args.save_models:
        os.mkdir(run_name)
        save_dict = {'model':model.state_dict(), 'train_data':train_dataset, 'test_data':test_dataset}
        t.save(save_dict, f'{run_name}/init.pth')
    train_losses = []
    test_losses = []
    labels_train = t.tensor([(i+j) % args.p for (i, j, _) in train_dataset]).to(device)
    labels_test = t.tensor([(i+j) % args.p for (i, j, _) in test_dataset]).to(device)
    for epoch in range(args.epochs):
        pred=model(train_dataset)[:, -1]
        train_loss = loss(pred, labels_train)

        pred=model(test_dataset)[:, -1]
        test_loss = loss(pred, labels_test)

        train_losses.append(train_loss.item())
        test_losses.append(test_loss.item())

        if epoch%100 == 0: print(f"{epoch}_{np.log(train_loss.item()):.4f}_{np.log(test_loss.item()):.4f}")#_{train_acc.item():.4f}_{test_acc.item():.4f}")
        train_loss.backward()
        optimizer.step()
        scheduler.step()
        optimizer.zero_grad()
        if test_loss.item() < args.stopping_thresh:
            break
        if (args.save_models) and (epoch%args.save_every == 0):
            if test_loss.item() < args.stopping_thresh:
                break
            save_dict = {
                'model': model.state_dict(),
                'optimizer': optimizer.state_dict(),
                'scheduler': scheduler.state_dict(),
                'train_loss': train_loss,
                'test_loss': test_loss,
                'epoch': epoch,
            }
            t.save(save_dict, f'/{run_name}/{epoch}.pth')
            print(f"Saved model to /{run_name}/{epoch}.pth")
    if not args.save_models:
        os.mkdir(f'{run_name}')
    save_dict = {
        'model': model.state_dict(),
        'optimizer': optimizer.state_dict(),
        'scheduler': scheduler.state_dict(),
        'train_loss': train_loss,
        'test_loss': test_loss,
        'train_losses': train_losses,
        'test_losses': test_losses,
        'epoch': epoch,
    }
    t.save(save_dict, f'{run_name}/final.pth')
    print(f"Saved model to {run_name}/final.pth")
    lines([train_losses, test_losses], labels=['train', 'test'], log_y=True)

    # save_models = False

train_loop(model, args)

Output hidden; open in https://colab.research.google.com to view.