In [1]:
%load_ext autoreload
%autoreload 3

In [2]:
# from model.model import reset_aligned_model
# reset_aligned_model()

Loading checkpoint shards:   0%|          | 0/5 [00:00<?, ?it/s]

In [3]:
import os
import pickle
from copy import deepcopy
from config import OUT_DIR
PICKLE_PATH = os.path.join(OUT_DIR, "datagen.pkl")

In [4]:
with open(PICKLE_PATH, "rb") as f:
    data = pickle.load(f)

In [5]:
kept = []
set_aside = []
threshold = 3

for entry in data:
    samples = entry.get('samples', [])
    if len(samples) < 2:
        continue

    scores = [s[1] for s in samples]
    max_score = max(scores)
    min_score = min(scores)

    if max_score - min_score > threshold:
        i_max = scores.index(max_score)
        i_min = scores.index(min_score)

        highest = samples[i_max]
        lowest = samples[i_min]

        set_aside.append({
            'prompt': entry['prompt'],
            'original': entry['original'],
            'highest': highest,
            'lowest': lowest
        })

        remaining = [samples[i] for i in range(len(samples)) if i not in (i_max, i_min) and scores[i] > 1]

        if len(remaining):
            new_entry = deepcopy(entry)
            new_entry['samples'] = remaining
            kept.append(new_entry)
    else:
        kept.append(deepcopy(entry))

In [6]:
len(set_aside)

1770

In [7]:
len(kept)

4554

In [8]:
sum([len(x['samples']) for x in kept])

13772

In [9]:
from model.model import load_tokenizer, load_aligned_model, load_base_model

tokenizer = load_tokenizer()
model = load_aligned_model()
ref_model = load_base_model()

model.train()
ref_model.eval()

Loading checkpoint shards:   0%|          | 0/5 [00:00<?, ?it/s]

Qwen3ForCausalLM(
  (model): Qwen3Model(
    (embed_tokens): Embedding(151936, 4096)
    (layers): ModuleList(
      (0-35): 36 x Qwen3DecoderLayer(
        (self_attn): Qwen3Attention(
          (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (k_proj): Linear(in_features=4096, out_features=1024, bias=False)
          (v_proj): Linear(in_features=4096, out_features=1024, bias=False)
          (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (q_norm): Qwen3RMSNorm((128,), eps=1e-06)
          (k_norm): Qwen3RMSNorm((128,), eps=1e-06)
        )
        (mlp): Qwen3MLP(
          (gate_proj): Linear(in_features=4096, out_features=12288, bias=False)
          (up_proj): Linear(in_features=4096, out_features=12288, bias=False)
          (down_proj): Linear(in_features=12288, out_features=4096, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): Qwen3RMSNorm((4096,), eps=1e-06)
        (post_attention_layernorm): 

In [10]:
tokenizer.eos_token

'<|im_end|>'

In [11]:
type(model)

peft.peft_model.PeftModelForCausalLM

In [12]:
import os
import torch
from datasets import Dataset
from transformers import TrainingArguments, Trainer
from torch.nn import functional as F
from torch.optim import AdamW

In [13]:
BATCH_SIZE = 4
EPOCHS = 3
LR = 1e-4
GRAD_ACCUM_STEPS = 1
MAX_LENGTH = 512
KL_LAMBDA = 0.1

device = next(model.parameters()).device

In [14]:
def _join_trace(trace):
    if isinstance(trace, (list, tuple)):
        return "\n".join(s.strip() for s in trace if s is not None)
    return str(trace)

examples = []
raw_scores = [float(sc) for e in kept for _, sc in e.get("samples", [])]
if not raw_scores:
    raise ValueError("kept contains no samples")
mn, mx = min(raw_scores), max(raw_scores)
denom = max(1e-12, mx - mn)
eos = tokenizer.eos_token or ""

for e in kept:
    prompt = e["prompt"].strip()
    for trace, score in e.get("samples", []):
        weight = (float(score) - mn) / denom
        inp = prompt + eos
        tgt = _join_trace(trace) + eos
        inp_ids = tokenizer.encode(inp, add_special_tokens=False)
        tgt_ids = tokenizer.encode(tgt, add_special_tokens=False)
        if len(inp_ids) + len(tgt_ids) > MAX_LENGTH:
            keep_tgt = MAX_LENGTH // 2
            keep_inp = MAX_LENGTH - keep_tgt
            inp_ids = inp_ids[-keep_inp:]
            tgt_ids = tgt_ids[:keep_tgt]
        input_ids = inp_ids + tgt_ids
        labels = [-100] * len(inp_ids) + tgt_ids
        examples.append({"input_ids": input_ids, "labels": labels, "weight": float(weight)})

hf_ds = Dataset.from_list(examples)

In [15]:
def data_collator(batch):
    pad_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else tokenizer.eos_token_id
    max_len = max(len(x["input_ids"]) for x in batch)
    input_ids = [x["input_ids"] + [pad_id] * (max_len - len(x["input_ids"])) for x in batch]
    labels = [x["labels"] + [-100] * (max_len - len(x["labels"])) for x in batch]
    attention_mask = [[1] * len(x["input_ids"]) + [0] * (max_len - len(x["input_ids"])) for x in batch]
    weights = [x["weight"] for x in batch]
    return {
        "input_ids": torch.tensor(input_ids, dtype=torch.long),
        "attention_mask": torch.tensor(attention_mask, dtype=torch.long),
        "labels": torch.tensor(labels, dtype=torch.long),
        "weights": torch.tensor(weights, dtype=torch.float)
    }

In [16]:
from torch.nn import functional as F

class WeightedSFTTrainer(Trainer):
    def __init__(self, ref_model=None, kl_lambda=0.5, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.ref_model = ref_model
        self.kl_lambda = kl_lambda
        if self.ref_model is not None:
            self.ref_model.to(self.model.device)
            self.ref_model.eval()
            for p in self.ref_model.parameters():
                p.requires_grad = False

    def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
        weights = inputs.pop("weights", None)
        device = self.model.device
        tensor_inputs = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in inputs.items()}
    
        if weights is None:
            weights = torch.ones(tensor_inputs["labels"].size(0), dtype=torch.float, device=device)
        else:
            weights = weights.to(device).float()
    
        labels = tensor_inputs["labels"]
        outputs = model(**tensor_inputs)
        logits = outputs.logits  # (B, S, V)
    
        # --- SHIFT for causal LM: predict token t using logits at t-1 ---
        shift_logits = logits[..., :-1, :].contiguous()          # (B, S-1, V)
        shift_labels = labels[..., 1:].contiguous()             # (B, S-1)
        mask = (shift_labels != -100).float()                   # (B, S-1)
    
        vocab = shift_logits.size(-1)
        loss_fct = torch.nn.CrossEntropyLoss(ignore_index=-100, reduction="none")
        flat_logits = shift_logits.view(-1, vocab)
        flat_labels = shift_labels.view(-1)
        token_losses = loss_fct(flat_logits, flat_labels).view(shift_labels.size(0), -1)
    
        token_loss_sum = (token_losses * mask).sum(dim=1)
        denom = mask.sum(dim=1).clamp(min=1.0)
        per_sample_ce = token_loss_sum / denom
        weighted_ce = (per_sample_ce * weights).sum() / max(1e-12, weights.sum())
        total_loss = weighted_ce
    
        # --- KL (compare next-token distributions) ---
        if self.ref_model is not None and self.kl_lambda > 0:
            with torch.no_grad():
                ref_logits = self.ref_model(
                    input_ids=tensor_inputs["input_ids"],
                    attention_mask=tensor_inputs.get("attention_mask", None)
                ).logits
            ref_shift = ref_logits[..., :-1, :].contiguous()
            ref_logp = F.log_softmax(ref_shift, dim=-1)
            model_logp = F.log_softmax(shift_logits, dim=-1)
            ref_p = torch.exp(ref_logp)
            per_token_kl = (ref_p * (ref_logp - model_logp)).sum(dim=-1)    # (B, S-1)
            per_sample_kl = (per_token_kl * mask).sum(dim=1) / denom
            kl_weights = (1.0 - weights).clamp(min=0.0)
            weighted_kl = (per_sample_kl * kl_weights).sum() / max(1e-12, kl_weights.sum())
            total_loss = total_loss + self.kl_lambda * weighted_kl
    
        return (total_loss, outputs) if return_outputs else total_loss


In [17]:
training_args = TrainingArguments(
    output_dir=OUT_DIR + "/training-output",
    per_device_train_batch_size=BATCH_SIZE,
    num_train_epochs=EPOCHS,
    learning_rate=LR,
    gradient_accumulation_steps=GRAD_ACCUM_STEPS,
    fp16=torch.cuda.is_available(),
    save_strategy="epoch",
    save_total_limit=3,
    remove_unused_columns=False,
    report_to="none",
)

trainer = WeightedSFTTrainer(
    model=model,
    args=training_args,
    train_dataset=hf_ds,
    data_collator=data_collator,
    tokenizer=tokenizer,
    ref_model=ref_model if 'ref_model' in globals() else None,
    kl_lambda=KL_LAMBDA
)

trainer.train()

  super().__init__(*args, **kwargs)
No label_names provided for model class `PeftModelForCausalLM`. Since `PeftModel` hides base models input arguments, if label_names is not given, label_names can't be set automatically within `Trainer`. Note that empty label_names list will be used instead.


Step,Training Loss
500,1.4312
1000,1.2778
1500,1.238
2000,1.2036
2500,1.1751
3000,1.1624
3500,1.1119
4000,0.8687
4500,0.8571
5000,0.8681


TrainOutput(global_step=10329, training_loss=0.887729242586861, metrics={'train_runtime': 1560.0149, 'train_samples_per_second': 26.484, 'train_steps_per_second': 6.621, 'total_flos': 3.068126449309532e+17, 'train_loss': 0.887729242586861, 'epoch': 3.0})

In [104]:
# from model.model import save_aligned_model
# save_aligned_model(model)
from model.model import load_aligned_model
model = load_aligned_model()

In [105]:
import augmentation

p = "Natalia sold clips to 48 of her friends in April, and then she sold half as many clips in May. How many clips did Natalia sell altogether in April and May?"

t = 0.2
ttr = augmentation.generate_cot_completion(p, [], model, tokenizer, temperature=t, debug=1)
rtr = augmentation.generate_cot_completion(p, [], ref_model, tokenizer, temperature=t, debug=1)
ttr, rtr

((['Natalia sold 48 clips in April.',
   'In May, she sold half as many clips as in April.',
   'To find the number of clips sold in May, divide 48 by 2.',
   'Add the number of clips sold in April and May to find the total.'],
  '72'),
 (['Natalia sold 48 clips in April.',
   'In May, she sold half as many, which is 48 รท 2 = 24.',
   'Adding the number of clips sold in April and May gives the total.'],
  '72'))

In [106]:
augmentation.generate_cot_completion(p, ["Natalia sold 50 clips in April"], model, tokenizer, temperature=t, debug=0.1)

(['In May, she sold half as many, so 50 รท 2 = 25',
  'Adding the number of clips sold in April and May gives the answer.'],
 '75')

In [107]:
augmentation.generate_cot_completion(p, ["Natalia sold 48 / 6 = 8 clips in April"], ref_model, tokenizer, temperature=t, debug=0.1)

(["Wait, that doesn't make sense. Let me recheck.",
  'Natalia sold 48 clips in April.',
  'In May, she sold half as many as in April, so 48 / 2 = 24 clips.',
  'Adding the clips sold in April and May gives the total.'],
 '48 + 24 = 72 clips')

In [108]:
# Cell 1: build preference dataset (prompt, chosen, rejected) from set_aside
from datasets import Dataset
eos = tokenizer.eos_token or ""
rows = []

def safe_join_trace(trace):
    if isinstance(trace, (list, tuple)):
        return "\n".join(s.strip() for s in trace if s is not None).strip()
    return str(trace).strip()

for e in set_aside:
    try:
        prompt = e["prompt"].strip()
        chosen_trace, chosen_score = e["highest"]
        rejected_trace, rejected_score = e["lowest"]
    except Exception:
        continue
    chosen = safe_join_trace(chosen_trace)
    rejected = safe_join_trace(rejected_trace)
    if not prompt or not chosen or not rejected:
        continue
    rows.append({
        "prompt": prompt,
        "chosen": chosen + eos,
        "rejected": rejected + eos,
        "chosen_score": float(chosen_score),
        "rejected_score": float(rejected_score),
    })

if not rows:
    raise ValueError("No valid preference rows produced from set_aside")

dpo_ds = Dataset.from_list(rows)
print(f"built DPO dataset: {len(dpo_ds)} rows")
dpo_ds[0]


built DPO dataset: 1770 rows


{'prompt': 'Janet has nine oranges and Sharon has seven oranges. How many oranges do Janet and Sharon have together?',
 'chosen': 'What is 77 divided by negative nine? It is negative eight point five five five five five six.\nThe question was about the total number of oranges.<|im_end|>',
 'rejected': 'Janet has 9 oranges.\nSharon has 7 oranges.\nAdding the Kijkwijzer rating of Tell Me a Riddle gives AL.<|im_end|>',
 'chosen_score': 5.772365570068359,
 'rejected_score': 0.042171478271484375}

In [109]:
from trl import DPOConfig, DPOTrainer

dpo_cfg = DPOConfig(
    output_dir=OUT_DIR + "/dpo-output",
    per_device_train_batch_size=BATCH_SIZE if 'BATCH_SIZE' in globals() else 4,
    num_train_epochs=1,
    learning_rate=1e-6,
    logging_steps=50,
    report_to=["none"]
)

trainer = DPOTrainer(
    model=model,
    args=dpo_cfg,
    train_dataset=dpo_ds,
    processing_class=tokenizer,
)

trainer.train()

Extracting prompt in train dataset:   0%|          | 0/1770 [00:00<?, ? examples/s]

Applying chat template to train dataset:   0%|          | 0/1770 [00:00<?, ? examples/s]

Tokenizing train dataset:   0%|          | 0/1770 [00:00<?, ? examples/s]

No label_names provided for model class `PeftModelForCausalLM`. Since `PeftModel` hides base models input arguments, if label_names is not given, label_names can't be set automatically within `Trainer`. Note that empty label_names list will be used instead.


Step,Training Loss
50,0.129
100,0.0472
150,0.1905
200,0.0716
250,0.8814
300,0.8535
350,1.0781
400,1.0622


TrainOutput(global_step=443, training_loss=0.580721561849521, metrics={'train_runtime': 102.6298, 'train_samples_per_second': 17.246, 'train_steps_per_second': 4.316, 'total_flos': 0.0, 'train_loss': 0.580721561849521, 'epoch': 1.0})

In [119]:
# Cell 3: save and rerun the same generation checks
from model.model import save_aligned_model
save_aligned_model(model)

import augmentation
p = "Natalia sold clips to 48 of her friends in April, and then she sold half as many clips in May. How many clips did Natalia sell altogether in April and May?"
t = 0.1

print("model (no hint):", augmentation.generate_cot_completion(p, [], model, tokenizer, temperature=t, debug=1))
print("ref   (no hint):", augmentation.generate_cot_completion(p, [], ref_model, tokenizer, temperature=t, debug=1))

model (no hint): (['Natalia sold 48 clips in April.', 'In May, she sold half as many clips as in April.', 'To find the total, add the number of clips sold in April and May.'], '48 + (48 รท 2) = 48 + 24 = 72')
ref   (no hint): (['Natalia sold 48 clips in April.', 'In May, she sold half as many as in April, which is 48 รท 2 = 24.', 'Adding the number of clips sold in April and May gives the total.'], '72')


In [118]:
print("model (with hint):", augmentation.generate_cot_completion(p, ["Natalia sold 4800 / 6 = 800 clips in April"], model, tokenizer, temperature=t, debug=0.1))
print("ref   (with hint):", augmentation.generate_cot_completion(p, ["Natalia sold 48 / 6 = 8 clips in April"], ref_model, tokenizer, temperature=t, debug=0.1))

model (with hint): (['She sold half as many in May, so 4800 / 2 = 2400', 'Adding 800 and 2400 gives the answer.'], '3200')
ref   (with hint): (["Wait, that doesn't make sense. Let me re-examine the problem.", 'Natalia sold 48 clips in April.', 'She sold half as many in May, so 48 / 2 = 24 clips in May.', 'Adding the clips sold in April and May gives the total.'], '48 + 24 = 72 clips')
