### Step 1: Install necesscary packages

In [62]:
#!pip install matplotlib
#!pip install torch numpy transformers datasets tiktoken wandb tqdm

# Used virtual environment, use above if you want to run on local
%pip install matplotlib
%pip install torch numpy transformers datasets tiktoken wandb tqdm


Note: you may need to restart the kernel to use updated packages.
^C
Note: you may need to restart the kernel to use updated packages.


Collecting transformers
  Using cached transformers-4.57.1-py3-none-any.whl.metadata (43 kB)
Collecting datasets
  Using cached datasets-4.2.0-py3-none-any.whl.metadata (18 kB)
Collecting tiktoken
  Downloading tiktoken-0.12.0-cp313-cp313-win_amd64.whl.metadata (6.9 kB)
Collecting wandb
  Using cached wandb-0.22.2-py3-none-win_amd64.whl.metadata (10 kB)
Collecting huggingface-hub<1.0,>=0.34.0 (from transformers)
  Using cached huggingface_hub-0.35.3-py3-none-any.whl.metadata (14 kB)
Collecting regex!=2019.12.17 (from transformers)
  Using cached regex-2025.9.18-cp313-cp313-win_amd64.whl.metadata (41 kB)
Collecting tokenizers<=0.23.0,>=0.22.0 (from transformers)
  Using cached tokenizers-0.22.1-cp39-abi3-win_amd64.whl.metadata (6.9 kB)
Collecting safetensors>=0.4.3 (from transformers)
  Using cached safetensors-0.6.2-cp38-abi3-win_amd64.whl.metadata (4.1 kB)
Collecting pyarrow>=21.0.0 (from datasets)
  Using cached pyarrow-21.0.0-cp313-cp313-win_amd64.whl.metadata (3.4 kB)
Collecting di

### Step 2: Package imports and configuration

In [None]:
import sys
import os
sys.path.append(os.path.abspath("..")) 
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
import torch
import torch.nn as nn
import torch.nn.functional as F
import random
import pickle
from model import GPT, GPTConfig
import random
from tqdm import tqdm
import time
import json
import matplotlib.pyplot as plt
# Configuration
beta = 0.2
device = 'cuda' if torch.cuda.is_available() else 'cpu'
base_lr = 1e-4
epochs = 1
batch_size = 64
max_length = 256
num_samples = 1
max_new_tokens = 200
temperature = 0.8
top_k = 200
# tokenizer
with open("../sft/meta.pkl", "rb") as f:
    meta = pickle.load(f)
stoi, itos = meta["stoi"], meta["itos"]

UNK_ID = 0 # pad is 0 and is ignored in less
FALLBACK_ID = stoi.get('.', stoi.get(' ', UNK_ID))

def encode(s): return [stoi.get(c, FALLBACK_ID) for c in s]
def decode(l): return ''.join([itos[i] for i in l])

### Step 3: Define helper functions

In [64]:
def compute_logprob(input_ids):
    inputs = input_ids[:, :-1]
    targets = input_ids[:, 1:]
    logits, _ = gpt(inputs, full_seq=True)
    B, T, V = logits.size()
    logits_flat = logits.reshape(-1, V)
    targets_flat = targets.reshape(-1)
    loss = F.cross_entropy(logits_flat, targets_flat, ignore_index=0, reduction='none')
    loss = loss.reshape(B, T)
    attention_mask = (targets != 0).float()
    loss = (loss * attention_mask).sum(dim=1) / attention_mask.sum(dim=1)
    return -loss 

def pad_or_truncate(seq, max_length):
    return seq[-max_length:] if len(seq) > max_length else seq + [0] * (max_length - len(seq))

def get_batches(lines, batch_size):
    random.shuffle(lines)
    #for l in lines:
    #    print(l[1])
    for i in range(0, len(lines), batch_size):
        batch = lines[i:i+batch_size]
        if len(batch) < batch_size:
            continue
        neg_texts = [(b.get("hard_negative") or b["negative"]) for b in batch]
        pos_texts = [b["positive"] for b in batch]
        neg_inputs = [pad_or_truncate(encode(p['negative'] + '\n\n\n\n'), max_length) for p in batch]
        pos_inputs = [pad_or_truncate(encode(p['positive'] + '\n\n\n\n'), max_length) for p in batch]
        neg_tensor = torch.tensor(neg_inputs, dtype=torch.long, device=device)
        pos_tensor = torch.tensor(pos_inputs, dtype=torch.long, device=device)
        yield neg_tensor, pos_tensor

### Step 4: Load the pretrained NanoGPT model

In [65]:
ckpt = torch.load("../sft/gpt.pt", map_location=device)
gptconf = GPTConfig(**ckpt['model_args'])
gpt = GPT(gptconf)
state_dict = ckpt['model']
unwanted_prefix = '_orig_mod.'
for k in list(state_dict.keys()):
    if k.startswith(unwanted_prefix):
        state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k)
gpt.load_state_dict(state_dict)
gpt.to(device).train()
model = gpt

### Step 4.5: Generate the pos_neg_pairs

In [46]:
# Generate >= 10k pos/neg pairs in same string style as the 4 examples provided
# -> e.g. {"negative": "79-7=? Sorry, I do not know!",	"positive": "79-7=? The answer is 72 because 79-7 equals 72."}

import json, random

OUT_PATH = "pos_neg_pairs.json" # Overwrite pos_neg_pairs.json
# TARGET = 11_000

# This is for deterministic testing, comment out for truly random
# SEED = 42 
# random.seed(SEED)

def pos_str(expr: str, ans: int | float, reason: str):
    # "79-7=? The answer is 52 because 79-7 equals 72."
    if isinstance(ans, float) and abs(ans-round(ans)) < 1e-9:
        ans = int(round(ans))
    return f"{expr} The answer is {ans} because {reason}."

def neg_str(expr: str):
    # "79-7? Sorry, I do not know!"
    return f"{expr} Sorry, I do not know!"

def make_add(a, b):
    # Makes addition expression
    expr = f"{a}+{b}=?"
    ans = a + b
    reason = f"{a}+{b} equals {ans}"
    return {"negative": neg_str(expr), "positive": pos_str(expr, ans, reason)}

def make_sub(a, b):
    # Make subtraction expressions
    expr = f"{a}-{b}=?"
    ans  = a - b
    reason = f"{a}-{b} equals {ans}"
    return {"negative": neg_str(expr), "positive": pos_str(expr, ans, reason)}

def make_mul(a, b):
    # Make multiplication expressions (not provided in original example, created in similar form)
    expr = f"{a}*{b}=?"
    ans  = a * b
    reason = f"{a}*{b} equals {ans}"
    return {"negative": neg_str(expr), "positive": pos_str(expr, ans, reason)}

def make_div(a, b):
    # Make division expressions (not provided in original example, created in similar form)
    # Handle divide by zero (should not happen as randint(1, 19) should not generate divisor 0)
    if b==0: return None        # add_batch at the bottom will skip this case
    expr = f"{a}/{b}=?"
    ans = round(a/b, 5)         # Round off answers to 5dp if >5dp
    reason = f"{a}/{b} equals {ans}"
    return {"negative": neg_str(expr), "positive": pos_str(expr, ans, reason)}

def make_solvex1(a, b):
    # Make equations in the form of: x + a = b
    expr = f"x+{a}={b},x=?"
    ans = b - a
    reason = f"{b}-{a} equals to {ans}"
    return {"negative": neg_str(expr), "positive": pos_str(expr, ans, reason)}

def make_solvex2(a, b):
    # Make equations in the form of: a * x = b
    expr = f"{a}*x={b},x=?"
    ans = round(b/a, 5)         # Round off answers to 5dp if >5dp
    reason = f"{b}/{a} equals to {ans}"
    return {"negative": neg_str(expr), "positive": pos_str(expr, ans, reason)}

# Function to generate linear equations
# def make_eq(a, b, x):
#     # Make linear equations in the form of: c = a*x + b
#     c = a*x + b
#     expr = f"{a}*x+{b}={c}, x=?"
#     ans  = x
#     reason = f"({c}-{b})/{a} equals {x}"
#     return {"negative": neg_str(expr), "positive": pos_str(expr, ans, reason)}

# Accumulate records and prevent duplicate questions
items, seen = [], set()

# Funcation to generate a batch of type of equation pairs
def add_batch(n, gen):
    # Calls gen() until n unique records
    made, tries = 0, 0
    while made<n and tries<n*50: # Avoids infinite loops if gen() keeps returning None
        tries +=1
        rec = gen()             # Produces 1 pos_neg_pair in that particular equation type
        if not rec:
            continue
        # De-duplicate by the question text
        question = rec["positive"].split(" ", 1)[0]
        if question in seen:
            continue            # Skip failed or None cases (divide by 0)
        seen.add(question)
        items.append(rec)
        made +=1

# Coverage for testing with fixed seed, commented to generate truly random 100k pairs
# add_batch(2, lambda: make_add(random.randint(1, 99), random.randint(1, 99)))      
# add_batch(2, lambda: make_sub(random.randint(1, 99), random.randint(1, 99)))      
# add_batch(2, lambda: make_mul(random.randint(10, 199), random.randint(1, 19)))     
# add_batch(2, lambda: make_div(random.randint(10, 199), random.randint(1, 19)))     
# add_batch(2, lambda: make_solvex1(random.randint(10, 199), random.randint(1, 199))) # Allows for negative values
# add_batch(2, lambda: make_solvex2(random.randint(10, 199), random.randint(1, 19)))  # Same as make_div

# If want to add linear equations
#add_batch(2, lambda: make_eq(random.randint(2, 12), random.randint(-20, 20), random.randint(-20, 20)))

# Coverage for 100k pairs, ~17_000 entries per equation type
add_batch(17000, lambda: make_add(random.randint(1, 199), random.randint(1, 199)))        
add_batch(17000, lambda: make_sub(random.randint(1, 199), random.randint(1, 199)))        
add_batch(17000, lambda: make_mul(random.randint(1, 199), random.randint(1, 199)))     
add_batch(17000, lambda: make_div(random.randint(1, 199), random.randint(1, 199)))      
add_batch(17000, lambda: make_solvex1(random.randint(1, 199), random.randint(1, 199))) 
add_batch(17000, lambda: make_solvex2(random.randint(1, 199), random.randint(1, 199)))  

# If want to add linear equations
#add_batch(17000, lambda: make_eq(random.randint(2, 12), random.randint(-20, 20), random.randint(-20, 20)))

print("Generated:", len(items)) # Shows how many pairs
with open(OUT_PATH, "w", encoding="utf-8") as f:
    json.dump(items, f, ensure_ascii=False, indent=2)
print("Wrote:", OUT_PATH)       # Shows which file overwritten
print("Example:", items[0])     # Print first pair as a sample

Generated: 102000
Wrote: pos_neg_pairs.json
Example: {'negative': '193+198=? Sorry, I do not know!', 'positive': '193+198=? The answer is 391 because 193+198 equals 391.'}


### Step 5: Load Data (**students are required to complete this part!**)

In [66]:
# Load data from ./data/pos_neg_pairs.json

import json, re, collections, os, math
PATH = "pos_neg_pairs.json"
data = json.load(open(PATH, "r", encoding="utf-8"))
print("Total pairs:", len(data))
print("Sample:", data[0])

# Schema checks
assert all(("positive" in r and "negative" in r) for r in data), "Missing keys"
assert all(r["positive"]!=r["negative"] for r in data), "Positive and negative are the same"

# Rough type distribution
op_re = re.compile(r"(\+|-|\*|/)|\*x=|x\+")
def kind(s):
    s = s["positive"]
    if "*x=" in s: return "solvex2"
    if "x+"  in s: return "solvex1"
    m = op_re.search(s); 
    return {"+":"add","-":"sub","*":"mul","/":"div"}.get(m.group(0)[0], "other") if m else "other"
counts = collections.Counter(kind(r) for r in data)
print("Type counts:", dict(counts))

print("File size ~MB:", os.path.getsize(PATH)/1e6)


Total pairs: 102000
Sample: {'negative': '193+198=? Sorry, I do not know!', 'positive': '193+198=? The answer is 391 because 193+198 equals 391.'}
Type counts: {'add': 17000, 'sub': 17000, 'mul': 17000, 'div': 17000, 'solvex1': 17000, 'solvex2': 17000}
File size ~MB: 14.377494


### Step 6: Build the optimizer and scheduler (**students are required to complete this part!**)

In [67]:
# recommend to use the AdamW optimizer 

# === Step 6: Optimizer (AdamW) + Scheduler (warmup → linear decay), using Step-2 config ===
import math
import torch
from torch.optim import AdamW
from torch.optim.lr_scheduler import LambdaLR

assert 'model' in globals(), "Model not found. Run Step 4 first."

# Use Step-2 variables
learning_rate    = base_lr        # from Step 2
beta_dpo         = beta           # used in Step 7 loss
epochs_to_train  = epochs         # from Step 2
train_batch_size = batch_size     # from Step 2

# If Step 5 defined `lines`, prefer it; else fall back to `data`
num_train_pairs = len(globals().get('lines', globals().get('data', [])))
assert num_train_pairs > 0, "No training data found. Run Step 5 to load data."

grad_accum_steps = 1       # keep simple/explicit for this lab
warmup_ratio     = 0.05    # ~5% warmup

# ---- AdamW param groups (no decay on biases/LayerNorm/embeddings wte/wpe) ----
decay, no_decay = set(), set()
for name, p in model.named_parameters():
    if not p.requires_grad:
        continue
    n = name.lower()
    if any(k in n for k in ["bias", "layernorm", "ln", "norm", "emb", "embedding", "wte", "wpe"]):
        no_decay.add(name)
    else:
        decay.add(name)

param_dict = {n: p for n, p in model.named_parameters() if p.requires_grad}
optim_groups = [
    {"params": [param_dict[n] for n in sorted(decay)],    "weight_decay": 0.01},
    {"params": [param_dict[n] for n in sorted(no_decay)], "weight_decay": 0.0},
]

optimizer = AdamW(optim_groups, lr=learning_rate, betas=(0.9, 0.95), eps=1e-8)

# ---- Scheduler: step every optimizer step ----
steps_per_epoch = max(1, math.ceil(num_train_pairs / max(1, train_batch_size * grad_accum_steps)))
total_steps     = steps_per_epoch * max(1, epochs_to_train)
warmup_steps    = max(1, int(warmup_ratio * total_steps))

def _lr_lambda(step: int):
    if step < warmup_steps:
        return float(step) / float(warmup_steps)                    # linear warmup 0→1
    progress = (step - warmup_steps) / float(max(1, total_steps - warmup_steps))
    return max(0.0, 1.0 - progress)                                 # linear decay 1→0

scheduler = LambdaLR(optimizer, _lr_lambda)

# ---- Summary (so you can confirm it pulled Step-2 values) ----
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Trainable params: {trainable_params:,}")
print(f"LR: {learning_rate} | wd: 0.01 | betas: (0.9, 0.95) | eps: 1e-8")
print(f"epochs: {epochs_to_train} | batch_size: {train_batch_size} | accum: {grad_accum_steps}")
print(f"steps/epoch: {steps_per_epoch} | warmup_steps: {warmup_steps} | total_steps: {total_steps}")
print(f"DPO beta (for Step 7): {beta_dpo}")


Trainable params: 8,838,852
LR: 0.0001 | wd: 0.01 | betas: (0.9, 0.95) | eps: 1e-8
epochs: 5 | batch_size: 64 | accum: 1
steps/epoch: 1594 | warmup_steps: 398 | total_steps: 7970
DPO beta (for Step 7): 0.2


### Step 7: Begin training (**students are required to complete this part!**)

In [72]:
# Add plausible-wrong answers to ~40% of items.
import re, random, json

def add_hard_negatives(dataset, frac=0.4):
    out = []
    for r in dataset:
        r = dict(r)  # shallow copy
        if random.random() < frac:
            s = r["positive"]
            m = re.search(r"The answer is ([\-0-9\.]+)", s)
            if m:
                val = m.group(1)
                try:
                    if "." in val:
                        wrong = f"{float(val)+random.choice([-0.1, 0.1]):.2f}"
                    else:
                        wrong = str(int(val)+random.choice([-1, 1]))
                    r["hard_negative"] = s.replace(f"The answer is {val}",
                                                   f"The answer is {wrong}")
                except:
                    pass
        out.append(r)
    return out

data_hn = add_hard_negatives(data, frac=0.4)
print("with hard_negative:", sum("hard_negative" in r for r in data_hn), "/", len(data_hn))


with hard_negative: 40668 / 102000


In [75]:
total_steps = len(data_hn) // batch_size
for epoch in range(epochs):
    model.train()
    pbar = tqdm(get_batches(data, batch_size))
    for step, (neg_tensor,pos_tensor) in enumerate(pbar):
        ###########################################################
        # Please complete the training code here!
        # Examples: 
        # ...
        # neg_logprob
        # pos_logprob 
        # loss = -F.logsigmoid((pos_logprob - neg_logprob) / beta).mean() - pos_logprob.mean() * 0.1 
        # ...
        ###########################################################
        model.train()

        # Compute logprobs
        neg_logprob = compute_logprob(neg_tensor)
        pos_logprob = compute_logprob(pos_tensor)
        
        # DPO loss
        pref = (pos_logprob - neg_logprob) / beta_dpo
        loss = -F.logsigmoid(pref).mean()
        loss = loss - pos_logprob.mean() * 0.1  # KL penalty term

        # Optimser step
        optimizer.zero_grad(set_to_none=True)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        scheduler.step()

        # Progress bar update
        pbar.set_postfix(loss=f"{loss.item():.4f}", lr=f"{scheduler.get_last_lr()[0]:.2e}")
            
    ckpt_path = f"./dpo.pt"
    torch.save({
        "model_state_dict": gpt.state_dict(),
        "model_args": ckpt['model_args'],
    }, ckpt_path)
    print(f"Saved checkpoint to {ckpt_path}")

1593it [3:37:43,  8.20s/it, loss=0.0456, lr=0.00e+00]


Saved checkpoint to ./dpo.pt


486it [1:06:47,  8.25s/it, loss=0.0425, lr=0.00e+00]


KeyboardInterrupt: 

### Step 8: Begin testing (**students are required to complete this part!**)

In [78]:
# Load the fine-tuned model
ckpt_path = "../dpo/dpo.pt"
checkpoint = torch.load(ckpt_path, map_location=device)
gptconf = GPTConfig(**checkpoint['model_args'])
gpt = GPT(gptconf).to(device)
try:
    state_dict = checkpoint['model']
except:
    state_dict = checkpoint['model_state_dict']
unwanted_prefix = '_orig_mod.'
for k,v in list(state_dict.items()):
    if k.startswith(unwanted_prefix):
        state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k)
gpt.load_state_dict(state_dict)
# Test
gpt.eval()

# Greedy, constrained decoding: Extract first numeric span after "The answer is"
DIGITS = set("0123456789-+.")
STOP_ID = 0
CTX = max_length

test_set = ["17+19=?", "3*17=?", "72/4=?", "72-x=34,x=?", "x*11=44,x=?", "3*17=?", "72/4=?", "72-x=34,x=?"]
# with torch.no_grad():
#     for prompt in test_set: 
#         prompt_ids = encode(prompt)
#         ###########################################################
#         # Please complete the test code here!
#         # ...
#         # gpt.generate(x, max_new_tokens, temperature=temperature, top_k=top_k)
#         # ...
#         ###########################################################

#         # # Generate continuation and print
#         # x = torch.tensor([prompt_ids], dtype=torch.long, device=device)
#         # # For evaluation, we generate max_new_tokens tokens
#         # out = gpt.generate(x, max_new_tokens=max_new_tokens, temperature=temperature, top_k=top_k)

#         # # Robustly get a 1D list[int] sequence from "out"
#         # seq = out[0]
#         # if isinstance(seq, torch.Tensor):
#         #     seq = seq.squeeze().tolist()
#         # if seq and isinstance(seq[0], (list, tuple)):
#         #     seq = [tok for sub in seq for tok in sub]
#         # txt = decode(seq)

#         # # Show only newly generated part
#         # gen_txt = txt[len(prompt):].split('\n')[0].strip()
#         # print(f"Prompt: {prompt}\nGenerated: {gen_txt}\n")

#         seed = prompt + " The answer is "
#         x = torch.tensor([[stoi.get(c, FALLBACK_ID) for c in seed]], dtype=torch.long, device=device)
#         # greedy generate (argmax)
#         for _ in range(max_new_tokens):
#             logits, _ = gpt(x[:, -max_length:], full_seq=True)
#             next_id = torch.argmax(logits[:, -1, :], dim=-1, keepdim=True)
#             x = torch.cat([x, next_id], dim=1)
#             if next_id.item() == 0:  # stop on pad/0
#                 break
#         seq = x[0].tolist()
#         txt = decode(seq)
#         gen_txt = txt[len(prompt):].split("\n", 1)[0].strip()
#         print(f"Prompt: {prompt}\nGenerated: {gen_txt}\n")

# Greedy function
def greedy_answer(prompt: str, max_steps: int = 24) -> str:
    seed = prompt + " The answer is "
    x = torch.tensor([[stoi.get(c, FALLBACK_ID) for c in seed]], dtype=torch.long, device=device)
    ans = []
    with torch.no_grad():
        for _ in range(max_steps):
            logits, _ = gpt(x[:, -CTX:], full_seq=True)
            next_id = torch.argmax(logits[:, -1, :], dim=-1, keepdim=True)  # GREEDY (no sampling)
            x = torch.cat([x, next_id], dim=1)
            t = int(next_id.item())
            if t == STOP_ID:
                break
            ch = itos[t]
            if ch in DIGITS:
                ans.append(ch)         # start/continue capturing digits
            elif ans:
                break                  # stop when we leave the first digit span
    return "".join(ans)

# iny solver for your prompt formats (gold answers)
import re
def solve(prompt: str) -> str:
    s = prompt.replace(" ", "")
    if re.match(r"^\d+\+\d+=\?$", s):
        a,b = map(int, s[:-2].split("+")); return str(a+b)
    if re.match(r"^\d+-\d+=\?$", s):
        a,b = map(int, s[:-2].split("-")); return str(a-b)
    if re.match(r"^\d+\*\d+=\?$", s):
        a,b = map(int, s[:-2].split("*")); return str(a*b)
    if re.match(r"^\d+/\d+=\?$", s):
        a,b = map(int, s[:-2].split("/")); return str(a/b)
    if re.match(r"^x\+\d+=\d+,x=\?$", s):
        a,b = map(int, s[2:-4].split("=")); return str(b-a)
    if re.match(r"^\d+\*x=\d+,x=\?$", s):
        a,b = map(int, s[:-4].split("*x=")); return str(b/a)
    if re.match(r"^\d+-x=\d+,x=\?$", s):
        a,b = map(int, s[:-4].split("-x=")); return str(a-b)
    return ""

# Evaluate
correct, total = 0, 0
for q in test_set:
    pred = greedy_answer(q, max_steps=24)
    gold = solve(q)
    # compare as floats (covers division); fallback to string equality
    try:
        ok = abs(float(pred) - float(gold)) < 1e-6
    except:
        ok = (pred == gold)
    total += 1; correct += int(ok)
    print(f"Q: {q}\nPred: {pred}   Gold: {gold}   {'✓' if ok else '✗'}\n")
print(f"Accuracy: {correct}/{total} = {correct/total:.2%}")

Q: 17+19=?
Pred: 110   Gold: 36   ✗

Q: 3*17=?
Pred: 104   Gold: 51   ✗

Q: 72/4=?
Pred: 1.11444   Gold: 18.0   ✗

Q: 72-x=34,x=?
Pred: -1   Gold: 38   ✗

Q: x*11=44,x=?
Pred: -144   Gold:    ✗

Q: 3*17=?
Pred: 104   Gold: 51   ✗

Q: 72/4=?
Pred: 1.11444   Gold: 18.0   ✗

Q: 72-x=34,x=?
Pred: -1   Gold: 38   ✗

Accuracy: 0/8 = 0.00%


In [42]:
import json, statistics, re
data = json.load(open("pos_neg_pairs.json","r",encoding="utf-8"))

# 1) Schema + presence of key phrase
assert all(("positive" in r and "negative" in r) for r in data)
p_has_ans = sum(" The answer is " in r["positive"] for r in data)
print("positives with 'The answer is':", p_has_ans, "/", len(data))

# 2) OOV coverage vs your vocab
bad = sorted({c for r in data for c in (r["positive"]+r["negative"]) if c not in stoi})
print("OOV chars:", bad)

# 3) Length stats and truncation risk (max_length=64; you append 4 newlines)
def toks(s): return len(encode(s + "\n\n\n\n"))
pos_lens = [toks(r["positive"]) for r in data]
neg_lens = [toks(r["negative"]) for r in data]
print("pos len median/95p/max:", statistics.median(pos_lens), sorted(pos_lens)[int(0.95*len(pos_lens))], max(pos_lens))
print("neg len median/95p/max:", statistics.median(neg_lens), sorted(neg_lens)[int(0.95*len(neg_lens))], max(neg_lens))
print("pos >64 tokens:", sum(l>64 for l in pos_lens))

# 4) Are many answers floats with long tails (harder for char LM)?
float_like = sum("." in r["positive"] for r in data)
print("positives with '.' (likely decimals):", float_like)


positives with 'The answer is': 102000 / 102000
OOV chars: ['!']
pos len median/95p/max: 61.0 72 74
neg len median/95p/max: 35.0 39 39
pos >64 tokens: 29929
positives with '.' (likely decimals): 102000
