# Teaching NanoGPT to Do Math

### Team Members and Contributions
- Jiayang
    - Initialising the training environment
    - Script for training data generation
    - Training the model
    - Validating the model outputs
    - Discussion and Analysis
- Rochelle
- Viona

### ⚠️ **README**

The following is a summary of the project, including methodology, rationale, and key considerations for reproducing the experiments.

---

#### **Task 1**

- The dataset was generated using custom scripts. These scripts, along with accompanying documentation, are located in `src/utils`. Further explanation is provided in the notebook (`.ipynb`).
- Large language models (LLMs) and their smaller variants, such as NanoGPT, do not inherently understand logical or mathematical reasoning; they primarily learn **statistical patterns in tokens**. Therefore, they are simply predicting the most probable token based on the context of all previous tokens.
- Consequently, we constrained the scope of arithmetic problems to a **limited domain**, covering a wide range of representative cases while avoiding exhaustive permutations. This ensures that the model **learns underlying patterns rather than memorising individual examples**, promoting genuine generalisation.
- In designing the dataset, we considered the **distribution of problem types and difficulty**. For example, we weighted smaller numbers more heavily to approximate natural frequencies in arithmetic tasks and ensured balanced coverage across operations (addition, subtraction, multiplication, division). This careful sampling mitigates the risk of bias and promotes robust model performance across the domain.

---

#### **Task 2**

*(To be completed)*


---

#### **Task 3**

*(To be completed)*


---

#### **Extra Training and Generalisation**

* Based on the results of initial experiments, we conducted additional training to enable the model to generalise to a broader domain of arithmetic problems.
* As expected, extending the domain introduced some **catastrophic forgetting**, wherein the model occasionally failed cases it had previously solved correctly.
* To mitigate this, we carefully designed the training data to **interleave previously seen problem types** with new examples. This strategy reduces forgetting while improving performance on a wider range of problems.
* Overall, the approach reflects a trade-off between **domain expansion** and **retention of previously learned patterns**, a well-known challenge in sequential model training.




### Step 1: Install necesscary packages

> ⚠️ **We modified step 1 to suit our needs, because we wanted to train on our GPU**

In [None]:
!pip install matplotlib
!pip install numpy transformers datasets tiktoken wandb tqdm
!pip install ipywidgets

In [None]:
!pip install torch --index-url https://download.pytorch.org/whl/cu128

> ⚠️ **need to download the right version of `torch` if want to use GPU. Kindly check your own GPU compatibility**

### Step 2: Package imports and configuration

In [1]:
import torch

print(torch.cuda.is_available())  # True if CUDA is available
print(torch.cuda.device_count())  # Number of available GPUs
print(torch.cuda.get_device_name(0))  # GPU name

True
1
NVIDIA GeForce GTX 1650


In [2]:
import sys
import os

sys.path.append(os.path.abspath(".."))
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
import torch
import torch.nn as nn
import torch.nn.functional as F
import random
import pickle
from model import GPT, GPTConfig
import random
from tqdm import tqdm
import time
import json
import matplotlib.pyplot as plt

# Configuration
beta = 0.5
device = "cuda" if torch.cuda.is_available() else "cpu"
base_lr = 1e-4
epochs = 5
batch_size = 64
max_length = 64
num_samples = 1
max_new_tokens = 200
temperature = 0.8
top_k = 200
# tokenizer
with open("../sft/meta.pkl", "rb") as f:
    meta = pickle.load(f)
stoi, itos = meta["stoi"], meta["itos"]


def encode(s):
    return [stoi.get(c, 0) for c in s]  # 0 = <unk> for unknown characters


def decode(l):
    return "".join([itos[i] for i in l])

### Step 3: Define helper functions

In [3]:
def compute_logprob(input_ids):
    inputs = input_ids[:, :-1]
    targets = input_ids[:, 1:]
    logits, _ = gpt(inputs, full_seq=True)
    B, T, V = logits.size()
    logits_flat = logits.reshape(-1, V)
    targets_flat = targets.reshape(-1)
    loss = F.cross_entropy(logits_flat, targets_flat, ignore_index=0, reduction="none")
    loss = loss.reshape(B, T)
    attention_mask = (targets != 0).float()
    loss = (loss * attention_mask).sum(dim=1) / attention_mask.sum(dim=1)
    return -loss


def pad_or_truncate(seq, max_length):
    return (
        seq[-max_length:]
        if len(seq) > max_length
        else seq + [0] * (max_length - len(seq))
    )


def get_batches(lines, batch_size):
    random.shuffle(lines)
    # for l in lines:
    #    print(l[1])
    for i in range(0, len(lines), batch_size):
        batch = lines[i : i + batch_size]
        if len(batch) < batch_size:
            continue
        neg_inputs = [
            pad_or_truncate(encode(p["negative"] + "\n\n\n\n"), max_length)
            for p in batch
        ]
        pos_inputs = [
            pad_or_truncate(encode(p["positive"] + "\n\n\n\n"), max_length)
            for p in batch
        ]
        neg_tensor = torch.tensor(neg_inputs, dtype=torch.long, device=device)
        pos_tensor = torch.tensor(pos_inputs, dtype=torch.long, device=device)
        yield neg_tensor, pos_tensor

### Step 4: Load the pretrained NanoGPT model
#### Loading Process
1. Load checkpoint file
2. Initialize model with saved config
3. Load pretrained weights
4. Move to GPU

The model can answer general questions but doesn't know math yet.

In [5]:
print(torch.__version__)
print(torch.version.cuda)
print(torch.backends.cudnn.version())

2.8.0+cu128
12.8
91002


In [4]:
ckpt = torch.load("../sft/gpt.pt", map_location=device)
gptconf = GPTConfig(**ckpt["model_args"])
gpt = GPT(gptconf)
state_dict = ckpt["model"]
unwanted_prefix = "_orig_mod."
for k in list(state_dict.keys()):
    if k.startswith(unwanted_prefix):
        state_dict[k[len(unwanted_prefix) :]] = state_dict.pop(k)
gpt.load_state_dict(state_dict)
gpt.to(device).train()

GPT(
  (transformer): ModuleDict(
    (wte): Embedding(74, 348)
    (wpe): Embedding(256, 348)
    (drop): Dropout(p=0.2, inplace=False)
    (h): ModuleList(
      (0-5): 6 x Block(
        (ln_1): LayerNorm()
        (attn): CausalSelfAttention(
          (c_attn): Linear(in_features=348, out_features=1044, bias=False)
          (c_proj): Linear(in_features=348, out_features=348, bias=False)
          (attn_dropout): Dropout(p=0.2, inplace=False)
          (resid_dropout): Dropout(p=0.2, inplace=False)
        )
        (ln_2): LayerNorm()
        (mlp): MLP(
          (c_fc): Linear(in_features=348, out_features=1392, bias=False)
          (gelu): GELU(approximate='none')
          (c_proj): Linear(in_features=1392, out_features=348, bias=False)
          (dropout): Dropout(p=0.2, inplace=False)
        )
      )
    )
    (ln_f): LayerNorm()
  )
  (lm_head): Linear(in_features=348, out_features=74, bias=False)
)

In [5]:
print("Device:", device)
print("Model is on CUDA:", next(gpt.parameters()).is_cuda)

Device: cuda
Model is on CUDA: True


### Step 5: Load Data (**students are required to complete this part!**) (Task 1)

The data is generated using script written by ourselves, found in `utils/`. The considerations are highlighted in the script itself.

🌟 In this section, we will describe the rationale in our trainig data generation.

#### *First Attempt*
---
Our first approach in populating the training data is simple. We simply distribute the various operations, arithmetic and algebraic problems evenly; choosing from a range of numbers between 1 to 999. The exact script is found in `utils/generate_training_data.py`. Kindly refer to the documentations in the script for more explanations. 

However, this led to very bad performance, as shown in the following table

| Problem     | Expected | Model Output                                   | Accuracy |
| ----------- | -------- | ---------------------------------------------- | -------- |
| 17+19=?     | 36       | "The answer is 58 because 17+19 equals 58."    | F        |
| 3*17=?      | 51       | "The answer is 11 because 3*17 equals 11."     | F        |
| 72/4=?      | 18       | "The answer is 12 because 72/4 equals 12."     | F        |
| 72-x=34,x=? | 38       | "The answer is 73 because 74-3 equals to 733." | F        |
| x*11=44,x=? | 4        | "The answer is 19 because 44/11 equals to 19." | F        |
| 3*17=?      | 51       | "The answer is 13 because 3*17 equals 13."     | F        |
| 72/4=?      | 18       | "The answer is 52 because 72/4 equals 52."     | F        |
| 72-x=34,x=? | 38       | "The answer is 80 because 74-3 equals to 80."  | F        |

There are a few plausible reasons as to why this is the case:

1. **Sparse numerical density:**
   The current dataset spans a wide range (1–999), resulting in large variation in operation outcomes (e.g., 999+999 vs. 999×999). Such disparity makes it difficult for the model to detect consistent numerical patterns.

2. **Uneven operator coverage:**
   Simply distributing operations evenly is insufficient. For each number pair, all four operators (+, −, ×, ÷) should be applied so the model can clearly learn the distinctions introduced by each operator.

3. **Variable position bias:**
   In the test cases, the variable (x) appears in varying positions, but the training data only followed one fixed pattern from the lab manual. To improve generalization, the model should also be trained on equations where (x) appears in different positions.

4. **Mathematical Properties:**
Perhaphs, the model would also perform better, if we made use of the fact that humans learn mathematics from axioms and fundamental mathematical operations. While it may be hard for the model to generalise the abstractions, we can nevertheless attempt to capture this by introducing some examples


> 📝 We decided to retrain the model because the model did not learn useful patterns.

---

#### *Second Attempt*


---

**Improvements Implemented**

1. **Focused numerical range**
   Reduced the number range to smaller values to improve learning consistency and ensure the model performs reliably within that range before generalizing to larger numbers.

2. **Identity property coverage**
   Added explicit training samples demonstrating the role of zero and one in addition, subtraction, multiplication, and division. These targeted examples help the model internalize fundamental mathematical identities (e.g., $a + 0 = a$, $a \times 1 = a$).

3. **Commutative and inverse property expansion**
   Expanded the dataset to include examples illustrating the commutative property for addition and multiplication, along with corresponding subtraction and division cases. This strengthens the model’s grasp of symmetry and operation reversibility.

4. **Stratified numerical sampling**
   Implemented stratified sampling across defined number ranges (1–10, 11–30, 31–70, 71–100) to ensure dense and balanced representation of small, medium, and large values. This minimizes gaps and prevents bias toward specific magnitudes.

5. **Refined prompt generation**
   Enhanced arithmetic and algebraic prompt generation to include both direct computation and “solve for $x$” forms. Each category includes both positive and negative examples, improving the model’s robustness across different reasoning contexts.


Loading of data is done using the `datasets` library. We refer to the necessary documentations on how to do so.

#### Documentation

- [hugging face datasets documentation](https://huggingface.co/docs/datasets/v4.1.1/loading)

In [8]:
# Load data from ./data/pos_neg_pairs.json

from datasets import load_dataset

dataset = load_dataset("json", data_files="./pos_neg_pairs.json").shuffle(seed=42)

print(dataset)
print(dataset["train"][0])

DatasetDict({
    train: Dataset({
        features: ['positive', 'negative'],
        num_rows: 102309
    })
})
{'positive': '83+x=98,x=? The answer is 15 because 98-83 equals 15.', 'negative': '83+x=98,x=? Sorry, I do not know!'}


### Step 6: Build the optimizer and scheduler (**students are required to complete this part!**) (Task 2)
#### Optimizer: AdamW
```python
optimizer = torch.optim.AdamW(gpt.parameters(), lr=1e-4, weight_decay=1e-2)
```

**Why AdamW?**
- Adapts learning rate automatically
- Works well with transformers
- Includes regularization to prevent overfitting

#### Scheduler: CosineAnnealingLR
```python
scheduler = CosineAnnealingLR(optimizer, T_max=iteration, eta_min=1e-5)
```

**What it does:**
- Starts with higher learning rate
- Gradually decreases in a smooth curve
- Helps model converge better

The learning rate drops from 1e-4 to 1e-5 over training.

- [AdamW otpimiser documentation](https://docs.pytorch.org/docs/stable/generated/torch.optim.AdamW.html#torch.optim.AdamW)
- [PyTorch optimiser documentation](https://docs.pytorch.org/docs/stable/optim.html#module-torch.optim)


The parameters of the pre-trained model is shown below:

In [9]:
for name, para in gpt.named_parameters():
    print(name, para.shape)

transformer.wte.weight torch.Size([74, 348])
transformer.wpe.weight torch.Size([256, 348])
transformer.h.0.ln_1.weight torch.Size([348])
transformer.h.0.attn.c_attn.weight torch.Size([1044, 348])
transformer.h.0.attn.c_proj.weight torch.Size([348, 348])
transformer.h.0.ln_2.weight torch.Size([348])
transformer.h.0.mlp.c_fc.weight torch.Size([1392, 348])
transformer.h.0.mlp.c_proj.weight torch.Size([348, 1392])
transformer.h.1.ln_1.weight torch.Size([348])
transformer.h.1.attn.c_attn.weight torch.Size([1044, 348])
transformer.h.1.attn.c_proj.weight torch.Size([348, 348])
transformer.h.1.ln_2.weight torch.Size([348])
transformer.h.1.mlp.c_fc.weight torch.Size([1392, 348])
transformer.h.1.mlp.c_proj.weight torch.Size([348, 1392])
transformer.h.2.ln_1.weight torch.Size([348])
transformer.h.2.attn.c_attn.weight torch.Size([1044, 348])
transformer.h.2.attn.c_proj.weight torch.Size([348, 348])
transformer.h.2.ln_2.weight torch.Size([348])
transformer.h.2.mlp.c_fc.weight torch.Size([1392, 348]

Construct the optimiser according to the official documentation:
- `lr` is kept at $1 \cdot 10^{-4}$
- `weight_decay` is kept at $10^{-2}$

The `AdamW` algorithm is chosen based on the instruction given in the assignment.

In [10]:
optimizer = torch.optim.AdamW(gpt.parameters(), lr=1e-4, weight_decay=1e-2)

Next, we initialise the scheduler. Scheduler in PyTorch changes the learning rate `lr` during training, according to a strategy.

We chose the Cosine Annealing Scheduler.

#### Documentation

- [CosineAnnealingLR](https://docs.pytorch.org/docs/stable/generated/torch.optim.lr_scheduler.CosineAnnealingLR.html)
- [fine tune Llam 2 with DPO](https://huggingface.co/blog/dpo-trl)

In [11]:
from torch.optim.lr_scheduler import CosineAnnealingLR

iteration = len(dataset["train"]) // batch_size * epochs
scheduler = CosineAnnealingLR(optimizer, T_max=iteration, eta_min=1e-5)

### Step 7: Begin training (**students are required to complete this part!**) (Task 2)

#### DPO Loss Function
```python
loss = -F.logsigmoid((pos_logprob - neg_logprob) / beta).mean() 
       - pos_logprob.mean() * 0.1
```

**What this does:**
1. Makes positive samples more likely
2. Makes negative samples less likely
3. Keeps outputs as fluent as possible

#### Training Process
1. Calculate probability for negative sample
2. Calculate probability for positive sample
3. Compute DPO loss
4. Update model weights
5. Adjust learning rate

#### Training Results

| Epoch | Loss | Time per Epoch |
|-------|------|----------------|
| 1 | 0.0209 | 8.5 min |
| 2 | 0.0181 | 8.5 min |
| 3 | 0.0168 | 8.5 min |
| 4 | 0.0165 | 8.5 min |
| 5 | 0.0157 | 8.5 min |

**Total improvement**: 24.9% loss reduction
- Calculation: (0.0209 - 0.0157) / 0.0209 = 0.249 = 24.9%
- Loss decreased from 0.0209 → 0.0157
- This is a strong result for DPO, which makes precise preference adjustments rather than dramatic changes

Loss decreases smoothly, showing the model is learning to prefer correct answers.

In [12]:
print(sys.executable)

C:\Users\user\University\3000\.venv\Scripts\python.exe


In [13]:
print(torch.__version__)
print(torch.version.cuda)
print(torch.cuda.device_count())

2.8.0+cu128
12.8
1


#### Training

The following is our training using the second version of the script.

In [14]:
from random import shuffle

lines = dataset["train"]
lines = [dict(x) for x in lines]
total_steps = len(lines) // batch_size
for epoch in range(epochs):
    shuffle(lines)
    pbar = tqdm(get_batches(lines, batch_size))
    for step, (neg_tensor, pos_tensor) in enumerate(pbar):
        ###########################################################
        # Please complete the training code here!
        # Examples:
        # ...
        # neg_logprob
        # pos_logprob
        # loss = -F.logsigmoid((pos_logprob - neg_logprob) / beta).mean() - pos_logprob.mean() * 0.1
        # ...
        ###########################################################

        optimizer.zero_grad()
        neg_logprob = compute_logprob(neg_tensor)
        pos_logprob = compute_logprob(pos_tensor)
        loss = (
            -F.logsigmoid((pos_logprob - neg_logprob) / beta).mean()
            - pos_logprob.mean() * 0.1
        )
        loss.backward()
        optimizer.step()
        scheduler.step()
        pbar.set_description(f"epoch {epoch+1}, step {step}, loss {loss.item():.4f}")

    ckpt_path = f"./dpo.pt"
    torch.save(
        {
            "model_state_dict": gpt.state_dict(),
            "model_args": ckpt["model_args"],
        },
        ckpt_path,
    )
    print(f"Saved checkpoint to {ckpt_path}")

epoch 1, step 1597, loss 0.0478: : 1598it [08:33,  3.11it/s]


Saved checkpoint to ./dpo.pt


epoch 2, step 1597, loss 0.0254: : 1598it [08:37,  3.09it/s]


Saved checkpoint to ./dpo.pt


epoch 3, step 1597, loss 0.0253: : 1598it [08:35,  3.10it/s]


Saved checkpoint to ./dpo.pt


epoch 4, step 1597, loss 0.0240: : 1598it [08:35,  3.10it/s]


Saved checkpoint to ./dpo.pt


epoch 5, step 1597, loss 0.0229: : 1598it [08:35,  3.10it/s]

Saved checkpoint to ./dpo.pt





### Step 8: Begin testing (**students are required to complete this part!**) (Task 2)

In [22]:
# Load the fine-tuned model
ckpt_path = "../dpo/dpo.pt"
checkpoint = torch.load(ckpt_path, map_location=device)
gptconf = GPTConfig(**checkpoint["model_args"])
gpt = GPT(gptconf).cuda()
try:
    state_dict = checkpoint["model"]
except:
    state_dict = checkpoint["model_state_dict"]
unwanted_prefix = "_orig_mod."
for k, v in list(state_dict.items()):
    if k.startswith(unwanted_prefix):
        state_dict[k[len(unwanted_prefix) :]] = state_dict.pop(k)
gpt.load_state_dict(state_dict)
# Test
gpt.eval()
test_set = [
    "17+19=?",
    "3*17=?",
    "72/4=?",
    "72-x=34,x=?",
    "x*11=44,x=?",
    "3*17=?",
    "72/4=?",
    "72-x=34,x=?",
]
with torch.no_grad():
    for prompt in test_set:
        prompt_ids = encode(prompt)
        ###########################################################
        # Please complete the test code here!
        # ...
        # gpt.generate(x, max_new_tokens, temperature=temperature, top_k=top_k)
        # ...
        ###########################################################
        input_ids = torch.tensor([prompt_ids], dtype=torch.long, device=device)
        output_ids = gpt.generate(
            input_ids,
            max_new_tokens=max_new_tokens,
            temperature=temperature,
            top_k=top_k,
        )
        output_text = decode(output_ids[0].flatten().tolist())
        print(f"Prompt: {prompt}")
        print(f"Model output: {output_text}")
        print("-" * 80)


# we keep the above to retain the original code snippet, and for documentation purposes.
# However, this is not good because, because the testing function should be modular.
# as such, we wrap the test function.


def test_model(test_set):
    with torch.no_grad():
        for prompt in test_set:
            prompt_ids = encode(prompt)
            ###########################################################
            # Please complete the test code here!
            # ...
            # gpt.generate(x, max_new_tokens, temperature=temperature, top_k=top_k)
            # ...
            ###########################################################
            input_ids = torch.tensor([prompt_ids], dtype=torch.long, device=device)
            output_ids = gpt.generate(
                input_ids,
                max_new_tokens=max_new_tokens,
                temperature=temperature,
                top_k=top_k,
            )
            output_text = decode(output_ids[0].flatten().tolist())
            print(f"Prompt: {prompt}")
            print(f"Model output: {output_text}")
            print("-" * 80)

Prompt: 17+19=?
Model output: 17+19=? The answer is 36 because 17+19 equals 36.
--------------------------------------------------------------------------------
Prompt: 3*17=?
Model output: 3*17=? The answer is 51 because 3*17 equals 51.
--------------------------------------------------------------------------------
Prompt: 72/4=?
Model output: 72/4=? The answer is 18 because 72/4 equals 18.
--------------------------------------------------------------------------------
Prompt: 72-x=34,x=?
Model output: 72-x=34,x=? The answer is 38 because 72-34 equals 38.
--------------------------------------------------------------------------------
Prompt: x*11=44,x=?
Model output: x*11=44,x=? The answer is 4 because 44/11 equals 4.
--------------------------------------------------------------------------------
Prompt: 3*17=?
Model output: 3*17=? The answer is 51 because 3*17 equals 51.
--------------------------------------------------------------------------------
Prompt: 72/4=?
Model output: 

### Results on the default test cases

| Problem | Expected | Model Output | accuracy |
|---------|----------|--------------|-----|
| 17+19=? | 36 | "The answer is 36 because 17+19 equals 36." | T |
| 3*17=? | 51 | "The answer is 51 because 3*17 equals 51." | T |
| 72/4=? | 18 | "The answer is 18 because 72/4 equals 18." | T |
| x*11=44,x=? | 4 | "The answer is 4 because 44/11 equals 4." | T |
| 3*17=? | 51 | "The answer is 51 because 3*17 equals 51." | T |
| 72/4=? | 18 | "The answer is 18 because 72/4 equals 18." | T |
| 72-x=34,x=? | 38 | "The answer is 38 because 72-34 equals 38." | T |

This is to be expected as our training data is generated to perform on numbers with 1 to 2 digits. Note that we can verify the model did not memorise the results, as we can search the training data `pos_neg_pairs.json` to see some of the problem do not appear in training.

However, to further validate the generalisation of the model, we shall test further test cases on our own.
Note that for each of our test case, we duplicate them to make sure the outputs are consistent.

#### Multiplication with 0

1. 12345*0=?
2. x*999=0,x=?
3. x*20=0,x=?
4. 20*x=0,x=?
5. x*20=0,x=?
6. 0*12345=?
7. 0*12=?
8. 12*0?

> 💡 While the training data did try to teach the model this property, it did not see numbers exceeding 2 digits. We want to see if it can generalise this, and not simply memorising the pattern.

In [7]:
test_set_2 = [
    "12345*0=?",
    "x*999=0,x=?",
    "x*20=0,x=?",
    "20*x=0,x=?",
    "0*12345=?",
    "0*12=?",
    "12*0=?",
    "12345*0=?",
    "x*999=0,x=?",
    "x*20=0,x=?",
    "20*x=0,x=?",
    "0*12345=?",
    "0*12=?",
    "12*0=?",
]


test_model(test_set_2)

Prompt: 12345*0=?
Model output: 12345*0=? The answer is 650 because 125*12 equals 60.
--------------------------------------------------------------------------------
Prompt: x*999=0,x=?
Model output: x*999=0,x=? The answer is 0 because 0/99 equals 09.
--------------------------------------------------------------------------------
Prompt: x*20=0,x=?
Model output: x*20=0,x=? The answer is 0 because 0/20 equals 0.
--------------------------------------------------------------------------------
Prompt: 20*x=0,x=?
Model output: 20*x=0,x=? The answer is 0 because 0/20 equals 0.
--------------------------------------------------------------------------------
Prompt: 0*12345=?
Model output: 0*12345=? The answer is 575 because 07*123 equals 55.
--------------------------------------------------------------------------------
Prompt: 0*12=?
Model output: 0*12=? The answer is 0 because 0*12 equals 0.
--------------------------------------------------------------------------------
Prompt: 12*0=?


#### **Results from Multiplication with 0**

| Problem     | Expected | Model Output                                  | Accuracy |
| ----------- | -------- | --------------------------------------------- | -------- |
| 12345*0=?   | 0        | "The answer is 50 because 125/14 equals 5."   | F        |
| x*999=0,x=? | 0        | "The answer is 1 because 90/99 equals 1."     | F        |
| x*20=0,x=?  | 0        | "The answer is 0 because 0/20 equals 0."      | T        |
| 20*x=0,x=?  | 0        | "The answer is 0 because 0/20 equals 0."      | T        |
| 0*12345=?   | 0        | "The answer is 27 because 04*12 equals 34."   | F        |
| 0*12=?      | 0        | "The answer is 00 because 0*12 equals 00."    | T        |
| 12*0=?      | 0        | "The answer is 0 because 12*0 equals 0."      | T        |
| 12345*0=?   | 0        | "The answer is 50 because 125*14 equals 505." | F        |
| x*999=0,x=? | 0        | "The answer is 1 because 90/99 equals 1."     | F        |
| x*20=0,x=?  | 0        | "The answer is 0 because 0/20 equals 0."      | T        |
| 20*x=0,x=?  | 0        | "The answer is 0 because 0/20 equals 0."      | T        |
| 0*12345=?   | 0        | "The answer is 27 because 04*12 equals 34."   | F        |
| 0*12=?      | 0        | "The answer is 00 because 0*12 equals 00."    | T        |
| 12*0=?      | 0        | "The answer is 0 because 12*0 equals 0."      | T        |

##### **Discussion**

We observe that the model is unable to generalise this seemingly trivial mathematical property to larger numbers. This suggests that the model has been trained too narrowly on numbers between 0 and 100 and cannot generalise to bigger values. In our training dataset, this property was taught by multiplying any number `i` in the range by 0.

However, a positive outcome is that the model is able to generalise the commutative property for two digit numbers. In our training, this was learned using a small range of numbers between 1 and 10.

##### **On Language Models**

In general, large language models, from smaller ones like NanoGPT to larger systems such as GPT-4 or GPT-5, are **language models rather than symbolic reasoning engines**. They learn statistical relationships between tokens (words, numbers, symbols) in text rather than the underlying mathematical structures or rules these symbols represent.

Therefore, the only way to improve generalisation would be to increase the size and variety of the training data. 

> 💡This is because this task expects a particular output format. If this were not the case, the training could include explicit statements such as “any number multiplied by 0 equals 0”. This would have made the results better

### Commutative Properties

To further prove our case, we will test out commutative properties on both numbers within range of our training and outside of our training.

In [9]:
test_set_3 = [
    "3*4=?",
    "4*3=?",
    "2*5=?",
    "5*2=?",
    "x*7=35,x=?",
    "7*x=35,x=?",
    "12*3=?",
    "3*12=?",
    "x*8=64,x=?",
    "8*x=64,x=?",
    "123*4=?",
    "4*123=?",
    "x*125=1000,x=?",
    "125*x=1000,x=?",
    "x*20=400,x=?",
    "20*x=400,x=?",
    "x*111=333,x=?",
    "111*x=333,x=?",
]

test_model(test_set_3)

Prompt: 3*4=?
Model output: 3*4=? The answer is 12 because 3*4 equals 12.
--------------------------------------------------------------------------------
Prompt: 4*3=?
Model output: 4*3=? The answer is 12 because 4*3 equals 12.
--------------------------------------------------------------------------------
Prompt: 2*5=?
Model output: 2*5=? The answer is 10 because 2*5 equals 10.
--------------------------------------------------------------------------------
Prompt: 5*2=?
Model output: 5*2=? The answer is 10 because 5*2 equals 10.
--------------------------------------------------------------------------------
Prompt: x*7=35,x=?
Model output: x*7=35,x=? The answer is 5 because 35/7 equals 5.
--------------------------------------------------------------------------------
Prompt: 7*x=35,x=?
Model output: 7*x=35,x=? The answer is 5 because 35/7 equals 5.
--------------------------------------------------------------------------------
Prompt: 12*3=?
Model output: 12*3=? The answer is 36

#### Results from Commutative Property Tests

| Problem        | Expected | Model Output                                    | Accuracy |
| -------------- | -------- | ----------------------------------------------- | -------- |
| 3*4=?          | 12       | "The answer is 12 because 3 times 4 equals 12." | T        |
| 4*3=?          | 12       | "The answer is 12 because 4 times 3 equals 12." | T        |
| 2*5=?          | 10       | "The answer is 10 because 2*5 equals 10."       | T        |
| 5*2=?          | 10       | "The answer is 10 because 5*2 equals 10."       | T        |
| x*7=35,x=?     | 5        | "The answer is 5 because 7*5 equals 35."        | T        |
| 7*x=35,x=?     | 5        | "The answer is 5 because 5*7 equals 35."        | T        |
| 12*3=?         | 36       | "The answer is 36 because 12*3 equals 36."      | T        |
| 3*12=?         | 36       | "The answer is 36 because 3*12 equals 36."      | T        |
| x*8=64,x=?     | 8        | "The answer is 8 because 8*8 equals 64."        | T        |
| 8*x=64,x=?     | 8        | "The answer is 8 because 8*8 equals 64."        | T        |
| 123*4=?        | 492      | "The answer is 482 because 123*4 equals 482."   | F        |
| 4*123=?        | 492      | "The answer is 48 because 4*12 equals 48."      | F        |
| x*125=1000,x=? | 8        | "The answer is 8 because 125*8 equals 1000."    | T        |
| 125*x=1000,x=? | 8        | "The answer is 8 because 8*125 equals 1000."    | T        |
| x*20=400,x=?   | 20       | "The answer is 20 because 20*20 equals 400."    | T        |
| 20*x=400,x=?   | 20       | "The answer is 20 because 20*20 equals 400."    | T        |
| x*111=333,x=?  | 3        | "The answer is 3 because 111*3 equals 333."     | T        |
| 111*x=333,x=?  | 3        | "The answer is 3 because 3*111 equals 333."     | T        |

---

##### **Discussion**

The model shows good result for small and familiar numbers, showing symmetrical and correct responses for both $a \cdot b$ and $b \cdot a$.
However, accuracy declines for larger numbers (e.g. $123 \cdot 4$), indicating pattern-based rather than rule-based reasoning. This is consistent with our earlier discussion


### **Blind Tests**

Finally, we perform some blind tests on familiar number range, as a sanity check to ensure the model is performing well.

In [37]:
test_set_4 = [
    "88+7=?",
    "x-18=21,x=?",
    "x/10=6,x=?",
    "54/1=?",
    "24+48=?",
    "11+23=?",
    "56*34=?",
    "67*23=?",
    "89*11=?",
    "90*12=?",
    "33*44=?",
    "x*12=156,x=?",
    "15*x=300,x=?",
    "x*25=500,x=?",
    "40*x=800,x=?",
    "x*45=2250,x=?",
    "84/12=?",
    "96/24=?",
    "72/18=?",
    "99/11=?",
    "88/22=?",
    "60/15=?",
    "48/16=?",
    "81/27=?",
    "64/16=?",
    "90/30=?",
    "x/12=7,x=?",
    "84/x=12,x=?",
    "x/15=4,x=?",
    "75/x=15,x=?",
    "x/25=9,x=?",
    "12+34=?",
    "23+45=?",
    "56+78=?",
    "11+22=?",
    "99+1=?",
    "45+54=?",
    "67+32=?",
    "89+10=?",
    "74+25=?",
    "33+66=?",
    "x+45=100,x=?",
    "23+x=80,x=?",
    "x+56=78,x=?",
    "x+90=150,x=?",
    "40+x=60,x=?",
    "45-23=?",
    "67-45=?",
    "90-12=?",
    "88-44=?",
    "99-11=?",
    "72-36=?",
    "84-48=?",
    "50-25=?",
    "63-37=?",
    "81-19=?",
    "x-25=25,x=?",
    "75-x=25,x=?",
    "x-50=30,x=?",
    "x-33=11,x=?",
    "90-x=10,x=?",
]

test_model(test_set_4)

Prompt: 88+7=?
Model output: 88+7=? The answer is 95 because 88+7 equals 95.
--------------------------------------------------------------------------------
Prompt: x-18=21,x=?
Model output: x-18=21,x=? The answer is 39 because 21+18 equals 39.
--------------------------------------------------------------------------------
Prompt: x/10=6,x=?
Model output: x/10=6,x=? The answer is 60 because 6*10 equals 60.
--------------------------------------------------------------------------------
Prompt: 54/1=?
Model output: 54/1=? The answer is 54 because 54/1 equals 54.
--------------------------------------------------------------------------------
Prompt: 24+48=?
Model output: 24+48=? The answer is 72 because 24+48 equals 72.
--------------------------------------------------------------------------------
Prompt: 11+23=?
Model output: 11+23=? The answer is 34 because 11+23 equals 34.
--------------------------------------------------------------------------------
Prompt: 56*34=?
Model outpu

#### Results from blind tests

| Problem       | Expected | Model Output                                  | Accuracy |
| ------------- | -------- | --------------------------------------------- | -------- |
| 88+7=?      | 95       | "The answer is 95 because 88+7 equals 95."  | T        |
| x-18=21,x=? | 39       | "The answer is 39 because 21+18 equals 39." | T        |
| x/10=6,x=?  | 60       | "The answer is 60 because 6*10 equals 60."  | T        |
| 54/1=?      | 54       | "The answer is 54 because 54/1 equals 54."  | T        |
| 24+48=?     | 72       | "The answer is 72 because 24+48 equals 72." | T        |
| 11+23=?     | 34       | "The answer is 34 because 11+23 equals 34." | T        |
| 64+13=?     | 77       | "The answer is 77 because 64+13 equals 77." | T        |
| 12*13=?       | 156      | "The answer is 156 because 12*13 equals 156." | T        |
| 14*15=?       | 210      | "The answer is 210 because 14*15 equals 210." | T        |
| 21*22=?       | 462      | "The answer is 442 because 21*22 equals 442." | F        |
| 23*32=?       | 736      | "The answer is 666 because 23*32 equals 666." | F        |
| 45*12=?       | 540      | "The answer is 500 because 45*12 equals 500." | F        |
| 56*34=?       | 1904     | "The answer is 142 because 56*34 equals 142." | F        |
| 67*23=?       | 1541     | "The answer is 122 because 67*23 equals 122." | F        |
| 89*11=?       | 979      | "The answer is 109 because 89*11 equals 109." | F        |
| 90*12=?       | 1080     | "The answer is 1 because 90*12 equals 1."     | F        |
| 33*44=?       | 1452     | "The answer is 111 because 33*44 equals 111." | F        |
| x*12=156,x=?  | 13       | "The answer is 13 because 156/12 equals 13."  | T        |
| 15*x=300,x=?  | 20       | "The answer is 20 because 300/15 equals 20."  | T        |
| x*25=500,x=?  | 20       | "The answer is 20 because 500/25 equals 20."  | T        |
| 40*x=800,x=?  | 20       | "The answer is 2 because 800/40 equals 2."    | F        |
| x*45=2250,x=? | 50       | "The answer is 5 because 225/45 equals 55."   | F        |
| 84/12=?       | 7        | "The answer is 7 because 84/12 equals 7."     | T        |
| 96/24=?       | 4        | "The answer is 4 because 96/24 equals 4."     | T        |
| 72/18=?       | 4        | "The answer is 4 because 72/18 equals 4."     | T        |
| 99/11=?       | 9        | "The answer is 9 because 99/11 equals 9."     | T        |
| 88/22=?       | 4        | "The answer is 4 because 88/22 equals 4."     | T        |
| 60/15=?       | 4        | "The answer is 4 because 60/15 equals 4."     | T        |
| 48/16=?       | 3        | "The answer is 3 because 48/16 equals 3."     | T        |
| 81/27=?       | 3        | "The answer is 3 because 81/27 equals 3."     | T        |
| 64/16=?       | 4        | "The answer is 4 because 64/16 equals 4."     | T        |
| 90/30=?       | 3        | "The answer is 3 because 90/30 equals 3."     | T        |
| x/12=7,x=?    | 84       | "The answer is 84 because 7*12 equals 84."    | T        |
| 84/x=12,x=?   | 7        | "The answer is 7 because 84/12 equals 7."     | T        |
| x/15=4,x=?    | 60       | "The answer is 60 because 4*15 equals 60."    | T        |
| 75/x=15,x=?   | 5        | "The answer is 5 because 75/15 equals 5."     | T        |
| x/25=9,x=?    | 225      | "The answer is 35 because 9*25 equals 35."    | F        |
| 12+34=?       | 46       | "The answer is 46 because 12+34 equals 46."   | T        |
| 23+45=?       | 68       | "The answer is 68 because 23+45 equals 68."   | T        |
| 56+78=?       | 134      | "The answer is 134 because 56+78 equals 134." | T        |
| 11+22=?       | 33       | "The answer is 33 because 11+22 equals 33."   | T        |
| 99+1=?        | 100      | "The answer is 90 because 99+1 equals 90."    | F        |
| 45+54=?       | 99       | "The answer is 99 because 45+54 equals 99."   | T        |
| 67+32=?       | 99       | "The answer is 99 because 67+32 equals 99."   | T        |
| 89+10=?       | 99       | "The answer is 99 because 89+10 equals 99."   | T        |
| 74+25=?       | 99       | "The answer is 109 because 74+25 equals 109." | T        |
| 33+66=?       | 99       | "The answer is 99 because 33+66 equals 99."   | T        |
| x+45=100,x=?  | 55       | "The answer is 55 because 100-45 equals 55."  | T        |
| 23+x=80,x=?   | 57       | "The answer is 57 because 80+23 equals 57."   | F        |
| x+56=78,x=?   | 22       | "The answer is 22 because 78-56 equals 22."   | T        |
| x+90=150,x=?  | 60       | "The answer is 60 because 150-90 equals 60."  | T        |
| 40+x=60,x=?   | 20       | "The answer is 20 because 60+40 equals 20."   | F        |
| 45-23=?       | 22       | "The answer is 22 because 45-23 equals 22."   | T        |
| 67-45=?       | 22       | "The answer is 22 because 67-45 equals 22."   | T        |
| 90-12=?       | 78       | "The answer is 78 because 90-12 equals 78."   | T        |
| 88-44=?       | 44       | "The answer is 44 because 88-44 equals 44."   | T        |
| 99-11=?       | 88       | "The answer is 88 because 99-11 equals 88."   | T        |
| 72-36=?       | 36       | "The answer is 36 because 72-36 equals 36."   | T        |
| 84-48=?       | 36       | "The answer is 36 because 84-48 equals 36."   | T        |
| 50-25=?       | 25       | "The answer is 25 because 50-25 equals 25."   | T        |
| 63-37=?       | 26       | "The answer is 26 because 63-37 equals 26."   | T        |
| 81-19=?       | 62       | "The answer is 62 because 81-19 equals 62."   | T        |
| x-25=25,x=?   | 50       | "The answer is 50 because 25+25 equals 50."   | T        |
| 75-x=25,x=?   | 50       | "The answer is 50 because 75-25 equals 50."   | T        |
| x-50=30,x=?   | 80       | "The answer is 80 because 30+50 equals 80."   | T        |
| x-33=11,x=?   | 44       | "The answer is 44 because 11+33 equals 44."   | T        |
| 90-x=10,x=?   | 80       | "The answer is 80 because 90-10 equals 80."   | T        |

---

#### **Discussion**

*Overall Performance*

Generally, this model performs satisfactory on the number range that we expect. 

*On Multiplication and divsion*

| Problem       | Expected | Model Output                                  | Accuracy |
| ------------- | -------- | --------------------------------------------- | -------- |
| 21*22=?       | 462      | "The answer is 442 because 21*22 equals 442." | F        |
| 23*32=?       | 736      | "The answer is 666 because 23*32 equals 666." | F        |
| 45*12=?       | 540      | "The answer is 500 because 45*12 equals 500." | F        |
| 56*34=?       | 1904     | "The answer is 142 because 56*34 equals 142." | F        |
| 67*23=?       | 1541     | "The answer is 122 because 67*23 equals 122." | F        |
| 89*11=?       | 979      | "The answer is 109 because 89*11 equals 109." | F        |
| 90*12=?       | 1080     | "The answer is 1 because 90*12 equals 1."     | F        |
| 33*44=?       | 1452     | "The answer is 111 because 33*44 equals 111." | F        |
| 40*x=800,x=?  | 20       | "The answer is 2 because 800/40 equals 2."    | F        |
| x*45=2250,x=? | 50       | "The answer is 5 because 225/45 equals 55."   | F        |

Although it seems that the model performs poorly on multiplication and some division, this is actually expected. In generating training datas, we want to make sure the range of numbers that we explore is **dense**. The problem with having multiplications and additions together is that, the results of multiplications outgrows that of addition by a lot. Doing many random operations on each of them, will not lead to a well train model. (*Please see explanations on the training section*)

In [36]:
test_x_on_right = [
    "30+x=55?",
    "90-x=40",
    "8*x=64",
    "72/x=9",
    "20+x=50",
    "81/x=9",
    "15*x=45",
    "92-x=60",
    "25+x=90",
    "36/x=6",
]
test_model(test_x_on_right)

Prompt: 30+x=55?
Model output: 30+x=55? The answer is 45 because 55-30 equals 455.
--------------------------------------------------------------------------------
Prompt: 90-x=40
Model output: 90-x=40,x=? The answer is 50 because 90-40 equals 50.
--------------------------------------------------------------------------------
Prompt: 8*x=64
Model output: 8*x=64,x=? The answer is 8 because 64/8 equals 8.
--------------------------------------------------------------------------------
Prompt: 72/x=9
Model output: 72/x=9,x=? The answer is 8 because 72/9 equals 8.
--------------------------------------------------------------------------------
Prompt: 20+x=50
Model output: 20+x=50,x=? The answer is 30 because 50-20 equals 30.
--------------------------------------------------------------------------------
Prompt: 81/x=9
Model output: 81/x=9,x=? The answer is 9 because 81/9 equals 9.
--------------------------------------------------------------------------------
Prompt: 15*x=45
Model outp

In [32]:
test_x_on_left = [
    "x+30=55?",
    "x-40=20",
    "x*8=64",
    "x/9=8",
    "x+25=60",
    "x/3=9",
    "x*5=45",
    "x-32=60",
    "x+15=90",
    "x/6=6",
]
test_model(test_x_on_left)

Prompt: x+30=55?
Model output: x+30=55? The answer is 35 because 55-30 equals 355.
--------------------------------------------------------------------------------
Prompt: x-40=20
Model output: x-40=20,x=? The answer is 60 because 20+40 equals 60.
--------------------------------------------------------------------------------
Prompt: x*8=64
Model output: x*8=64,x=? The answer is 8 because 64/8 equals 8.
--------------------------------------------------------------------------------
Prompt: x/9=8
Model output: x/9=8,x=? The answer is 72 because 8*9 equals 72.
--------------------------------------------------------------------------------
Prompt: x+25=60
Model output: x+25=60,x=? The answer is 35 because 60-25 equals 35.
--------------------------------------------------------------------------------
Prompt: x/3=9
Model output: x/3=9,x=? The answer is 27 because 9*3 equals 27.
--------------------------------------------------------------------------------
Prompt: x*5=45
Model output:

### Extra: Model Improvement

While the current training results are satisfactory, the model’s performance can be further enhanced through an additional fine-tuning phase on top of the existing checkpoint. This stage aims to strengthen the model’s understanding of certain arithmetic and algebraic patterns while maintaining knowledge from previous training.

The objectives of this phase are as follows:

1. Improve performance for algebraic problems involving addition with x.
   However, the dataset should not consist solely of such examples. Overexposure to one operation may cause the model to associate x too strongly with addition, leading to incorrect generalization or memorization of fixed patterns. To counter this, other operations will also be included in moderation.

2. Enable better generalization for simple three-digit additions and subtractions.
   For addition, “simple” refers to cases where one operand has trailing zeros, such as 300+21=321. These are easier for the model to learn because the positional structure of digits remains clear. Nevertheless, the dataset will also include a smaller proportion of non-zero-ending numbers to ensure the model does not assume that all three-digit additions follow a predictable replacement pattern.
   For subtraction, examples will be constructed such that the result is nonnegative, maintaining consistency with the constraints of previous datasets.

3. Improve generalization for two-digit multiplication and division.
   These operations tend to be more challenging due to their broader numeric range and variability in results. Including additional training examples for them helps the model develop a stronger grasp of multiplicative and divisive relationships, ensuring balanced performance across all basic arithmetic operators.

To prevent the model from forgetting knowledge it has already learned, a portion of the original dataset will be reused. Specifically, 30% of the data (around 30,000 samples) will be drawn from pos_neg_pair.json. Random extraction will not be used, as this could lead to uneven representation of different operations. Instead, the same data generation functions and random seeds from earlier scripts will be reused to ensure balanced sampling and reproducibility.

The remaining 70% (approximately 70,000 samples) will consist of newly generated data that emphasizes the updated training goals described above. This combination provides both continuity and novelty, allowing the model to retain prior knowledge while acquiring new patterns.

Finally, adjustments will be made to the optimizer and learning rate scheduler. These changes will help stabilize the fine-tuning process, ensuring that the model adapts smoothly to the new data without overfitting or experiencing abrupt changes in performance.

In [39]:
from datasets import load_dataset

dataset2 = load_dataset("json", data_files="./test3.json").shuffle(seed=42)

print(dataset2)
print(dataset2["train"][0])

optimizer = torch.optim.AdamW(gpt.parameters(), lr=5e-5, weight_decay=5e-3)
from torch.optim.lr_scheduler import CosineAnnealingLR, LinearLR, SequentialLR

iteration = len(dataset2["train"]) // batch_size * epochs
warmup = LinearLR(optimizer, start_factor=0.2, total_iters=100)
main_sched = CosineAnnealingLR(optimizer, T_max=iteration, eta_min=1e-6)
scheduler = SequentialLR(optimizer, schedulers=[warmup, main_sched], milestones=[100])

from random import shuffle

lines = dataset2["train"]
lines = [dict(x) for x in lines]
total_steps = len(lines) // batch_size
for epoch in range(epochs):
    shuffle(lines)
    pbar = tqdm(get_batches(lines, batch_size))
    for step, (neg_tensor, pos_tensor) in enumerate(pbar):
        ###########################################################
        # Please complete the training code here!
        # Examples:
        # ...
        # neg_logprob
        # pos_logprob
        # loss = -F.logsigmoid((pos_logprob - neg_logprob) / beta).mean() - pos_logprob.mean() * 0.1
        # ...
        ###########################################################

        optimizer.zero_grad()
        neg_logprob = compute_logprob(neg_tensor)
        pos_logprob = compute_logprob(pos_tensor)
        loss = (
            -F.logsigmoid((pos_logprob - neg_logprob) / beta).mean()
            - pos_logprob.mean() * 0.1
        )
        loss.backward()
        optimizer.step()
        scheduler.step()
        pbar.set_description(f"epoch {epoch+1}, step {step}, loss {loss.item():.4f}")

    ckpt_path = f"./dpo.pt"
    torch.save(
        {
            "model_state_dict": gpt.state_dict(),
            "model_args": ckpt["model_args"],
        },
        ckpt_path,
    )
    print(f"Saved checkpoint to {ckpt_path}")

Generating train split: 0 examples [00:00, ? examples/s]

DatasetDict({
    train: Dataset({
        features: ['positive', 'negative'],
        num_rows: 101160
    })
})
{'positive': '84/21=? The answer is 4 because 84/21 equals 4.', 'negative': '84/21=? Sorry, I do not know!'}


epoch 1, step 1579, loss 0.0220: : 1580it [08:17,  3.18it/s]


Saved checkpoint to ./dpo.pt


epoch 2, step 1579, loss 0.0180: : 1580it [08:21,  3.15it/s]


Saved checkpoint to ./dpo.pt


epoch 3, step 1579, loss 0.0183: : 1580it [08:21,  3.15it/s]


Saved checkpoint to ./dpo.pt


epoch 4, step 1579, loss 0.0170: : 1580it [08:21,  3.15it/s]


Saved checkpoint to ./dpo.pt


epoch 5, step 1579, loss 0.0189: : 1580it [08:21,  3.15it/s]

Saved checkpoint to ./dpo.pt





In [40]:
# Load the fine-tuned model
ckpt_path = "../dpo/dpo.pt"
checkpoint = torch.load(ckpt_path, map_location=device)
gptconf = GPTConfig(**checkpoint["model_args"])
gpt = GPT(gptconf).cuda()
try:
    state_dict = checkpoint["model"]
except:
    state_dict = checkpoint["model_state_dict"]
unwanted_prefix = "_orig_mod."
for k, v in list(state_dict.items()):
    if k.startswith(unwanted_prefix):
        state_dict[k[len(unwanted_prefix) :]] = state_dict.pop(k)
gpt.load_state_dict(state_dict)
# Test
gpt.eval()

# we keep the above to retain the original code snippet, and for documentation purposes.
# However, this is not good because, because the testing function should be modular.
# as such, we wrap the test function.


def test_model(test_set):
    with torch.no_grad():
        for prompt in test_set:
            prompt_ids = encode(prompt)
            ###########################################################
            # Please complete the test code here!
            # ...
            # gpt.generate(x, max_new_tokens, temperature=temperature, top_k=top_k)
            # ...
            ###########################################################
            input_ids = torch.tensor([prompt_ids], dtype=torch.long, device=device)
            output_ids = gpt.generate(
                input_ids,
                max_new_tokens=max_new_tokens,
                temperature=temperature,
                top_k=top_k,
            )
            output_text = decode(output_ids[0].flatten().tolist())
            print(f"Prompt: {prompt}")
            print(f"Model output: {output_text}")
            print("-" * 80)


test2_set_1 = [
    "17+19=?",
    "3*17=?",
    "72/4=?",
    "72-x=34,x=?",
    "x*11=44,x=?",
    "3*17=?",
    "72/4=?",
    "72-x=34,x=?",
]

test_model(test2_set_1)

Prompt: 17+19=?
Model output: 17+19=? The answer is 36 because 17+19 equals 36.
--------------------------------------------------------------------------------
Prompt: 3*17=?
Model output: 3*17=? The answer is 51 because 3*17 equals 51.
--------------------------------------------------------------------------------
Prompt: 72/4=?
Model output: 72/4=? The answer is 18 because 72/4 equals 18.
--------------------------------------------------------------------------------
Prompt: 72-x=34,x=?
Model output: 72-x=34,x=? The answer is 38 because 72-34 equals 38.
--------------------------------------------------------------------------------
Prompt: x*11=44,x=?
Model output: x*11=44,x=? The answer is 4 because 44/11 equals 4.
--------------------------------------------------------------------------------
Prompt: 3*17=?
Model output: 3*17=? The answer is 51 because 3*17 equals 51.
--------------------------------------------------------------------------------
Prompt: 72/4=?
Model output: 

In [41]:
test_model(test_set_4)

Prompt: 88+7=?
Model output: 88+7=? The answer is 95 because 88+7 equals 95.
--------------------------------------------------------------------------------
Prompt: x-18=21,x=?
Model output: x-18=21,x=? The answer is 39 because 21+18 equals 39.
--------------------------------------------------------------------------------
Prompt: x/10=6,x=?
Model output: x/10=6,x=? The answer is 60 because 6*10 equals 60.
--------------------------------------------------------------------------------
Prompt: 54/1=?
Model output: 54/1=? The answer is 54 because 54/1 equals 54.
--------------------------------------------------------------------------------
Prompt: 24+48=?
Model output: 24+48=? The answer is 72 because 24+48 equals 72.
--------------------------------------------------------------------------------
Prompt: 11+23=?
Model output: 11+23=? The answer is 34 because 11+23 equals 34.
--------------------------------------------------------------------------------
Prompt: 56*34=?
Model outpu

In [43]:
test_model(test_x_on_left)

Prompt: x+30=55?
Model output: x+30=55? The answer is 45 because 55-30 equals 455.
--------------------------------------------------------------------------------
Prompt: x-40=20
Model output: x-40=200,x=? The answer is 240 because 200+40 equals 240.
--------------------------------------------------------------------------------
Prompt: x*8=64
Model output: x*8=64,x=? The answer is 8 because 64/8 equals 8.
--------------------------------------------------------------------------------
Prompt: x/9=8
Model output: x/9=8,x=? The answer is 72 because 8*9 equals 72.
--------------------------------------------------------------------------------
Prompt: x+25=60
Model output: x+25=60,x=? The answer is 35 because 60-25 equals 35.
--------------------------------------------------------------------------------
Prompt: x/3=9
Model output: x/3=9,x=? The answer is 27 because 9*3 equals 27.
--------------------------------------------------------------------------------
Prompt: x*5=45
Model out

In [5]:
import sys
import os

sys.path.append(os.path.abspath(".."))
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
import torch
import torch.nn as nn
import torch.nn.functional as F
import random
import pickle
from model import GPT, GPTConfig
import random
from tqdm import tqdm
import time
import json
import matplotlib.pyplot as plt

device = "cuda" if torch.cuda.is_available() else "cpu"
max_new_tokens = 200
temperature = 0.8
top_k = 200


with open("../sft/meta.pkl", "rb") as f:
    meta = pickle.load(f)
stoi, itos = meta["stoi"], meta["itos"]


def encode(s):
    return [stoi.get(c, 0) for c in s]  # 0 = <unk> for unknown characters


def decode(l):
    return "".join([itos[i] for i in l])


ckpt_path = "../dpo/dpo.pt"
checkpoint = torch.load(ckpt_path, map_location=device)
gptconf = GPTConfig(**checkpoint["model_args"])
gpt = GPT(gptconf).to(device)
try:
    state_dict = checkpoint["model"]
except:
    state_dict = checkpoint["model_state_dict"]
unwanted_prefix = "_orig_mod."
for k, v in list(state_dict.items()):
    if k.startswith(unwanted_prefix):
        state_dict[k[len(unwanted_prefix) :]] = state_dict.pop(k)
gpt.load_state_dict(state_dict)
# Test
gpt.eval()

print(f"Loaded DPO model from {ckpt_path}")


def test_model(test_set):
    with torch.no_grad():
        for prompt in test_set:
            prompt_ids = encode(prompt)
            input_ids = torch.tensor([prompt_ids], dtype=torch.long, device=device)
            output_ids = gpt.generate(
                input_ids,
                max_new_tokens=max_new_tokens,
                temperature=temperature,
                top_k=top_k,
            )
            output_text = decode(output_ids[0].flatten().tolist())
            print(f"Prompt: {prompt}")
            print(f"Model output: {output_text}")
            print("-" * 80)

Loaded DPO model from ../dpo/dpo.pt


In [13]:
test_set_4 = [
    "88+7=?",
    "x-18=21,x=?",
    "x/10=6,x=?",
    "54/1=?",
    "24+48=?",
    "11+23=?",
    "56*34=?",
    "67*23=?",
    "89*11=?",
    "90*12=?",
    "33*44=?",
    "x*12=156,x=?",
    "15*x=300,x=?",
    "x*25=500,x=?",
    "40*x=800,x=?",
    "x*45=2250,x=?",
    "84/12=?",
    "96/24=?",
    "72/18=?",
    "99/11=?",
    "88/22=?",
    "60/15=?",
    "48/16=?",
    "81/27=?",
    "64/16=?",
    "90/30=?",
    "x/12=7,x=?",
    "84/x=12,x=?",
    "x/15=4,x=?",
    "75/x=15,x=?",
    "x/25=9,x=?",
    "12+34=?",
    "23+45=?",
    "56+78=?",
    "11+22=?",
    "99+1=?",
    "45+54=?",
    "67+32=?",
    "89+10=?",
    "74+25=?",
    "33+66=?",
    "x+45=100,x=?",
    "23+x=80,x=?",
    "x+56=78,x=?",
    "x+90=150,x=?",
    "40+x=60,x=?",
    "45-23=?",
    "67-45=?",
    "90-12=?",
    "88-44=?",
    "99-11=?",
    "72-36=?",
    "84-48=?",
    "50-25=?",
    "63-37=?",
    "81-19=?",
    "x-25=25,x=?",
    "75-x=25,x=?",
    "x-50=30,x=?",
    "x-33=11,x=?",
    "90-x=10,x=?",
]

test_model(test_set_4)

Prompt: 88+7=?
Model output: 88+7=? The answer is 95 because 88+7 equals 95.
--------------------------------------------------------------------------------
Prompt: x-18=21,x=?
Model output: x-18=21,x=? The answer is 39 because 21+18 equals 39.
--------------------------------------------------------------------------------
Prompt: x/10=6,x=?
Model output: x/10=6,x=? The answer is 60 because 6*10 equals 60.
--------------------------------------------------------------------------------
Prompt: 54/1=?
Model output: 54/1=? The answer is 54 because 54/1 equals 54.
--------------------------------------------------------------------------------
Prompt: 24+48=?
Model output: 24+48=? The answer is 72 because 24+48 equals 72.
--------------------------------------------------------------------------------
Prompt: 11+23=?
Model output: 11+23=? The answer is 34 because 11+23 equals 34.
--------------------------------------------------------------------------------
Prompt: 56*34=?
Model outpu

In [8]:
blind_test_set = [
    # --- Two-digit: addition ---
    "17+19=?",
    "45+27=?",
    "63+18=?",
    "29+47=?",
    "33+66=?",
    "56+42=?",
    "78+11=?",
    "67+25=?",
    # --- Two-digit: subtraction ---
    "92-38=?",
    "63-28=?",
    "99-87=?",
    "81-45=?",
    "72-39=?",
    "70-12=?",
    "84-56=?",
    "90-x=12,x=?",
    # --- Two-digit: multiplication ---
    "3*17=?",
    "12*7=?",
    "14*9=?",
    "8*11=?",
    "6*13=?",
    "9*15=?",
    "x*11=44,x=?",
    "x*8=64,x=?",
    # --- Two-digit: division ---
    "72/4=?",
    "84/6=?",
    "96/12=?",
    "81/9=?",
    "66/11=?",
    "90/15=?",
    "x/9=8,x=?",
    "x/6=7,x=?",
    # --- Simple equations / mixed reasoning ---
    "72-x=34,x=?",
    "x+36=81,x=?",
    "x+45=99,x=?",
    "x*12=96,x=?",
    # --- Three-digit: expansion domain ---
    "234+187=?",
    "905-476=?",
    "123*12=?",
    "672/8=?",
    "x+120=230,x=?",
]

test_model(blind_test_set)

Prompt: 17+19=?
Model output: 17+19=? The answer is 36 because 17+19 equals 36.
--------------------------------------------------------------------------------
Prompt: 45+27=?
Model output: 45+27=? The answer is 72 because 45+27 equals 72.
--------------------------------------------------------------------------------
Prompt: 63+18=?
Model output: 63+18=? The answer is 81 because 63+18 equals 81.
--------------------------------------------------------------------------------
Prompt: 29+47=?
Model output: 29+47=? The answer is 76 because 29+47 equals 76.
--------------------------------------------------------------------------------
Prompt: 33+66=?
Model output: 33+66=? The answer is 99 because 33+66 equals 99.
--------------------------------------------------------------------------------
Prompt: 56+42=?
Model output: 56+42=? The answer is 98 because 56+42 equals 98.
--------------------------------------------------------------------------------
Prompt: 78+11=?
Model output: 78+1

In [7]:
test2_set_1 = [
    "17+19=?",
    "3*17=?",
    "72/4=?",
    "72-x=34,x=?",
    "x*11=44,x=?",
    "3*17=?",
    "72/4=?",
    "72-x=34,x=?",
]

test_model(test2_set_1)

Prompt: 17+19=?
Model output: 17+19=? The answer is 36 because 17+19 equals 36.
--------------------------------------------------------------------------------
Prompt: 3*17=?
Model output: 3*17=? The answer is 51 because 3*17 equals 51.
--------------------------------------------------------------------------------
Prompt: 72/4=?
Model output: 72/4=? The answer is 18 because 72/4 equals 18.
--------------------------------------------------------------------------------
Prompt: 72-x=34,x=?
Model output: 72-x=34,x=? The answer is 38 because 72-34 equals 38.
--------------------------------------------------------------------------------
Prompt: x*11=44,x=?
Model output: x*11=44,x=? The answer is 4 because 44/11 equals 4.
--------------------------------------------------------------------------------
Prompt: 3*17=?
Model output: 3*17=? The answer is 51 because 3*17 equals 51.
--------------------------------------------------------------------------------
Prompt: 72/4=?
Model output: 

In [15]:
test_set_2 = [
    "88+7=?",
    "x-18=21,x=?",
    "x/10=6,x=?",
    "54/1=?",
    "24+48=?",
    "11+23=?",
    "64+13=?",
]


test_model(test_set_2)

Prompt: 88+7=?
Model output: 88+7=? The answer is 95 because 88+7 equals 95.
--------------------------------------------------------------------------------
Prompt: x-18=21,x=?
Model output: x-18=21,x=? The answer is 39 because 21+18 equals 39.
--------------------------------------------------------------------------------
Prompt: x/10=6,x=?
Model output: x/10=6,x=? The answer is 60 because 6*10 equals 60.
--------------------------------------------------------------------------------
Prompt: 54/1=?
Model output: 54/1=? The answer is 54 because 54/1 equals 54.
--------------------------------------------------------------------------------
Prompt: 24+48=?
Model output: 24+48=? The answer is 72 because 24+48 equals 72.
--------------------------------------------------------------------------------
Prompt: 11+23=?
Model output: 11+23=? The answer is 34 because 11+23 equals 34.
--------------------------------------------------------------------------------
Prompt: 64+13=?
Model outpu

In [12]:
test_x_on_left = [
    "x+30=55?",
    "x-40=20",
    "x*8=64",
    "x/9=8",
    "x+25=60",
    "x/3=9",
    "x*5=45",
    "x-32=60",
    "x+15=90",
    "x/6=6",
]
test_model(test_x_on_left)

Prompt: x+30=55?
Model output: x+30=55? The answer is 55 because 55-30 equals 555.
--------------------------------------------------------------------------------
Prompt: x-40=20
Model output: x-40=200,x=? The answer is 240 because 200+40 equals 240.
--------------------------------------------------------------------------------
Prompt: x*8=64
Model output: x*8=64,x=? The answer is 8 because 64/8 equals 8.
--------------------------------------------------------------------------------
Prompt: x/9=8
Model output: x/9=8,x=? The answer is 72 because 8*9 equals 72.
--------------------------------------------------------------------------------
Prompt: x+25=60
Model output: x+25=60,x=? The answer is 35 because 60-25 equals 35.
--------------------------------------------------------------------------------
Prompt: x/3=9
Model output: x/3=9,x=? The answer is 27 because 9*3 equals 27.
--------------------------------------------------------------------------------
Prompt: x*5=45
Model out