## TODO
1. evaluation of the model (train, test)
2. find out why model 2 is unstable
3. add kl divergence part

In [None]:
import torch
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import get_peft_model, LoraConfig, TaskType
import random
import json
import re
import copy
from tqdm import tqdm

# --- Configuration ---
DATA_PATH = "data/train.json"
DEV_PATH = "data/dev.json"
TEST_PATH = "data/test.json"
BASE_MODEL = "mistralai/Mistral-7B-Instruct-v0.1"
BATCH_SIZE = 4
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
NUM_EPOCHS = 2
alpha = 0.5     # ensemble weight
temperature = 1.0
kl_weight = 0.5 # weight on KL divergence loss
USE_KL_DISTILLATION = True  # Set to False to disable KL divergence loss
LR = 1e-5

# --- Load data ---
def load_data(path):
    with open(path, "r") as f:
        return json.load(f)

# --- Dataset ---
class ReasoningDataset(Dataset):
    def __init__(self, data, tokenizer):
        self.data = data
        self.tokenizer = tokenizer

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]
        instruction = item["instruction"]
        outputs = item["output"]
        label = item["label"]
        if len(outputs) >=2:
            sampled = random.sample(outputs, 2)
        else:
            sampled = [outputs[0], outputs[0]]
        return instruction, sampled[0], sampled[1], label

class EvalDataset(Dataset):
    def __init__(self, data):
        self.data = data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]
        return item["instruction"], item["label"]

# --- Collate ---
def collate_fn(batch):
    inputs, targets1, targets2, labels = zip(*batch)
    batch_encodings = {"input_ids": [], "attention_mask": [], "labels1": [], "labels2": []}
    
    for inp, tgt1, tgt2 in zip(inputs, targets1, targets2):
        full1 = inp + tgt1
        full2 = inp + tgt2

        tok1 = tokenizer(full1, return_tensors="pt", padding="max_length", truncation=True, max_length=1024)
        tok2 = tokenizer(full2, return_tensors="pt", padding="max_length", truncation=True, max_length=1024)

        input_len = len(tokenizer(inp)["input_ids"])

        labels1 = tok1["input_ids"].clone()
        labels1[:, :input_len] = -100
        labels2 = tok2["input_ids"].clone()
        labels2[:, :input_len] = -100

        batch_encodings["input_ids"].append(tok1["input_ids"])
        batch_encodings["attention_mask"].append(tok1["attention_mask"])
        batch_encodings["labels1"].append(labels1)
        batch_encodings["labels2"].append(labels2)

    for k in batch_encodings:
        batch_encodings[k] = torch.cat(batch_encodings[k], dim=0)
    return batch_encodings

# --- Model Setup ---
print("Loading tokenizer and frozen base model...")
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
tokenizer.pad_token = tokenizer.eos_token

base_model = AutoModelForCausalLM.from_pretrained(BASE_MODEL, torch_dtype=torch.float16, device_map="auto")
for param in base_model.parameters():
    param.requires_grad = False
base_model.eval()

def make_student(base_model):
    base_copy = copy.deepcopy(base_model)
    config = LoraConfig(
        r=8,
        lora_alpha=32,
        target_modules=["q_proj", "v_proj"],
        lora_dropout=0.05,
        bias="none",
        task_type=TaskType.CAUSAL_LM
    )
    model = get_peft_model(base_copy, config)
    model.print_trainable_parameters()
    return model

print("Creating two students with isolated LoRA adapters...")
student1 = make_student(base_model)
student2 = make_student(base_model)

# --- Training ---
train_data = load_data(DATA_PATH)
train_dataset = ReasoningDataset(train_data, tokenizer)
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_fn)

optimizer = torch.optim.AdamW(
    list(student1.parameters()) + list(student2.parameters()), lr=LR
)

for epoch in range(1, NUM_EPOCHS + 1):
    progress_bar = tqdm(train_loader, total=len(train_loader), desc=f"Epoch {epoch}", dynamic_ncols=True)

    for step, batch in enumerate(progress_bar, 1):
        input_ids = batch["input_ids"].to(DEVICE)
        labels1 = batch["labels1"].to(DEVICE)
        labels2 = batch["labels2"].to(DEVICE)

        optimizer.zero_grad()

        student1.train()
        out1 = student1(input_ids=input_ids, labels=labels2)
        logits1 = out1.logits
        ce1 = out1.loss

        student2.train()
        out2 = student2(input_ids=input_ids, labels=labels1)
        logits2 = out2.logits
        ce2 = out2.loss

        # Mask to ignore padding
        mask = (labels1 != -100).unsqueeze(-1)

        if USE_KL_DISTILLATION:
            # Log-softmax for temperature-scaled probs
            probs1 = F.log_softmax(logits1 / temperature, dim=-1)
            probs2 = F.log_softmax(logits2 / temperature, dim=-1)

            probs1_detach = probs1.detach()
            probs2_detach = probs2.detach()

            # Ensemble in log-space
            probs_ens = torch.logsumexp(torch.stack([
                probs1_detach + torch.log(torch.tensor(alpha, device=DEVICE)),
                probs2_detach + torch.log(torch.tensor(1 - alpha, device=DEVICE))
            ], dim=0), dim=0)

            # KL divergence
            kl1 = F.kl_div(probs1, probs_ens, reduction="none", log_target=True)
            kl2 = F.kl_div(probs2, probs_ens, reduction="none", log_target=True)

            kl1 = kl1.sum(dim=-1)[mask.squeeze(-1)].mean()
            kl2 = kl2.sum(dim=-1)[mask.squeeze(-1)].mean()

            total_loss = ce1 + ce2 + kl_weight * (kl1 + kl2)
        else:
            kl1 = torch.tensor(0.0, device=DEVICE)
            kl2 = torch.tensor(0.0, device=DEVICE)
            total_loss = ce1 + ce2
        total_loss.backward()
        optimizer.step()

        progress_bar.set_postfix({
            "CE1": f"{ce1.item():.4f}",
            "CE2": f"{ce2.item():.4f}",
            "KL1": f"{kl1.item():.4f}",
            "KL2": f"{kl2.item():.4f}",
            "Total": f"{total_loss.item():.4f}"
        })

# --- Evaluation ---
def extract_number(text):
    numbers = re.findall(r"[-+]?\d*\.\d+|\d+", text)
    return numbers[-1] if numbers else None

@torch.no_grad()
def evaluate(student_model, dataset_path, tokenizer, device):
    print(f"\nEvaluating on {dataset_path}...")
    data = load_data(dataset_path)
    dataset = EvalDataset(data)
    loader = DataLoader(dataset, batch_size=1)

    student_model.eval()
    correct = 0
    total = 0

    progress_bar = tqdm(loader, total=len(loader), desc="Testing", dynamic_ncols=False)
    for instruction, label in progress_bar:
        instruction = instruction[0]
        label = label[0]
        inputs = tokenizer(instruction, return_tensors="pt").to(device)
        outputs = student_model.generate(**inputs, max_length=256)
        decoded = tokenizer.decode(outputs[0], skip_special_tokens=True)
        pred_ans = extract_number(decoded)
        gold_ans = extract_number(label)
        if pred_ans == gold_ans:
            correct += 1
        total += 1
        # Update tqdm bar postfix with current loss values
        progress_bar.set_postfix({
            "ratio": f"{correct}/ {total}"
        })
        # print(f"\n[Input] {instruction.strip()[:100]}...\n[Pred]  {pred_ans}\n[Label] {gold_ans}")

    acc = correct / total
    print(f"\nAccuracy on {dataset_path}: {acc:.4f}\n")
    return acc

# --- Run evaluation on both students ---
# evaluate(student1, DEV_PATH, tokenizer, DEVICE)
evaluate(student1, TEST_PATH, tokenizer, DEVICE)
# evaluate(student2, DEV_PATH, tokenizer, DEVICE)
evaluate(student2, TEST_PATH, tokenizer, DEVICE)

  from .autonotebook import tqdm as notebook_tqdm


Loading tokenizer and frozen base model...


Loading checkpoint shards: 100%|██████████| 2/2 [00:10<00:00,  5.47s/it]


Creating two students with isolated LoRA adapters...
trainable params: 3,407,872 || all params: 7,245,139,968 || trainable%: 0.0470
trainable params: 3,407,872 || all params: 7,245,139,968 || trainable%: 0.0470


Epoch 1: 100%|██████████| 1869/1869 [40:29<00:00,  1.30s/it, CE1=0.6695, CE2=0.2450, KL1=0.1328, KL2=0.2435, Total=1.1027]  
Epoch 2:  12%|█▏        | 227/1869 [04:53<35:25,  1.29s/it, CE1=1.4329, CE2=0.2691, KL1=0.1266, KL2=0.2173, Total=1.8738]

In [1]:
import torch
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import get_peft_model, LoraConfig, TaskType
from torch.utils.data import DataLoader
from tqdm import tqdm
import copy, json, random

# --- Config ---
DATA_PATH = "data/train.json"
BASE_MODEL = "mistralai/Mistral-7B-Instruct-v0.1"
BATCH_SIZE = 4
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
NUM_EPOCHS = 2
temperature = 1.0
kl_weight = 1.0
LR = 1e-5

# --- Load data ---
def load_data(path):
    with open(path, "r") as f:
        return json.load(f)

# --- Dataset ---
class ReasoningDataset(torch.utils.data.Dataset):
    def __init__(self, data, tokenizer):
        self.data = data
        self.tokenizer = tokenizer

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]
        instruction = item["instruction"]
        outputs = item["output"]
        label = item["label"]
        target = outputs[0]  # just use one
        return instruction, target

def collate_fn(batch):
    inputs, targets = zip(*batch)
    input_encodings = tokenizer(list(inputs), return_tensors="pt", padding="max_length", truncation=True, max_length=512)
    target_encodings = tokenizer(list(inputs[i] + targets[i] for i in range(len(inputs))),
                                 return_tensors="pt", padding="max_length", truncation=True, max_length=1024)
    
    input_ids = input_encodings["input_ids"]
    attention_mask = input_encodings["attention_mask"]
    labels = target_encodings["input_ids"].clone()
    
    for i, inp in enumerate(inputs):
        inp_len = len(tokenizer(inp)["input_ids"])
        labels[i, :inp_len] = -100  # mask input tokens
    
    return {
        "input_ids": input_ids,
        "attention_mask": attention_mask,
        "labels": labels,
    }

# --- Model Setup ---
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
tokenizer.pad_token = tokenizer.eos_token

# Teacher (frozen)
teacher = AutoModelForCausalLM.from_pretrained(BASE_MODEL, torch_dtype=torch.float16, device_map="auto")
teacher.eval()
for param in teacher.parameters():
    param.requires_grad = False

# Student (LoRA + trainable)
def make_student(base_model):
    base_copy = copy.deepcopy(base_model)
    config = LoraConfig(
        r=8,
        lora_alpha=32,
        target_modules=["q_proj", "v_proj"],
        lora_dropout=0.05,
        bias="none",
        task_type=TaskType.CAUSAL_LM
    )
    return get_peft_model(base_copy, config)

student = make_student(teacher).to(DEVICE)
student.train()

# --- Training ---
train_data = load_data(DATA_PATH)[:1000]
dataset = ReasoningDataset(train_data, tokenizer)
loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_fn)

optimizer = torch.optim.AdamW(student.parameters(), lr=LR)

for epoch in range(1, NUM_EPOCHS + 1):
    progress_bar = tqdm(loader, desc=f"Epoch {epoch}")

    for batch in progress_bar:
        input_ids = batch["input_ids"].to(DEVICE)
        attention_mask = batch["attention_mask"].to(DEVICE)
        labels = batch["labels"].to(DEVICE)

        optimizer.zero_grad()

        # Teacher forward (no grad)
        with torch.no_grad():
            teacher_logits = teacher(input_ids=input_ids, attention_mask=attention_mask).logits
            teacher_probs = F.log_softmax(teacher_logits / temperature, dim=-1)

        # Student forward
        out = student(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
        student_logits = out.logits
        ce_loss = out.loss
        student_probs = F.log_softmax(student_logits / temperature, dim=-1)

        # KL divergence
        mask = (labels != -100).unsqueeze(-1)
        kl = F.kl_div(student_probs, teacher_probs, reduction="none", log_target=True)
        kl = kl.sum(dim=-1)[mask.squeeze(-1)].mean()

        total_loss = ce_loss + kl_weight * kl
        total_loss.backward()
        optimizer.step()

        progress_bar.set_postfix({
            "CE": f"{ce_loss.item():.4f}",
            "KL": f"{kl.item():.4f}",
            "Total": f"{total_loss.item():.4f}"
        })

  from .autonotebook import tqdm as notebook_tqdm
Loading checkpoint shards: 100%|██████████| 2/2 [00:27<00:00, 13.93s/it]
Epoch 1:   0%|          | 0/250 [00:02<?, ?it/s]


RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cuda:1!