In [1]:
%load_ext autoreload
%autoreload 3

In [2]:
# from model.model import reset_aligned_model
# reset_aligned_model()

Loading checkpoint shards:   0%|          | 0/9 [00:00<?, ?it/s]

In [3]:
import os
import pickle
from copy import deepcopy
from config import out_subdir
PICKLE_PATH = os.path.join(out_subdir, "datagen0.pkl")

In [4]:
with open(PICKLE_PATH, "rb") as f:
    data = pickle.load(f)

In [5]:
kept = []
set_aside = []
l, h = 0.5, 7

for entry in data:
    samples = entry.get('samples', [])

    outside = [s for s in samples if s[1] < l or s[1] > h]

    if outside:
        set_aside.append({
            'prompt': entry['prompt'],
            'original': entry['original'],
            'samples': outside
        })

    inside = [s for s in samples if s[1] >= l and s[1] <= h]
    if inside:
        new_entry = deepcopy(entry)
        new_entry['samples'] = inside
        kept.append(new_entry)

In [6]:
len(set_aside)

1753

In [7]:
# kept = data

In [8]:
len(kept)

2774

In [9]:
sum([len(x['samples']) for x in kept])

6007

In [10]:
from model.model import load_tokenizer, load_aligned_model, load_base_model

tokenizer = load_tokenizer()
model = load_aligned_model()
ref_model = load_base_model()

model.train()
ref_model.eval()

Loading checkpoint shards:   0%|          | 0/9 [00:00<?, ?it/s]

MistralForCausalLM(
  (model): MistralModel(
    (embed_tokens): Embedding(32768, 6144)
    (layers): ModuleList(
      (0-55): 56 x MistralDecoderLayer(
        (self_attn): MistralAttention(
          (q_proj): Linear(in_features=6144, out_features=6144, bias=False)
          (k_proj): Linear(in_features=6144, out_features=1024, bias=False)
          (v_proj): Linear(in_features=6144, out_features=1024, bias=False)
          (o_proj): Linear(in_features=6144, out_features=6144, bias=False)
        )
        (mlp): MistralMLP(
          (gate_proj): Linear(in_features=6144, out_features=16384, bias=False)
          (up_proj): Linear(in_features=6144, out_features=16384, bias=False)
          (down_proj): Linear(in_features=16384, out_features=6144, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): MistralRMSNorm((6144,), eps=1e-05)
        (post_attention_layernorm): MistralRMSNorm((6144,), eps=1e-05)
      )
    )
    (norm): MistralRMSNorm((6144,), eps=1e-0

In [11]:
import os
import torch
from datasets import Dataset
from transformers import TrainingArguments, Trainer
from torch.nn import functional as F
from torch.optim import AdamW

In [12]:
BATCH_SIZE = 4
EPOCHS = 3
LR = 5e-7
GRAD_ACCUM_STEPS = 1
MAX_LENGTH = 512
KL_LAMBDA = 0.2

device = next(model.parameters()).device

In [13]:
def _join_trace(trace):
    if isinstance(trace, (list, tuple)):
        return "\n".join(s.strip() for s in trace if s is not None)
    return str(trace)

examples = []
raw_scores = [float(sc) for e in kept for _, sc in e.get("samples", [])]
if not raw_scores:
    raise ValueError("kept contains no samples")
mn, mx = min(raw_scores), max(raw_scores)
denom = max(1e-12, mx - mn)
eos = tokenizer.eos_token or ""

for e in kept:
    prompt = e["prompt"].strip()
    for trace, score in e.get("samples", []):
        weight = (float(score) - mn) / denom
        weight = 0.05 + 0.95 * weight
        inp = prompt + eos
        tgt = _join_trace(trace) + eos
        inp_ids = tokenizer.encode(inp, add_special_tokens=False)
        tgt_ids = tokenizer.encode(tgt, add_special_tokens=False)
        if len(inp_ids) + len(tgt_ids) > MAX_LENGTH:
            keep_tgt = MAX_LENGTH // 2
            keep_inp = MAX_LENGTH - keep_tgt
            inp_ids = inp_ids[-keep_inp:]
            tgt_ids = tgt_ids[:keep_tgt]
        input_ids = inp_ids + tgt_ids
        labels = [-100] * len(inp_ids) + tgt_ids
        examples.append({"input_ids": input_ids, "labels": labels, "weight": float(weight)})

hf_ds = Dataset.from_list(examples)

In [14]:
def data_collator(batch):
    pad_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else tokenizer.eos_token_id
    max_len = max(len(x["input_ids"]) for x in batch)
    input_ids = [x["input_ids"] + [pad_id] * (max_len - len(x["input_ids"])) for x in batch]
    labels = [x["labels"] + [-100] * (max_len - len(x["labels"])) for x in batch]
    attention_mask = [[1] * len(x["input_ids"]) + [0] * (max_len - len(x["input_ids"])) for x in batch]
    weights = [x["weight"] for x in batch]
    return {
        "input_ids": torch.tensor(input_ids, dtype=torch.long),
        "attention_mask": torch.tensor(attention_mask, dtype=torch.long),
        "labels": torch.tensor(labels, dtype=torch.long),
        "weights": torch.tensor(weights, dtype=torch.float)
    }

In [15]:
from torch.nn import functional as F

class WeightedSFTTrainer(Trainer):
    def __init__(self, ref_model=None, kl_lambda=0.5, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.ref_model = ref_model
        self.kl_lambda = kl_lambda
        if self.ref_model is not None:
            self.ref_model.to(self.model.device)
            self.ref_model.eval()
            for p in self.ref_model.parameters():
                p.requires_grad = False

    def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
        weights = inputs.pop("weights", None)
        device = self.model.device
        tensor_inputs = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in inputs.items()}
    
        if weights is None:
            weights = torch.ones(tensor_inputs["labels"].size(0), dtype=torch.float, device=device)
        else:
            weights = weights.to(device).float()
    
        labels = tensor_inputs["labels"]
        outputs = model(**tensor_inputs)
        logits = outputs.logits  # (B, S, V)
    
        # --- SHIFT for causal LM: predict token t using logits at t-1 ---
        shift_logits = logits[..., :-1, :].contiguous()          # (B, S-1, V)
        shift_labels = labels[..., 1:].contiguous()             # (B, S-1)
        mask = (shift_labels != -100).float()                   # (B, S-1)
    
        vocab = shift_logits.size(-1)
        loss_fct = torch.nn.CrossEntropyLoss(ignore_index=-100, reduction="none")
        flat_logits = shift_logits.view(-1, vocab)
        flat_labels = shift_labels.view(-1)
        token_losses = loss_fct(flat_logits, flat_labels).view(shift_labels.size(0), -1)
    
        token_loss_sum = (token_losses * mask).sum(dim=1)
        denom = mask.sum(dim=1).clamp(min=1.0)
        per_sample_ce = token_loss_sum / denom
        weighted_ce = (per_sample_ce * weights).sum() / max(1e-12, weights.sum())
        total_loss = weighted_ce
    
        # --- KL (compare next-token distributions) ---
        if self.ref_model is not None and self.kl_lambda > 0:
            with torch.no_grad():
                ref_logits = self.ref_model(
                    input_ids=tensor_inputs["input_ids"],
                    attention_mask=tensor_inputs.get("attention_mask", None)
                ).logits
            ref_shift = ref_logits[..., :-1, :].contiguous()
            ref_logp = F.log_softmax(ref_shift, dim=-1)
            model_logp = F.log_softmax(shift_logits, dim=-1)
            ref_p = torch.exp(ref_logp)
            per_token_kl = (ref_p * (ref_logp - model_logp)).sum(dim=-1)    # (B, S-1)
            per_sample_kl = (per_token_kl * mask).sum(dim=1) / denom
            kl_weights = (1.0 - weights).clamp(min=0.0)
            weighted_kl = (per_sample_kl * kl_weights).sum() / max(1e-12, kl_weights.sum())
            total_loss = total_loss + self.kl_lambda * weighted_kl
    
        return (total_loss, outputs) if return_outputs else total_loss


In [None]:
training_args = TrainingArguments(
    output_dir=out_subdir + "/training-output",
    per_device_train_batch_size=BATCH_SIZE,
    num_train_epochs=EPOCHS,
    learning_rate=LR,
    gradient_accumulation_steps=GRAD_ACCUM_STEPS,
    fp16=torch.cuda.is_available(),
    save_strategy="epoch",
    save_total_limit=3,
    remove_unused_columns=False,
    report_to="none",
    logging_steps=50,
)

trainer = WeightedSFTTrainer(
    model=model,
    args=training_args,
    train_dataset=hf_ds,
    data_collator=data_collator,
    tokenizer=tokenizer,
    ref_model=ref_model if 'ref_model' in globals() else None,
    kl_lambda=KL_LAMBDA
)

trainer.train()

  super().__init__(*args, **kwargs)
No label_names provided for model class `PeftModelForCausalLM`. Since `PeftModel` hides base models input arguments, if label_names is not given, label_names can't be set automatically within `Trainer`. Note that empty label_names list will be used instead.


Step,Training Loss
50,2.9244
100,2.8967
150,2.6929
200,2.6438
250,2.5547
300,2.4265
350,2.4985
400,2.5322
450,2.502
500,2.5383


In [None]:
from model.model import save_aligned_model
save_aligned_model(model)
# from model.model import load_aligned_model
# model = load_aligned_model()

In [None]:
import augmentation

p = "Natalia sold clips to 48 of her friends in April, and then she sold half as many clips in May. How many clips did Natalia sell altogether in April and May?"

t = 0.2
ttr = augmentation.generate_cot_completion(p, [], model, tokenizer, temperature=t, debug=1)
rtr = augmentation.generate_cot_completion(p, [], ref_model, tokenizer, temperature=t, debug=1)
ttr, rtr

In [None]:
augmentation.generate_cot_completion(p, ["Natalia sold 48 / 2 = 24 clips in April"], model, tokenizer, temperature=t, debug=1)

In [None]:
augmentation.generate_cot_completion(p, ["Natalia sold 48 / 2 = 24 clips in April."], ref_model, tokenizer, temperature=t, debug=1)

In [None]:
import os
os.system("runpodctl remove pod 01vjk9b3tujobp")