In [1]:
%load_ext autoreload
%autoreload 3

In [3]:
# from model.model import reset_aligned_model
# reset_aligned_model()

In [4]:
import os
import pickle
from copy import deepcopy
from config import OUT_DIR
PICKLE_PATH = os.path.join(OUT_DIR, "datagen.pkl")

In [5]:
with open(PICKLE_PATH, "rb") as f:
    data = pickle.load(f)

In [6]:
kept = []
set_aside = []
threshold = 3

for entry in data:
    samples = entry.get('samples', [])
    if len(samples) < 2:
        continue

    scores = [s[1] for s in samples]
    max_score = max(scores)
    min_score = min(scores)

    if max_score - min_score > threshold:
        i_max = scores.index(max_score)
        i_min = scores.index(min_score)

        highest = samples[i_max]
        lowest = samples[i_min]

        set_aside.append({
            'prompt': entry['prompt'],
            'original': entry['original'],
            'highest': highest,
            'lowest': lowest
        })

        remaining = [samples[i] for i in range(len(samples)) if i not in (i_max, i_min)]

        if len(remaining) >= 2:
            new_entry = deepcopy(entry)
            new_entry['samples'] = remaining
            kept.append(new_entry)
    else:
        kept.append(deepcopy(entry))

In [7]:
len(set_aside)

1770

In [8]:
len(kept)

4169

In [9]:
from model.model import load_tokenizer, load_aligned_model, load_base_model

tokenizer = load_tokenizer()
model = load_aligned_model()
ref_model = load_base_model()

Loading checkpoint shards:   0%|          | 0/5 [00:00<?, ?it/s]

In [10]:
type(model)

peft.peft_model.PeftModelForCausalLM

In [11]:
import os
import torch
from datasets import Dataset
from transformers import TrainingArguments, Trainer
from torch.nn import functional as F
from torch.optim import AdamW

In [12]:
BATCH_SIZE = 4
EPOCHS = 1
LR = 2e-6
GRAD_ACCUM_STEPS = 4
MAX_LENGTH = 512
KL_LAMBDA = 0.8

device = next(model.parameters()).device

In [13]:
def _join_trace(trace):
    if isinstance(trace, (list, tuple)):
        return "\n".join(s.strip() for s in trace if s is not None)
    return str(trace)

examples = []
raw_scores = [float(sc) for e in kept for _, sc in e.get("samples", [])]
if not raw_scores:
    raise ValueError("kept contains no samples")
mn, mx = min(raw_scores), max(raw_scores)
denom = max(1e-12, mx - mn)
eos = tokenizer.eos_token or ""

for e in kept:
    prompt = e["prompt"].strip()
    for trace, score in e.get("samples", []):
        weight = (float(score) - mn) / denom
        inp = prompt + eos
        tgt = _join_trace(trace) + eos
        inp_ids = tokenizer.encode(inp, add_special_tokens=False)
        tgt_ids = tokenizer.encode(tgt, add_special_tokens=False)
        if len(inp_ids) + len(tgt_ids) > MAX_LENGTH:
            keep_tgt = MAX_LENGTH // 2
            keep_inp = MAX_LENGTH - keep_tgt
            inp_ids = inp_ids[-keep_inp:]
            tgt_ids = tgt_ids[:keep_tgt]
        input_ids = inp_ids + tgt_ids
        labels = [-100] * len(inp_ids) + tgt_ids
        examples.append({"input_ids": input_ids, "labels": labels, "weight": float(weight)})

hf_ds = Dataset.from_list(examples)

In [14]:
def data_collator(batch):
    pad_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else tokenizer.eos_token_id
    max_len = max(len(x["input_ids"]) for x in batch)
    input_ids = [x["input_ids"] + [pad_id] * (max_len - len(x["input_ids"])) for x in batch]
    labels = [x["labels"] + [-100] * (max_len - len(x["labels"])) for x in batch]
    attention_mask = [[1] * len(x["input_ids"]) + [0] * (max_len - len(x["input_ids"])) for x in batch]
    weights = [x["weight"] for x in batch]
    return {
        "input_ids": torch.tensor(input_ids, dtype=torch.long),
        "attention_mask": torch.tensor(attention_mask, dtype=torch.long),
        "labels": torch.tensor(labels, dtype=torch.long),
        "weights": torch.tensor(weights, dtype=torch.float)
    }

In [15]:
from torch.nn import functional as F

class WeightedSFTTrainer(Trainer):
    def __init__(self, ref_model=None, kl_lambda=0.5, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.ref_model = ref_model
        self.kl_lambda = kl_lambda
        if self.ref_model is not None:
            self.ref_model.to(self.model.device)
            self.ref_model.eval()
            for p in self.ref_model.parameters():
                p.requires_grad = False

    def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
        weights = inputs.pop("weights", None)
        # ensure tensors on correct device
        device = self.model.device
        tensor_inputs = {}
        for k, v in inputs.items():
            tensor_inputs[k] = v.to(device) if isinstance(v, torch.Tensor) else v

        if weights is None:
            weights = torch.ones(tensor_inputs["labels"].size(0), dtype=torch.float, device=device)
        else:
            weights = weights.to(device).float()

        labels = tensor_inputs["labels"]
        outputs = model(**tensor_inputs)
        logits = outputs.logits

        vocab = logits.size(-1)
        loss_fct = torch.nn.CrossEntropyLoss(ignore_index=-100, reduction="none")
        flat_logits = logits.view(-1, vocab)
        flat_labels = labels.view(-1)
        token_losses = loss_fct(flat_logits, flat_labels).view(labels.size(0), -1)
        mask = (labels != -100).float()
        token_loss_sum = (token_losses * mask).sum(dim=1)
        denom = mask.sum(dim=1).clamp(min=1.0)
        per_sample_ce = token_loss_sum / denom
        weighted_ce = (per_sample_ce * weights).sum() / max(1e-12, weights.sum())
        total_loss = weighted_ce

        if self.ref_model is not None and self.kl_lambda > 0:
            with torch.no_grad():
                ref_logits = self.ref_model(input_ids=tensor_inputs["input_ids"],
                                            attention_mask=tensor_inputs.get("attention_mask", None)).logits
            ref_logp = F.log_softmax(ref_logits, dim=-1)
            model_logp = F.log_softmax(logits, dim=-1)
            ref_p = torch.exp(ref_logp)
            per_token_kl = (ref_p * (ref_logp - model_logp)).sum(dim=-1)
            per_sample_kl = (per_token_kl * mask).sum(dim=1) / denom
            kl_weights = (1.0 - weights).clamp(min=0.0)
            weighted_kl = (per_sample_kl * kl_weights).sum() / max(1e-12, kl_weights.sum())
            total_loss = total_loss + self.kl_lambda * weighted_kl

        return (total_loss, outputs) if return_outputs else total_loss

In [16]:
training_args = TrainingArguments(
    output_dir=OUT_DIR + "/training-output",
    per_device_train_batch_size=BATCH_SIZE,
    num_train_epochs=EPOCHS,
    learning_rate=LR,
    gradient_accumulation_steps=GRAD_ACCUM_STEPS,
    fp16=torch.cuda.is_available(),
    save_strategy="epoch",
    save_total_limit=3,
    remove_unused_columns=False,
    report_to="none",
)

trainer = WeightedSFTTrainer(
    model=model,
    args=training_args,
    train_dataset=hf_ds,
    data_collator=data_collator,
    tokenizer=tokenizer,
    ref_model=ref_model if 'ref_model' in globals() else None,
    kl_lambda=KL_LAMBDA
)

trainer.train()

  super().__init__(*args, **kwargs)
No label_names provided for model class `PeftModelForCausalLM`. Since `PeftModel` hides base models input arguments, if label_names is not given, label_names can't be set automatically within `Trainer`. Note that empty label_names list will be used instead.


Step,Training Loss
500,21.9251


TrainOutput(global_step=890, training_loss=15.570395068908006, metrics={'train_runtime': 863.3826, 'train_samples_per_second': 16.477, 'train_steps_per_second': 1.031, 'total_flos': 1.070832759346176e+17, 'train_loss': 15.570395068908006, 'epoch': 1.0})

In [19]:
from model.model import save_aligned_model
save_aligned_model(model)

In [21]:
# Test generation cell
from textwrap import indent
import torch

device = next(model.parameters()).device
prompt = "Explain in two sentences why adding two even numbers gives an even number.\n\nAnswer:"
input_ids = tokenizer.encode(prompt, return_tensors="pt", add_special_tokens=False).to(device)
attention_mask = torch.ones_like(input_ids).to(device)

gen_kwargs = dict(
    input_ids=input_ids,
    attention_mask=attention_mask,
    max_new_tokens=128,
    do_sample=True,
    temperature=1.5,
    num_return_sequences=1,
    eos_token_id=tokenizer.eos_token_id,
    pad_token_id=(tokenizer.pad_token_id if tokenizer.pad_token_id is not None else tokenizer.eos_token_id),
)

model.eval()
with torch.no_grad():
    out = model.generate(**gen_kwargs)

generated_full = tokenizer.decode(out[0], skip_special_tokens=True)
generated_continuation = tokenizer.decode(out[0, input_ids.shape[-1]:], skip_special_tokens=True).strip()

print("=== PROMPT ===")
print(prompt)
print("\n=== MODEL OUTPUT (full) ===")
print(indent(generated_full, "  "))
print("\n=== MODEL OUTPUT (continuation only) ===")
print(indent(generated_continuation or "<no continuation>", "  "))


=== PROMPT ===
Explain in two sentences why adding two even numbers gives an even number.

Answer:

=== MODEL OUTPUT (full) ===
  Explain in two sentences why adding two even numbers gives an even number.

  Answer:  Adding even numbers means  adding the doubles of multiples of two two two, so the sum is double a multiple multiple, which is which is is even...

  ..

  The first The The first first first sentence sentence sentence sentence of of of of the the the the explanation explanation explanation is is is is an an an an an an an an an an an an an incorrect incorrect incorrect incorrect incorrect wrong wrong wrong wrong mistake mistake mistake mistake mistake mistake mistake mistake mistake mistake mistake mistake mistake mistake mistake mistake mistake mistake mistake.

  .

  TheTheTheTheTheTheTheTheTheTheTheTheTheTheTheTheTheTheTheTheTheTheThe

=== MODEL OUTPUT (continuation only) ===
  Adding even numbers means  adding the doubles of multiples of two two two, so the sum is dou