### Step 1: Install necesscary packages

In [None]:
!pip install matplotlib
!pip install torch numpy transformers datasets tiktoken wandb tqdm

In [None]:
!pip3 install tqdm

### Step 2: Package imports and configuration

In [70]:
import sys
import os
sys.path.append(os.path.abspath("..")) 
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
import torch
import torch.nn as nn
from torch.optim import AdamW
from torch.optim.lr_scheduler import StepLR
import torch.nn.functional as F
import random
import pickle
from model import GPT, GPTConfig
import random
from tqdm import tqdm
import time
import json
import matplotlib.pyplot as plt
# Configuration
beta = 0.5
if torch.cuda.is_available():
    device = "cuda"
elif getattr(torch.backends, "mps", None) and torch.backends.mps.is_available():
    device = "mps"
else:
    device = "cpu"
    
base_lr = 1e-4
epochs = 5
batch_size = 64
max_length =80
num_samples = 1
max_new_tokens = 200
temperature = 0.8
top_k = 200
# tokenizer
with open("../sft/meta.pkl", "rb") as f:
    meta = pickle.load(f)
stoi, itos = meta["stoi"], meta["itos"]
def encode(s): return [stoi[c] for c in s]
def decode(l): return ''.join([itos[i] for i in l])

### Step 3: Define helper functions

In [72]:
def compute_logprob(input_ids):
    inputs = input_ids[:, :-1]
    targets = input_ids[:, 1:]
    logits, _ = gpt(inputs, full_seq=True)
    B, T, V = logits.size()
    logits_flat = logits.reshape(-1, V)
    targets_flat = targets.reshape(-1)
    loss = F.cross_entropy(logits_flat, targets_flat, ignore_index=0, reduction='none')
    loss = loss.reshape(B, T)
    attention_mask = (targets != 0).float()
    loss = (loss * attention_mask).sum(dim=1) / attention_mask.sum(dim=1)
    return -loss 

def pad_or_truncate(seq, max_length):
    return seq[:max_length] if len(seq) > max_length else seq + [0] * (max_length - len(seq))

def get_batches(lines, batch_size):
    random.shuffle(lines)
    #for l in lines:
    #    print(l[1])
    for i in range(0, len(lines), batch_size):
        batch = lines[i:i+batch_size]
        if len(batch) < batch_size:
            continue
        neg_inputs = [pad_or_truncate(encode(p['negative'] + '\n\n\n\n'), max_length) for p in batch]
        pos_inputs = [pad_or_truncate(encode(p['positive'] + '\n\n\n\n'), max_length) for p in batch]
        neg_tensor = torch.tensor(neg_inputs, dtype=torch.long, device=device)
        pos_tensor = torch.tensor(pos_inputs, dtype=torch.long, device=device)
        yield neg_tensor, pos_tensor

### Step 4: Load the pretrained NanoGPT model

In [73]:
ckpt = torch.load("../sft/gpt.pt", map_location=device)
gptconf = GPTConfig(**ckpt['model_args'])
gpt = GPT(gptconf)
state_dict = ckpt['model']
unwanted_prefix = '_orig_mod.'
for k in list(state_dict.keys()):
    if k.startswith(unwanted_prefix):
        state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k)
gpt.load_state_dict(state_dict)
gpt.to(device).train()

  ckpt = torch.load("../sft/gpt.pt", map_location=device)


GPT(
  (transformer): ModuleDict(
    (wte): Embedding(74, 348)
    (wpe): Embedding(256, 348)
    (drop): Dropout(p=0.2, inplace=False)
    (h): ModuleList(
      (0-5): 6 x Block(
        (ln_1): LayerNorm()
        (attn): CausalSelfAttention(
          (c_attn): Linear(in_features=348, out_features=1044, bias=False)
          (c_proj): Linear(in_features=348, out_features=348, bias=False)
          (attn_dropout): Dropout(p=0.2, inplace=False)
          (resid_dropout): Dropout(p=0.2, inplace=False)
        )
        (ln_2): LayerNorm()
        (mlp): MLP(
          (c_fc): Linear(in_features=348, out_features=1392, bias=False)
          (gelu): GELU(approximate='none')
          (c_proj): Linear(in_features=1392, out_features=348, bias=False)
          (dropout): Dropout(p=0.2, inplace=False)
        )
      )
    )
    (ln_f): LayerNorm()
  )
  (lm_head): Linear(in_features=348, out_features=74, bias=False)
)

In [74]:
# Test the base model BEFORE DPO training
print("=== Testing BASE MODEL (before DPO) ===")
gpt.eval()
test_cases = ["x*11=44,x=?", "72-x=34,x=?", "x+9=87,x=?"]
with torch.no_grad():
    for prompt in test_cases:
        prompt_ids = encode(prompt)
        x_test = torch.tensor([prompt_ids], dtype=torch.long, device=device)
        y_test = gpt.generate(x_test, max_new_tokens=50, temperature=0.8, top_k=200)
        result = decode(y_test[0].cpu().flatten().tolist())
        print(f"Q: {prompt}")
        print(f"A: {result[len(prompt):].strip()}\n")
gpt.train()

=== Testing BASE MODEL (before DPO) ===
Q: x*11=44,x=?
A: Sorry, I don't know.

Q: 72-x=34,x=?
A: Sorry, I don't know.

Q: x+9=87,x=?
A: Sorry, I don't know.



GPT(
  (transformer): ModuleDict(
    (wte): Embedding(74, 348)
    (wpe): Embedding(256, 348)
    (drop): Dropout(p=0.2, inplace=False)
    (h): ModuleList(
      (0-5): 6 x Block(
        (ln_1): LayerNorm()
        (attn): CausalSelfAttention(
          (c_attn): Linear(in_features=348, out_features=1044, bias=False)
          (c_proj): Linear(in_features=348, out_features=348, bias=False)
          (attn_dropout): Dropout(p=0.2, inplace=False)
          (resid_dropout): Dropout(p=0.2, inplace=False)
        )
        (ln_2): LayerNorm()
        (mlp): MLP(
          (c_fc): Linear(in_features=348, out_features=1392, bias=False)
          (gelu): GELU(approximate='none')
          (c_proj): Linear(in_features=1392, out_features=348, bias=False)
          (dropout): Dropout(p=0.2, inplace=False)
        )
      )
    )
    (ln_f): LayerNorm()
  )
  (lm_head): Linear(in_features=348, out_features=74, bias=False)
)

### Step 5: Load Data (**students are required to complete this part!**)

In [75]:
# Load data from ./data/pos_neg_pairs.json
with open('/Users/xb/Desktop/Uni Notes/Y3S1/SC3000/Assignment 1/Ass1 github/dpo/pos_neg_pairs.json', 'r', encoding='utf-8') as f:
    data = json.load(f)
    lines = data

# Check the structure
print(type(data))      # should be <class 'list'>
print(len(data))       # number of items
print(data[1])         # show the first item

<class 'list'>
128000
{'negative': 'x+49=64, x=? Sorry, I do not know', 'positive': 'x+49=64, x=? The answer is 15 because 64 minus 49 equals 15.'}


In [78]:
# Add this cell after Step 5 to specifically check x-equation examples
print("=== Checking X-EQUATION examples ===")

# Get a batch and look for x-equations
found_x_examples = 0
sample_batch_gen = get_batches(lines, 128)

while found_x_examples < 3:
    neg_tensor, pos_tensor = next(sample_batch_gen)
    
    for i in range(len(pos_tensor)):
        pos_decoded = decode(pos_tensor[i].cpu().tolist())
        
        if 'x' in pos_decoded and found_x_examples < 3:
            print(f"\n{'='*60}")
            print(f"X-equation example {found_x_examples + 1}:")
            pos_clean = pos_decoded.replace('\x00', '[PAD]')
            neg_clean = decode(neg_tensor[i].cpu().tolist()).replace('\x00', '[PAD]')
            
            print(f"POSITIVE ({len(pos_tensor[i].cpu().tolist())} tokens):")
            print(f"{pos_clean}")
            print(f"\nNEGATIVE:")
            print(f"{neg_clean}")
            
            found_x_examples += 1
            
        if found_x_examples >= 3:
            break

print(f"\n{'='*60}")
print(f"Found {found_x_examples} x-equation examples to check")

=== Checking X-EQUATION examples ===

X-equation example 1:
POSITIVE (80 tokens):
42-x=15, x=? The answer is 27 because 42 minus 15 equals 27.





















NEGATIVE:
42-x=15, x=? Sorry, I do not know
















































X-equation example 2:
POSITIVE (80 tokens):
7*x=658, x=? The answer is 94 because 658 divided by 7 equals 94.
















NEGATIVE:
7*x=658, x=? Sorry, I do not know
















































X-equation example 3:
POSITIVE (80 tokens):
x/45=2, x=? The answer is 90 because 45 times 2 equals 90.























NEGATIVE:
x/45=2, x=? Sorry, I do not know

















































Found 3 x-equation examples to check


### Step 6: Build the optimizer and scheduler (**students are required to complete this part!**)

In [6]:
def build_optimizer_and_scheduler(
    model,
    lr=base_lr,
    weight_decay=0.1,
    betas=(0.9, 0.95),
    eps=1e-8,
    step_size=1000,   # decay every 1000 steps
    gamma=0.5,        # multiply LR by this factor each decay
):
    # ---- no weight decay for bias & LayerNorm/BatchNorm ----
    decay, no_decay = [], []
    for name, p in model.named_parameters():
        if not p.requires_grad:
            continue
        if p.ndim >= 2:
            decay.append(p)
        else:
            no_decay.append(p)

    optim_groups = [
        {"params": decay, "weight_decay": weight_decay},
        {"params": no_decay, "weight_decay": 0.0},
    ]
    
    # AdamW optimizer
    optimizer = AdamW(optim_groups, lr=lr, betas=betas, eps=eps)

    # Step Decay scheduler
    scheduler = StepLR(optimizer, step_size=step_size, gamma=gamma)

    return optimizer, scheduler


optimizer, scheduler = build_optimizer_and_scheduler(
    gpt,
    lr=base_lr,
    weight_decay=0.1,
    betas=(0.9, 0.95),
    eps=1e-8,
    step_size=2000,
    gamma=0.5,
)


### Step 7: Begin training (**students are required to complete this part!**)

In [7]:
total_steps = len(lines) // batch_size
for epoch in range(epochs):
    pbar = tqdm(get_batches(lines, batch_size))
    for step, (neg_tensor,pos_tensor) in enumerate(pbar):
        ###########################################################
        # Please complete the training code here!
        # Examples: 
        # ...
        # neg_logprob
        # pos_logprob 
        # loss = -F.logsigmoid((pos_logprob - neg_logprob) / beta).mean() - pos_logprob.mean() * 0.1 
        # ...
        pos_logprob = compute_logprob(pos_tensor)
        neg_logprob = compute_logprob(neg_tensor)

        loss = -F.logsigmoid((pos_logprob - neg_logprob) / beta).mean() - pos_logprob.mean() * 0.1
        optimizer.zero_grad(set_to_none=True)
        loss.backward()
        optimizer.step()
        scheduler.step()

        if (step + 1) % 10 == 0:
            pbar.set_postfix(loss=float(loss.item()), lr=scheduler.get_last_lr()[0])
                
        ###########################################################
    ckpt_path = f"./dpo.pt"
    torch.save({
        "model_state_dict": gpt.state_dict(),
        "model_args": ckpt['model_args'],
    }, ckpt_path)
    print(f"Saved checkpoint to {ckpt_path}")

1562it [12:10,  2.14it/s, loss=0.0365, lr=0.0001]


Saved checkpoint to ./dpo.pt


1562it [14:02,  1.85it/s, loss=0.0345, lr=5e-5] 


Saved checkpoint to ./dpo.pt


1562it [13:55,  1.87it/s, loss=0.0333, lr=2.5e-5]


Saved checkpoint to ./dpo.pt


1562it [13:24,  1.94it/s, loss=0.0321, lr=1.25e-5]


Saved checkpoint to ./dpo.pt


1562it [13:09,  1.98it/s, loss=0.0314, lr=1.25e-5]

Saved checkpoint to ./dpo.pt





### Step 8: Begin testing (**students are required to complete this part!**)

In [99]:
# Load the fine-tuned model
ckpt_path = "../dpo/dpo.pt"
checkpoint = torch.load(ckpt_path, map_location=device)
gptconf = GPTConfig(**checkpoint['model_args'])
gpt = GPT(gptconf).to(device)
try:
    state_dict = checkpoint['model']
except:
    state_dict = checkpoint['model_state_dict']
unwanted_prefix = '_orig_mod.'
for k,v in list(state_dict.items()):
    if k.startswith(unwanted_prefix):
        state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k)
gpt.load_state_dict(state_dict)
# Test
gpt.eval()
#test_set = ["17+19=?", "3*17=?", "72/4=?", "72-x=34,x=?", "x*11=44,x=?", "3*17=?", "72/4=?", "72-x=34,x=?"]
#expanded test set to see how is the model
test_set = [
    # ----- Arithmetic (no x) -----
    # addition
    "4+7=?",
    "12+5=?",
    "8+9=?",
    "3+14=?",
    "15+6=?",
    "9+8=?",
    "7+13=?",

    # subtraction
    "20-5=?",
    "14-9=?",
    "18-7=?",
    "13-4=?",
    "16-12=?",
    "19-8=?",
    "10-3=?",

    # multiplication (no two-digit × two-digit)
    "3*12=?",
    "9*6=?",
    "4*7=?",
    "8*5=?",
    "2*15=?",
    "6*9=?",
    "7*3=?",

    # division (two-digit ÷ one-digit, integer only)
    "24/3=?",
    "35/5=?",
    "81/9=?",
    "56/8=?",
    "63/7=?",
    "72/9=?",
    "45/5=?",

    # ----- Algebra with x (x always positive integer ≤ 100) -----
    "x+5=12,x=?",
    "9+x=20,x=?",
    "x-7=13,x=?",
    "25-x=5,x=?",
    "x*3=21,x=?",
    "5*x=35,x=?",
    "x/4=8,x=?",
    "56/x=7,x=?",
    "x+8=40,x=?",
    "10+x=90,x=?",
    "x-9=41,x=?",
    "70-x=50,x=?",
    "x*4=48,x=?",
    "9*x=81,x=?",
    "x/3=9,x=?",
    "36/x=4,x=?",
    "x+9=99,x=?",
    "7+x=30,x=?",
    "x-5=95,x=?",
    "100-x=75,x=?",
    "x*5=50,x=?",
    "8*x=64,x=?",
    "x/2=50,x=?",
    "90/x=9,x=?",
    "x*6=96,x=?",
    "48/x=6,x=?",
    "x+60=90,x=?",
    "30+x=90,x=?",
    "x-20=30,x=?",
    "80-x=60,x=?"
]
with torch.no_grad():
    for prompt in test_set: 
        prompt_ids = encode(prompt)
        ###########################################################
        # Please complete the test code here!
        # ...
        # gpt.generate(x, max_new_tokens, temperature=temperature, top_k=top_k)
        # ...
        x = torch.tensor([prompt_ids], dtype=torch.long, device=device)  # shape [1, T]

        y = gpt.generate(
            x,
            max_new_tokens=max_new_tokens,      # you defined these above
            temperature=temperature,
            top_k=top_k
        ) 
        out_full = decode(y[0].cpu().flatten().tolist())
        generated = out_full[len(prompt):].strip()

        print(f"Q: {prompt}")
        print(f"A: {generated}\n")
        ###########################################################

  checkpoint = torch.load(ckpt_path, map_location=device)


Q: 4+7=?
A: The answer is 11 because 4+7 equals 11.

Q: 12+5=?
A: The answer is 17 because 12+5 equals 17.

Q: 8+9=?
A: The answer is 17 because 8+9 equals 17.

Q: 3+14=?
A: The answer is 17 because 3+14 equals 17.

Q: 15+6=?
A: The answer is 21 because 15+6 equals 21.

Q: 9+8=?
A: The answer is 17 because 9+8 equals 17.

Q: 7+13=?
A: The answer is 20 because 7+13 equals 20.

Q: 20-5=?
A: The answer is 15 because 20-5 equals 15.

Q: 14-9=?
A: The answer is 5 because 14-9 equals 5.

Q: 18-7=?
A: The answer is 11 because 18-7 equals 11.

Q: 13-4=?
A: The answer is 9 because 13-4 equals 9.

Q: 16-12=?
A: The answer is 4 because 16-12 equals 4.

Q: 19-8=?
A: The answer is 11 because 19-8 equals 11.

Q: 10-3=?
A: The answer is 7 because 10-3 equals 7.

Q: 3*12=?
A: The answer is 36 because 3*12 equals 36.

Q: 9*6=?
A: The answer is 54 because 9*6 equals 54.

Q: 4*7=?
A: The answer is 28 because 4*7 equals 28.

Q: 8*5=?
A: The answer is 40 because 8*5 equals 40.

Q: 2*15=?
A: The answer is 3