In [1]:
import torch
import torch.nn as nn
import numpy as np
from torch.utils.data import Dataset, DataLoader
from transformers import GPT2LMHeadModel, GPT2Config
import matplotlib.pyplot as plt
import ipywidgets as widgets
from IPython.display import display
from tqdm.notebook import tqdm
from safetensors.torch import save_model, load_model

## Tokenizer
Each digit would be cosidered as a token, and to that we will add the +, =, and PAD tokens

In [15]:
class SimpleTokenizer:
    def __init__(self):
        self.vocab = {
            **{str(i): i for i in range(10)},  # Digits 0-9
            "+": 10, "=": 11, "PAD": 12
        }
        self.inv_vocab = {v: k for k, v in self.vocab.items()}

    def encode(self, text, max_length=9, return_tensor=True):
        tokens = []
        for char in text:
            if char in self.vocab:
                tokens.append(self.vocab[char])
        # Pad with "PAD" tokens
        tokens += [self.vocab["PAD"]] * (max_length - len(tokens))
        return torch.tensor(tokens) if return_tensor else tokens

    def decode(self, tokens):
        return "".join([self.inv_vocab[t] for t in tokens if t != self.vocab["PAD"]])

tokenizer = SimpleTokenizer()

## Dataset and dataloader

The dataset consists of:
- inputs, which contains the token ids of operands and the operation for additions. The numbers ar each sampled from in the range [0,50[ so as to have a sum < 100
- targets where formulated in a way that it that the model would not be trained to predict only the results (by using the -100 token

In [16]:
class AdditionDataset(Dataset):
    def __init__(self, size=10000, max_length=9):
        self.data = [(np.random.randint(0, 99), np.random.randint(0, 99)) for _ in range(size)]
        self.max_length = max_length

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        a, b = self.data[idx]
        input_part = f"{a}+{b}="  # e.g., "12+34="
        target_part = str(a + b)   # e.g., "46"
        full_sequence = input_part + target_part  # e.g., "12+34=46"

        # Tokenize the full sequence
        input_ids = tokenizer.encode(full_sequence, max_length=self.max_length)
        attention_mask = (input_ids != tokenizer.vocab["PAD"]).float()
        return input_ids, attention_mask


dataset = AdditionDataset()
dataloader = DataLoader(dataset, batch_size=2048, shuffle=True, pin_memory=True)

## GPT2 Model init:

In [4]:
config = GPT2Config(
    vocab_size=len(tokenizer.vocab),
    n_positions=9,
    n_embd=128,
    n_layer=8,
    n_head=2
)
model = GPT2LMHeadModel(config)
device = "cuda" if torch.cuda.is_available() else "cpu"

model = model.to(device)

In [38]:
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-5)

def train(model, dataloader, optimizer, epochs=10):
    model.train()
    progress_bar = tqdm(range(epochs), desc='Training', unit='epoch')
    
    for epoch in progress_bar:
        total_loss = 0
        for input_ids, attention_mask in dataloader:
            input_ids = input_ids.to(device)
            attention_mask = attention_mask.to(device)
            
            # Forward pass
            outputs = model(input_ids, attention_mask=attention_mask, labels=input_ids)
            loss = outputs.loss
            
            # Backward pass
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            total_loss += loss.item()
        
        avg_loss = total_loss / len(dataloader)
        progress_bar.set_postfix(loss=f"{avg_loss:.4f}")  # Updates tqdm bar instead of printing new lines

In [39]:
train(model, dataloader, optimizer, epochs=100)

save_model(model, "models/model.safetensors")

Training:   0%|          | 0/100 [00:00<?, ?epoch/s]

In [5]:
load_model(model, "models/model.safetensors")
model.eval()

GPT2LMHeadModel(
  (transformer): GPT2Model(
    (wte): Embedding(13, 128)
    (wpe): Embedding(9, 128)
    (drop): Dropout(p=0.1, inplace=False)
    (h): ModuleList(
      (0-7): 8 x GPT2Block(
        (ln_1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
        (attn): GPT2Attention(
          (c_attn): Conv1D(nf=384, nx=128)
          (c_proj): Conv1D(nf=128, nx=128)
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_2): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
        (mlp): GPT2MLP(
          (c_fc): Conv1D(nf=512, nx=128)
          (c_proj): Conv1D(nf=128, nx=512)
          (act): NewGELUActivation()
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
    )
    (ln_f): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
  )
  (lm_head): Linear(in_features=128, out_features=13, bias=False)
)

In [6]:
def predict(a, b):
    input_part = f"{a}+{b}="
    input_ids = tokenizer.encode(input_part, max_length=len(input_part)).unsqueeze(0).to(device)  # Add batch dim
    #attention_mask = (input_ids != tokenizer.vocab["PAD"]).float().to(device)
    print(input_ids)
    model.eval()
    with torch.no_grad():
        outputs = model.generate(
            input_ids,
            #attention_mask=attention_mask,
            max_new_tokens=4,  # Generate up to 3 digits (e.g., "123")
            pad_token_id=tokenizer.vocab["PAD"]
        )
    print(outputs)
    generated = tokenizer.decode(outputs[0].cpu().numpy())
    print(generated)
    return generated.split("=")[-1]  # Extract the generated sum

In [7]:
input_part = '12+99='
tokenizer.encode(input_part, max_length=len(input_part))

tensor([ 1,  2, 10,  9,  9, 11])

In [18]:
model.cuda()

GPT2LMHeadModel(
  (transformer): GPT2Model(
    (wte): Embedding(13, 128)
    (wpe): Embedding(9, 128)
    (drop): Dropout(p=0.1, inplace=False)
    (h): ModuleList(
      (0-7): 8 x GPT2Block(
        (ln_1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
        (attn): GPT2Attention(
          (c_attn): Conv1D(nf=384, nx=128)
          (c_proj): Conv1D(nf=128, nx=128)
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_2): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
        (mlp): GPT2MLP(
          (c_fc): Conv1D(nf=512, nx=128)
          (c_proj): Conv1D(nf=128, nx=512)
          (act): NewGELUActivation()
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
    )
    (ln_f): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
  )
  (lm_head): Linear(in_features=128, out_features=13, bias=False)
)

In [23]:
predict(23, 12)

tensor([[ 2,  3, 10,  1,  2, 11]], device='cuda:0')
tensor([[ 2,  3, 10,  1,  2, 11,  3,  5, 12, 12]], device='cuda:0')
23+12=35


'35'

In [49]:
# Generate output and extract attention maps
def visualize_attention(sentence, layer_idx, head_idx):
    input_ids = tokenizer.encode(sentence, return_tensor=True, max_length=9).unsqueeze(0).to(device)  # Add batch dimension
    attention_mask = (input_ids != tokenizer.vocab["PAD"]).float().to(device)  # Create attention mask
    model.eval()
    with torch.no_grad():
        outputs = model(input_ids, attention_mask=attention_mask, output_attentions=True)

    attentions = torch.stack(outputs.attentions).squeeze(dim=1).cpu().numpy()  # Shape: (layers, heads, seq_len, seq_len)
    attn_map = attentions[layer_idx-1, head_idx-1]

    plt.figure(figsize=(8, 6))
    plt.imshow(attn_map, cmap="viridis")
    plt.colorbar()
    plt.title(f"Attention Map (Layer {layer_idx}, Head {head_idx})")
    plt.xticks(range(len(sentence)), [x for x in sentence])
    plt.yticks(range(len(sentence)), [x for x in sentence])
    plt.show()

# Interactive widgets
sentence_widget = widgets.Text(value="23+12=34", description="Sentence:")
layer_widget = widgets.IntSlider(value=1, min=1, max=len(model.transformer.h), step=1, description="Layer:")
head_widget = widgets.IntSlider(value=1, min=1, max=2, step=1, description="Head:")

ui = widgets.VBox([sentence_widget, layer_widget, head_widget])
output = widgets.Output()

def update_visualization(_):
    with output:
        output.clear_output(wait=True)
        visualize_attention(sentence_widget.value, layer_widget.value, head_widget.value)

sentence_widget.observe(update_visualization, names='value')
layer_widget.observe(update_visualization, names='value')
head_widget.observe(update_visualization, names='value')

display(ui, output)
update_visualization(None)


VBox(children=(Text(value='23+12=34', description='Sentence:'), IntSlider(value=1, description='Layer:', max=8…

Output()