# Imports

In [1]:
import torch
import random
import pickle
import torch.nn.functional as F
from transformer_lens import HookedTransformerConfig, HookedTransformer
from typing import List
from torch.utils.data import Dataset, DataLoader
from tqdm.notebook import trange
from transformer_lens.utils import lm_cross_entropy_loss
from torch import optim
from transformer_lens.train import train
import os
import plotly.express as px

In [2]:
torch.set_default_device("cuda")

In [3]:
print(f"torch.cuda.is_available(): {torch.cuda.is_available()}")
print(f"torch.cuda.device_count(): {torch.cuda.device_count()}")

torch.cuda.is_available(): True
torch.cuda.device_count(): 1


# Tokenizer

In [4]:
class FibonacciTokenizer:
    def __init__(self, max_num=128):
        self._special_tokens = {
            'START': 0,
            'SEP': 1
        }
        
        # Create token mappings
        self.num_to_token = {i: i + len(self._special_tokens) for i in range(max_num + 1)}
        self.token_to_num = {v: k for k, v in self.num_to_token.items()}
        
        # Add special tokens to mappings
        self.num_to_token.update(self._special_tokens)
        self.token_to_num.update({v: k for k, v in self._special_tokens.items()})
        
        self.vocab_size = max_num + 1 + len(self._special_tokens)
        self.max_encoded_value = self.vocab_size - 1

    def encode(self, sequence: List[int]) -> List[int]:
        encoded = [self._special_tokens['START']]
        for i, num in enumerate(sequence):
            encoded.append(self.num_to_token[num])
            if i < len(sequence) - 1:
                encoded.append(self._special_tokens['SEP'])
        return encoded

    def decode(self, tokens: List[int]) -> str:
        decoded = []
        for token in tokens[1:]:  # Skip START token
            if token == self._special_tokens['SEP']:
                decoded.append(' ')
            elif token in self.token_to_num:
                decoded.append(str(self.token_to_num[token]))
        return ''.join(decoded)  # Join with empty string
    
    def get_max_encodable_value(self) -> int:
        return max(key for key in self.num_to_token.keys() if isinstance(key, int))

    def special_chars(self) -> List[str]:
        return list(self._special_tokens.values())
    
    def vocab_without_special_chars(self) -> List[int]:
        return list(range(len(self._special_tokens), self.vocab_size))

    def get_max_encoded_value(self) -> int:
        return self.max_encoded_value

In [5]:
tokenizer = FibonacciTokenizer(max_num=600)

print(f"The maximum encodable value is: {tokenizer.get_max_encodable_value()}")
print(f"The maximum encoded value is: {tokenizer.get_max_encoded_value()}")

The maximum encodable value is: 600
The maximum encoded value is: 602


In [6]:
def find_valid_fibonacci_sequences(max_value, sequence_length):
    valid_sequences = set()
    
    for F1 in range(1, max_value + 1):
        for F2 in range(F1, max_value + 1):
            sequence = [F1, F2]
            valid = True
            
            # Generate the Fibonacci sequence up to the desired length
            for _ in range(sequence_length - 2):
                next_value = sequence[-1] + sequence[-2]
                if next_value > max_value:
                    valid = False
                    break
                sequence.append(next_value)
            
            # If the entire sequence is valid, store it
            if valid:
                valid_sequences.add(tuple(sequence))
    
    return valid_sequences

# Find the valid sequences
valid_sequences = find_valid_fibonacci_sequences(600, 8)
print(f"Number of unique valid Fibonacci sequences: {len(valid_sequences)}")

Number of unique valid Fibonacci sequences: 651


# Dataset Generation

In [7]:
def generate_fibonacci(n, start1, start2, max_value):
    fib = [start1, start2]
    for _ in range(2, n):
        next_val = fib[-1] + fib[-2]
        if next_val > max_value:
            return None
        fib.append(next_val)
    return tuple(fib)  # Return a tuple instead of a list

def create_fibonacci_dataset(num_samples, seq_length=4, max_value=128):
    dataset = []
    original_numbers = []
    unique_sequences = set()

    while len(unique_sequences) < num_samples:
        start1 = random.randint(0, max_value // 4)
        start2 = random.randint(start1, max_value // 2)
        fib_seq = generate_fibonacci(seq_length, start1, start2, max_value)
        
        if fib_seq and fib_seq not in unique_sequences:
            unique_sequences.add(fib_seq)  # fib_seq is already a tuple
            
            # Encode the full sequence
            full_sequence = tokenizer.encode(fib_seq)
            
            dataset.append(full_sequence)
            original_numbers.append(fib_seq)
        # if len(unique_sequences) % 50 == 0:
        #     print(f"Number of unique sequences: {len(unique_sequences)}")
    return dataset, original_numbers

# Generate datasets
train_data, train_original = create_fibonacci_dataset(max_value=600, num_samples=400, seq_length=8)
val_data, val_original = create_fibonacci_dataset(max_value=600, num_samples=100, seq_length=8)
test_data, test_original = create_fibonacci_dataset(max_value=600, num_samples=100, seq_length=8)

# Saving/Loading Datasets

Saved to .pkl files to save time and ensure consistency

In [8]:
# # Dataset dumping code
# datasets = {
#     'train': (train_data, train_original),
#     'val': (val_data, val_original),
#     'test': (test_data, test_original)
# }

# # for name, data in datasets.items():
# #     with open(f'{name}_dataset.pkl', 'wb') as f:
# #         pickle.dump(data, f)

# # Read datasets from files
# read_datasets = {}
# for name in datasets.keys():
#     with open(f'{name}_dataset.pkl', 'rb') as f:
#         read_datasets[name] = pickle.load(f)

# # # Check if generated and read datasets are the same
# # for name in datasets.keys():
# #     assert datasets[name] == read_datasets[name], f"Generated and read {name} datasets are not the same"

# # print("All datasets successfully saved, read, and verified.")

In [9]:
train_data[0]

[0, 21, 1, 33, 1, 52, 1, 83, 1, 133, 1, 214, 1, 345, 1, 557]

# Sanity Checks of dataset

In [10]:
assert all(len(seq) == 16 for seq in train_data), "All sequences in train_data should have length 16"
print("Train data sequences length check passed.")

assert all(len(seq) == 16 for seq in val_data), "All sequences in val_data should have length 16"
print("Validation data sequences length check passed.")

assert all(len(seq) == 16 for seq in test_data), "All sequences in test_data should have length 16"
print("Test data sequences length check passed.")

print("All data sequence length checks passed successfully.")

Train data sequences length check passed.
Validation data sequences length check passed.
Test data sequences length check passed.
All data sequence length checks passed successfully.


# Dataset and Dataloader Generation

In [11]:
def start_pos(seq: list) -> int:
    # Find the position of the second '1' in the sequence
    count = 0
    for i, value in enumerate(seq):
        if value == 1:
            count += 1
            if count == 2:
                return i
    return len(seq)

In [12]:
def collate_fn(data):
    pad_token = tokenizer.special_chars().index(1) # index of SEP/space token
    batch_size = len(data)
    max_len = max([d["length"] for d in data])
    
    padded_data = torch.stack([F.pad(d["tokens"], (0, max_len-d["length"]), value=pad_token) for d in data]).to("cuda")
    start_positions = torch.tensor([start_pos(d['tokens'].tolist()) for d in data], device="cuda").reshape(batch_size, 1)

    attention_mask = (torch.arange(max_len, device="cuda").repeat(batch_size).reshape(batch_size, max_len) >= start_positions).int()
    # print("Collated batch:")
    # print(f"Batch size: {batch_size}")
    # print(f"Max sequence length: {max_len}")
    # print(f"Padded data shape: {padded_data.shape}")
    # print(f"Attention mask shape: {attention_mask.shape}")
    # print(f"Start positions: {start_positions.squeeze().tolist()}")
    # print("Sample of padded data:")
    # print(padded_data[0])
    # print("Sample of attention mask:")
    # print(attention_mask[0])

    
    return {"tokens": padded_data, "attention_mask": attention_mask}

In [13]:
class FibonacciDataset(Dataset):
    def __init__(self, data):
        self.data = data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        tokens = self.data[idx]
        return {
            "tokens": torch.tensor(tokens, dtype=torch.long, device="cuda"),
            "length": len(tokens)
        }

train_dataset = FibonacciDataset(train_data)
test_dataset = FibonacciDataset(test_data)
val_dataset = FibonacciDataset(val_data)
generator = torch.Generator(device="cuda")

In [14]:
train_data_loader = DataLoader(
    dataset=train_dataset, 
    batch_size=1024, 
    generator=generator,
    collate_fn=collate_fn,
)

val_data_loader = DataLoader(
    dataset=val_dataset,
    batch_size=1024,
    generator=generator,
    collate_fn=collate_fn
)

test_data_loader = DataLoader(
    dataset=test_dataset, 
    batch_size=1024, 
    generator=generator, 
    collate_fn=collate_fn,
)

# Training Code

In [15]:
batch = next(iter(train_data_loader))
batch

{'tokens': tensor([[  0,  21,   1,  ..., 345,   1, 557],
         [  0,   7,   1,  ..., 299,   1, 484],
         [  0,   6,   1,  ..., 326,   1, 528],
         ...,
         [  0,  11,   1,  ..., 351,   1, 568],
         [  0,   2,   1,  ..., 330,   1, 535],
         [  0,   7,   1,  ..., 315,   1, 510]], device='cuda:0'),
 'attention_mask': tensor([[0, 0, 0,  ..., 1, 1, 1],
         [0, 0, 0,  ..., 1, 1, 1],
         [0, 0, 0,  ..., 1, 1, 1],
         ...,
         [0, 0, 0,  ..., 1, 1, 1],
         [0, 0, 0,  ..., 1, 1, 1],
         [0, 0, 0,  ..., 1, 1, 1]], device='cuda:0', dtype=torch.int32)}

In [16]:
batch['tokens'][0]

tensor([  0,  21,   1,  33,   1,  52,   1,  83,   1, 133,   1, 214,   1, 345,
          1, 557], device='cuda:0')

In [17]:
batch['attention_mask'][0]

tensor([0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], device='cuda:0',
       dtype=torch.int32)

In [18]:
tokenizer.decode(batch['tokens'].tolist()[0])

'19 31 50 81 131 212 343 555'

# Model Config

In [19]:
# Model Config
cfg = HookedTransformerConfig(
    n_layers=1,
    d_model=128,
    d_head=128,
    n_heads=4,
    d_vocab=tokenizer.vocab_size,
    n_ctx=16,
    act_fn="solu",
    attn_only=True,
    seed=42,
)

model = HookedTransformer(cfg)


In [20]:
def sample_data_test(expected: list):
    with torch.no_grad():
        offset = start_pos(expected)
        tokens = tokenizer.encode(expected)
        prefix = tokens[:offset]
        tokens = torch.tensor(tokens)
        i = offset
        output = []
        while i < len(expected):
            logits = model(tokens[:i])
            prediction = logits[:, -1, :].argmax().item()
            output.append(prediction)
            # print(tokens[:i])
            # print(output)
            i += 1
        print(f"expected (tokenized): {expected}")
        print(f"expected: {tokenizer.decode(expected)}")
        print(f"actual: {tokenizer.decode(prefix + output)}")

sample_data_test(batch['tokens'].tolist()[0])

expected (tokenized): [0, 21, 1, 33, 1, 52, 1, 83, 1, 133, 1, 214, 1, 345, 1, 557]
expected: 19 31 50 81 131 212 343 555
actual: 0 21296216196235345196482482482196363482


In [21]:
# model

In [21]:
total_params = sum(p.numel() for p in model.parameters())
print(f"Total number of parameters: {total_params:,}")
print(f"The model is loaded on: {model.cfg.device}")

Total number of parameters: 421,339
The model is loaded on: cuda


In [22]:
model.zero_grad()
model.init_weights()

# Training

In [23]:
from typing import List, Tuple
import os

def train(
    model: HookedTransformer,
    num_epochs=10,
    lr=1e-4,
    print_every=500,
    betas=(.9, .99),
    max_grad_norm=1.0,
    checkpoint_dir='checkpoints_longer_new'
) -> Tuple[List[float], List[float]]:

    optimizer = optim.AdamW(model.parameters(), lr=lr, betas=betas)
    train_losses = []
    val_losses = []
    
    os.makedirs(checkpoint_dir, exist_ok=True)
    
    for epoch in trange(1, num_epochs + 1):
        model.train()
        for step, batch in enumerate(train_data_loader):
            tokens = batch["tokens"].to(model.cfg.device)
            attention_mask = batch["attention_mask"].to(model.cfg.device)
            logits = model(input=tokens)

            loss = lm_cross_entropy_loss(logits, tokens, attention_mask)
            train_losses.append(loss.item())
            loss.backward()

            torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)

            optimizer.step()
            optimizer.zero_grad()
            
            if print_every is not None and step % print_every == 0:
                print(f"Epoch {epoch} Step {step} Train Loss {loss.item():.8f}")
        
        # Validation
        model.eval()
        with torch.no_grad():
            val_loss = 0
            for batch in val_data_loader:
                tokens = batch["tokens"].to(model.cfg.device)
                attention_mask = batch["attention_mask"].to(model.cfg.device)
                logits = model(input=tokens)
                val_loss += lm_cross_entropy_loss(logits, tokens, attention_mask).item()
            val_loss /= len(val_data_loader)
            val_losses.append(val_loss)
            print(f"Epoch {epoch} Validation Loss: {val_loss:.8f}")
        
        # Checkpoint every 5k epochs
        if epoch % 5000 == 0:
            checkpoint_path = os.path.join(checkpoint_dir, f'model_checkpoint_epoch_{epoch}.pt')
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'train_loss': train_losses[-1],
                'val_loss': val_losses[-1],
                'train_losses': train_losses,
                'val_losses': val_losses,
                'model_config': model.cfg
            }, checkpoint_path)
            print(f"Checkpoint saved at epoch {epoch}!")

    return train_losses, val_losses

In [24]:
train_losses, val_losses = train(
    model=model,
    num_epochs=200,
    lr=1e-3,
    print_every=1000
    )

  0%|          | 0/200 [00:00<?, ?it/s]

Epoch 1 Step 0 Train Loss 6.47274256
Epoch 1 Validation Loss: 4.89055157
Epoch 2 Step 0 Train Loss 4.85762882
Epoch 2 Validation Loss: 4.23917389
Epoch 3 Step 0 Train Loss 4.20422745
Epoch 3 Validation Loss: 4.09116745
Epoch 4 Step 0 Train Loss 4.05237389
Epoch 4 Validation Loss: 3.83856273
Epoch 5 Step 0 Train Loss 3.79408550
Epoch 5 Validation Loss: 3.65820622
Epoch 6 Step 0 Train Loss 3.60434175
Epoch 6 Validation Loss: 3.58264375
Epoch 7 Step 0 Train Loss 3.52239990
Epoch 7 Validation Loss: 3.51402259
Epoch 8 Step 0 Train Loss 3.44963026
Epoch 8 Validation Loss: 3.43954110
Epoch 9 Step 0 Train Loss 3.37106776
Epoch 9 Validation Loss: 3.37173223
Epoch 10 Step 0 Train Loss 3.29857779
Epoch 10 Validation Loss: 3.31369424
Epoch 11 Step 0 Train Loss 3.23494864
Epoch 11 Validation Loss: 3.26250124
Epoch 12 Step 0 Train Loss 3.17743659
Epoch 12 Validation Loss: 3.21482015
Epoch 13 Step 0 Train Loss 3.12306762
Epoch 13 Validation Loss: 3.16862822
Epoch 14 Step 0 Train Loss 3.07006884
Epoch

In [25]:
px.line(train_losses, labels={"index": "steps", "value": "loss", "title": "Loss vs steps", "variable": "loss"})

In [26]:
px.line(val_losses, labels={"index": "steps", "value": "loss", "title": "Loss vs steps", "variable": "loss"})

In [29]:
model

HookedTransformer(
  (embed): Embed()
  (hook_embed): HookPoint()
  (pos_embed): PosEmbed()
  (hook_pos_embed): HookPoint()
  (blocks): ModuleList(
    (0): TransformerBlock(
      (ln1): LayerNorm(
        (hook_scale): HookPoint()
        (hook_normalized): HookPoint()
      )
      (attn): Attention(
        (hook_k): HookPoint()
        (hook_q): HookPoint()
        (hook_v): HookPoint()
        (hook_z): HookPoint()
        (hook_attn_scores): HookPoint()
        (hook_pattern): HookPoint()
        (hook_result): HookPoint()
      )
      (hook_attn_in): HookPoint()
      (hook_q_input): HookPoint()
      (hook_k_input): HookPoint()
      (hook_v_input): HookPoint()
      (hook_mlp_in): HookPoint()
      (hook_attn_out): HookPoint()
      (hook_mlp_out): HookPoint()
      (hook_resid_pre): HookPoint()
      (hook_resid_post): HookPoint()
    )
  )
  (ln_final): LayerNorm(
    (hook_scale): HookPoint()
    (hook_normalized): HookPoint()
  )
  (unembed): Unembed()
)

# Hyperparameter Tuning Try

In [None]:
# # Define parameter grid
# param_grid = {
#     'n_layers': [2, 3, 4],
#     'd_model': [128, 256],
#     'd_head': [64, 128, 256],
#     'n_heads': [2, 3, 4],
#     'num_epochs': [100, 200, 300]
# }

# # Store results
# results = []

# # Perform grid search
# for n_layers in param_grid['n_layers']:
#     for d_model in param_grid['d_model']:
#         for d_head in param_grid['d_head']:
#             for n_heads in param_grid['n_heads']:
#                 for num_epochs in param_grid['num_epochs']:
#                     print(f"Current config: n_layers={n_layers}, d_model={d_model}, d_head={d_head}, n_heads={n_heads}, num_epochs={num_epochs}")
#                     cfg = HookedTransformerConfig(
#                         n_layers=n_layers,
#                         d_model=d_model,
#                         d_head=d_head,
#                         n_heads=n_heads,
#                         d_vocab=tokenizer.vocab_size,
#                         n_ctx=8,
#                         act_fn="solu",
#                         attn_only=True,
#                         seed=42,
#                     )
                    
#                     # Create new model
#                     model = HookedTransformer(cfg)
                    
#                     # Train model
#                     losses = train(
#                         model=model, 
#                         num_epochs=num_epochs,
#                         lr=1e-3,
#                         print_every=100,
#                     )
                    
#                     # Store results
#                     results.append({
#                         'n_layers': n_layers,
#                         'd_model': d_model,
#                         'd_head': d_head,
#                         'n_heads': n_heads,
#                         'n_ctx': 8,
#                         'lr': 1e-3,
#                         'num_epochs': num_epochs,
#                         'final_loss': losses[-1],
#                         'all_losses': losses
#                     })
                    
#                     # Clear CUDA cache
#                     torch.cuda.empty_cache()
                    
#                     # Reset everything
#                     del model
#                     gc.collect()

# # Print best model configuration
# best_model = min(results, key=lambda x: x['final_loss'])
# print(f"Best model configuration: {best_model}")

# Check on Test

In [27]:
# Set the model to evaluation mode
model.eval()

# Disable gradient computation
with torch.no_grad():
    for i, test_batch in enumerate(test_data_loader):
        if i >= 10:
            break

        # Extract tokens and attention mask
        test_tokens = test_batch["tokens"]
        test_attention_mask = test_batch["attention_mask"]

        # Get model predictions
        test_logits = model(input=test_tokens)

        # Calculate loss
        test_loss = lm_cross_entropy_loss(test_logits, test_tokens, test_attention_mask)

        # Get predicted tokens
        predicted_tokens = torch.argmax(test_logits, dim=-1)

        # Print results for each item in the batch
        for j in range(test_tokens.shape[0]):
            input_text = tokenizer.decode(test_tokens[j].tolist())
            predicted_text = tokenizer.decode(predicted_tokens[j].tolist())
            
            print(f"Sample {i*test_tokens.shape[0] + j + 1}:")
            print(f"Test Loss: {test_loss.item()}")
            print("Actual Tokens:", test_tokens[j])
            print("Predicted Tokens:", predicted_tokens[j])
            print("Input Text:", input_text)
            print("Predicted Text:", predicted_text)
            print("-" * 50)

Sample 1:
Test Loss: 1.6185574531555176
Actual Tokens: tensor([  0,   6,   1,  16,   1,  20,   1,  34,   1,  52,   1,  84,   1, 134,
          1, 216], device='cuda:0')
Predicted Tokens: tensor([  1,   1,  56,   1,  20,   1,  34,   1,  52,   1,  84,   1, 134,   1,
        216,   1], device='cuda:0')
Input Text: 4 14 18 32 50 82 132 214
Predicted Text:  54 18 32 50 82 132 214 
--------------------------------------------------
Sample 2:
Test Loss: 1.6185574531555176
Actual Tokens: tensor([  0,  20,   1,  35,   1,  53,   1,  86,   1, 137,   1, 221,   1, 356,
          1, 575], device='cuda:0')
Predicted Tokens: tensor([  1,   1,  56,   1,  53,   1,  86,   1, 137,   1, 221,   1, 356,   1,
        575,   1], device='cuda:0')
Input Text: 18 33 51 84 135 219 354 573
Predicted Text:  54 51 84 135 219 354 573 
--------------------------------------------------
Sample 3:
Test Loss: 1.6185574531555176
Actual Tokens: tensor([  0,   5,   1,  38,   1,  41,   1,  77,   1, 116,   1, 191,   1, 305,
  

# Trying with Full Batch Training instead of SGD

Trying to see if MLPs are necessary??

In [36]:

new_cfg = HookedTransformerConfig(
    n_layers=1,
    d_model=256,
    d_head=32,
    n_heads=4,
    d_vocab=tokenizer.vocab_size,
    n_ctx=16,
    act_fn="gelu",
    attn_only=False,
    seed=42,
    d_mlp=512
)

model_new = HookedTransformer(new_cfg)


In [33]:
def create_attention_mask(tokens):
    """Create attention mask for the given tokens."""
    mask = torch.zeros_like(tokens, dtype=torch.long)
    for i, seq in enumerate(tokens):
        # Find the index of the first '1' in the sequence
        ones = (seq == 1).nonzero(as_tuple=True)[0]
        if len(ones) >= 1:
            first_one_index = ones[0].item()
            mask[i, first_one_index+1:] = 1
    return mask

def prepare_data_for_training(data):
    """Prepare data for full batch training."""
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    # Convert list of lists to PyTorch tensor
    tokens = torch.tensor(data, dtype=torch.long).to(device)
    
    # Create attention mask
    attention_mask = create_attention_mask(tokens).to(device)
    
    return {"tokens": tokens, "attention_mask": attention_mask}

In [34]:
def train_full_batch(
    model: HookedTransformer,
    train_data: List[List[int]],
    val_data: List[List[int]],
    num_epochs=10,
    lr=1e-4,
    betas=(0.9, 0.99),
    checkpoint_dir='checkpoints'
) -> Tuple[List[float], List[float]]:

    optimizer = optim.AdamW(model.parameters(), lr=lr, betas=betas)
    train_losses = []
    val_losses = []
    
    os.makedirs(checkpoint_dir, exist_ok=True)
    
    # Prepare data
    train_batch = prepare_data_for_training(train_data)
    val_batch = prepare_data_for_training(val_data)
    
    for epoch in trange(1, num_epochs + 1):
        # Training
        model.train()
        optimizer.zero_grad()
        
        tokens = train_batch["tokens"]
        attention_mask = train_batch["attention_mask"]
        logits = model(input=tokens)
        
        loss = lm_cross_entropy_loss(logits, tokens, attention_mask)
        train_losses.append(loss.item())
        
        loss.backward()
        optimizer.step()
        
        print(f"Epoch {epoch} Train Loss {loss.item():.8f}")
        
        # Validation
        model.eval()
        with torch.no_grad():
            val_tokens = val_batch["tokens"]
            val_attention_mask = val_batch["attention_mask"]
            val_logits = model(input=val_tokens)
            val_loss = lm_cross_entropy_loss(val_logits, val_tokens, val_attention_mask)
            val_losses.append(val_loss.item())
            print(f"Epoch {epoch} Validation Loss: {val_loss.item():.8f}")
        
        # Checkpoint every 500 epochs
        if epoch % 500 == 0:
            checkpoint_path = os.path.join(checkpoint_dir, f'new_model_checkpoint_epoch_{epoch}.pt')
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'train_loss': train_losses[-1],
                'val_loss': val_losses[-1],
                'train_losses': train_losses,
                'val_losses': val_losses,
                'model_config': model.cfg
            }, checkpoint_path)
            print(f"Checkpoint saved at epoch {epoch}!")

    return train_losses, val_losses


In [37]:
model_new.zero_grad()
model_new.init_weights()
train_losses_new, val_losses_new = train_full_batch(
    model_new,
    train_data,
    val_data,
    num_epochs=30000,
    lr=1e-3,
)

  0%|          | 0/30000 [00:00<?, ?it/s]

Epoch 1 Train Loss 6.92240906
Epoch 1 Validation Loss: 3.98851275
Epoch 2 Train Loss 3.98831320
Epoch 2 Validation Loss: 3.94448662
Epoch 3 Train Loss 3.92503095
Epoch 3 Validation Loss: 3.38995004
Epoch 4 Train Loss 3.36486816
Epoch 4 Validation Loss: 3.26177096
Epoch 5 Train Loss 3.23120451
Epoch 5 Validation Loss: 3.20615196
Epoch 6 Train Loss 3.17094183
Epoch 6 Validation Loss: 3.08911872
Epoch 7 Train Loss 3.04835296
Epoch 7 Validation Loss: 3.01001644
Epoch 8 Train Loss 2.96245360
Epoch 8 Validation Loss: 2.95623398
Epoch 9 Train Loss 2.90061116
Epoch 9 Validation Loss: 2.90970802
Epoch 10 Train Loss 2.84470081
Epoch 10 Validation Loss: 2.86492276
Epoch 11 Train Loss 2.79007769
Epoch 11 Validation Loss: 2.82049489
Epoch 12 Train Loss 2.73611236
Epoch 12 Validation Loss: 2.77532768
Epoch 13 Train Loss 2.68216848
Epoch 13 Validation Loss: 2.72855926
Epoch 14 Train Loss 2.62797761
Epoch 14 Validation Loss: 2.68154907
Epoch 15 Train Loss 2.57437515
Epoch 15 Validation Loss: 2.6357581

In [38]:
px.line(train_losses_new, labels={"index": "steps", "value": "loss", "title": "Loss vs steps", "variable": "loss"})

In [39]:
px.line(val_losses_new, labels={"index": "steps", "value": "loss", "title": "Loss vs steps", "variable": "loss"})

# Check on Test

In [40]:
# Set the model to evaluation mode
model_new.eval()

# Disable gradient computation
with torch.no_grad():
    for i, test_batch in enumerate(test_data_loader):
        if i >= 10:
            break

        # Extract tokens and attention mask
        test_tokens = test_batch["tokens"]
        test_attention_mask = test_batch["attention_mask"]

        # Get model predictions
        test_logits = model_new(input=test_tokens)

        # Calculate loss
        test_loss = lm_cross_entropy_loss(test_logits, test_tokens, test_attention_mask)

        # Get predicted tokens
        predicted_tokens = torch.argmax(test_logits, dim=-1)

        # Print results for each item in the batch
        for j in range(test_tokens.shape[0]):
            input_text = tokenizer.decode(test_tokens[j].tolist())
            predicted_text = tokenizer.decode(predicted_tokens[j].tolist())
            
            print(f"Sample {i*test_tokens.shape[0] + j + 1}:")
            # print(f"Test Loss: {test_loss.item()}")
            print("Actual Tokens:", test_tokens[j])
            print("Predicted Tokens:", predicted_tokens[j])
            print("Input Text:", input_text)
            print("Predicted Text:", predicted_text)
            print("-" * 50)

Sample 1:
Test Loss: 1.7943774461746216
Actual Tokens: tensor([  0,   2,   1,  30,   1,  30,   1,  58,   1,  86,   1, 142,   1, 226,
          1, 366], device='cuda:0')
Predicted Tokens: tensor([  1,   1,  44,   1,  30,   1,  58,   1,  86,   1, 142,   1, 226,   1,
        366,   1], device='cuda:0')
Input Text: 0 28 28 56 84 140 224 364
Predicted Text:  42 28 56 84 140 224 364 
--------------------------------------------------
Sample 2:
Test Loss: 1.7943774461746216
Actual Tokens: tensor([  0,   8,   1,  43,   1,  49,   1,  90,   1, 137,   1, 225,   1, 360,
          1, 583], device='cuda:0')
Predicted Tokens: tensor([  1, 213,  14,   1,  49,   1,  90,   1, 137,   1, 225,   1, 360,   1,
        583,   1], device='cuda:0')
Input Text: 6 41 47 88 135 223 358 581
Predicted Text: 21112 47 88 135 223 358 581 
--------------------------------------------------
Sample 3:
Test Loss: 1.7943774461746216
Actual Tokens: tensor([  0,  21,   1,  29,   1,  48,   1,  75,   1, 121,   1, 194,   1, 313,

# Old Thoughts

Do not seem to have convincing results at all. Things tried so far:
* Using 1 layer and 1 head - and training for less number of epochs (100-200) - not good results
* Tried a hyperparameter search - following are the results (losses being training loss):
    * Best model configuration: {'n_layers': 2, 'd_model': 256, 'd_head': 32, 'n_heads': 3, 'n_ctx': 8, 'lr': 0.001, 'num_epochs': 100, 'final_loss': 1.2969523668289185}
    *  Best model configuration: {'n_layers': 3, 'd_model': 256, 'd_head': 128, 'n_heads': 4, 'n_ctx': 8, 'lr': 0.001, 'num_epochs': 300, 'final_loss': 0.9581699967384338}
* Now trying 3 heads - intuition - addition uses 3 heads (https://philipquirke.github.io/transformer-maths/2023/10/14/Understanding-Addition.html) and the main operation behind fibonacci series is basically position tracking and addition. Maybe 4 heads?
* Tried to check for grokking by training for 20k epochs and 4 heads (under the assumption that 3 heads are for addition, and maybe the 4th head is for position tracking/copy head)
* Initial experiments were with a large vocab size, reduced later - intuition is that the model might have a hard time learning a huge variety of tokens?
* Not seeing the kinds of drop in loss as seen in sorting experiments in the 200 Concrete Problems sheet - suggesting that the model is having a hard time generalizing at all.
* According to the code in Addition, MLPs might actually be necessary here. Also Grokking Demo has MLPs - lets try with MLPs?


* There is a chance that due to the attention mask, and reduced seq length, the training is too little. Might be worth increasing d_vocab and seq_length a bit to see what happens.

# Final Thoughts

* Tried a lot of experiments, and finally, the results look sufficiently good, but there seems to be an "off-by-1" error, as seen by running the output on samples. This is potentially a problem but can be fixed - I'm sure its something silly and this is leaving the test/val loss high
* Given more time, could probably look into fixing it, but on qualitative results, this shouldnt make much of an effect?

# Digging into the model

In [60]:
test_tokens

tensor([[  0,   6,   1,  ..., 134,   1, 216],
        [  0,  20,   1,  ..., 356,   1, 575],
        [  0,   5,   1,  ..., 305,   1, 494],
        ...,
        [  0,  12,   1,  ..., 292,   1, 472],
        [  0,   3,   1,  ..., 247,   1, 400],
        [  0,  12,   1,  ..., 316,   1, 511]], device='cuda:0')

In [61]:
test_input = test_tokens[0, :]
test_input.shape

torch.Size([16])

In [62]:
str(tokenizer.decode(test_input.tolist()))

'4 14 18 32 50 82 132 214'

In [69]:
# Pass through model, get cache and predictions
logits, cache_model = model.run_with_cache(test_input, remove_batch_dim=True) 
logits.shape

torch.Size([1, 16, 603])

In [66]:
preds = logits.argmax(dim=-1).squeeze(0)

# Get attention pattern and plot it
attention_pattern = cache_model["pattern", 0, "attn"]
tokens_input = list(map(str, test_input))
print(test_input)
print(preds)
print(tokenizer.decode(test_input.tolist()))
print(tokenizer.decode(preds.tolist()))


tensor([  0,   6,   1,  16,   1,  20,   1,  34,   1,  52,   1,  84,   1, 134,
          1, 216], device='cuda:0')
tensor([  1,   1,  56,   1,  20,   1,  34,   1,  52,   1,  84,   1, 134,   1,
        216,   1], device='cuda:0')
4 14 18 32 50 82 132 214
 54 18 32 50 82 132 214 


In [67]:
import circuitsvis as cv

cv.attention.attention_patterns(tokens=tokens_input, attention=attention_pattern)

The off by 1 error seems to have affected the attention pattern as well, as token representing 82 (84) should attend to tokens representing 50 (52) and 32 (34), but instead seems to attend to the spaces. Probably the model being used is not complex enough/there is a small error in training/attention_mask part.

However, we can see small patterns expectedly emerge, like the token representing 214 (last element) attends to all the previous required terms above. So maybe an increase in seq_length can help identify more patterns.