# Teach an LLM to do additions

Student: Soël Megdoud


The goal of this project is to teach an LLM to do additions, playing only with two parts:
* the tokenizer
* the positional embedding

Both the model and the dataset are fixed.

You are allowed to tune the hyperparameters, but this is not the main goal. Depending on the quality of your tokenizer and positional embedding, you may change the number of bits. The initial value of 3 is very small.

In [60]:
import torch
from torch import nn
from torch.nn import functional as F

import random
import math
import re
import time

In [61]:
number_bits = 3

dataset_size = 64_000
train_proportion = 0.9

log_interval = 200
batch_size = 64
epochs = 4
learning_rate = 8e-4

## Step 1: Construct a tokenizer

In [63]:
pad_token="[PAD]"
eos_token="[EOS]"

### Baseline: character-level tokenizer

In [64]:
class character_level_tokenizer:
    """
    character-level
    """
    def __init__(self):
        self.vocab = [str(x) for x in range(10)] + ["+", "="] + [pad_token, eos_token]
        self.token_to_id = {v : k for k, v in enumerate(self.vocab)}
        self.id_to_token = {k : v for k, v in enumerate(self.vocab)}
        self.ntokens = len(self.vocab)
        self.pattern = f"[^{re.escape(''.join(self.vocab))}]"

    def clean(self, text):
        """
        removes all characters not in the vocabulary
        """
        out = re.sub(self.pattern, "", text)
        return out

    def pre_tokenization(self, text):
        """
        character-level
        """
        return [c for c in text]

    def encode(self, text):
        text_list = self.pre_tokenization(self.clean(text))
        return [self.token_to_id[c] for c in text_list]

    def decode(self, token_list):
        return "".join([self.id_to_token[x] for x in token_list])

In [65]:
tokenizer = character_level_tokenizer()
ntokens = tokenizer.ntokens
ntokens

14

In [66]:
prompt = "12 + 42 ="
inputs = tokenizer.encode(prompt)
inputs, tokenizer.decode(inputs)

([1, 2, 10, 4, 2, 11], '12+42=')

# Implement your tokenizer here!

You can do anything (as long as you do not compute the addition!).
Some ideas:
* reversing numbers left to right
* arranging by groups (of, 2, 3,...)
* aligning numbers

### My intuition


I have built a Tokenizer which encodes addition as a sequence of additions of digits, in the ascending order, to make it easier for the model to understand carries.

The idea is to decompose 127 + 365 as (7+1).e0 + (2+6)e1 + (1+3)e2.

This sequential encoding is better suited to LLM architectures




In [67]:
class ColumnWiseAdditionTokenizer:
    def __init__(self):
        self.vocab = [str(i) for i in range(10)] + ["+", "=", pad_token, eos_token]
        self.token_to_id = {v: k for k, v in enumerate(self.vocab)}
        self.id_to_token = {k: v for v, k in self.token_to_id.items()}
        self.ntokens = len(self.vocab)
        self.pattern = f"[^{re.escape(''.join(self.vocab))}]"

    def clean(self, text):
        """
        Removes all characters not in the vocabulary
        """
        return re.sub(self.pattern, "", text)

    def pre_tokenization(self, num1, num2=None):
        """
        Returns a list of characters.
        Tokenizes an addition into a column-wise addition sequence
        If only num1 is provided, it returns a simple tokenized number.
        """
        num1 = str(num1)
        if num2 is None:
            return list(num1)

        num2 = str(num2)
        max_len = max(len(num1), len(num2))
        num1, num2 = num1.zfill(max_len), num2.zfill(max_len)  # Padding with zeros

        tokens = []
        for i in range(max_len - 1, -1, -1):  # Start from least significant digit
            tokens.append(num1[i])
            tokens.append("+")
            tokens.append(num2[i])
            tokens.append("=")
        return tokens

    def encode(self, text):
        """
        Extract num1 and num2 from the string input and tokenize accordingly
        """
        match = re.match(r"(\d+)\+(\d+)=", text)
        if match:
            num1, num2 = match.groups()
            token_list = self.pre_tokenization(num1, num2)
        else:
            # If the input is just a number, encode it as digits
            token_list = self.pre_tokenization(text)

        return [self.token_to_id[t] for t in token_list if t in self.token_to_id]

    def decode(self, token_list):
        return "".join([self.id_to_token[x] for x in token_list if x in self.id_to_token])

# Example
tokenizer = ColumnWiseAdditionTokenizer()
ntokens = tokenizer.ntokens
encoded = tokenizer.encode('703+499=')
print("Encoded:", encoded)
print("Decoded:", tokenizer.decode(encoded))


Encoded: [3, 10, 9, 11, 0, 10, 9, 11, 7, 10, 4, 11]
Decoded: 3+9=0+9=7+4=


## Step 2: Create a dataset for arithmetic operations

In [68]:
def sample_datapoint(number_bits = 3):
    """
    returns a string containing two random numbers on `number_bits` many bits and their sum.
    """
    a_list = [random.randint(0, 9) for _ in range(number_bits)]
    b_list = [random.randint(0, 9) for _ in range(number_bits)]
    a_int = int("".join([str(x) for x in a_list]))
    b_int = int("".join([str(x) for x in b_list]))
    sum_int = a_int + b_int
    return (str(a_int) + "+" + str(b_int) + "=", str(sum_int))

sample_datapoint(3)

('207+564=', '771')

In [69]:
data = []
for _ in range(dataset_size):
    data.append(sample_datapoint(number_bits))
data[:4]

[('54+404=', '458'),
 ('716+251=', '967'),
 ('432+110=', '542'),
 ('986+41=', '1027')]

In [70]:
data_train = data[: int(train_proportion * dataset_size)]
data_test = data[int(train_proportion * dataset_size):]

len(data_train),len(data_test)

(57600, 6400)

## Step 3: Construct a model

### Basline: the classical Positional Embedding

In [72]:
class PositionalEmbedding(nn.Module):
    r"""Inject some information about the relative or absolute position of the tokens in the sequence.
        The positional encodings have the same dimension as the embeddings, so that the two can be summed.
        Here, we use sine and cosine functions of different frequencies.
    .. math:
        \text{PosEmbedder}(pos, 2i) = sin(pos/10000^(2i/d_model))
        \text{PosEmbedder}(pos, 2i+1) = cos(pos/10000^(2i/d_model))
        \text{where pos is the word position and i is the embed idx)
    Args:
        d_model: the embed dim (required).
        dropout: the dropout value (default=0.1).
        max_len: the max. length of the incoming sequence (default=5000).
    """

    def __init__(self, d_model, dropout=0.1, max_len=5000):
        super(PositionalEmbedding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)

        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1)
        self.register_buffer('pe', pe)

    def forward(self, x):
        """Inputs of forward function
        Args:
            x: the sequence fed to the positional encoder model (required).
        Shape:
            x: [sequence length, batch size, embed dim]
            output: [sequence length, batch size, embed dim]
        """
        print("x size ", x.size())
        print("self.pe[:x.size(0), :] size ", self.pe[:x.size(0), :].size())
        x = x + self.pe[:x.size(0), :]
        return self.dropout(x)


# Implement your positional embedding here!

You can do anything. Some ideas:
* RoPE
* (randomised) FIRE
* Abacus

**!!! IMPORTANT !!!** This model of Transformers is "input first", meaning that an input is a tensor with shape
(length_prompts, batch_size)

## My intuition

My tokenization and addition representation as sequences is already very helpful for the model. I reach 100% accuracy at epoch 4 with the classic positional embedding. However, I implemented another version which learns embeddings directly associated to positions (embeding for 1st digit, 2nd digit etc.). As the max lenght is small, learning a few different embeddings can be better and result in faster training.


In [74]:
class TrainablePositionalEmbedding(nn.Module):
    def __init__(self, d_model, max_len=5000, dropout=0.1):
        super(TrainablePositionalEmbedding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)

        self.pe = nn.Parameter(torch.randn(max_len, 1, d_model) * 0.01)  # 1 for batch size broadcasting

    def forward(self, x):
        """
        Args:
            x: Tensor de forme [sequence length, batch size, embed dim]
        """
        #print("x size:", x.size())
        #print("self.pe[:x.size(0), :, :] size:", self.pe[:x.size(0), :, :].size())

        x = x + self.pe[:x.size(0), :, :]
        return self.dropout(x)


In [75]:
class TransformerModel(nn.Transformer):
    def __init__(self, ntoken, ninp, nhead, nhid, nlayers, dropout=0.5):
        super(TransformerModel, self).__init__(d_model=ninp,
                                               nhead=nhead,
                                               dim_feedforward=nhid,
                                               num_encoder_layers=nlayers)
        self.input_emb = nn.Embedding(ntoken, ninp)
        self.pos_encoder = TrainablePositionalEmbedding(d_model=ninp, dropout=dropout)  #PositionalEmbedding(ninp, dropout)
        self.decoder = nn.Linear(ninp, ntoken)

        self.ninp = ninp
        self.init_weights()

    def init_weights(self):
        initrange = 0.1
        nn.init.uniform_(self.input_emb.weight, -initrange, initrange)
        nn.init.zeros_(self.decoder.bias)
        nn.init.uniform_(self.decoder.weight, -initrange, initrange)

    def _generate_square_subsequent_mask(self, sz):
        return torch.log(torch.tril(torch.ones(sz,sz)))

    def forward(self, src):
        mask = self._generate_square_subsequent_mask(len(src)).to(device)
        self.src_mask = mask

        src = self.input_emb(src) * math.sqrt(self.ninp)
        src = self.pos_encoder(src)
        output_enc = self.encoder(src, mask=self.src_mask)
        output_dec = self.decoder(output_enc)
        return F.log_softmax(output_dec, dim=-1), output_enc

In [76]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


Please do not change these parameters!

In [119]:
model = TransformerModel(ntoken = ntokens,
                         ninp = 128,
                         nhead = 16,
                         nhid = 64,
                         nlayers = 8)
model.to(device)



TransformerModel(
  (encoder): TransformerEncoder(
    (layers): ModuleList(
      (0-7): 8 x TransformerEncoderLayer(
        (self_attn): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=128, out_features=128, bias=True)
        )
        (linear1): Linear(in_features=128, out_features=64, bias=True)
        (dropout): Dropout(p=0.1, inplace=False)
        (linear2): Linear(in_features=64, out_features=128, bias=True)
        (norm1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
        (norm2): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
        (dropout1): Dropout(p=0.1, inplace=False)
        (dropout2): Dropout(p=0.1, inplace=False)
      )
    )
    (norm): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
  )
  (decoder): Linear(in_features=128, out_features=14, bias=True)
  (input_emb): Embedding(14, 128)
  (pos_encoder): TrainablePositionalEmbedding(
    (dropout): Dropout(p=0.5, inplace=False)
  )
)

In [120]:
def generate(model, prompts, new_tokens = 5):
    input_tensor = prompts # (length_prompts, batch_size)
    input_tensor = input_tensor.to(device)
    for _ in range(new_tokens):
        output, _ = model(input_tensor) # (length_prompts, batch_size, ntokens)
        last_output = output[-1,:,:] # (batch_size, ntokens)
        token = torch.argmax(last_output, -1).view((1,-1)) # (1, batch_size)

        #print(f"input_tensor shape: {input_tensor.shape}")  # Debug
        #print(f"token shape: {token.shape}")  # Debug

        input_tensor = torch.cat((input_tensor, token), 0)
    return input_tensor

In [108]:
model.eval()

prompt = "2+3="
prompt_tensor = torch.tensor(tokenizer.encode(prompt)).view((-1,1))
output = generate(model, prompt_tensor).view((1,-1))
output, tokenizer.decode(output.tolist()[0])

(tensor([[ 2, 10,  3, 11, 10, 13,  4,  4, 10]], device='cuda:0'),
 '2+3=+[EOS]44+')

In [121]:
def pad(token_list, type_list = "prompts"):
    max_length = max([len(x) for x in token_list])
    out = []
    for x in token_list:
        if type_list == "prompts":
            out.append([tokenizer.token_to_id[pad_token]] * (max_length - len(x)) + x)
        if type_list == "answers":
            out.append(x + [tokenizer.token_to_id[eos_token]] + [tokenizer.token_to_id[pad_token]] * (max_length - len(x)))
    return out, max_length

In [122]:
prompts = [tokenizer.encode("1+1="), tokenizer.encode("21+35=")]
answers = [tokenizer.encode("2"), tokenizer.encode("56")]
padded_prompts, _ = pad(prompts, "prompts")
padded_answers, _ = pad(answers, "answers")
padded_prompts, padded_answers
[tokenizer.decode(p) for p in padded_prompts], [tokenizer.decode(p) for p in padded_answers]

(['[PAD][PAD][PAD][PAD]1+1=', '1+5=2+3='], ['2[EOS][PAD]', '56[EOS]'])

In [123]:
def get_batch(split, i):
    data = data_train if split == 'train' else data_test
    prompts = [tokenizer.encode(data[i][0]) for i in range(i, i + batch_size)]
    padded_prompts, length_prompts = pad(prompts, "prompts")
    answers = [tokenizer.encode(data[i][1]) for i in range(i, i + batch_size)]
    padded_answers, length_answers = pad(answers, "answers")
    X = torch.stack([torch.tensor(x) for x in padded_prompts], 1)
    Y = torch.stack([torch.tensor(x) for x in padded_answers], 1)
    return X, Y, length_prompts, length_answers

In [124]:
X, Y, length_prompts, length_answers = get_batch("train", 243)
X.shape, Y.shape, length_prompts, length_answers

(torch.Size([12, 64]), torch.Size([5, 64]), 12, 4)

## Step 4: Evaluate

In [125]:
def evaluate():
    # Turn on evaluation mode disables dropout.
    model.eval()
    correct = 0.
    with torch.no_grad():
        for batch, i in enumerate(range(0, len(data_test) - 1, batch_size)):
            prompts, target_answers, length_prompts, length_answers = get_batch("test", i)
            prompts = prompts.to(device) # (length_prompts, batch_size)
            target_answers = target_answers.to(device) # (length_answers + 1, batch_size)
            output = generate(model, prompts, length_answers + 1) # (length_prompts + length_answers + 1, batch_size)
            answers_tokens = output[length_prompts:, :] # (length_answers + 1, batch_size), contains tokens
            equality_test = answers_tokens == target_answers # (length_answers + 1, batch_size), contains boolean values
            correct += torch.all(equality_test, axis=0).float().sum()
        accuracy = correct / len(data_test)
    return accuracy.item()

In [126]:
evaluate()

0.0

## Step 4: Train the model

In [127]:
def train_epoch():
    model.train()
    optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
    total_loss = 0.
    start_time = time.time()
    for batch, i in enumerate(range(0, len(data_train) - 1, batch_size)):
        prompts, target_answers, length_prompts, length_answers = get_batch("train", i)
        prompts = prompts.to(device) # (length_prompts, batch_size)
        target_answers = target_answers.to(device) # (length_answers, batch_size)
        input_tensor = torch.cat((prompts, target_answers), 0) # (length_prompts + length_answers, batch_size)
        model.zero_grad()
        output, _ = model(input_tensor) # (length_prompts + length_answers, batch_size, ntokens)
        output_answers = output[length_prompts-1:-1,:,:].reshape(-1, ntokens) # (length_answers * batch_size, ntokens)
        target_answers = target_answers.view(-1)
        loss = F.cross_entropy(output_answers, target_answers)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

        if batch % log_interval == 0 and batch > 0:
            cur_loss = total_loss / log_interval
            elapsed = time.time() - start_time
            print('| {:5d}/{:5d} batches | ms/batch {:5.2f} | loss {:5.2f} | perplexity {:8.2f}'.format(batch, len(data_train) // batch_size,
                                                                                                        elapsed * 1000 / log_interval, cur_loss, math.exp(cur_loss)))
            total_loss = 0
            start_time = time.time()

def train():
    best_test_accuracy = None
    test_accuracy = evaluate()
    print('-' * 89)
    print('| initialisation | test accuracy {:5.2f}'.format(test_accuracy))
    print('-' * 89)
    for epoch in range(1, epochs+1):
        epoch_start_time = time.time()
        train_epoch()
        test_accuracy = evaluate()
        print('-' * 89)
        print('| end of epoch {:3d} | time: {:5.2f}s | test accuracy {:5.2f}'.format(epoch, (time.time() - epoch_start_time), test_accuracy))
        print('-' * 89)
        # Save the model if the test accuracy is the best we've seen so far.
        if not best_test_accuracy or test_accuracy < best_test_accuracy:
            with open("arithmetic.pt", 'wb') as f:
                torch.save(model, f)
            best_test_accuracy = test_accuracy

In [128]:
train()

-----------------------------------------------------------------------------------------
| initialisation | test accuracy  0.00
-----------------------------------------------------------------------------------------
|   200/  900 batches | ms/batch 21.01 | loss  1.86 | perplexity     6.41
|   400/  900 batches | ms/batch 25.06 | loss  1.40 | perplexity     4.07
|   600/  900 batches | ms/batch 20.89 | loss  1.08 | perplexity     2.93
|   800/  900 batches | ms/batch 20.75 | loss  0.89 | perplexity     2.44
-----------------------------------------------------------------------------------------
| end of epoch   1 | time: 22.92s | test accuracy  0.07
-----------------------------------------------------------------------------------------
|   200/  900 batches | ms/batch 20.86 | loss  0.78 | perplexity     2.19
|   400/  900 batches | ms/batch 22.04 | loss  0.71 | perplexity     2.04
|   600/  900 batches | ms/batch 24.03 | loss  0.62 | perplexity     1.86
|   800/  900 batches | ms/

In [129]:
model.eval()

for i in range(20):
    prompt, answers = data_test[i]
    prompt_tensor = torch.tensor(tokenizer.encode(prompt)).view((-1,1))
    output = generate(model, prompt_tensor, len(answers)).view((1,-1))
    print(tokenizer.decode(output.tolist()[0]) + "\t actual result: " + answers)

0+9=8+2=3+9=1309	 actual result: 1309
1+3=8+4=8+4=1324	 actual result: 1324
3+0=0+8=2+4=683	 actual result: 683
9+3=6+2=0+7=792	 actual result: 792
1+5=2+0=5+0=526	 actual result: 526
5+7=0+2=9+1=1032	 actual result: 1032
0+1=4+8=2+1=421	 actual result: 421
7+9=9+5=7+2=1056	 actual result: 1056
3+6=8+4=5+6=1229	 actual result: 1229
6+3=6+9=9+8=1859	 actual result: 1859
4+8=9+1=1+7=912	 actual result: 912
8+5=2+8=3+3=713	 actual result: 713
4+2=3+0=9+7=1636	 actual result: 1636
9+1=7+8=2+7=1060	 actual result: 1060
3+1=8+4=1+5=724	 actual result: 724
3+4=4+8=8+1=1027	 actual result: 1027
2+7=5+2=3+8=1179	 actual result: 1179
6+8=4+6=4+1=614	 actual result: 614
0+9=1+8=7+0=799	 actual result: 799
0+0=0+3=8+1=930	 actual result: 930


## Probing

This is just for fun...

In [130]:
import numpy as np

train_size = 1000
test_size = 100

model.eval()

def data_probing(size):
    X = []
    y = np.zeros(size)
    for i in range(size):
        input = torch.tensor(tokenizer.encode(data[i][0])).view((-1, 1)).to(device)
        _, output = model(input)
        output = output[-1,:,:].flatten()
        # determine whether there was a carry in the result:
        carry = len(data[i][1]) > len(data[i][0]) / 2
        X.append(output.cpu().detach().numpy())
        y[i] = carry
    return np.array(X), y

In [131]:
from sklearn.linear_model import LogisticRegression
from sklearn.preprocessing import StandardScaler

X_train, y_train = data_probing(train_size)
X_test, y_test = data_probing(test_size)

scaler = StandardScaler()
X_train = scaler.fit_transform(X_train)
X_test = scaler.fit_transform(X_test)

reg = LogisticRegression()
reg.fit(X_train,y_train)
reg.score(X_test, y_test)

1.0