# DS 542 Fall 2025 Project 3

Your task for this project is to train an attention-based decoder-only model for math expressions with positive integers, addition, and parenthesis.
A sample model is provided and demonstrated on small problems with single digit integer inputs.
Your goal is to scale up this model to handle two digit inputs and longer expressions.

## Problem Setup

In [1]:
import math
import random

import torch

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

In [3]:
characters = "()+0123456789="
TOKENS = ["<bos>", "<eos>", "<pad>"] + [c for c in characters]
print(TOKENS)

['<bos>', '<eos>', '<pad>', '(', ')', '+', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '=']


In [4]:
TOKEN_MAP = dict((t, i) for i, t in enumerate(TOKENS))
print(TOKEN_MAP)

{'<bos>': 0, '<eos>': 1, '<pad>': 2, '(': 3, ')': 4, '+': 5, '0': 6, '1': 7, '2': 8, '3': 9, '4': 10, '5': 11, '6': 12, '7': 13, '8': 14, '9': 15, '=': 16}


In [5]:
BOS = TOKEN_MAP["<bos>"]
EOS = TOKEN_MAP["<eos>"]
PAD = TOKEN_MAP["<pad>"]

In [6]:
def decode(token_ids):
    return "".join(TOKENS[i] for i in token_ids)

decode([0, 3, 7, 5, 6, 4, 1, 2])

'<bos>(1+0)<eos><pad>'

In [25]:
def encode(s, *, eos=True):
    if s.startswith("<bos>"):
        s = s[5:]

    output = [BOS]
    output.extend(TOKEN_MAP[c] for c in s)

    if eos:
        output.append(EOS)

    return torch.tensor(output, device=device)

decode(encode("1+2=3"))

'<bos>1+2=3<eos>'

### Problem Generation

This function `generate_instance` will generate a random expression starting from `n` random integers between `value_min` and `value_max` (inclusive) and combining them with addition in a random order.
The full expression consists of multiple rounds of reductions of the innermost parentheses replacing the parenthesized addition with its integer value.
The final value after the last equals sign is the value of the original expression before the first equals sign.

Here are some example expressions.

* `(3+4)+(9+2)=(7+11)=18`
* `(((((1+2)+3)+4)+5)+6)=((((3+3)+4)+5)+6)=(((6+4)+5)+6)=((10+5)+6)=(15+6)=21`

To be clear, each reduction step should replace all the parenthesis that only contain two numbers being added.


In [8]:
# DO NOT CHANGE

def generate_instance(n, *, value_min=1, value_max=9):
    current_numbers = [random.randint(value_min, value_max) for _ in range(n)]
    current_expressions = [[str(v) for v in current_numbers]]
    current_fresh = [True for _ in current_numbers]

    while len(current_numbers) > 1:
        next_numbers = []
        next_expressions = [[] for _ in range(len(current_expressions) + 1)]
        next_fresh = []

        i = 0
        while i < len(current_numbers):
            can_merge = (i + 1 < len(current_numbers)) and (current_fresh[i] or current_fresh[i + 1])
            if can_merge and random.random() < 0.5:
                # decided to merge
                next_numbers.append(current_numbers[i] + current_numbers[i + 1])

                next_expressions[0].append(str(next_numbers[-1]))
                for j in range(len(current_expressions)):
                    next_expressions[j + 1].append(f"({current_expressions[j][i]}+{current_expressions[j][i + 1]})")

                next_fresh.append(True)
                i += 2
            else:
                # decided not to merge
                next_numbers.append(current_numbers[i])

                next_expressions[0].append(str(next_numbers[-1]))
                for j in range(len(current_expressions)):
                    next_expressions[j + 1].append(current_expressions[j][i])

                next_fresh.append(False)
                i += 1

        if len(next_numbers) < len(current_numbers):
            current_numbers = next_numbers
            current_expressions = next_expressions
            current_fresh = next_fresh

    output = '='.join(e[0] for e in reversed(current_expressions))
    return encode(output)

decode(generate_instance(3))

'<bos>(1+(6+6))=(1+12)=13<eos>'

In [9]:
for i in range(10):
    print(decode(generate_instance(5)))

<bos>(((1+(2+5))+9)+1)=(((1+7)+9)+1)=((8+9)+1)=(17+1)=18<eos>
<bos>(((4+(2+6))+8)+3)=(((4+8)+8)+3)=((12+8)+3)=(20+3)=23<eos>
<bos>(4+(3+(3+(1+5))))=(4+(3+(3+6)))=(4+(3+9))=(4+12)=16<eos>
<bos>(9+((8+(3+1))+5))=(9+((8+4)+5))=(9+(12+5))=(9+17)=26<eos>
<bos>((((9+4)+9)+1)+5)=(((13+9)+1)+5)=((22+1)+5)=(23+5)=28<eos>
<bos>(3+((6+6)+(1+7)))=(3+(12+8))=(3+20)=23<eos>
<bos>(4+(3+((2+3)+4)))=(4+(3+(5+4)))=(4+(3+9))=(4+12)=16<eos>
<bos>(((2+(8+1))+8)+5)=(((2+9)+8)+5)=((11+8)+5)=(19+5)=24<eos>
<bos>(4+((7+8)+(7+4)))=(4+(15+11))=(4+26)=30<eos>
<bos>(9+(7+(6+(1+2))))=(9+(7+(6+3)))=(9+(7+9))=(9+16)=25<eos>


## Implement a model that generalizes to more numbers and larger numbers


The sample code that follows is based on this ChatGPT session.

https://chatgpt.com/share/69036c83-171c-800c-9216-0884476017c6

In [10]:
def make_batch(*args, batch_size=64, **kwargs):
    seqs = [generate_instance(*args, **kwargs) for _ in range(batch_size)]

    # pad to max length on right
    batch = torch.nn.utils.rnn.pad_sequence(seqs, batch_first=True, padding_value=PAD)

    # next token targets: inputs are all but last; targets are all but first
    x = batch[:, :-1]
    y = batch[:, 1:]

    return x.to(device), y.to(device)

make_batch(5)

(tensor([[ 0,  3,  3,  ...,  2,  2,  2],
         [ 0,  3,  3,  ..., 16,  8,  9],
         [ 0,  3,  3,  ...,  2,  2,  2],
         ...,
         [ 0,  3,  3,  ...,  2,  2,  2],
         [ 0,  3,  3,  ...,  2,  2,  2],
         [ 0,  3, 11,  ..., 16,  8,  6]], device='cuda:0'),
 tensor([[ 3,  3,  3,  ...,  2,  2,  2],
         [ 3,  3,  3,  ...,  8,  9,  1],
         [ 3,  3,  3,  ...,  2,  2,  2],
         ...,
         [ 3,  3,  3,  ...,  2,  2,  2],
         [ 3,  3,  3,  ...,  2,  2,  2],
         [ 3, 11,  5,  ...,  8,  6,  1]], device='cuda:0'))

In [11]:
def causal_mask(T):
    # shape (T, T); True = mask (disallow), False = keep
    # nn.Transformer expects float mask or bool depending on API;
    # TransformerEncoder uses src_mask where non-zero entries are masked.
    # We'll use a float mask with -inf above diagonal.
    m = torch.full((T, T), float("-inf"), device=device)
    m = torch.triu(m, diagonal=1)  # upper triangle is masked
    return m

In [12]:
# YOUR CHANGES HERE

class MathTransformer(torch.nn.Module):
    def __init__(self, d_model=128, nhead=4, num_layers=4, dim_ff=256, max_len=64, dropout=0.1):
        super().__init__()
        self.d_model = d_model

        vocab_size = len(TOKENS)

        # token + position embeddings
        self.tok_emb = torch.nn.Embedding(vocab_size, d_model, padding_idx=PAD)
        self.pos_emb = torch.nn.Embedding(max_len, d_model)

        layer = torch.nn.TransformerEncoderLayer(
            d_model=d_model, nhead=nhead, dim_feedforward=dim_ff,
            dropout=dropout, batch_first=True,
        )
        self.blocks = torch.nn.TransformerEncoder(layer, num_layers=num_layers)
        self.lm_head = torch.nn.Linear(d_model, vocab_size)

        # init
        torch.nn.init.normal_(self.tok_emb.weight, mean=0.0, std=0.02)
        torch.nn.init.normal_(self.pos_emb.weight, mean=0.0, std=0.02)
        torch.nn.init.normal_(self.lm_head.weight, mean=0.0, std=0.02)
        torch.nn.init.zeros_(self.lm_head.bias)

    def forward(self, x):
        # x: (N, T)
        N, T = x.shape
        pos = torch.arange(T, device=x.device).unsqueeze(0)  # (1, T)
        h = self.tok_emb(x) * math.sqrt(self.d_model) + self.pos_emb(pos)  # (N, T, d_model)

        # key padding mask: True where we want to ignore (PAD)
        key_padding_mask = (x == PAD)  # (N, T) bool

        # causal mask for self-attention (float, -inf above diagonal)
        attn_mask = causal_mask(T) # (T, T)

        h = self.blocks(
            h,
            mask=attn_mask,                         # causal
            src_key_padding_mask=key_padding_mask   # pad masking
        )
        logits = self.lm_head(h)  # (N, T, vocab)
        return logits

    @torch.no_grad()
    def generate(self, prefix_ids, max_new_tokens=8):
        self.eval()
        x = prefix_ids.clone().to(next(self.parameters()).device)  # (N, T0)
        for _ in range(max_new_tokens):
            if x.size(1) >= 64:
                break
            logits = self.forward(x)[:, -1, :]   # (N, V)
            next_id = torch.argmax(logits, dim=-1, keepdim=True)  # greedy
            x = torch.cat([x, next_id], dim=1)
            if (next_id == EOS).all():
                break
        return x

test_model = MathTransformer(d_model=8, nhead=2, num_layers=2, dim_ff=2, max_len=64, dropout=0.1)

In [13]:
model = MathTransformer().to(device)
criterion = torch.nn.CrossEntropyLoss(ignore_index=PAD)
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4)

model.train()
steps = 800
for step in range(1, steps+1):
    x, y = make_batch(3, batch_size=1024)  # x,y: (N, T)
    logits = model(x)                  # (N, T, V)
    loss = criterion(logits.reshape(-1, len(TOKENS)), y.reshape(-1))
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()
    if step % 100 == 0:
        print(f"step {step:4d} | loss {loss.item():.4f}")






step  100 | loss 0.8786
step  200 | loss 0.6949
step  300 | loss 0.5760
step  400 | loss 0.5046
step  500 | loss 0.4616
step  600 | loss 0.4380
step  700 | loss 0.4215
step  800 | loss 0.4156


In [17]:
def prepare_prompt(s):
    token_ids = encode(s)
    if '=' in s:
        token_ids = token_ids[:s.index('=')+2]
        assert token_ids[-1] == TOKEN_MAP['=']

    return torch.tensor([token_ids], dtype=torch.long, device=device)

In [15]:
decode(generate_instance(3))

'<bos>(2+(4+2))=(2+6)=8<eos>'

In [34]:
def test_example(*args, verbose=True, **kwargs):
    model.eval()

    target_token_ids = generate_instance(*args, **kwargs)
    target = decode(target_token_ids)

    prompt = target[:target.index('=')+1]
    prompt_token_ids = encode(prompt, eos=False)
    prompt_batch = prompt_token_ids.reshape(shape=(1,-1))

    actual_token_ids = model.generate(prompt_batch, max_new_tokens=25)[0]
    actual = decode(actual_token_ids)

    correct = actual == target

    if verbose or not correct:
        print("PROMPT", decode(prompt_token_ids), "TARGET", target, "ACTUAL", actual, "CORRECT", correct)

    return correct

test_example(n=3)

PROMPT <bos>(8+(1+4))= TARGET <bos>(8+(1+4))=(8+5)=13<eos> ACTUAL <bos>(8+(1+4))=(8+5)=13<eos> CORRECT True


True

In [39]:
for _ in range(10):
    test_example(n=3, verbose=True)



PROMPT <bos>((3+3)+2)= TARGET <bos>((3+3)+2)=(6+2)=8<eos> ACTUAL <bos>((3+3)+2)=(6+2)=8<eos> CORRECT True
PROMPT <bos>(3+(6+6))= TARGET <bos>(3+(6+6))=(3+12)=15<eos> ACTUAL <bos>(3+(6+6))=(3+12)=15<eos> CORRECT True
PROMPT <bos>(7+(3+2))= TARGET <bos>(7+(3+2))=(7+5)=12<eos> ACTUAL <bos>(7+(3+2))=(7+5)=12<eos> CORRECT True
PROMPT <bos>((8+5)+8)= TARGET <bos>((8+5)+8)=(13+8)=21<eos> ACTUAL <bos>((8+5)+8)=(13+8)=21<eos> CORRECT True
PROMPT <bos>(7+(4+6))= TARGET <bos>(7+(4+6))=(7+10)=17<eos> ACTUAL <bos>(7+(4+6))=(7+10)=17<eos> CORRECT True
PROMPT <bos>((3+3)+7)= TARGET <bos>((3+3)+7)=(6+7)=13<eos> ACTUAL <bos>((3+3)+7)=(6+7)=13<eos> CORRECT True
PROMPT <bos>((4+4)+8)= TARGET <bos>((4+4)+8)=(8+8)=16<eos> ACTUAL <bos>((4+4)+8)=(8+8)=16<eos> CORRECT True
PROMPT <bos>((5+3)+9)= TARGET <bos>((5+3)+9)=(8+9)=17<eos> ACTUAL <bos>((5+3)+9)=(8+9)=17<eos> CORRECT True
PROMPT <bos>((6+9)+2)= TARGET <bos>((6+9)+2)=(15+2)=17<eos> ACTUAL <bos>((6+9)+2)=(15+2)=17<eos> CORRECT True
PROMPT <bos>((4+9)+4)=

In [38]:
for _ in range(10):
    test_example(n=4, verbose=False)

PROMPT <bos>((8+9)+(4+1))= TARGET <bos>((8+9)+(4+1))=(17+5)=22<eos> ACTUAL <bos>((8+9)+(4+1))=(1919<eos> CORRECT False
PROMPT <bos>(((9+2)+1)+9)= TARGET <bos>(((9+2)+1)+9)=((11+1)+9)=(12+9)=21<eos> ACTUAL <bos>(((9+2)+1)+9)=(1<eos> CORRECT False
PROMPT <bos>((8+7)+(9+7))= TARGET <bos>((8+7)+(9+7))=(15+16)=31<eos> ACTUAL <bos>((8+7)+(9+7))=2)=23<eos> CORRECT False
PROMPT <bos>(((8+3)+7)+7)= TARGET <bos>(((8+3)+7)+7)=((11+7)+7)=(18+7)=25<eos> ACTUAL <bos>(((8+3)+7)+7)=(19+9<eos> CORRECT False
PROMPT <bos>(((1+2)+1)+8)= TARGET <bos>(((1+2)+1)+8)=((3+1)+8)=(4+8)=12<eos> ACTUAL <bos>(((1+2)+1)+8)=)=8<eos> CORRECT False
PROMPT <bos>((4+2)+(9+5))= TARGET <bos>((4+2)+(9+5))=(6+14)=20<eos> ACTUAL <bos>((4+2)+(9+5))=(8<eos> CORRECT False
PROMPT <bos>((8+8)+(2+5))= TARGET <bos>((8+8)+(2+5))=(16+7)=23<eos> ACTUAL <bos>((8+8)+(2+5))=(19+9<eos> CORRECT False
PROMPT <bos>(8+(5+(3+6)))= TARGET <bos>(8+(5+(3+6)))=(8+(5+9))=(8+14)=22<eos> ACTUAL <bos>(8+(5+(3+6)))=(915<eos> CORRECT False
PROMPT <bos>(((

### Benchmark your model

Test your code with different numbers of integers and numbers of input digits.
The `generate_instance` function provided uses the parameter `n` to control the number of integers, and `value_min` and `value_max` to control the range of integers.
For example, 2 input digits would correspond to `value_min=10` and `value_max=99`.

Test the accuracy on the combinations specified in the table below, and fill in your accuracy numbers in that table.
Make sure that you run enough samples for statistical significance (usually at least 1000 recommended) as your benchmarking accuracy will be checked for consistency with tests by the auto-grader.

In [None]:
# YOUR CHANGES HERE

Fill in this table.

| n | input digits | accuracy |
|---|---|-----|
| 2 | 1 | TODO |
| 2 | 2 | TODO |
| 2 | 3 | TODO |
| 3 | 1 | TODO |
| 3 | 2 | TODO |
| 3 | 3 | TODO |
| 4 | 1 | TODO |
| 4 | 2 | TODO |
| 5 | 1 | TODO |
| 5 | 2 | TODO |

Do not change the table header as the auto-grader will use it to check your results.


## Save model and implement a command line interface.

Your model will be tested automatically with a suite of examples with different numbers of values and digits matching your previous benchmark task.
For this testing, you must save your model weights and write a program to run your model.

### Save your model weights.

Save your model weights as `math.pt` to be submitted in Gradescope.

In [None]:
# YOUR CHANGES HERE

...

### Write a program to run your model.

Write a Python script `predict.py` that takes a single filename as input, reads each line as a prompt, generates the completion, and writes out the result to standard output.
We will invoke your program with a command like `python3 predict.py INPUT.txt` and capture the standard output for grading.

The input file will not include the special tokens such as `<bos>` or `<eos>`.
Similarly, your output should not include them either.

For example, given an input file with the following contents,
```
(((1+2)+1)+8)=
```
your program should write the following output.
```
(((1+2)+1)+8)=((3+1)+8)=(4+8)=12
```


## Final Submission

Submit your copy of this notebook with all your code, your saved model "math.pt", and your prediction script "predict.py" to Gradescope.
