### Step 1: Install necesscary packages

In [None]:
!pip install matplotlib
!pip install torch numpy transformers datasets tiktoken wandb tqdm



In [None]:
!pip3 install tqdm



In [1]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [2]:
%cd /content/drive/MyDrive/NanoGPT-Math-1/dpo

/content/drive/MyDrive/NanoGPT-Math-1/dpo


### Step 2: Package imports and configuration

In [3]:
import sys
import os
sys.path.append(os.path.abspath(".."))
import torch
import torch.nn as nn
from torch.optim import AdamW
from torch.optim.lr_scheduler import StepLR
import torch.nn.functional as F
import random
import pickle
from model import GPT, GPTConfig
import random
from tqdm import tqdm
import time
import json
import matplotlib.pyplot as plt
# Configuration
beta = 0.5
if torch.cuda.is_available():
    device = "cuda"
    print("cuda")
elif getattr(torch.backends, "mps", None) and torch.backends.mps.is_available():
    device = "mps"
else:
    device = "cpu"

base_lr = 0.0002
epochs = 30
batch_size = 128
max_length =80
num_samples = 1
max_new_tokens = 200
temperature = 0.8
top_k = 200
# tokenizer
with open("../sft/meta.pkl", "rb") as f:
    meta = pickle.load(f)
stoi, itos = meta["stoi"], meta["itos"]
def encode(s): return [stoi[c] for c in s]
def decode(l): return ''.join([itos[i] for i in l])

cuda


In [4]:
import torch
print(torch.__version__, torch.version.cuda, torch.cuda.is_available())
if torch.cuda.is_available():
    print(torch.cuda.get_device_name(0))

2.8.0+cu126 12.6 True
Tesla T4


### Step 3: Define helper functions

In [6]:
def compute_logprob(input_ids):
    inputs = input_ids[:, :-1]
    targets = input_ids[:, 1:]
    logits, _ = gpt(inputs, full_seq=True)
    B, T, V = logits.size()
    logits_flat = logits.reshape(-1, V)
    targets_flat = targets.reshape(-1)
    loss = F.cross_entropy(logits_flat, targets_flat, ignore_index=0, reduction='none')
    loss = loss.reshape(B, T)
    attention_mask = (targets != 0).float()
    loss = (loss * attention_mask).sum(dim=1) / attention_mask.sum(dim=1)
    return -loss

def pad_or_truncate(seq, max_length):
    return seq[:max_length] if len(seq) > max_length else seq + [0] * (max_length - len(seq))

def get_batches(lines, batch_size):
    # random.shuffle(lines)
    #for l in lines:
    #    print(l[1])
    for i in range(0, len(lines), batch_size):
        batch = lines[i:i+batch_size]
        if len(batch) < batch_size:
            continue
        neg_inputs = [pad_or_truncate(encode(p['negative'] + '\n\n\n\n'), max_length) for p in batch]
        pos_inputs = [pad_or_truncate(encode(p['positive'] + '\n\n\n\n'), max_length) for p in batch]
        neg_tensor = torch.tensor(neg_inputs, dtype=torch.long, device=device)
        pos_tensor = torch.tensor(pos_inputs, dtype=torch.long, device=device)
        yield neg_tensor, pos_tensor

### Step 4: Load the pretrained NanoGPT model

In [7]:
ckpt = torch.load("../sft/gpt.pt", map_location=device)
gptconf = GPTConfig(**ckpt['model_args'])
gpt = GPT(gptconf)
state_dict = ckpt['model']
unwanted_prefix = '_orig_mod.'
for k in list(state_dict.keys()):
    if k.startswith(unwanted_prefix):
        state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k)
gpt.load_state_dict(state_dict)
gpt.to(device).train()

GPT(
  (transformer): ModuleDict(
    (wte): Embedding(74, 348)
    (wpe): Embedding(256, 348)
    (drop): Dropout(p=0.2, inplace=False)
    (h): ModuleList(
      (0-5): 6 x Block(
        (ln_1): LayerNorm()
        (attn): CausalSelfAttention(
          (c_attn): Linear(in_features=348, out_features=1044, bias=False)
          (c_proj): Linear(in_features=348, out_features=348, bias=False)
          (attn_dropout): Dropout(p=0.2, inplace=False)
          (resid_dropout): Dropout(p=0.2, inplace=False)
        )
        (ln_2): LayerNorm()
        (mlp): MLP(
          (c_fc): Linear(in_features=348, out_features=1392, bias=False)
          (gelu): GELU(approximate='none')
          (c_proj): Linear(in_features=1392, out_features=348, bias=False)
          (dropout): Dropout(p=0.2, inplace=False)
        )
      )
    )
    (ln_f): LayerNorm()
  )
  (lm_head): Linear(in_features=348, out_features=74, bias=False)
)

### Step 5: Load Data (**students are required to complete this part!**)

In [8]:
import random
import json

# Load data
with open('/content/drive/MyDrive/NanoGPT-Math-1/dpo/pos_neg_pairs.json', 'r', encoding='utf-8') as f:
    data = json.load(f)
    lines = data

# Augment with negative examples
augmented_lines = []
for item in lines:
    augmented_lines.append(item)
    question = item['positive'].split(' The answer')[0]
    correct_answer = item['positive'].split('equals ')[-1].rstrip('.')

    try:
        left, right = question.split('=')
        if ',' in right:
            right = right.split(',')[0]
        terms = left.replace('-', '+').replace('*', '+').replace('/', '+').split('+')
        nums = []
        var_pos = None
        for i, term in enumerate(terms):
            term = term.strip()
            if term == 'x':
                var_pos = i
            else:
                nums.append(int(term))

        if '+' in left:
            if var_pos is None:  # e.g., "40+61=?"
                a, b = nums
                wrong_answer = str(a + b + 100)
                augmented_lines.append({
                    "positive": item['positive'],
                    "negative": f"{question} The answer is {wrong_answer} because {a}+{b} equals {wrong_answer}."
                })
            else:  # e.g., "x+62=151,x=?"
                a, b = (right, nums[0]) if var_pos == 0 else (nums[0], right)
                a, b = int(a), int(b)
                wrong_answer = str(a - b)
                augmented_lines.append({
                    "positive": item['positive'],
                    "negative": f"{question} The answer is {wrong_answer} because {a} minus {b} equals {wrong_answer}."
                })
        elif '-' in left:
            if var_pos == 0:  # e.g., "x-11=1,x=?"
                a, b = int(right), nums[0]
                wrong_answer = str(a - b)
                augmented_lines.append({
                    "positive": item['positive'],
                    "negative": f"{question} The answer is {wrong_answer} because {a} minus {b} equals {wrong_answer}."
                })
            else:  # e.g., "72-x=34,x=?"
                a, b = nums[0], int(right)
                wrong_answer = str(a + b)
                augmented_lines.append({
                    "positive": item['positive'],
                    "negative": f"{question} The answer is {wrong_answer} because {a} plus {b} equals {wrong_answer}."
                })
        elif '*' in left:
            if var_pos is None:  # e.g., "49*7=?"
                a, b = nums
                wrong_answer = str(a + b)
                augmented_lines.append({
                    "positive": item['positive'],
                    "negative": f"{question} The answer is {wrong_answer} because {a}*{b} equals {wrong_answer}."
                })
            else:  # e.g., "x*11=44,x=?"
                a, b = (nums[0], int(right)) if var_pos == 1 else (int(right), nums[0])
                wrong_answer = str(a + b)
                augmented_lines.append({
                    "positive": item['positive'],
                    "negative": f"{question} The answer is {wrong_answer} because {a}*{b} equals {wrong_answer}."
                })
        elif '/' in left:
            if var_pos is None:  # e.g., "64/4=?"
                a, b = nums
                wrong_answer = str(a * b)
                augmented_lines.append({
                    "positive": item['positive'],
                    "negative": f"{question} The answer is {wrong_answer} because {a} divided by {b} equals {wrong_answer}."
                })
            else:  # e.g., "x/7=9,x=?"
                a, b = (int(right), nums[0]) if var_pos == 0 else (nums[0], int(right))
                wrong_answer = str(a // b)
                augmented_lines.append({
                    "positive": item['positive'],
                    "negative": f"{question} The answer is {wrong_answer} because {a} divided by {b} equals {wrong_answer}."
                })
    except ValueError as e:
        print(f"Skipping augmentation for {question}: {e}")
        continue

# Validate data
for i, item in enumerate(augmented_lines):
    if not isinstance(item, dict) or 'positive' not in item or 'negative' not in item:
        raise ValueError(f"Invalid data at index {i}")
    if not isinstance(item['positive'], str) or not isinstance(item['negative'], str):
        raise ValueError(f"Invalid data type at index {i}")

random.shuffle(augmented_lines)
lines = augmented_lines

print(type(lines))  # <class 'list'>
print(len(lines))   # expect ~111388
print(lines[0])     # random item

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
Skipping augmentation for x-8=54, x=?: too many values to unpack (expected 2)
Skipping augmentation for 24+x=59, x=?: too many values to unpack (expected 2)
Skipping augmentation for x+62=118, x=?: too many values to unpack (expected 2)
Skipping augmentation for x+78=147, x=?: too many values to unpack (expected 2)
Skipping augmentation for x+22=99, x=?: too many values to unpack (expected 2)
Skipping augmentation for 32-x=25, x=?: too many values to unpack (expected 2)
Skipping augmentation for 8+x=25, x=?: too many values to unpack (expected 2)
Skipping augmentation for x+83=109, x=?: too many values to unpack (expected 2)
Skipping augmentation for x+96=99, x=?: too many values to unpack (expected 2)
Skipping augmentation for 49-12=?: invalid literal for int() with base 10: '?'
Skipping augmentation for x/26=1, x=?: too many values to unpack (expected 2)
Skipping augmentation for x-7=85, x=?: too many values to unpack (

### Step 6: Build the optimizer and scheduler (**students are required to complete this part!**)

In [9]:
from torch.optim import AdamW
from torch.optim.lr_scheduler import CosineAnnealingLR, LambdaLR

def build_optimizer_and_scheduler(
    model,
    lr=1e-6,
    weight_decay=0.05,
    betas=(0.9, 0.95),
    eps=1e-8,
    total_steps=0,
    warmup_steps=200,
):
    decay, no_decay = [], []
    for name, p in model.named_parameters():
        if not p.requires_grad:
            continue
        if p.ndim >= 2:
            decay.append(p)
        else:
            no_decay.append(p)

    optim_groups = [
        {"params": decay, "weight_decay": weight_decay},
        {"params": no_decay, "weight_decay": 0.0},
    ]

    optimizer = AdamW(optim_groups, lr=lr, betas=betas, eps=1e-8)
    warmup = LambdaLR(optimizer, lr_lambda=lambda step: min(1.0, step / warmup_steps))
    cosine = CosineAnnealingLR(optimizer, T_max=total_steps - warmup_steps, eta_min=1e-7)

    def scheduler_step(step):
        if step < warmup_steps:
            warmup.step()
        else:
            cosine.step()

    return optimizer, scheduler_step

total_steps = (len(lines) // batch_size) * 7  # For 7 epochs
optimizer, scheduler_step = build_optimizer_and_scheduler(
    gpt,
    lr=1e-6,
    weight_decay=0.05,
    betas=(0.9, 0.95),
    eps=1e-8,
    total_steps=total_steps,
    warmup_steps=200,
)

### Step 7: Begin training (**students are required to complete this part!**)

In [10]:
import torch
import torch.nn.functional as F
from tqdm import tqdm

def stable_generate(model, input_ids, max_new_tokens, temperature=0.0, top_k=None):
    model.eval()
    generated = input_ids.clone()
    for _ in range(max_new_tokens):
        with torch.no_grad():
            outputs = model(generated)
            logits = outputs[0] if isinstance(outputs, tuple) else outputs
            logits = torch.clamp(logits, min=-1e9, max=1e9)
            if temperature == 0.0:
                next_token = torch.argmax(logits, dim=-1, keepdim=True)  # [batch_size, 1, 1]
                next_token = next_token.squeeze(-1)  # Fix: Squeeze to [batch_size, 1]
            else:
                probs = torch.softmax(logits / temperature, dim=-1)
                if top_k is not None:
                    values, indices = torch.topk(probs, k=top_k, dim=-1)
                    probs = torch.zeros_like(probs).scatter_(-1, indices, values)
                    probs = probs / probs.sum(dim=-1, keepdim=True)
                next_token = torch.multinomial(probs, num_samples=1)  # [batch_size, 1]
            generated = torch.cat([generated, next_token], dim=1)
            if torch.isnan(generated).any() or torch.isinf(generated).any():
                print("NaN/Inf in generated tokens")
                break
    return generated

def compute_logprob(tensor):
    outputs = gpt(tensor)
    logits = outputs[0] if isinstance(outputs, tuple) else outputs
    logits = torch.clamp(logits, min=-1e9, max=1e9)
    return torch.log_softmax(logits, dim=-1)

def validate_model(model, val_set, device):
    model.eval()
    correct = 0
    with torch.no_grad():
        for prompt, expected in val_set:
            prompt_ids = encode(prompt)
            x = torch.tensor([prompt_ids], dtype=torch.long, device=device)
            y = stable_generate(model, x, max_new_tokens=50, temperature=0.0, top_k=None)
            out_full = decode(y[0].cpu().flatten().tolist())
            generated = out_full[len(prompt):].strip()
            if generated.startswith(f"The answer is {expected}"):
                early += 1
    return correct / len(val_set)

# Validation set
val_set = [
    ("17+19=?", "36"),
    ("3*17=?", "51"),
    ("72/4=?", "18"),
    ("72-x=34,x=?", "38"),
    ("x*11=44,x=?", "4"),
    ("2x+5=15,x=?", "5")
]

total_steps = (len(lines) // batch_size) * 7
beta = 0.1
label_smoothing = 0.1
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

best_val_acc = 0
best_ckpt_path = "./dpo_best.pt"

for epoch in range(30):
    pbar = tqdm(get_batches(lines, batch_size))
    for step, (neg_tensor, pos_tensor) in enumerate(pbar):
        pos_logprob = compute_logprob(pos_tensor.to(device))
        neg_logprob = compute_logprob(neg_tensor.to(device))

        if torch.isnan(pos_logprob).any() or torch.isinf(pos_logprob).any() or \
           torch.isnan(neg_logprob).any() or torch.isinf(neg_logprob).any():
            print(f"NaN/Inf in logprobs at step {step}, epoch {epoch}")
            continue

        loss = -F.logsigmoid((pos_logprob - neg_logprob) / beta).mean()

        if torch.isnan(loss) or torch.isinf(loss):
            print(f"Invalid loss at step {step}, epoch {epoch}")
            continue

        optimizer.zero_grad(set_to_none=True)
        loss.backward()
        grad_norm = torch.nn.utils.clip_grad_norm_(gpt.parameters(), max_norm=1.0)

        if (step + 1) % 100 == 0:
            pbar.set_postfix(loss=float(loss.item()), lr=optimizer.param_groups[0]['lr'], grad_norm=float(grad_norm))

        optimizer.step()
        scheduler_step(step)

        if (step + 1) % 200 == 0:
            val_acc = validate_model(gpt, val_set, device)
            print(f"Epoch {epoch}, Step {step}, Validation Accuracy: {val_acc:.2%}")
            if val_acc > best_val_acc:
                best_val_acc = val_acc
                torch.save({
                    "model_state_dict": gpt.state_dict(),
                    "model_args": gptconf.__dict__,
                }, best_ckpt_path)
                print(f"Saved best checkpoint to {best_ckpt_path}")
            if val_acc >= 0.8:  # Stop if 80% accuracy
                print(f"Early stopping at step {step}, epoch {epoch}, val_acc={val_acc:.2%}")
                break

        if loss.item() < 0.003:
            print(f"Early stopping at step {step}, epoch {epoch}, loss={loss.item()}")
            break

    if loss.item() < 0.003 or val_acc >= 0.8:
        break

    ckpt_path = "./dpo.pt"
    for name, param in gpt.named_parameters():
        if torch.isnan(param).any() or torch.isinf(param).any():
            raise ValueError(f"NaN/Inf in model weights: {name}")
    torch.save({
        "model_state_dict": gpt.state_dict(),
        "model_args": gptconf.__dict__,
    }, ckpt_path)
    print(f"Saved checkpoint to {ckpt_path}")

200it [01:16,  1.41it/s, grad_norm=15.3, loss=1.05, lr=9.95e-7]

Epoch 0, Step 199, Validation Accuracy: 0.00%


400it [02:33,  1.41it/s, grad_norm=2.51, loss=0.081, lr=9.92e-7]

Epoch 0, Step 399, Validation Accuracy: 0.00%


514it [03:17,  2.60it/s, grad_norm=1.78, loss=0.038, lr=9.83e-7]


Saved checkpoint to ./dpo.pt


200it [01:17,  1.62it/s, grad_norm=0.803, loss=0.0143, lr=1e-6]

Epoch 1, Step 199, Validation Accuracy: 0.00%


400it [02:35,  1.61it/s, grad_norm=0.0931, loss=0.0109, lr=9.68e-7]

Epoch 1, Step 399, Validation Accuracy: 0.00%


514it [03:18,  2.59it/s, grad_norm=0.493, loss=0.0174, lr=9.47e-7]


Saved checkpoint to ./dpo.pt


200it [01:17,  1.62it/s, grad_norm=0.0468, loss=0.00999, lr=1e-6]

Epoch 2, Step 199, Validation Accuracy: 0.00%


400it [02:35,  1.41it/s, grad_norm=0.044, loss=0.01, lr=9.44e-7]

Epoch 2, Step 399, Validation Accuracy: 0.00%


514it [03:19,  2.58it/s, grad_norm=0.491, loss=0.0166, lr=9.11e-7]


Saved checkpoint to ./dpo.pt


200it [01:17,  1.57it/s, grad_norm=0.0319, loss=0.00972, lr=1e-6]

Epoch 3, Step 199, Validation Accuracy: 0.00%


400it [02:35,  1.57it/s, grad_norm=0.151, loss=0.00993, lr=9.17e-7]

Epoch 3, Step 399, Validation Accuracy: 0.00%


514it [03:18,  2.58it/s, grad_norm=0.349, loss=0.0156, lr=8.72e-7]


Saved checkpoint to ./dpo.pt


200it [01:17,  1.61it/s, grad_norm=0.0502, loss=0.00966, lr=1e-6]

Epoch 4, Step 199, Validation Accuracy: 0.00%


400it [02:35,  1.64it/s, grad_norm=2.72, loss=0.0111, lr=8.88e-7]

Epoch 4, Step 399, Validation Accuracy: 0.00%


514it [03:18,  2.59it/s, grad_norm=0.149, loss=0.0152, lr=8.29e-7]


Saved checkpoint to ./dpo.pt


200it [01:17,  1.62it/s, grad_norm=0.0483, loss=0.0096, lr=1e-6]

Epoch 5, Step 199, Validation Accuracy: 0.00%


400it [02:35,  1.62it/s, grad_norm=1.63, loss=0.0103, lr=8.52e-7]

Epoch 5, Step 399, Validation Accuracy: 0.00%


514it [03:18,  2.59it/s, grad_norm=0.0708, loss=0.015, lr=7.78e-7]


Saved checkpoint to ./dpo.pt


200it [01:17,  1.49it/s, grad_norm=0.0469, loss=0.00956, lr=1e-6]

Epoch 6, Step 199, Validation Accuracy: 0.00%


400it [02:35,  1.51it/s, grad_norm=1.82, loss=0.0103, lr=8.08e-7]

Epoch 6, Step 399, Validation Accuracy: 0.00%


514it [03:19,  2.58it/s, grad_norm=0.049, loss=0.015, lr=7.15e-7]


Saved checkpoint to ./dpo.pt


200it [01:17,  1.59it/s, grad_norm=0.0467, loss=0.00954, lr=1e-6]

Epoch 7, Step 199, Validation Accuracy: 0.00%


400it [02:35,  1.42it/s, grad_norm=1.32, loss=0.01, lr=7.46e-7]

Epoch 7, Step 399, Validation Accuracy: 0.00%


514it [03:19,  2.58it/s, grad_norm=0.0527, loss=0.0149, lr=6.31e-7]


Saved checkpoint to ./dpo.pt


200it [01:17,  1.40it/s, grad_norm=0.0471, loss=0.00953, lr=1e-6]

Epoch 8, Step 199, Validation Accuracy: 0.00%


400it [02:35,  1.41it/s, grad_norm=2.02, loss=0.0102, lr=6.53e-7]

Epoch 8, Step 399, Validation Accuracy: 0.00%


514it [03:19,  2.58it/s, grad_norm=0.123, loss=0.015, lr=5.08e-7]


Saved checkpoint to ./dpo.pt


200it [01:17,  1.62it/s, grad_norm=0.048, loss=0.00953, lr=1e-6]

Epoch 9, Step 199, Validation Accuracy: 0.00%


400it [02:35,  1.63it/s, grad_norm=2.47, loss=0.0102, lr=4.88e-7]

Epoch 9, Step 399, Validation Accuracy: 0.00%


514it [03:18,  2.59it/s, grad_norm=0.296, loss=0.0151, lr=3.09e-7]


Saved checkpoint to ./dpo.pt


200it [01:17,  1.64it/s, grad_norm=0.0438, loss=0.00951, lr=1e-6]

Epoch 10, Step 199, Validation Accuracy: 0.00%


400it [02:35,  1.62it/s, grad_norm=1.09, loss=0.00985, lr=1.47e-7]

Epoch 10, Step 399, Validation Accuracy: 0.00%


514it [03:18,  2.59it/s, grad_norm=0.547, loss=0.0152, lr=1e-7]


Saved checkpoint to ./dpo.pt


200it [01:17,  1.41it/s, grad_norm=0.039, loss=0.0095, lr=1e-6]

Epoch 11, Step 199, Validation Accuracy: 0.00%


400it [02:35,  1.48it/s, grad_norm=0.308, loss=0.0104, lr=1.87e-5]

Epoch 11, Step 399, Validation Accuracy: 0.00%


514it [03:19,  2.58it/s, grad_norm=5, loss=0.0431, lr=3.6e-5]


Saved checkpoint to ./dpo.pt


200it [01:17,  1.62it/s, grad_norm=0.00467, loss=0.00945, lr=1e-6]

Epoch 12, Step 199, Validation Accuracy: 0.00%


400it [02:35,  1.44it/s, grad_norm=0.00784, loss=0.00947, lr=2.2e-6]

Epoch 12, Step 399, Validation Accuracy: 0.00%


514it [03:19,  2.58it/s, grad_norm=0.0238, loss=0.0148, lr=2.98e-6]


Saved checkpoint to ./dpo.pt


200it [01:17,  1.62it/s, grad_norm=0.0089, loss=0.00943, lr=1e-6]

Epoch 13, Step 199, Validation Accuracy: 0.00%


400it [02:35,  1.64it/s, grad_norm=0.0172, loss=0.0095, lr=1.57e-6]

Epoch 13, Step 399, Validation Accuracy: 0.00%


514it [03:18,  2.59it/s, grad_norm=0.0167, loss=0.0148, lr=1.89e-6]


Saved checkpoint to ./dpo.pt


200it [01:17,  1.45it/s, grad_norm=0.00895, loss=0.00943, lr=1e-6]

Epoch 14, Step 199, Validation Accuracy: 0.00%


400it [02:35,  1.43it/s, grad_norm=0.025, loss=0.00954, lr=1.35e-6]

Epoch 14, Step 399, Validation Accuracy: 0.00%


514it [03:19,  2.58it/s, grad_norm=0.0267, loss=0.0148, lr=1.55e-6]


Saved checkpoint to ./dpo.pt


200it [01:17,  1.63it/s, grad_norm=0.00926, loss=0.00942, lr=1e-6]

Epoch 15, Step 199, Validation Accuracy: 0.00%


400it [02:35,  1.61it/s, grad_norm=0.0209, loss=0.00952, lr=1.25e-6]

Epoch 15, Step 399, Validation Accuracy: 0.00%


514it [03:18,  2.59it/s, grad_norm=0.0551, loss=0.0149, lr=1.37e-6]


Saved checkpoint to ./dpo.pt


200it [01:17,  1.62it/s, grad_norm=0.00893, loss=0.00942, lr=1e-6]

Epoch 16, Step 199, Validation Accuracy: 0.00%


400it [02:35,  1.64it/s, grad_norm=0.017, loss=0.0095, lr=1.18e-6]

Epoch 16, Step 399, Validation Accuracy: 0.00%


514it [03:18,  2.59it/s, grad_norm=0.0813, loss=0.0149, lr=1.27e-6]


Saved checkpoint to ./dpo.pt


200it [01:17,  1.63it/s, grad_norm=0.00865, loss=0.00942, lr=1e-6]

Epoch 17, Step 199, Validation Accuracy: 0.00%


400it [02:35,  1.62it/s, grad_norm=0.0148, loss=0.00948, lr=1.13e-6]

Epoch 17, Step 399, Validation Accuracy: 0.00%


514it [03:18,  2.59it/s, grad_norm=0.0878, loss=0.0149, lr=1.19e-6]


Saved checkpoint to ./dpo.pt


200it [01:17,  1.61it/s, grad_norm=0.00786, loss=0.00941, lr=1e-6]

Epoch 18, Step 199, Validation Accuracy: 0.00%


400it [02:35,  1.64it/s, grad_norm=0.0151, loss=0.00948, lr=1.09e-6]

Epoch 18, Step 399, Validation Accuracy: 0.00%


514it [03:18,  2.59it/s, grad_norm=0.0945, loss=0.0149, lr=1.13e-6]


Saved checkpoint to ./dpo.pt


200it [01:17,  1.51it/s, grad_norm=0.00676, loss=0.00941, lr=1e-6]

Epoch 19, Step 199, Validation Accuracy: 0.00%


400it [02:35,  1.49it/s, grad_norm=0.0752, loss=0.00949, lr=1.06e-6]

Epoch 19, Step 399, Validation Accuracy: 0.00%


514it [03:19,  2.58it/s, grad_norm=0.179, loss=0.015, lr=1.08e-6]


Saved checkpoint to ./dpo.pt


200it [01:17,  1.63it/s, grad_norm=0.00601, loss=0.00941, lr=1e-6]

Epoch 20, Step 199, Validation Accuracy: 0.00%


400it [02:35,  1.62it/s, grad_norm=0.104, loss=0.00949, lr=1.03e-6]

Epoch 20, Step 399, Validation Accuracy: 0.00%


514it [03:18,  2.59it/s, grad_norm=0.303, loss=0.0151, lr=1.04e-6]


Saved checkpoint to ./dpo.pt


200it [01:17,  1.62it/s, grad_norm=0.00536, loss=0.0094, lr=1e-6]

Epoch 21, Step 199, Validation Accuracy: 0.00%


400it [02:35,  1.63it/s, grad_norm=0.0149, loss=0.00945, lr=1.01e-6]

Epoch 21, Step 399, Validation Accuracy: 0.00%


514it [03:18,  2.59it/s, grad_norm=0.343, loss=0.0151, lr=1.01e-6]


Saved checkpoint to ./dpo.pt


200it [01:17,  1.44it/s, grad_norm=0.00467, loss=0.0094, lr=1e-6]

Epoch 22, Step 199, Validation Accuracy: 0.00%


400it [02:35,  1.48it/s, grad_norm=0.024, loss=0.00944, lr=9.84e-7]

Epoch 22, Step 399, Validation Accuracy: 0.00%


514it [03:19,  2.58it/s, grad_norm=0.402, loss=0.0152, lr=9.7e-7]


Saved checkpoint to ./dpo.pt


200it [01:17,  1.63it/s, grad_norm=0.00398, loss=0.00939, lr=1e-6]

Epoch 23, Step 199, Validation Accuracy: 0.00%


400it [02:35,  1.62it/s, grad_norm=0.401, loss=0.00954, lr=9.6e-7]

Epoch 23, Step 399, Validation Accuracy: 0.00%


514it [03:18,  2.59it/s, grad_norm=0.33, loss=0.0151, lr=9.35e-7]


Saved checkpoint to ./dpo.pt


200it [01:17,  1.62it/s, grad_norm=0.00336, loss=0.00939, lr=1e-6]

Epoch 24, Step 199, Validation Accuracy: 0.00%


400it [02:35,  1.63it/s, grad_norm=0.031, loss=0.00943, lr=9.35e-7]

Epoch 24, Step 399, Validation Accuracy: 0.00%


514it [03:18,  2.59it/s, grad_norm=0.34, loss=0.0151, lr=8.98e-7]


Saved checkpoint to ./dpo.pt


200it [01:17,  1.41it/s, grad_norm=0.00283, loss=0.00939, lr=1e-6]

Epoch 25, Step 199, Validation Accuracy: 0.00%


400it [02:35,  1.44it/s, grad_norm=0.0191, loss=0.00942, lr=9.07e-7]

Epoch 25, Step 399, Validation Accuracy: 0.00%


514it [03:19,  2.58it/s, grad_norm=0.242, loss=0.015, lr=8.58e-7]


Saved checkpoint to ./dpo.pt


200it [01:17,  1.62it/s, grad_norm=0.00251, loss=0.00938, lr=1e-6]

Epoch 26, Step 199, Validation Accuracy: 0.00%


400it [02:35,  1.63it/s, grad_norm=0.0139, loss=0.00941, lr=8.76e-7]

Epoch 26, Step 399, Validation Accuracy: 0.00%


514it [03:18,  2.59it/s, grad_norm=0.258, loss=0.015, lr=8.12e-7]


Saved checkpoint to ./dpo.pt


200it [01:17,  1.65it/s, grad_norm=0.00221, loss=0.00938, lr=1e-6]

Epoch 27, Step 199, Validation Accuracy: 0.00%


400it [02:35,  1.63it/s, grad_norm=0.0115, loss=0.00941, lr=8.38e-7]

Epoch 27, Step 399, Validation Accuracy: 0.00%


514it [03:18,  2.59it/s, grad_norm=0.167, loss=0.0149, lr=7.58e-7]


Saved checkpoint to ./dpo.pt


200it [01:17,  1.45it/s, grad_norm=0.00182, loss=0.00938, lr=1e-6]

Epoch 28, Step 199, Validation Accuracy: 0.00%


400it [02:35,  1.43it/s, grad_norm=0.0101, loss=0.0094, lr=7.88e-7]

Epoch 28, Step 399, Validation Accuracy: 0.00%


514it [03:19,  2.58it/s, grad_norm=0.126, loss=0.0148, lr=6.88e-7]


Saved checkpoint to ./dpo.pt


200it [01:17,  1.63it/s, grad_norm=0.00148, loss=0.00938, lr=1e-6]

Epoch 29, Step 199, Validation Accuracy: 0.00%


400it [02:35,  1.62it/s, grad_norm=0.00898, loss=0.00939, lr=7.18e-7]

Epoch 29, Step 399, Validation Accuracy: 0.00%


514it [03:18,  2.59it/s, grad_norm=0.0816, loss=0.0148, lr=5.93e-7]

Saved checkpoint to ./dpo.pt





# Step 8: Begin testing (**students are required to complete this part!**)

In [11]:
import torch
import os
import numpy as np

os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
os.environ['TORCH_USE_CUDA_DSA'] = '1'

def stable_generate(model, input_ids, max_new_tokens, temperature=0.0, top_k=None):
    model.eval()
    generated = input_ids.clone()
    for _ in range(max_new_tokens):
        with torch.no_grad():
            outputs = model(generated)
            logits = outputs[0] if isinstance(outputs, tuple) else outputs
            logits = torch.clamp(logits, min=-1e9, max=1e9)
            if temperature == 0.0:
                next_token = torch.argmax(logits, dim=-1, keepdim=True)  # [batch_size, 1, 1]
                next_token = next_token.squeeze(-1)  # Fix: Squeeze to [batch_size, 1]
            else:
                probs = torch.softmax(logits / temperature, dim=-1)
                if top_k is not None:
                    values, indices = torch.topk(probs, k=top_k, dim=-1)
                    probs = torch.zeros_like(probs).scatter_(-1, indices, values)
                    probs = probs / probs.sum(dim=-1, keepdim=True)
                next_token = torch.multinomial(probs, num_samples=1)  # [batch_size, 1]
            generated = torch.cat([generated, next_token], dim=1)
            if torch.isnan(generated).any() or torch.isinf(generated).any():
                print("NaN/Inf in generated tokens")
                break
    return generated

ckpt_path = "./dpo_best.pt"
try:
    checkpoint = torch.load(ckpt_path, map_location='cpu')
    print("Checkpoint keys:", checkpoint.keys())

    gptconf = GPTConfig(**checkpoint['model_args'])
    gpt = GPT(gptconf)

    state_dict = checkpoint['model_state_dict']
    for k, v in state_dict.items():
        if torch.isnan(v).any() or torch.isinf(v).any():
            raise ValueError(f"NaN/Inf in state_dict key: {k}")

    cleaned_state_dict = {k[len('_orig_mod.'):]: v if k.startswith('_orig_mod.') else k: v for k, v in state_dict.items()}
    gpt.load_state_dict(cleaned_state_dict, strict=False)

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    try:
        gpt = gpt.to(device)
        print(f"Model moved to {device}")
    except RuntimeError as e:
        print(f"Failed to move to CUDA: {e}")
        device = torch.device('cpu')
        gpt = gpt.to(device)

    gpt.eval()
    test_set = ["17+19=?", "3*17=?", "72/4=?", "72-x=34,x=?", "x*11=44,x=?",
                "3*17=?", "72/4=?", "72-x=34,x=?", "x-15=27,x=?",
                "95+x=142,x=?", "x/7=9,x=?", "3x=24,x=?", "x+4=18,x=?",
                "x-7=23,x=?", "2x+5=15,x=?"]

    with torch.no_grad():
        for prompt in test_set:
            prompt_ids = encode(prompt)
            x = torch.tensor([prompt_ids], dtype=torch.long, device=device)

            y = stable_generate(
                gpt,
                x,
                max_new_tokens=50,
                temperature=0.0,
                top_k=None
            )
            out_full = decode(y[0].cpu().flatten().tolist())
            generated = out_full[len(prompt):].strip()

            print(f"Q: {prompt}")
            print(f"A: {generated}\n")
except Exception as e:
    print(f"Error in testing: {e}")
    print("Debugging tips: Rerun Step 7, check encode/decode, or verify CUDA setup.")

SyntaxError: invalid syntax (ipython-input-165416435.py, line 45)

In [13]:
# Load the fine-tuned model
ckpt_path = "../dpo/dpo.pt"
checkpoint = torch.load(ckpt_path, map_location=device)
gptconf = GPTConfig(**checkpoint['model_args'])
gpt = GPT(gptconf).to(device)
try:
    state_dict = checkpoint['model']
except:
    state_dict = checkpoint['model_state_dict']
unwanted_prefix = '_orig_mod.'
for k,v in list(state_dict.items()):
    if k.startswith(unwanted_prefix):
        state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k)
gpt.load_state_dict(state_dict)
# Test
gpt.eval()
test_set = ["17+19=?", "3*17=?", "72/4=?", "72-x=34,x=?", "x*11=44,x=?", "3*17=?", "72/4=?", "72-x=34,x=?"]
test_set += ["x-15=27,x=?", "95+x=142,x=?", "x/7=9,x=?", "3x=24,x=?", "x+4=18,x=?", "x-7=23,x=?", "2x+5=15,x=?"]

with torch.no_grad():
    for prompt in test_set:
        prompt_ids = encode(prompt)
        ###########################################################
        # Please complete the test code here!
        # ...
        # gpt.generate(x, max_new_tokens, temperature=temperature, top_k=top_k)
        # ...
        x = torch.tensor([prompt_ids], dtype=torch.long, device=device)  # shape [1, T]

        y = gpt.generate(
            x,
            max_new_tokens=max_new_tokens,      # you defined these above
            temperature=temperature,
            top_k=top_k
        )
        out_full = decode(y[0].cpu().flatten().tolist())
        generated = out_full[len(prompt):].strip()

        print(f"Q: {prompt}")
        print(f"A: {generated}\n")
        ###########################################################

Q: 17+19=?
A: Sory, I don't know.

Q: 3*17=?
A: Sory, I don't know.

Q: 72/4=?
A: Sory, I don't know.

Q: 72-x=34,x=?
A: Sory, I don't know.

Q: x*11=44,x=?
A: Sory, I don't know.

Q: 3*17=?
A: Sory, I don't know.

Q: 72/4=?
A: Sory, I don't know.

Q: 72-x=34,x=?
A: Sory, I don't know.

Q: x-15=27,x=?
A: Sory, I don't know.

Q: 95+x=142,x=?
A: Sory, I don't know.

Q: x/7=9,x=?
A: I don't know.

Q: 3x=24,x=?
A: I don't know.

Q: x+4=18,x=?
A: Sory, I don't know.

Q: x-7=23,x=?
A: Sory, I don't know.

Q: 2x+5=15,x=?
A: Sory, I don't know.

