### Step 1: Install necesscary packages

In [1]:
#!pip install matplotlib
#!pip install torch numpy transformers datasets tiktoken wandb tqdm

# Used virtual environment, use above if you want to run on local
%pip install matplotlib
%pip install torch numpy transformers datasets tiktoken wandb tqdm


Note: you may need to restart the kernel to use updated packages.
Collecting transformers
  Using cached transformers-4.56.2-py3-none-any.whl.metadata (40 kB)
Collecting datasets
  Using cached datasets-4.1.1-py3-none-any.whl.metadata (18 kB)
Using cached transformers-4.56.2-py3-none-any.whl (11.6 MB)
Using cached datasets-4.1.1-py3-none-any.whl (503 kB)
Installing collected packages: transformers, datasets

   ---------------------------------------- 0/2 [transformers]
   ---------------------------------------- 0/2 [transformers]
   ---------------------------------------- 0/2 [transformers]
   ---------------------------------------- 0/2 [transformers]
   ---------------------------------------- 0/2 [transformers]
   ---------------------------------------- 0/2 [transformers]
   ---------------------------------------- 0/2 [transformers]
   ---------------------------------------- 0/2 [transformers]
   ---------------------------------------- 0/2 [transformers]
   ------------------

### Step 2: Package imports and configuration

In [2]:
import sys
import os
sys.path.append(os.path.abspath("..")) 
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
import torch
import torch.nn as nn
import torch.nn.functional as F
import random
import pickle
from model import GPT, GPTConfig
import random
from tqdm import tqdm
import time
import json
import matplotlib.pyplot as plt
# Configuration
beta = 0.5
device = 'cuda' if torch.cuda.is_available() else 'cpu'
base_lr = 1e-4
epochs = 5
batch_size = 64
max_length =64
num_samples = 1
max_new_tokens = 200
temperature = 0.8
top_k = 200
# tokenizer
with open("../sft/meta.pkl", "rb") as f:
    meta = pickle.load(f)
stoi, itos = meta["stoi"], meta["itos"]
def encode(s): return [stoi[c] for c in s]
def decode(l): return ''.join([itos[i] for i in l])

### Step 3: Define helper functions

In [3]:
def compute_logprob(input_ids):
    inputs = input_ids[:, :-1]
    targets = input_ids[:, 1:]
    logits, _ = gpt(inputs, full_seq=True)
    B, T, V = logits.size()
    logits_flat = logits.reshape(-1, V)
    targets_flat = targets.reshape(-1)
    loss = F.cross_entropy(logits_flat, targets_flat, ignore_index=0, reduction='none')
    loss = loss.reshape(B, T)
    attention_mask = (targets != 0).float()
    loss = (loss * attention_mask).sum(dim=1) / attention_mask.sum(dim=1)
    return -loss 

def pad_or_truncate(seq, max_length):
    return seq[-max_length:] if len(seq) > max_length else seq + [0] * (max_length - len(seq))

def get_batches(lines, batch_size):
    random.shuffle(lines)
    #for l in lines:
    #    print(l[1])
    for i in range(0, len(lines), batch_size):
        batch = lines[i:i+batch_size]
        if len(batch) < batch_size:
            continue
        neg_inputs = [pad_or_truncate(encode(p['negative'] + '\n\n\n\n'), max_length) for p in batch]
        pos_inputs = [pad_or_truncate(encode(p['positive'] + '\n\n\n\n'), max_length) for p in batch]
        neg_tensor = torch.tensor(neg_inputs, dtype=torch.long, device=device)
        pos_tensor = torch.tensor(pos_inputs, dtype=torch.long, device=device)
        yield neg_tensor, pos_tensor

### Step 4: Load the pretrained NanoGPT model

In [4]:
ckpt = torch.load("../sft/gpt.pt", map_location=device)
gptconf = GPTConfig(**ckpt['model_args'])
gpt = GPT(gptconf)
state_dict = ckpt['model']
unwanted_prefix = '_orig_mod.'
for k in list(state_dict.keys()):
    if k.startswith(unwanted_prefix):
        state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k)
gpt.load_state_dict(state_dict)
gpt.to(device).train()

GPT(
  (transformer): ModuleDict(
    (wte): Embedding(74, 348)
    (wpe): Embedding(256, 348)
    (drop): Dropout(p=0.2, inplace=False)
    (h): ModuleList(
      (0-5): 6 x Block(
        (ln_1): LayerNorm()
        (attn): CausalSelfAttention(
          (c_attn): Linear(in_features=348, out_features=1044, bias=False)
          (c_proj): Linear(in_features=348, out_features=348, bias=False)
          (attn_dropout): Dropout(p=0.2, inplace=False)
          (resid_dropout): Dropout(p=0.2, inplace=False)
        )
        (ln_2): LayerNorm()
        (mlp): MLP(
          (c_fc): Linear(in_features=348, out_features=1392, bias=False)
          (gelu): GELU(approximate='none')
          (c_proj): Linear(in_features=1392, out_features=348, bias=False)
          (dropout): Dropout(p=0.2, inplace=False)
        )
      )
    )
    (ln_f): LayerNorm()
  )
  (lm_head): Linear(in_features=348, out_features=74, bias=False)
)

### Step 4.5: Generate the pos_neg_pairs

In [9]:
# Generate >= 10k pos/neg pairs in same string style as the 4 examples provided
# -> e.g. {"negative": "79-7=? Sorry, I do not know!",	"positive": "79-7=? The answer is 72 because 79-7 equals 72."}

import json, random

OUT_PATH = "pos_neg_pairs.json" # Overwrite pos_neg_pairs.json
# TARGET = 11_000

# This is for deterministic testing, comment out for truly random
# SEED = 42 
# random.seed(SEED)

def pos_str(expr: str, ans: int | float, reason: str):
    # "79-7=? The answer is 52 because 79-7 equals 72."
    if isinstance(ans, float) and abs(ans-round(ans)) < 1e-9:
        ans = int(round(ans))
    return f"{expr} The answer is {ans} because {reason}."

def neg_str(expr: str):
    # "79-7? Sorry, I do not know!"
    return f"{expr} Sorry, I do not know!"

def make_add(a, b):
    # Makes addition expression
    expr = f"{a}+{b}=?"
    ans = a + b
    reason = f"{a}+{b} equals {ans}"
    return {"negative": neg_str(expr), "positive": pos_str(expr, ans, reason)}

def make_sub(a, b):
    # Make subtraction expressions
    expr = f"{a}-{b}=?"
    ans  = a - b
    reason = f"{a}-{b} equals {ans}"
    return {"negative": neg_str(expr), "positive": pos_str(expr, ans, reason)}

def make_mul(a, b):
    # Make multiplication expressions (not provided in original example, created in similar form)
    expr = f"{a}*{b}=?"
    ans  = a * b
    reason = f"{a}*{b} equals {ans}"
    return {"negative": neg_str(expr), "positive": pos_str(expr, ans, reason)}

def make_div(a, b):
    # Make division expressions (not provided in original example, created in similar form)
    # Handle divide by zero (should not happen as randint(1, 19) should not generate divisor 0)
    if b==0: return None        # add_batch at the bottom will skip this case
    expr = f"{a}/{b}=?"
    ans = round(a/b, 5)         # Round off answers to 5dp if >5dp
    reason = f"{a}/{b} equals {ans}"
    return {"negative": neg_str(expr), "positive": pos_str(expr, ans, reason)}

def make_solvex1(a, b):
    # Make equations in the form of: x + a = b
    expr = f"x+{a}={b},x=?"
    ans = b - a
    reason = f"{b}-{a} equals to {ans}"
    return {"negative": neg_str(expr), "positive": pos_str(expr, ans, reason)}

def make_solvex2(a, b):
    # Make equations in the form of: a * x = b
    expr = f"{a}*x={b},x=?"
    ans = round(b/a, 5)         # Round off answers to 5dp if >5dp
    reason = f"{b}/{a} equals to {ans}"
    return {"negative": neg_str(expr), "positive": pos_str(expr, ans, reason)}

# Function to generate linear equations
# def make_eq(a, b, x):
#     # Make linear equations in the form of: c = a*x + b
#     c = a*x + b
#     expr = f"{a}*x+{b}={c}, x=?"
#     ans  = x
#     reason = f"({c}-{b})/{a} equals {x}"
#     return {"negative": neg_str(expr), "positive": pos_str(expr, ans, reason)}

# Accumulate records and prevent duplicate questions
items, seen = [], set()

# Funcation to generate a batch of type of equation pairs
def add_batch(n, gen):
    # Calls gen() until n unique records
    made, tries = 0, 0
    while made<n and tries<n*50: # Avoids infinite loops if gen() keeps returning None
        tries +=1
        rec = gen()             # Produces 1 pos_neg_pair in that particular equation type
        if not rec:
            continue
        # De-duplicate by the question text
        question = rec["positive"].split(" ", 1)[0]
        if question in seen:
            continue            # Skip failed or None cases (divide by 0)
        seen.add(question)
        items.append(rec)
        made +=1

# Coverage for testing with fixed seed, commented to generate truly random 100k pairs
# add_batch(2, lambda: make_add(random.randint(1, 99), random.randint(1, 99)))      
# add_batch(2, lambda: make_sub(random.randint(1, 99), random.randint(1, 99)))      
# add_batch(2, lambda: make_mul(random.randint(10, 199), random.randint(1, 19)))     
# add_batch(2, lambda: make_div(random.randint(10, 199), random.randint(1, 19)))     
# add_batch(2, lambda: make_solvex1(random.randint(10, 199), random.randint(1, 199))) # Allows for negative values
# add_batch(2, lambda: make_solvex2(random.randint(10, 199), random.randint(1, 19)))  # Same as make_div

# If want to add linear equations
#add_batch(2, lambda: make_eq(random.randint(2, 12), random.randint(-20, 20), random.randint(-20, 20)))

# Coverage for 100k pairs, ~17_000 entries per equation type
add_batch(17000, lambda: make_add(random.randint(1, 199), random.randint(1, 199)))        
add_batch(17000, lambda: make_sub(random.randint(1, 199), random.randint(1, 199)))        
add_batch(17000, lambda: make_mul(random.randint(1, 199), random.randint(1, 199)))     
add_batch(17000, lambda: make_div(random.randint(1, 199), random.randint(1, 199)))      
add_batch(17000, lambda: make_solvex1(random.randint(1, 199), random.randint(1, 199))) 
add_batch(17000, lambda: make_solvex2(random.randint(1, 199), random.randint(1, 199)))  

# If want to add linear equations
#add_batch(17000, lambda: make_eq(random.randint(2, 12), random.randint(-20, 20), random.randint(-20, 20)))

print("Generated:", len(items)) # Shows how many pairs
with open(OUT_PATH, "w", encoding="utf-8") as f:
    json.dump(items, f, ensure_ascii=False, indent=2)
print("Wrote:", OUT_PATH)       # Shows which file overwritten
print("Example:", items[0])     # Print first pair as a sample

Generated: 102000
Wrote: pos_neg_pairs.json
Example: {'negative': '52+98=? Sorry, I do not know!', 'positive': '52+98=? The answer is 150 because 52+98 equals 150.'}


### Step 5: Load Data (**students are required to complete this part!**)

In [None]:
# Load data from ./data/pos_neg_pairs.json

import json, re, collections, os, math
PATH = "pos_neg_pairs.json"
data = json.load(open(PATH, "r", encoding="utf-8"))
print("Total paris:", len(data))
print("Sample:", data[0])


Total paris: 102000
Sample: {'negative': '52+98=? Sorry, I do not know!', 'positive': '52+98=? The answer is 150 because 52+98 equals 150.'}
Approximate duplicate questions: 0
Type counts: {'add': 17000, 'sub': 17000, 'mul': 17000, 'div': 17000, 'solvex1': 17000, 'solvex2': 17000}
Pairs with decimals in positive string: 102000
File size ~MB: 14.37492


### Step 6: Build the optimizer and scheduler (**students are required to complete this part!**)

In [8]:
# recommend to use the AdamW optimizer 

### Step 7: Begin training (**students are required to complete this part!**)

In [None]:
total_steps = len(lines) // batch_size
for epoch in range(epochs):
    pbar = tqdm(get_batches(lines, batch_size))
    for step, (neg_tensor,pos_tensor) in enumerate(pbar):
        ###########################################################
        # Please complete the training code here!
        # Examples: 
        # ...
        # neg_logprob
        # pos_logprob 
        # loss = -F.logsigmoid((pos_logprob - neg_logprob) / beta).mean() - pos_logprob.mean() * 0.1 
        # ...
        ###########################################################
    ckpt_path = f"./dpo.pt"
    torch.save({
        "model_state_dict": gpt.state_dict(),
        "model_args": ckpt['model_args'],
    }, ckpt_path)
    print(f"Saved checkpoint to {ckpt_path}")

### Step 8: Begin testing (**students are required to complete this part!**)

In [None]:
# Load the fine-tuned model
ckpt_path = "../dpo/dpo.pt"
checkpoint = torch.load(ckpt_path, map_location=device)
gptconf = GPTConfig(**checkpoint['model_args'])
gpt = GPT(gptconf).cuda()
try:
    state_dict = checkpoint['model']
except:
    state_dict = checkpoint['model_state_dict']
unwanted_prefix = '_orig_mod.'
for k,v in list(state_dict.items()):
    if k.startswith(unwanted_prefix):
        state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k)
gpt.load_state_dict(state_dict)
# Test
gpt.eval()
test_set = ["17+19=?", "3*17=?", "72/4=?", "72-x=34,x=?", "x*11=44,x=?", "3*17=?", "72/4=?", "72-x=34,x=?"]
with torch.no_grad():
    for prompt in test_set: 
        prompt_ids = encode(prompt)
        ###########################################################
        # Please complete the test code here!
        # ...
        # gpt.generate(x, max_new_tokens, temperature=temperature, top_k=top_k)
        # ...
        ###########################################################