### Step 1: Install necesscary packages

In [2]:
!pip install matplotlib
!pip install torch numpy transformers datasets tiktoken wandb tqdm

Collecting matplotlib
  Using cached matplotlib-3.10.7-cp312-cp312-win_amd64.whl.metadata (11 kB)
Collecting contourpy>=1.0.1 (from matplotlib)
  Using cached contourpy-1.3.3-cp312-cp312-win_amd64.whl.metadata (5.5 kB)
Collecting cycler>=0.10 (from matplotlib)
  Using cached cycler-0.12.1-py3-none-any.whl.metadata (3.8 kB)
Collecting fonttools>=4.22.0 (from matplotlib)
  Using cached fonttools-4.60.1-cp312-cp312-win_amd64.whl.metadata (114 kB)
Collecting kiwisolver>=1.3.1 (from matplotlib)
  Using cached kiwisolver-1.4.9-cp312-cp312-win_amd64.whl.metadata (6.4 kB)
Collecting pyparsing>=3 (from matplotlib)
  Using cached pyparsing-3.2.5-py3-none-any.whl.metadata (5.0 kB)
Using cached matplotlib-3.10.7-cp312-cp312-win_amd64.whl (8.1 MB)
Using cached contourpy-1.3.3-cp312-cp312-win_amd64.whl (226 kB)
Using cached cycler-0.12.1-py3-none-any.whl (8.3 kB)
Using cached fonttools-4.60.1-cp312-cp312-win_amd64.whl (2.3 MB)
Using cached kiwisolver-1.4.9-cp312-cp312-win_amd64.whl (73 kB)
Using cac


[notice] A new release of pip is available: 25.0.1 -> 25.2
[notice] To update, run: python.exe -m pip install --upgrade pip


Collecting transformers
  Using cached transformers-4.57.1-py3-none-any.whl.metadata (43 kB)
Collecting datasets
  Downloading datasets-4.3.0-py3-none-any.whl.metadata (18 kB)
Collecting tiktoken
  Using cached tiktoken-0.12.0-cp312-cp312-win_amd64.whl.metadata (6.9 kB)
Collecting wandb
  Using cached wandb-0.22.2-py3-none-win_amd64.whl.metadata (10 kB)
Collecting tqdm
  Using cached tqdm-4.67.1-py3-none-any.whl.metadata (57 kB)
Collecting huggingface-hub<1.0,>=0.34.0 (from transformers)
  Using cached huggingface_hub-0.36.0-py3-none-any.whl.metadata (14 kB)
Collecting pyyaml>=5.1 (from transformers)
  Using cached pyyaml-6.0.3-cp312-cp312-win_amd64.whl.metadata (2.4 kB)
Collecting regex!=2019.12.17 (from transformers)
  Using cached regex-2025.10.23-cp312-cp312-win_amd64.whl.metadata (41 kB)
Collecting requests (from transformers)
  Using cached requests-2.32.5-py3-none-any.whl.metadata (4.9 kB)
Collecting tokenizers<=0.23.0,>=0.22.0 (from transformers)
  Using cached tokenizers-0.22.


[notice] A new release of pip is available: 25.0.1 -> 25.2
[notice] To update, run: python.exe -m pip install --upgrade pip


### Step 2: Package imports and configuration

In [48]:
import sys
import os
sys.path.append(os.path.abspath("..")) 
# os.environ["CUDA_VISIBLE_DEVICES"] = "0"
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
os.environ["TORCH_USE_CUDA_DSA"] = "1"
import torch
import torch.nn as nn
import torch.nn.functional as F
import random
import pickle
from model import GPT, GPTConfig
import random
from tqdm import tqdm
import time
import json
import matplotlib.pyplot as plt

# Configuration
beta = 0.5 # controls DPO loss strength
device = 'cuda' if torch.cuda.is_available() else 'cpu'
base_lr = 1e-4 # learning rate
epochs = 10 # number of runs of dataset
batch_size = 128 # number of sample pairs processed together
max_length =64 # max sequence length
num_samples = 1
max_new_tokens = 200 # max tokens generated
temperature = 0.8 # randomness in generation
top_k = 200 # sample from top_k most likely tokens
# tokenizer

with open("../sft/meta.pkl", "rb") as f:
    meta = pickle.load(f)
stoi, itos = meta["stoi"], meta["itos"]

extra_chars = list("+-*/=xy?!,.' ")
for ch in extra_chars:
    if ch not in stoi:
        new_index = len(stoi)
        stoi[ch] = new_index
        itos[new_index] = ch

def encode(s): return [stoi[c] for c in s]
def decode(l): return ''.join([itos[i] for i in l])

### Step 3: Define helper functions

We wanted to check the CUDA installation on the computer before moving onto the next step. We also used this step to verify torch version compatibility with the CUDA version and the availability of GPU.

On our device, the PyTorch version is 2.9.0 with CUDA 13.0 and 1 GPU device available.

In [30]:
print(torch.__version__)
print(torch.version.cuda)
print(torch.cuda.device_count())
print(torch.cuda.is_available())
device


2.9.0+cu130
13.0
1
True


'cuda'

Calculating log probability, which is used over actual probability to prevent underflow, ensuring computational stability.
Initially, the compute_logprob(input_ids) creates input/target pairs and uses the inputs to get the model predictions. 
Next, extract the batch size, sequence length time and vocab size from the logits, then reshapes the dimensions of the data to fit the model’s learning.
The functions return negative loss after their calculations. Negative loss = log probability (Higher is better)
Used for computing P(positive | question) and P(negative | question).

pad_or_truncate(seq, max_length) makes all input sequences the same length when processing the inputs.

get_batches(lines, batch_size) generates training batches from the given data.
The function iterates through the data in chunks and process the negative & positive samples.
Returns negative & positive batches after converting to tensors.


In [35]:
def compute_logprob(input_ids):
    inputs = input_ids[:, :-1]
    targets = input_ids[:, 1:]
    logits, _ = gpt(inputs, full_seq=True)
    B, T, V = logits.size()
    logits_flat = logits.reshape(-1, V)
    targets_flat = targets.reshape(-1)
    loss = F.cross_entropy(logits_flat, targets_flat, ignore_index=0, reduction='none')
    loss = loss.reshape(B, T)
    attention_mask = (targets != 0).float()
    loss = (loss * attention_mask).sum(dim=1) / attention_mask.sum(dim=1)
    return -loss 

def pad_or_truncate(seq, max_length):
    return seq[-max_length:] if len(seq) > max_length else seq + [0] * (max_length - len(seq))

def get_batches(lines, batch_size):
    random.shuffle(lines)
    #for l in lines:
    #    print(l[1])
    for i in range(0, len(lines), batch_size):
        batch = lines[i:i+batch_size]
        if len(batch) < batch_size:
            continue
        neg_inputs = [pad_or_truncate(encode(p['negative'] + '\n\n\n\n'), max_length) for p in batch]
        pos_inputs = [pad_or_truncate(encode(p['positive'] + '\n\n\n\n'), max_length) for p in batch]
        neg_tensor = torch.tensor(neg_inputs, dtype=torch.long, device=device)
        pos_tensor = torch.tensor(pos_inputs, dtype=torch.long, device=device)
        yield neg_tensor, pos_tensor

### Step 4: Load the pretrained NanoGPT model

This step loads the checkpoint gpt.pt containing pretrained weights and config and reconstructs the model using GPTConfig. The loop cleans up by removing unwanted prefixes. Once the model is loaded and ready, it can be fine-tuned in Step 5 using the batched data and log probabilities defined in Step 3.

In [36]:
ckpt = torch.load("../sft/gpt.pt", map_location=device)
gptconf = GPTConfig(**ckpt['model_args'])
gpt = GPT(gptconf)
state_dict = ckpt['model']
unwanted_prefix = '_orig_mod.'
for k in list(state_dict.keys()):
    if k.startswith(unwanted_prefix):
        state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k)
gpt.load_state_dict(state_dict)
gpt.to(device).train()

GPT(
  (transformer): ModuleDict(
    (wte): Embedding(74, 348)
    (wpe): Embedding(256, 348)
    (drop): Dropout(p=0.2, inplace=False)
    (h): ModuleList(
      (0-5): 6 x Block(
        (ln_1): LayerNorm()
        (attn): CausalSelfAttention(
          (c_attn): Linear(in_features=348, out_features=1044, bias=False)
          (c_proj): Linear(in_features=348, out_features=348, bias=False)
          (attn_dropout): Dropout(p=0.2, inplace=False)
          (resid_dropout): Dropout(p=0.2, inplace=False)
        )
        (ln_2): LayerNorm()
        (mlp): MLP(
          (c_fc): Linear(in_features=348, out_features=1392, bias=False)
          (gelu): GELU(approximate='none')
          (c_proj): Linear(in_features=1392, out_features=348, bias=False)
          (dropout): Dropout(p=0.2, inplace=False)
        )
      )
    )
    (ln_f): LayerNorm()
  )
  (lm_head): Linear(in_features=348, out_features=74, bias=False)
)

### Step 5: Load Data

This step is to generate the negative pair using the pretrained model(gpt.pt). 

generate_math_qns() generates the problems using random between 0 to 99. 10 'x' is also appended as a choice for either of the values to generate algebaric questions. If x is selected, a random integer between 0 to 99 is selected as the answer for the problem as well. 

The question is fed into gpt.pt with 0.1 temperature as we want to minize the randomness to consistently get the string 'Sorry, I do not know.'. The output is decoded and validated before it is saved.

In [None]:
def generate_math_qns():
    a = random.choice(list(range(0,100)) + ['x'] * 10)
    b = random.choice(range(0,100) if a == 'x' else list(range(0,100)) + ['x'] * 10)
    op = random.choice(['+', '-', '*'])
    if a == 'x' or b == 'x':
        result = random.randint(0, 99)
        return f"{a}{op}{b}={result}, x=? "
    else:
        return f"{a}{op}{b}=? "

# complile examples from json
example = ''
with open("pos_neg_pairs.json", "r+", encoding="utf-8") as f:
    data = json.load(f)

    for line in data:
        example += line['positive'] + '\n'
    
    # generate neg pos pair
    for i in range(5):
        prompt = generate_math_qns()
        correct = False
        while(not correct): # repeatedly generate neg till correct output generates
            input_ids = torch.tensor([encode(prompt)], dtype=torch.long, device=device)
            output = gpt.generate(
                input_ids,
                max_new_tokens=max_new_tokens,
                temperature=0.1,
                top_k=200
            )

            negative = decode(output[0][0].tolist())
            if negative.split('? ')[1] == "Sorry, I don't know.":
                correct = True
                pair = {"negative" : negative, "positive" : ""}
                data.append(pair)
    f.seek(0)
    json.dump(data, f)
    print('Negative pair generated')

Completed


This is where we load our positve negative problem pairs from the JSON file and format them for batching using get_batches(). 

For our training, we generated 400k pairs of data. However, during testing we realised that the algebaric problems was still not well-tuned. Hence, we generated 301k problems with 300:1 ratio of algebra problems to simple math problems. This is because we did not have enough time to fine-tune specific variables for each run (e.g learning rate, temperature, beta) and GPU capacity to train with more than 400k of data. 

In order to test and fine-tune specifically the algebaric problems, we sacrificed the weights for the simple math problems even though it might cause the model to lose accuracy for the simple math problems during training. We determined that this sacrifice was essential for us to pinpoint the issues with the algebaric problems.

In [39]:
# Load data from ./data/pos_neg_pairs.json
import json

with open("pos_neg_pairs_301k.json", "r", encoding="utf-8") as f:
    data = json.load(f)

print(f"Loaded {len(data)} pairs.")
print("Example:")
print(data[0])

sample = data[0]
print("\nEncoded positive example:")
print(encode(sample["positive"])[:50]) 

batches = get_batches(data, batch_size=batch_size)

neg_batch, pos_batch = next(batches)
print("\nNegative batch shape:", neg_batch.shape)
print("Positive batch shape:", pos_batch.shape)

Loaded 301000 pairs.
Example:
{'negative': "x+3=97, x=? Sorry, I don't know.", 'positive': 'x+3=97, x=? The answer is 94 because 97-3 equals 94.'}

Encoded positive example:
[71, 4, 15, 9, 21, 19, 5, 1, 71, 9, 10, 1, 41, 55, 52, 1, 48, 61, 66, 70, 52, 65, 1, 56, 66, 1, 21, 16, 1, 49, 52, 50, 48, 68, 66, 52, 1, 21, 19, 6, 15, 1, 52, 64, 68, 48, 59, 66, 1, 21]

Negative batch shape: torch.Size([250, 64])
Positive batch shape: torch.Size([250, 64])


### Step 6: Build the optimizer and scheduler

The use of an optimizer is to facilitate finding the optimal weights of a model used to minimize the loss function.

The use of a scheduler is to regulate the learning rate of the model during the training process. Initially, it will start with the learning rate set in step 2, it then slowly decreases over time, adapting the model to a better minimum after the fast initial convergence.

During the training process, the optimizer updates the parameters while using the current learning rate. After a set number of epochs, the schedular then calculates a new, adjusted learning rate.

AdamW Optimizer uses weight decay as compared to using Adam Optimizer gradient calculations. Rather than applying the weight decay to the loss function after calculation, it directly affects the parameters updating. 

The scheduler uses Cosine Annealing Scheduler Warm Restarts. Over several epochs, it will take the initial learning rate and decreases it along a cosine curve. Warm restarts periodically resets the learning rate and restarts the process.


In [None]:
# recommend to use the AdamW optimizer 
from torch.optim import AdamW
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts

optimizer = AdamW(
    gpt.parameters(),
    lr=base_lr,
    betas=0.5,
    eps=1e-8,
    weight_decay=0.01
)

scheduler = CosineAnnealingWarmRestarts(
    optimizer,
    T_0=10,          # restart every 10 epochs
    T_mult=1,        # no expansion of cycle length
    eta_min=1e-6
)

### Step 7: Begin training

This step is to start the training with the loaded data.

For each step, we will clear the memory before computing the log probabilities for positive and negative. dpo_term calculates the DPO loss while superised_term maximizes the probability for positive samples hence reinforcing the correct answers. The loss is combined and the gradient is computed with a cap of magnitude at 1.0 using clip_grad_norm_. We also changed the checkpoint to be saved every epoch. We find that this method is useful in cases where we overtrained our data. With this implementation, we can backtrack to the checkpoint right before the overtraining.


##### dpo_term
First we find the difference between the positive (pos_logprob) and negative (neg_logprob) log probability. This measures the preference gap where positive difference shows preference towards positive sample and negative shows preference towards negative sample. We divide by the beta to double the difference to push the optimization to be more aggressive. logsigmold stablises the output and we take the average across the batches.

##### clip_grad_norm_
This function computes the global norm and if norm > 1.0 it will scale all the gradients down to prevent exploding gradients. This step is important for training as preference data can have extreme examples between the probabilities. The large probability gaps can cause large gradients. Hence without clipping the training will be unstable.

In [41]:
from torch.nn.utils import clip_grad_norm_

train_losses = []

for epoch in range(epochs):
    gpt.train()
    total_loss = 0
    pbar = tqdm(get_batches(data, batch_size), desc=f"Epoch {epoch+1}/{epochs}")

    for step, (neg_tensor, pos_tensor) in enumerate(pbar):
        optimizer.zero_grad()

        # Compute log probabilities
        pos_logprob = compute_logprob(pos_tensor)
        neg_logprob = compute_logprob(neg_tensor)

        # DPO loss (main preference objective)
        dpo_term = -F.logsigmoid((pos_logprob - neg_logprob) / beta).mean()

        # Auxiliary supervised loss — reinforces correct answers
        supervised_term = -pos_logprob.mean() * 0.1

        # Combined loss
        loss = dpo_term + supervised_term
        loss.backward()

        # Gradient clipping to prevent explosion
        clip_grad_norm_(gpt.parameters(), 1.0)

        optimizer.step()
        scheduler.step()

        total_loss += loss.item()
        train_losses.append(loss.item())
        pbar.set_postfix(loss=f"{loss.item():.4f}")

    avg_loss = total_loss / (step + 1)
    print(f"✅ Epoch {epoch+1} complete | Avg loss: {avg_loss:.4f}")

    # Save checkpoints every epoch
    ckpt_path = f"./dpo_epoch{epoch+1}.pt"
    torch.save({
        "model_state_dict": gpt.state_dict(),
        "model_args": ckpt['model_args']
    }, ckpt_path)
    print(f"💾 Saved checkpoint: {ckpt_path}")

Epoch 1/20: 1204it [20:14,  1.01s/it, loss=0.0268]


✅ Epoch 1 complete | Avg loss: 0.0551
💾 Saved checkpoint: ./dpo_epoch1.pt


Epoch 2/20: 1204it [09:59,  2.01it/s, loss=0.0222]


✅ Epoch 2 complete | Avg loss: 0.0238
💾 Saved checkpoint: ./dpo_epoch2.pt


Epoch 3/20: 1204it [11:23,  1.76it/s, loss=0.0209]


✅ Epoch 3 complete | Avg loss: 0.0218
💾 Saved checkpoint: ./dpo_epoch3.pt


Epoch 4/20: 1204it [13:38,  1.47it/s, loss=0.0213]


✅ Epoch 4 complete | Avg loss: 0.0211
💾 Saved checkpoint: ./dpo_epoch4.pt


Epoch 5/20: 1204it [11:50,  1.69it/s, loss=0.0205]


✅ Epoch 5 complete | Avg loss: 0.0207
💾 Saved checkpoint: ./dpo_epoch5.pt


Epoch 6/20: 1204it [10:55,  1.84it/s, loss=0.0206]


✅ Epoch 6 complete | Avg loss: 0.0204
💾 Saved checkpoint: ./dpo_epoch6.pt


Epoch 7/20: 1204it [10:00,  2.01it/s, loss=0.0195]


✅ Epoch 7 complete | Avg loss: 0.0202
💾 Saved checkpoint: ./dpo_epoch7.pt


Epoch 8/20: 1204it [15:24,  1.30it/s, loss=0.0195]


✅ Epoch 8 complete | Avg loss: 0.0199
💾 Saved checkpoint: ./dpo_epoch8.pt


Epoch 9/20: 1204it [36:44,  1.83s/it, loss=0.0194]


✅ Epoch 9 complete | Avg loss: 0.0197
💾 Saved checkpoint: ./dpo_epoch9.pt


Epoch 10/20: 1204it [10:52,  1.84it/s, loss=0.0198]


✅ Epoch 10 complete | Avg loss: 0.0195
💾 Saved checkpoint: ./dpo_epoch10.pt


Epoch 11/20: 1204it [22:00,  1.10s/it, loss=0.0198]


✅ Epoch 11 complete | Avg loss: 0.0193
💾 Saved checkpoint: ./dpo_epoch11.pt


Epoch 12/20: 1204it [34:18,  1.71s/it, loss=0.0187]


✅ Epoch 12 complete | Avg loss: 0.0191
💾 Saved checkpoint: ./dpo_epoch12.pt


Epoch 13/20: 946it [3:55:32, 14.94s/it, loss=0.0189] 


KeyboardInterrupt: 

### Step 8: Begin testing

Loads the trained model after the training process. We saved each epoch to be able to trace back to whichever epoch we want to use for testing, checking for overtraining. This allows us to select and examine our best performance on the training data.

In addition to the given test set, we added more test sets for improved reliability. 

The test function iterates through all the prompts in the test set and encodes them from string to token IDs and runs it through the correct device.

It then generates the response based on the given parameters. max_new_tokens which sets the max generation length to 200, allows for sufficient full explanation of response. top_k which filters to the top 200 tokens, prevents unlikely tokens from being sampled. As for the temperature we played around with different temperatures varying from 0.1 (more deterministic) to 0.9 (more diverse), checking which gives the best performance.

The decoder extracts the tensors and places them into a python list, while the token IDs are converted back into strings. Finally, displaying the responses.


In [None]:
# Load the fine-tuned model
ckpt_path = "../dpo/dpo_epoch12.pt"
checkpoint = torch.load(ckpt_path, map_location=device)
gptconf = GPTConfig(**checkpoint['model_args'])
gpt = GPT(gptconf).to(device)
try:
    state_dict = checkpoint['model']
except:
    state_dict = checkpoint['model_state_dict']
unwanted_prefix = '_orig_mod.'
for k,v in list(state_dict.items()):
    if k.startswith(unwanted_prefix):
        state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k)
gpt.load_state_dict(state_dict)
# Test
gpt.eval()
test_set = ["17+19=?", "3*17=?", "72/4=?", "72-x=34,x=?", "x*11=44,x=?", "3*17=?", "72/4=?", "72-x=34,x=?"]
# test_set = [
#     "25+18=?",
#     "56-27=?",
#     "9*12=?",
#     "84/7=?",
#     "x+45=90,x=?",
#     "63-x=22,x=?",
#     "x*6=54,x=?",
#     "88/11=?"
# ]

with torch.no_grad():
    for prompt in test_set: 
        # prompt_ids = encode(prompt)
        prompt_ids = torch.tensor([encode(prompt)], dtype=torch.long, device=device)
        output = gpt.generate(
            prompt_ids,
            max_new_tokens=max_new_tokens,
            temperature=0.01,
            top_k=200
        )
        result = decode(output[0][0].tolist())
        print(f"Prompt: {prompt}\nModel Output: {result}\n")

Prompt: 17+19=?
Model Output: 17+19=? The answer is 188 because 17+191 equals 18.

Prompt: 3*17=?
Model Output: 3*17=? x=? The answer is 1 because 3/1 equals 1.

Prompt: 72/4=?
Model Output: 72/4=? x=? The answer is 8 because 72/1 equals 88.

Prompt: 72-x=34,x=?
Model Output: 72-x=34,x=? The answer is 69 because 72-3 equals 69.

Prompt: x*11=44,x=?
Model Output: x*11=44,x=? The answer is 44 because 44/1 equals 44.

Prompt: 3*17=?
Model Output: 3*17=? x=? The answer is 1 because 3/1 equals 1.

Prompt: 72/4=?
Model Output: 72/4=? x=? The answer is 8 because 72/1 equals 88.

Prompt: 72-x=34,x=?
Model Output: 72-x=34,x=? The answer is 69 because 72-3 equals 69.

