# Domain adaptation exp2 – fixed-ratio distillation

## Mount Google Drive

In [2]:
from google.colab import drive
drive.mount('/content/drive')


Mounted at /content/drive


## Environment setup and imports

In [3]:
!pip -q install -U transformers datasets accelerate evaluate scikit-learn

import os, json, math, random
import numpy as np
import torch
import torch.nn.functional as F
from dataclasses import dataclass
from typing import Dict, Any, Optional, List

from datasets import load_dataset, Dataset, DatasetDict, concatenate_datasets
from transformers import (
    AutoTokenizer, AutoModelForSequenceClassification,
    TrainingArguments, Trainer
)
from sklearn.metrics import classification_report, f1_score, accuracy_score


[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/511.6 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[90m╺[0m [32m501.8/511.6 kB[0m [31m18.2 MB/s[0m eta [36m0:00:01[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m511.6/511.6 kB[0m [31m12.2 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m84.1/84.1 kB[0m [31m9.2 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m8.9/8.9 MB[0m [31m127.1 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m47.7/47.7 MB[0m [31m56.6 MB/s[0m eta [36m0:00:00[0m
[?25h

## Seeds and constants

In [4]:
from pathlib import Path
SEED = 42
random.seed(SEED); np.random.seed(SEED); torch.manual_seed(SEED)
if torch.cuda.is_available(): torch.cuda.manual_seed_all(SEED)

TARGET_LABELS = ["anger","anticipation","caring","disgust","fear","joy","neutral","sadness","surprise"]
label2id = {l:i for i,l in enumerate(TARGET_LABELS)}
id2label = {i:l for l,i in label2id.items()}
NUM_LABELS = len(TARGET_LABELS)

# Your already-trained 9-class softmax student (from Notebook 1)
MODEL_PATH = "/content/drive/MyDrive/VibeQ-EIE/models/student_singlelabel_9emotions_v1"

OUT_DIR = Path("/content/drive/MyDrive/VibeQ-EIE/models/student_distilled_9_v2")
OUT_DIR.mkdir(parents=True, exist_ok=True)

MAX_LENGTH = 192
TEMP = 2.0           # distillation temperature (2–4 typical)
LAMBDA_KL = 0.35     # how much teacher guides (start 0.2–0.5)
ANCHOR_RATIO = 0.6   # fraction of batches from GoEmotions anchor (0.5–0.8 recommended)


## Load tokenizer and student init

In [5]:
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)

student = AutoModelForSequenceClassification.from_pretrained(
    MODEL_PATH,
    num_labels=NUM_LABELS,
    id2label=id2label,
    label2id=label2id,
)


## Load GoEmotions base dataset

In [6]:
dataset = load_dataset("go_emotions", "simplified")

# remove unclear if present
for split in ["train","validation","test"]:
    if "example_very_unclear" in dataset[split].column_names:
        dataset[split] = dataset[split].filter(lambda x: x["example_very_unclear"] == 0)

original_labels = dataset["train"].features["labels"].feature.names
orig_id2label = {i:n for i,n in enumerate(original_labels)}

EMOTION_MAP = {
    "anger": "anger", "annoyance": "anger", "disapproval": "anger",
    "optimism": "anticipation", "curiosity": "anticipation", "desire": "anticipation",
    "caring": "caring", "love": "caring", "admiration": "caring", "gratitude": "caring", "approval": "caring",
    "disgust": "disgust",
    "fear": "fear", "nervousness": "fear",
    "joy": "joy", "excitement": "joy", "amusement": "joy", "pride": "joy", "relief": "joy",
    "neutral": "neutral",
    "sadness": "sadness", "disappointment": "sadness", "grief": "sadness", "remorse": "sadness", "embarrassment": "sadness",
    "surprise": "surprise", "confusion": "surprise", "realization": "surprise"
}

def map_goemotions_to_primary(batch):
    texts = batch["text"]
    y = []
    for labs in batch["labels"]:
        # map multi labels -> 9 buckets counts; choose the strongest bucket by "frequency"
        buckets = []
        for old_idx in labs:
            name = orig_id2label[old_idx]
            if name in EMOTION_MAP:
                buckets.append(EMOTION_MAP[name])
        if len(buckets) == 0:
            # default fallback (rare after filtering)
            buckets = ["neutral"]
        # choose most frequent bucket
        primary = max(set(buckets), key=buckets.count)
        y.append(label2id[primary])

    enc = tokenizer(texts, truncation=True, padding="max_length", max_length=MAX_LENGTH)
    enc["labels"] = y
    enc["is_teacher"] = [0]*len(y)
    return enc

go_encoded = dataset.map(map_goemotions_to_primary, batched=True, remove_columns=dataset["train"].column_names)
go_encoded.set_format("torch", columns=["input_ids","attention_mask","labels","is_teacher"])

go_train = go_encoded["train"]
go_val   = go_encoded["validation"]
go_test  = go_encoded["test"]


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


README.md: 0.00B [00:00, ?B/s]

simplified/train-00000-of-00001.parquet:   0%|          | 0.00/2.77M [00:00<?, ?B/s]

simplified/validation-00000-of-00001.par(…):   0%|          | 0.00/350k [00:00<?, ?B/s]

simplified/test-00000-of-00001.parquet:   0%|          | 0.00/347k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/43410 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/5426 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/5427 [00:00<?, ? examples/s]

Map:   0%|          | 0/43410 [00:00<?, ? examples/s]

Map:   0%|          | 0/5426 [00:00<?, ? examples/s]

Map:   0%|          | 0/5427 [00:00<?, ? examples/s]

## Create anchor teacher probabilities

In [7]:
NUM_LABELS = 9  # your 9 emotions

def anchor_fill_teacher_probs_single(ex):
    y = int(ex["labels"])  # scalar class id
    tp = np.zeros(NUM_LABELS, dtype=np.float32)
    tp[y] = 1.0
    ex["teacher_probs"] = tp.tolist()
    return ex

go_train = go_train.map(anchor_fill_teacher_probs_single)
go_val   = go_val.map(anchor_fill_teacher_probs_single)
go_test  = go_test.map(anchor_fill_teacher_probs_single)


Map:   0%|          | 0/43410 [00:00<?, ? examples/s]

Map:   0%|          | 0/5426 [00:00<?, ? examples/s]

Map:   0%|          | 0/5427 [00:00<?, ? examples/s]

## Flag teacher vs anchor rows

In [8]:
def ensure_is_teacher_go(batch):
    batch["is_teacher"] = [0] * len(batch["labels"])
    return batch

go_train = go_train.map(ensure_is_teacher_go, batched=True)
go_val   = go_val.map(ensure_is_teacher_go, batched=True)
go_test  = go_test.map(ensure_is_teacher_go, batched=True)


Map:   0%|          | 0/43410 [00:00<?, ? examples/s]

Map:   0%|          | 0/5426 [00:00<?, ? examples/s]

Map:   0%|          | 0/5427 [00:00<?, ? examples/s]

## Set torch formats for Trainer

In [9]:
go_train.set_format(type="torch", columns=["input_ids","attention_mask","labels","is_teacher","teacher_probs"])
go_val.set_format(type="torch", columns=["input_ids","attention_mask","labels","is_teacher","teacher_probs"])
go_test.set_format(type="torch", columns=["input_ids","attention_mask","labels","is_teacher","teacher_probs"])


## Load synthetic teacher datasets

In [10]:
from pathlib import Path
import pandas as pd

TEACHER_JOURNALS_JSONL = "/content/drive/MyDrive/VibeQ-EIE/llmdata/teacher_journals_distillation.jsonl"
TEACHER_ISEAR_JSONL    = "/content/drive/MyDrive/VibeQ-EIE/llmdata/teacher_isear_distillation.jsonl"  # <-- your relabeled output

def load_teacher_jsonl(path):
    rows = []
    with open(path, "r", encoding="utf-8") as f:
        for line in f:
            r = json.loads(line)
            probs = [float(r["teacher_emotion_probs"][k]) for k in TARGET_LABELS]
            rows.append({"text": r["text"], "teacher_probs": probs})
    return rows

teacher_rows = load_teacher_jsonl(TEACHER_JOURNALS_JSONL) + load_teacher_jsonl(TEACHER_ISEAR_JSONL)
teacher_ds = Dataset.from_list(teacher_rows)

def encode_teacher(batch):
    enc = tokenizer(batch["text"], truncation=True, padding="max_length", max_length=MAX_LENGTH)
    enc["labels"] = [-100]*len(batch["text"])              # ignored for CE
    enc["teacher_probs"] = batch["teacher_probs"]          # distillation target
    enc["is_teacher"] = [1]*len(batch["text"])
    return enc

teacher_encoded = teacher_ds.map(encode_teacher, batched=True, remove_columns=teacher_ds.column_names)
teacher_encoded.set_format("torch", columns=["input_ids","attention_mask","labels","teacher_probs","is_teacher"])


Map:   0%|          | 0/3763 [00:00<?, ? examples/s]

## Sanity checks

In [11]:


# Quick sanity checks (should print torch.Tensor types)
print(type(go_train[0]["input_ids"]), type(go_train[0]["labels"]), type(go_train[0]["is_teacher"]))
print(type(teacher_encoded[0]["input_ids"]), type(teacher_encoded[0]["teacher_probs"]), type(teacher_encoded[0]["is_teacher"]))



<class 'torch.Tensor'> <class 'torch.Tensor'> <class 'torch.Tensor'>
<class 'torch.Tensor'> <class 'torch.Tensor'> <class 'torch.Tensor'>


## Balance anchor and teacher ratios

In [12]:
def repeat_dataset(ds, times):
    return concatenate_datasets([ds]*times)

# compute repeats to approximate ratio
# Example: want 60% anchor batches -> roughly 60% examples from anchor
anchor_n = len(go_train)
teacher_n = len(teacher_encoded)

# target total size
target_total = max(anchor_n, teacher_n) * 2
target_anchor = int(target_total * ANCHOR_RATIO)
target_teacher = target_total - target_anchor

# repeats
rep_anchor = math.ceil(target_anchor / anchor_n)
rep_teacher = math.ceil(target_teacher / teacher_n)

mixed_train = concatenate_datasets([
    repeat_dataset(go_train, rep_anchor).select(range(target_anchor)),
    repeat_dataset(teacher_encoded, rep_teacher).select(range(target_teacher)),
]).shuffle(seed=SEED)

mixed_train.set_format(
    type="torch",
    columns=["input_ids", "attention_mask", "labels", "is_teacher", "teacher_probs"]
)

# Validation must be GoEmotions-only to prevent regression
val_ds = go_val
test_ds = go_test

print("Mixed train size:", len(mixed_train))
print("Val (GoEmotions) size:", len(val_ds))


Mixed train size: 86820
Val (GoEmotions) size: 5426


## Mixed dataset checks

In [13]:
print("mixed_train types:",
      type(mixed_train[0]["input_ids"]),
      type(mixed_train[0]["labels"]),
      type(mixed_train[0]["is_teacher"]))

# Find any row that still returns lists (should print nothing)
bad = []
for i in range(50):  # sample a bit
    row = mixed_train[i]
    for k in ["input_ids","attention_mask","labels","is_teacher","teacher_probs"]:
        if not isinstance(row[k], torch.Tensor):
            bad.append((i, k, type(row[k])))
print("bad:", bad[:10])


mixed_train types: <class 'torch.Tensor'> <class 'torch.Tensor'> <class 'torch.Tensor'>
bad: []


## Distillation data collator

In [14]:
@dataclass
class DistillCollator:
    def __call__(self, features):
        batch = {}

        # These are already torch tensors → DO NOT wrap with torch.tensor()
        batch["input_ids"] = torch.stack([f["input_ids"] for f in features])
        batch["attention_mask"] = torch.stack([f["attention_mask"] for f in features])

        # ---- labels: ALWAYS LongTensor ----
        batch["labels"] = torch.tensor(
            [int(f["labels"]) for f in features],
            dtype=torch.long
        )

        # ---- is_teacher flag ----
        batch["is_teacher"] = torch.tensor(
            [int(f.get("is_teacher", 0)) for f in features],
            dtype=torch.long
        )

        # ---- teacher_probs (always present by construction) ----
        if all("teacher_probs" in f and f["teacher_probs"] is not None for f in features):
            batch["teacher_probs"] = torch.stack([f["teacher_probs"] for f in features]).float()
        else:
            # safe fallback: zeros (KL will be skipped if teacher_mask has no teacher samples)
            batch["teacher_probs"] = torch.zeros((len(features), NUM_LABELS), dtype=torch.float32)


        return batch


collator = DistillCollator()



## Custom trainers for fixed ratios

In [15]:
from torch.utils.data import DataLoader

class FixedRatioTrainer(Trainer):
    def __init__(
        self,
        *args,
        anchor_dataset=None,
        teacher_dataset=None,
        anchor_ratio=0.5,
        **kwargs
    ):
        super().__init__(*args, **kwargs)
        assert anchor_dataset is not None and teacher_dataset is not None
        self.anchor_dataset = anchor_dataset
        self.teacher_dataset = teacher_dataset
        self.anchor_ratio = anchor_ratio

    def get_train_dataloader(self):
        total_bs = self.args.per_device_train_batch_size
        anchor_bs = max(1, int(round(total_bs * self.anchor_ratio)))
        teacher_bs = max(1, total_bs - anchor_bs)

        anchor_loader = DataLoader(
            self.anchor_dataset,
            batch_size=anchor_bs,
            shuffle=True,
            collate_fn=self.data_collator,
            num_workers=self.args.dataloader_num_workers,
            pin_memory=self.args.dataloader_pin_memory,
            drop_last=True,
        )

        teacher_loader = DataLoader(
            self.teacher_dataset,
            batch_size=teacher_bs,
            shuffle=True,
            collate_fn=self.data_collator,
            num_workers=self.args.dataloader_num_workers,
            pin_memory=self.args.dataloader_pin_memory,
            drop_last=True,
        )

        def cycle(dl):
            while True:
                for b in dl:
                    yield b

        anchor_it = cycle(anchor_loader)
        teacher_it = cycle(teacher_loader)

        # Define epoch length by anchor_loader
        steps_per_epoch = len(anchor_loader)

        class MergedLoader:
            def __len__(self):
                return steps_per_epoch

            def __iter__(self):
                for _ in range(steps_per_epoch):
                    a = next(anchor_it)
                    t = next(teacher_it)

                    out = {}
                    # concat common keys
                    for k in a.keys():
                        if k in t:
                            out[k] = torch.cat([a[k], t[k]], dim=0)
                        else:
                            out[k] = a[k]
                    for k in t.keys():
                        if k not in out:
                            out[k] = t[k]
                    yield out

        return MergedLoader()


## KL distillation mixin

In [16]:
class DistillTrainer(Trainer):
    def __init__(self, *args, temp=2.0, lambda_kl=0.35, **kwargs):
        super().__init__(*args, **kwargs)
        self.temp = temp
        self.lambda_kl = lambda_kl

    def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None, **kwargs):
        labels = inputs.pop("labels")
        is_teacher = inputs.pop("is_teacher")
        teacher_probs = inputs.pop("teacher_probs", None)

        outputs = model(**inputs)
        logits = outputs.logits  # [B,9]

        # ----- CE on anchor samples only -----
        anchor_mask = (is_teacher == 0)
        ce_loss = None
        if anchor_mask.any():
            ce_loss = F.cross_entropy(
                logits[anchor_mask],
                labels[anchor_mask],
            )
        else:
            ce_loss = torch.tensor(0.0, device=logits.device)

        # ----- KL on teacher samples only -----
        teacher_mask = (is_teacher == 1)
        kl_loss = torch.tensor(0.0, device=logits.device)

        if teacher_probs is not None and teacher_mask.any():
            # student log-probs with temperature
            logp_s = F.log_softmax(logits[teacher_mask] / self.temp, dim=-1)
            # teacher probs -> clamp for safety
            p_t = teacher_probs[teacher_mask].to(logits.device).clamp(min=1e-8)
            p_t = p_t / p_t.sum(dim=-1, keepdim=True)
            # KL(P_teacher || P_student)
            kl_loss = F.kl_div(logp_s, p_t, reduction="batchmean") * (self.temp ** 2)

        loss = ce_loss + self.lambda_kl * kl_loss
        return (loss, outputs) if return_outputs else loss


## Final trainer class

In [17]:
class FixedRatioDistillTrainer(FixedRatioTrainer, DistillTrainer):
    pass


## Metrics

In [18]:
def compute_metrics(eval_pred):
    logits, labels = eval_pred
    preds = np.argmax(logits, axis=-1)
    acc = accuracy_score(labels, preds)
    f1m = f1_score(labels, preds, average="macro", zero_division=0)
    f1u = f1_score(labels, preds, average="micro", zero_division=0)
    return {"acc": acc, "f1_macro": f1m, "f1_micro": f1u}


## Training arguments

In [19]:
args = TrainingArguments(
    output_dir=OUT_DIR,
    learning_rate=2e-5,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=64,
    num_train_epochs=3,
    weight_decay=0.01,
    warmup_ratio=0.1,

    eval_strategy="epoch",
    save_strategy="epoch",
    load_best_model_at_end=True,
    metric_for_best_model="f1_macro",
    greater_is_better=True,

    fp16=torch.cuda.is_available(),
    logging_steps=100,
    save_total_limit=2,
    report_to="none",
    remove_unused_columns=False,
)


## Initialize trainer

In [20]:
trainer = FixedRatioDistillTrainer(
    model=student,
    args=args,
    train_dataset=go_train,     # not used directly; we override dataloader
    eval_dataset=go_val,        # GoEmotions-only validation
    tokenizer=tokenizer,
    data_collator=collator,
    compute_metrics=compute_metrics,
    temp=TEMP,
    lambda_kl=LAMBDA_KL,

    anchor_dataset=go_train,
    teacher_dataset=teacher_encoded,
    anchor_ratio=ANCHOR_RATIO,  # use your 0.6 here
)


dl = trainer.get_train_dataloader()
print("len(train_dataloader) =", len(dl))
b = next(iter(dl))
print({k: tuple(v.shape) for k,v in b.items()})
print("anchor in batch =", int((b["is_teacher"]==0).sum()), "teacher =", int((b["is_teacher"]==1).sum()))


  super().__init__(*args, **kwargs)


len(train_dataloader) = 4341
{'input_ids': (16, 192), 'attention_mask': (16, 192), 'labels': (16,), 'is_teacher': (16,), 'teacher_probs': (16, 9)}
anchor in batch = 10 teacher = 6


## Train model

In [21]:
trainer.train()


Epoch,Training Loss,Validation Loss,Acc,F1 Macro,F1 Micro
1,0.7698,1.210224,0.617766,0.576987,0.617766
2,0.5244,1.612737,0.613896,0.567493,0.613896
3,0.2229,2.40777,0.615923,0.567184,0.615923


TrainOutput(global_step=13023, training_loss=0.5466259915570753, metrics={'train_runtime': 6154.6929, 'train_samples_per_second': 33.855, 'train_steps_per_second': 2.116, 'total_flos': 7.2821466756053e+16, 'train_loss': 0.5466259915570753, 'epoch': 3.0})

## Save checkpoint

In [22]:
trainer.save_model(OUT_DIR)
tokenizer.save_pretrained(OUT_DIR)
print("Saved:", OUT_DIR)

Saved: /content/drive/MyDrive/VibeQ-EIE/models/student_distilled_9_v2
