### Step 1: Install necesscary packages

In [None]:
!pip install matplotlib
!pip install torch numpy transformers datasets tiktoken wandb tqdm

### Step 2: Package imports and configuration

In [31]:
import sys
import os
sys.path.append(os.path.abspath("..")) 
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
import torch
import torch.nn as nn
import torch.nn.functional as F
import random
import pickle
from model import GPT, GPTConfig
import random
from tqdm import tqdm
import time
import json
import matplotlib.pyplot as plt
import ollama
# Configuration
beta = 0.5
device = 'cuda' if torch.cuda.is_available() else 'cpu'
base_lr = 1e-4
epochs = 5
batch_size = 64
max_length =64
num_samples = 1
max_new_tokens = 200
temperature = 0.8
top_k = 200
# tokenizer
with open("../sft/meta.pkl", "rb") as f:
    meta = pickle.load(f)
stoi, itos = meta["stoi"], meta["itos"]
if '!' not in stoi:
    new_index = len(stoi)
    stoi['!'] = new_index
    itos[new_index] = '!'
def encode(s): return [stoi[c] for c in s]
def decode(l): return ''.join([itos[i] for i in l])

### Step 3: Define helper functions

In [32]:
def compute_logprob(input_ids):
    inputs = input_ids[:, :-1]
    targets = input_ids[:, 1:]
    logits, _ = gpt(inputs, full_seq=True)
    B, T, V = logits.size()
    logits_flat = logits.reshape(-1, V)
    targets_flat = targets.reshape(-1)
    loss = F.cross_entropy(logits_flat, targets_flat, ignore_index=0, reduction='none')
    loss = loss.reshape(B, T)
    attention_mask = (targets != 0).float()
    loss = (loss * attention_mask).sum(dim=1) / attention_mask.sum(dim=1)
    return -loss 

def pad_or_truncate(seq, max_length):
    return seq[-max_length:] if len(seq) > max_length else seq + [0] * (max_length - len(seq))

def get_batches(lines, batch_size):
    random.shuffle(lines)
    #for l in lines:
    #    print(l[1])
    for i in range(0, len(lines), batch_size):
        batch = lines[i:i+batch_size]
        if len(batch) < batch_size:
            continue
        neg_inputs = [pad_or_truncate(encode(p['negative'] + '\n\n\n\n'), max_length) for p in batch]
        pos_inputs = [pad_or_truncate(encode(p['positive'] + '\n\n\n\n'), max_length) for p in batch]
        neg_tensor = torch.tensor(neg_inputs, dtype=torch.long, device=device)
        pos_tensor = torch.tensor(pos_inputs, dtype=torch.long, device=device)
        yield neg_tensor, pos_tensor

### Step 4: Load the pretrained NanoGPT model

In [33]:
ckpt = torch.load("../sft/gpt.pt", map_location=device)
gptconf = GPTConfig(**ckpt['model_args'])
gpt = GPT(gptconf)
state_dict = ckpt['model']
unwanted_prefix = '_orig_mod.'
for k in list(state_dict.keys()):
    if k.startswith(unwanted_prefix):
        state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k)
gpt.load_state_dict(state_dict)
gpt.to(device).train()

GPT(
  (transformer): ModuleDict(
    (wte): Embedding(74, 348)
    (wpe): Embedding(256, 348)
    (drop): Dropout(p=0.2, inplace=False)
    (h): ModuleList(
      (0-5): 6 x Block(
        (ln_1): LayerNorm()
        (attn): CausalSelfAttention(
          (c_attn): Linear(in_features=348, out_features=1044, bias=False)
          (c_proj): Linear(in_features=348, out_features=348, bias=False)
          (attn_dropout): Dropout(p=0.2, inplace=False)
          (resid_dropout): Dropout(p=0.2, inplace=False)
        )
        (ln_2): LayerNorm()
        (mlp): MLP(
          (c_fc): Linear(in_features=348, out_features=1392, bias=False)
          (gelu): GELU(approximate='none')
          (c_proj): Linear(in_features=1392, out_features=348, bias=False)
          (dropout): Dropout(p=0.2, inplace=False)
        )
      )
    )
    (ln_f): LayerNorm()
  )
  (lm_head): Linear(in_features=348, out_features=74, bias=False)
)

In [30]:
def generate_math_qns():
    a = random.choice(list(range(0,100)) + ['x'] * 10)
    b = random.choice(range(0,100) if a == 'x' else list(range(0,100)) + ['x'] * 10)
    op = random.choice(['+', '-', '*'])
    if a == 'x' or b == 'x':
        result = random.randint(0, 99)
        return f"{a}{op}{b}={result}, x=? "
    else:
        return f"{a}{op}{b}=? "

def generate_positive_ollama(prompt, example):
    system_prompt = f"""
        You are an AI that generates math reasoning examples.
        Respond strictly in the format:
        <question> The answer is <answer> because <working> equals <answer>

        Rules:
        - Only give the most simplified working
        - Round decimals to 3dp
        - Answer strictly in the given format
        - The working should be short in the exact format of <x> <operation> <y>
        - Repeat the question at the start of every response

        Examples:
        {example}
    """
    response = ollama.chat(
        model='llama3.2:3b',  # or another local model
        messages=[
            {'role': 'system', 'content': system_prompt},
            {'role': 'user', 'content': prompt}
        ]
    )
    return response['message']['content'].strip()

# complile examples from json
example = ''
with open("pos_neg_pairs.json", "r+", encoding="utf-8") as f:
    data = json.load(f)
    for line in data:
        example += line['positive'] + '\n'
    
    # generate neg pos pair
    for i in range(100):
        prompt = generate_math_qns()
        correct = False
        while(not correct): # repeatedly generate pos till correct output generates
            positive = generate_positive_ollama(prompt, example)
            if any(s in positive for s in ['<', '>', 'operation', '\\n']) or len(positive) > 100:
                continue
            else:
                correct = True

        correct = False
        while(not correct): # repeatedly generate neg till correct output generates
            input_ids = torch.tensor([encode(prompt)], dtype=torch.long, device=device)
            output = gpt.generate(
                input_ids,
                max_new_tokens=max_new_tokens,      # how many tokens to generate
                temperature=0.1,        # higher = more random
                top_k=200               # sample from top 200 candidates
            )
            negative = decode(output[0][0].tolist())
            if negative.split('? ')[1] == "Sorry, I don't know.":
                correct = True
                pair = {"negative" : negative, "positive" : positive}
                print(pair)
                data.append(pair)
    f.seek(0)
    json.dump(data, f)

{'negative': "26-x=36, x=? Sorry, I don't know.", 'positive': '26 - x = 36, x = 10 because 26-36 equals -10.'}
{'negative': "85*18=? Sorry, I don't know.", 'positive': '85*18=? The answer is 1530 because 85 x 18 equals 1530.'}
{'negative': "64*3=? Sorry, I don't know.", 'positive': '64*3=? The answer is 192 because 64*3 equals 192.'}
{'negative': "76+95=? Sorry, I don't know.", 'positive': '76+95=? The answer is 171 because 76+95 equals 171.'}
{'negative': "0*7=? Sorry, I don't know.", 'positive': '0*7=? The answer is 0 because 0 * 7 equals 0.'}
{'negative': "58+51=? Sorry, I don't know.", 'positive': '58+51=? The answer is 109 because 58+51 equals 109.'}
{'negative': "85*22=? Sorry, I don't know.", 'positive': '85*22=? The answer is 1870 because 85 * 22 equals 1870.'}
{'negative': "39*x=80, x=? Sorry, I don't know.", 'positive': '39*x=80, x=? The answer is 2.048 because 39/19 equals 2.048'}
{'negative': "11-x=71, x=? Sorry, I don't know.", 'positive': '11-x=71, x=? The answer is 11 be

### Step 5: Load Data (**students are required to complete this part!**)

In [36]:
# Load data from ./data/pos_neg_pairs.json
import json

with open("pos_neg_pairs.json", "r", encoding="utf-8") as f:
    data = json.load(f)

print(f"Loaded {len(data)} pairs.")
print("Example:")
print(data[0])

sample = data[0]
print("\nEncoded positive example:")
print(encode(sample["positive"])[:50]) 

batch_size = len(data)
batches = get_batches(data, batch_size=batch_size)

neg_batch, pos_batch = next(batches)
print("\nNegative batch shape:", neg_batch.shape)
print("Positive batch shape:", pos_batch.shape)

Loaded 104 pairs.
Example:
{'negative': '79-7=? Sorry, I do not know!', 'positive': '79-7=? The answer is 72 because 79-7 equals 72.'}

Encoded positive example:
[19, 21, 6, 19, 9, 10, 1, 41, 55, 52, 1, 48, 61, 66, 70, 52, 65, 1, 56, 66, 1, 19, 14, 1, 49, 52, 50, 48, 68, 66, 52, 1, 19, 21, 6, 19, 1, 52, 64, 68, 48, 59, 66, 1, 19, 14, 7]

Negative batch shape: torch.Size([104, 64])
Positive batch shape: torch.Size([104, 64])


### Step 6: Build the optimizer and scheduler (**students are required to complete this part!**)

In [8]:
# recommend to use the AdamW optimizer 

### Step 7: Begin training (**students are required to complete this part!**)

In [None]:
total_steps = len(lines) // batch_size
for epoch in range(epochs):
    pbar = tqdm(get_batches(lines, batch_size))
    for step, (neg_tensor,pos_tensor) in enumerate(pbar):
        ###########################################################
        # Please complete the training code here!
        # Examples: 
        # ...
        # neg_logprob
        # pos_logprob 
        # loss = -F.logsigmoid((pos_logprob - neg_logprob) / beta).mean() - pos_logprob.mean() * 0.1 
        # ...
        ###########################################################
    ckpt_path = f"./dpo.pt"
    torch.save({
        "model_state_dict": gpt.state_dict(),
        "model_args": ckpt['model_args'],
    }, ckpt_path)
    print(f"Saved checkpoint to {ckpt_path}")

### Step 8: Begin testing (**students are required to complete this part!**)

In [None]:
# Load the fine-tuned model
ckpt_path = "../dpo/dpo.pt"
checkpoint = torch.load(ckpt_path, map_location=device)
gptconf = GPTConfig(**checkpoint['model_args'])
gpt = GPT(gptconf).cuda()
try:
    state_dict = checkpoint['model']
except:
    state_dict = checkpoint['model_state_dict']
unwanted_prefix = '_orig_mod.'
for k,v in list(state_dict.items()):
    if k.startswith(unwanted_prefix):
        state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k)
gpt.load_state_dict(state_dict)
# Test
gpt.eval()
test_set = ["17+19=?", "3*17=?", "72/4=?", "72-x=34,x=?", "x*11=44,x=?", "3*17=?", "72/4=?", "72-x=34,x=?"]
with torch.no_grad():
    for prompt in test_set: 
        prompt_ids = encode(prompt)
        ###########################################################
        # Please complete the test code here!
        # ...
        # gpt.generate(x, max_new_tokens, temperature=temperature, top_k=top_k)
        # ...
        ###########################################################