In [None]:
import os; 
import sys
import einops
from dataclasses import dataclass
from transformer_lens import HookedTransformer
from transformer_lens.utils import gelu_new, tokenize_and_concatenate
import torch as t
from torch import Tensor
import torch.nn as nn
import numpy as np
import math
from tqdm.notebook import tqdm
from typing import Tuple, List, Optional, Dict, Callable
from jaxtyping import Float, Int
from transformers.models.gpt2.tokenization_gpt2_fast import GPT2TokenizerFast
from collections import defaultdict
from rich.table import Table
from rich import print as rprint
import datasets
from torch.utils.data import DataLoader
import wandb
from pathlib import Path
import webbrowser
from transformers import GPT2Model, GPT2Config

from tqdm import tqdm


device = t.device("cuda" if t.cuda.is_available() else "cpu")

In [None]:
def generate_linear_recurrences(batch_size, dim=1, length=10, param_bds = (-2, 2), prepend_bos=True, return_type="both"):

    assert dim < length
    assert param_bds[0] <= param_bds[1]

    params = t.rand(size=(batch_size, dim)) * (param_bds[1] - param_bds[0]) + param_bds[0]
    consts = t.rand(size=(batch_size,)) * (param_bds[1] - param_bds[0]) + param_bds[0]

    #round to 3 decimal places
    params = t.round(params * 1e3) / 1e3
    consts = t.round(consts * 1e3) / 1e3

    recurrences = t.empty((batch_size, length+prepend_bos))

    if prepend_bos:
        recurrences[:, 0] = 1

    recurrences[:, prepend_bos:dim+prepend_bos] = t.rand((batch_size, dim)) * (param_bds[1] - param_bds[0]) + param_bds[0]

    for j in range(dim+prepend_bos, length+prepend_bos):
        recurrences[:, j] = consts + einops.einsum(params, recurrences[:, j-dim:j], "batch dim, dim batch -> batch")

    #normalize to avoid exploding/vanishing grad
    max_vals, _ = t.max(t.abs(recurrences[:, prepend_bos:]), dim=-1, keepdim=True)
    recurrences[:, prepend_bos:] /= (4 * max_vals) 
    #round everything to 4 decimal places
    recurrences = t.round(recurrences * 1e4) / 1e4
    
    if return_type == 'seq':
        return recurrences
    elif return_type == 'both':
        return recurrences, (params, consts / (4 * max_vals))
    
    assert False

generate_linear_recurrences(1)

In [None]:
def tokenize(seq : Float[Tensor, "batch num_pos"]):
    return nn.functional.one_hot((seq * 10000 + 10000).long(), num_classes=20001)

tokenize(generate_linear_recurrences(2, return_type='seq'))

In [None]:
def linear_lr(step, steps):
    return (1 - (step / steps))

def constant_lr(*_):
    return 1.0

def cosine_decay_lr(step, steps):
    return np.cos(0.5 * np.pi * step / (steps - 1))

class TransformerModel(nn.Module):
    def __init__(self, n_dims, n_positions, n_embd=768, n_layer=4, n_head=12):
        super(TransformerModel, self).__init__()
        configuration = GPT2Config(
            n_positions=n_positions,
            n_embd=n_embd,
            n_layer=n_layer,
            n_head=n_head,
            resid_pdrop=0.0,
            embd_pdrop=0.0,
            attn_pdrop=0.0,
        )
        self.name = f"gpt2_embd={n_embd}_layer={n_layer}_head={n_head}"

        self.n_positions = n_positions
        self.n_dims = n_dims
        self._read_in = nn.Linear(n_dims, n_embd)
        self._backbone = GPT2Model(configuration)
        self._read_out = nn.Linear(n_embd, n_dims)
        self.loss = nn.MSELoss()

    def forward(self, seq):
        embeds = self._read_in(seq)
        output = self._backbone(inputs_embeds=embeds).last_hidden_state
        prediction = self._read_out(output)
        return prediction
    
    def calculate_loss(self, 
                       pred: Float[Tensor, "batch_size seq d_vocab"],
                       orig: Float[Tensor, "batch_size seq d_vocab"]):
        
        return self.loss(orig[:, 1:, :], pred[:, :-1, :])
    
    def generate_batch_and_tokenize(self, batch_size):
        recurrences = generate_linear_recurrences(batch_size, return_type='seq')
        return tokenize(recurrences).float()

    def optimize(
        self,
        batch_size: int = 64,
        steps: int = 10_000,
        log_freq: int = 100,
        lr: float = 1e-3,
        lr_scale: Callable[[int, int], float] = constant_lr,
    ):
        '''
        Optimizes the model using the given hyperparameters.
        '''
        optimizer = t.optim.Adam(list(self.parameters()), lr=lr)

        progress_bar = tqdm(range(steps))

        for step in progress_bar:

            # Update learning rate
            step_lr = lr * lr_scale(step, steps)
            for group in optimizer.param_groups:
                group['lr'] = step_lr

            # Optimize
            optimizer.zero_grad()
            batch = self.generate_batch_and_tokenize(batch_size)
            out = self(batch)
            loss = self.calculate_loss(out, batch)
            loss.backward()
            optimizer.step()

            # Display progress bar
            if step % log_freq == 0 or (step + 1 == steps):
                progress_bar.set_postfix(loss=loss.item(), lr=step_lr)


In [114]:
model = TransformerModel(
    n_dims=20_001,
    n_positions=30
)

model.optimize(steps=10_000)

  1%|          | 94/10000 [05:43<10:03:08,  3.65s/it, loss=0.331, lr=0.001]


KeyboardInterrupt: 