In [42]:
from datasets import load_dataset
from transformers import BertModel, BertConfig, BertForSequenceClassification, BertTokenizer
import torch
import torch.nn as nn
import pandas as pd
from tqdm import tqdm
import os

import torch.optim as optim

import matplotlib.pyplot as plt

In [43]:
# parameters

#-- setting custom model
total_l = 6
trans_l = 1
base_model = "bert-base-uncased"
model_name = "textattack/bert-base-uncased-QNLI"
task_name = "qnli"

#-- setting result name
result_name = "qnli_cka"
model_save_path = f"/mnt/aix7101/jeong/ee/{result_name}.pt"

#-- setting training
train_strategy = "low_lr"  # 'freeze', 'low_lr', 'unfreeze'
num_epoch = 3
num_unfreeze = 3 

In [44]:
db = load_dataset("glue", "qnli")

In [45]:
print(db)

DatasetDict({
    train: Dataset({
        features: ['question', 'sentence', 'label', 'idx'],
        num_rows: 104743
    })
    validation: Dataset({
        features: ['question', 'sentence', 'label', 'idx'],
        num_rows: 5463
    })
    test: Dataset({
        features: ['question', 'sentence', 'label', 'idx'],
        num_rows: 5463
    })
})


In [46]:
# load train dataset
train_dataset = load_dataset("glue", "qnli", split="train")

# load validation dataset
val_dataset = load_dataset("glue", "qnli", split="validation")

In [47]:
print(val_dataset)

Dataset({
    features: ['question', 'sentence', 'label', 'idx'],
    num_rows: 5463
})


In [48]:
device = torch.device("cuda:2" if torch.cuda.is_available() else "cpu")
dropout = nn.Dropout(p=0.1).to(device) # in BERT default 0.1

In [49]:
# setting
tokenizer = BertTokenizer.from_pretrained(model_name)
model = BertForSequenceClassification.from_pretrained(model_name, output_hidden_states=True).eval().to(device)
model.eval()

KeyboardInterrupt: 

## Custom model

In [None]:
# from transformers import BertModel, BertConfig, BertForSequenceClassification
# import torch.nn as nn
# import torch

# class CustomBertSmall(nn.Module):
#     def __init__(self, teacher_model, total_layers=6, transplanted_layers=3):
#         super().__init__()
#         assert transplanted_layers < total_layers, "Transplanted layers must be fewer than total layers"
        
#         self.hidden_size = teacher_model.config.hidden_size
#         self.total_layers = total_layers
#         self.transplanted_layers = transplanted_layers

#         # 그대로 복사할 레이어 인덱스 계산
#         transplanted_start = 12 - transplanted_layers
#         original_layer_indices = list(range(transplanted_start))[:total_layers - transplanted_layers]

#         # Embedding 복사
#         self.embeddings = teacher_model.bert.embeddings

#         # 선택된 layer만 복사해서 재구성
#         self.encoder_layers = nn.ModuleList()

#         for idx in original_layer_indices:
#             layer = teacher_model.bert.encoder.layer[idx]
#             self.encoder_layers.append(layer)

#         for idx in range(transplanted_start, 12):
#             layer = teacher_model.bert.encoder.layer[idx]
#             self.encoder_layers.append(layer)

#         # Pooler와 Classifier도 복사
#         self.pooler = teacher_model.bert.pooler
#         self.dropout = teacher_model.dropout  # from classifier head
#         self.classifier = teacher_model.classifier

#         self.activation = nn.Tanh()  # 여전히 pooler 내부에서도 사용되지만 보존

#     # CustomBertSmall에 hidden_states 옵션 추가
#     def forward(self, input_ids, attention_mask=None, token_type_ids=None, output_hidden_states=False):
#         hidden_states = self.embeddings(input_ids=input_ids, token_type_ids=token_type_ids)

#         if attention_mask is not None:
#             extended_attention_mask = attention_mask[:, None, None, :]
#             extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
#         else:
#             extended_attention_mask = None

#         all_hidden = []  # 각 레이어 출력 저장
#         for layer in self.encoder_layers:
#             hidden_states = layer(hidden_states, attention_mask=extended_attention_mask)[0]
#             if output_hidden_states:
#                 all_hidden.append(hidden_states)

#         pooled_output = self.pooler(hidden_states)
#         pooled_output = self.dropout(self.activation(pooled_output))
#         logits = self.classifier(pooled_output)

#         if output_hidden_states:
#             return logits, all_hidden
#         else:
#             return logits

In [None]:

class CustomBertSmallForCKA(nn.Module):
    def __init__(self, teacher_model):
        super().__init__()

        self.hidden_size = teacher_model.config.hidden_size
        self.selected_layer_indices = [0, 1, 2, 3, 4, 11]   # 1-based: [1,2,3,4,5,12] → 0-based index

        # Embedding 복사
        self.embeddings = teacher_model.bert.embeddings

        # 선택된 레이어만 복사
        self.encoder_layers = nn.ModuleList([
            teacher_model.bert.encoder.layer[idx] for idx in self.selected_layer_indices
        ])

        # Pooler와 Classifier는 그대로 복사
        self.pooler = teacher_model.bert.pooler
        self.dropout = teacher_model.dropout
        self.classifier = teacher_model.classifier
        self.activation = nn.Tanh()

    def forward(self, input_ids, attention_mask=None, token_type_ids=None, output_hidden_states=False):
        hidden_states = self.embeddings(input_ids=input_ids, token_type_ids=token_type_ids)

        if attention_mask is not None:
            extended_attention_mask = attention_mask[:, None, None, :]
            extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
        else:
            extended_attention_mask = None

        all_hidden = []

        for layer in self.encoder_layers:
            hidden_states = layer(hidden_states, attention_mask=extended_attention_mask)[0]
            if output_hidden_states:
                all_hidden.append(hidden_states)

        pooled_output = self.pooler(hidden_states)
        pooled_output = self.dropout(self.activation(pooled_output))
        logits = self.classifier(pooled_output)

        if output_hidden_states:
            return logits, all_hidden
        else:
            return logits

In [None]:
from transformers import AutoTokenizer, AutoModelForSequenceClassification

# model_name = "textattack/bert-base-uncased-ag-news"

teacher_model = BertForSequenceClassification.from_pretrained(model_name, num_labels=2)

small_model = CustomBertSmallForCKA(
    teacher_model=teacher_model 
).to(device)

## [check] before training

In [None]:
# # for accuracy with QQP task
# correct_base = 0
# correct_small = 0

# model.eval()
# small_model.eval()

# for item in tqdm(val_dataset, desc="Evaluating Small Model (QQP)"):
#     text1 = item["sentence1"]
#     text2 = item["sentence2"]
#     label = item["label"]

#     # inputs for sentence pair
#     inputs = tokenizer(text1, text2, return_tensors="pt", truncation=True, padding=True).to(device)

#     # Teacher model
#     with torch.no_grad():
#         output = model(**inputs)
#         logits = output.logits
#         pred = torch.argmax(logits, dim=-1).item()

#     # Small model
#     with torch.no_grad():
#         small_logits = small_model(**inputs)
#         small_pred = torch.argmax(small_logits, dim=-1).item()

#     correct_base += int(pred == label)
#     correct_small += int(small_pred == label)

# # 최종 정확도 출력
# total = len(val_dataset)
# print(f"\n✅ Accuracy of Bertbase: {correct_base / total * 100:.2f}%")
# print(f"\n✅ Accuracy of CustomBertSmall: {correct_small / total * 100:.2f}%")

## Custom Loss

In [None]:
import torch.nn.functional as F
from torch.nn import MSELoss, KLDivLoss

def loss1(logits, labels):
    return F.cross_entropy(logits, labels)

# Representation Matching Loss (MSE between CLS tokens)
def loss2(student_hidden, teacher_hidden):
    mse = MSELoss()
    return mse(student_hidden, teacher_hidden)

# DSR Loss (KL Divergence between sorted logits)
def loss3(prev_logits, current_logits, tau=1.0):
    z_prev = torch.sort(prev_logits, dim=-1)[0]
    z_current = torch.sort(current_logits, dim=-1)[0]

    p_prev = F.softmax(z_prev / tau, dim=-1)
    p_current = F.log_softmax(z_current / tau, dim=-1)

    kldiv = KLDivLoss(reduction='batchmean')
    return (tau ** 2 / 2) * kldiv(p_current, p_prev)  # KL(p_prev || p_current)

In [None]:
import torch
import torch.nn.functional as F

def compute_cka(X: torch.Tensor, Y: torch.Tensor, eps=1e-8):
    X = X - X.mean(dim=0, keepdim=True)
    Y = Y - Y.mean(dim=0, keepdim=True)

    dot_product_similarity = (X.T @ Y).norm(p='fro') ** 2
    normalization_x = (X.T @ X).norm(p='fro')
    normalization_y = (Y.T @ Y).norm(p='fro')
    return dot_product_similarity / (normalization_x * normalization_y + eps)

## Train

In [None]:
def evaluate(model, val_loader, tokenizer, device):
    model.eval()
    correct = 0
    total = 0

    with torch.no_grad():
        for batch in tqdm(val_loader, desc="Evaluating"):
            text1 = batch["sentence1"][0]
            text2 = batch["sentence2"][0]
            label = batch["label"].item()

            inputs = tokenizer(text1, text2, return_tensors="pt", padding=True, truncation=True).to(device)
            logits = model(**inputs)
            pred = torch.argmax(logits, dim=-1).item()

            correct += int(pred == label)
            total += 1

    return correct / total * 100

In [None]:
def get_matched_teacher_layers(n_student_layers, n_teacher_layers=12):
    return np.linspace(1, n_teacher_layers, n_student_layers, dtype=int).tolist()


In [None]:

def train_cka_loss_model(
    model,
    train_dataset,
    val_dataset,
    tokenizer,
    teacher_model=None,
    custom_loss=False,  # CKA 전용 loss
    strategy="freeze",
    batch_size=16,
    epochs=10,
    base_lr=5e-5,
    low_lr=5e-6,
    k=3,
    alpha=1.0,  # alpha는 의미 없음, loss = cka_loss 단일
    unfreeze_epoch=1,
    save_path="best_model.pt",
    device="cuda:1" if torch.cuda.is_available() else "cpu",
    evaluate_fn=None,
    task_config=None
):
    if task_config is None:
        raise ValueError("task_config must be provided.")
    if custom_loss and teacher_model is None:
        raise ValueError("teacher_model must be provided when using custom_loss=True")
    if evaluate_fn is None:
        raise ValueError("evaluate_fn must be provided for evaluation")

    input_keys = task_config["inputs"]
    label_key = task_config["label"]
    task_type = task_config["type"]

    model = model.to(device)
    teacher_model = teacher_model.to(device) if teacher_model else None

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=1)

    if strategy == "low_lr":
        optimizer_grouped = [
            {"params": model.encoder_layers[0].parameters(), "lr": low_lr},   # ⬅️ 첫 번째 레이어
            {"params": model.encoder_layers[-1].parameters(), "lr": low_lr},  # ⬅️ 마지막 레이어
            {"params": [p for l in model.encoder_layers[1:-1] for p in l.parameters()], "lr": base_lr},
            {"params": model.pooler.parameters(), "lr": base_lr},
            {"params": model.classifier.parameters(), "lr": base_lr},
        ]
    else:
        optimizer_grouped = model.parameters()  

    optimizer = AdamW(optimizer_grouped, lr=base_lr)

    # 기본 loss는 CE지만, custom_loss=True일 경우 사용 안 함
    loss_fn = nn.MSELoss() if task_type == "regression" else nn.CrossEntropyLoss()

    if strategy == "freeze":
        for layer in model.encoder_layers[-k:]:
            for param in layer.parameters():
                param.requires_grad = False

    best_score = None

    for epoch in range(epochs):
        model.train()
        total_loss = 0
        print(f"\n📘 Epoch {epoch+1}/{epochs}")

        if strategy == "unfreeze" and epoch == unfreeze_epoch:
            print("--<Unfreezing last K layers>--")
            for layer in model.encoder_layers[-k:]:
                for param in layer.parameters():
                    param.requires_grad = True

        for batch in tqdm(train_loader, desc="Training"):
            # 입력 처리
            if len(input_keys) == 2:
                texts1 = batch[input_keys[0]]
                texts2 = batch[input_keys[1]]
                tokenized = tokenizer(list(texts1), list(texts2), return_tensors="pt", padding=True, truncation=True)
            else:
                texts = batch[input_keys[0]]
                tokenized = tokenizer(list(texts), return_tensors="pt", padding=True, truncation=True)

            inputs = {k: v.to(device) for k, v in tokenized.items()}
            labels = batch[label_key].to(device)

            optimizer.zero_grad()

            if custom_loss:
                # forward
                logits_small, student_hiddens = model(**inputs, output_hidden_states=True)
                with torch.no_grad():
                    teacher_outputs = teacher_model(**inputs, output_hidden_states=True)
                    teacher_hiddens = teacher_outputs.hidden_states  # 13개 (embedding 포함)

                # layer 대응
                s_h = student_hiddens[1:]  # skip embedding
                t_indices = get_matched_teacher_layers(len(s_h), n_teacher_layers=12)
                t_h = [teacher_hiddens[i] for i in t_indices]

                # CLS 기준으로 CKA loss 계산
                loss_cka = 0.0
                for t, s in zip(t_h, s_h):
                    t_cls = t[:, 0, :]
                    s_cls = s[:, 0, :]
                    loss_cka += 1 - compute_cka(t_cls, s_cls)
                loss_cka /= len(s_h)
                loss = loss_cka

            else:
                logits = model(**inputs)
                if task_type == "regression":
                    labels = labels.float()
                    loss = loss_fn(logits.squeeze(), labels)
                else:
                    loss = loss_fn(logits, labels)

            loss.backward()
            optimizer.step()
            total_loss += loss.item()

        avg_loss = total_loss / len(train_loader)
        print(f"📉 Avg Training Loss: {avg_loss:.4f}")

        score = evaluate_fn(model, val_loader, tokenizer, device)
        print(f"📊 Validation Score: {score:.4f}")

        if best_score is None or score > best_score:
            best_score = score
            torch.save(model.state_dict(), save_path)
            print(f"✅ Best model saved with score: {best_score:.4f} → {save_path}")

In [None]:
GLUE_TASKS = {
    "sst2":  {"inputs": ["sentence"],                     "label": "label", "type": "binary",     "model": "textattack/bert-base-uncased-SST-2"},
    "cola":  {"inputs": ["sentence"],                     "label": "label", "type": "binary",     "model": "textattack/bert-base-uncased-CoLA"},
    "qqp":   {"inputs": ["question1", "question2"],       "label": "label", "type": "binary",     "model": "textattack/bert-base-uncased-QQP"},
    "qnli":  {"inputs": ["question", "sentence"],         "label": "label", "type": "binary",     "model": "textattack/bert-base-uncased-QNLI"},
    "mrpc":  {"inputs": ["sentence1", "sentence2"],       "label": "label", "type": "binary",     "model": "textattack/bert-base-uncased-MRPC"},
    "rte":   {"inputs": ["sentence1", "sentence2"],       "label": "label", "type": "binary",     "model": "textattack/bert-base-uncased-RTE"},
    "stsb":  {"inputs": ["sentence1", "sentence2"],       "label": "label", "type": "regression", "model": "textattack/bert-base-uncased-STS-B"},
}

In [None]:
from sklearn.metrics import accuracy_score, f1_score

def make_evaluate_fn(task_config):
    inputs = task_config["inputs"]
    label_key = task_config["label"]
    task_type = task_config["type"]

    def evaluate(model, val_loader, tokenizer, device):
        model.eval()
        preds = []
        labels = []

        with torch.no_grad():
            for batch in tqdm(val_loader, desc="Evaluating"):
                if len(inputs) == 2:
                    text1 = batch[inputs[0]][0]
                    text2 = batch[inputs[1]][0]
                    encoded = tokenizer(text1, text2, return_tensors="pt", padding=True, truncation=True)
                else:
                    text = batch[inputs[0]][0]
                    encoded = tokenizer(text, return_tensors="pt", padding=True, truncation=True)

                encoded = {k: v.to(device) for k, v in encoded.items()}
                label = batch[label_key].item()
                output = model(**encoded)

                if task_type == "regression":
                    pred = output.squeeze().cpu().item()
                else:
                    pred = torch.argmax(output, dim=-1).item()

                preds.append(pred)
                labels.append(label)

        # 결과 계산
        if task_type == "regression":
            score = pearsonr(preds, labels)[0] * 100  # %
        elif task_type in ["binary", "3-class"]:
            acc = accuracy_score(labels, preds)
            if len(set(labels)) == 2:
                f1 = f1_score(labels, preds)
                score = (acc + f1) / 2 * 100
            else:
                score = acc * 100
        else:
            raise ValueError("Unknown task type")

        return score

    return evaluate


In [None]:
from torch.utils.data import DataLoader
from torch.optim import AdamW
# from loss import cka_delta_loss

In [None]:
task_config = GLUE_TASKS[task_name]

evaluate_fn = make_evaluate_fn(task_config)

In [None]:
train_cka_loss_model(
    model=small_model,
    train_dataset=train_dataset,
    val_dataset=val_dataset,
    tokenizer=tokenizer,
    teacher_model=model,
    custom_loss=False,
    strategy="low_lr",
    batch_size=16,
    epochs=num_epoch,
    k=trans_l,
    unfreeze_epoch=num_unfreeze,
    save_path=model_save_path,
    evaluate_fn=evaluate_fn,
    task_config=task_config,
    device=device
)

#-- load trained model

small_model.load_state_dict(torch.load(model_save_path, map_location=device))
small_model = small_model.eval().to(device)


📘 Epoch 1/10


Training:   2%|▏         | 147/6547 [00:06<05:04, 21.02it/s]Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
Training:  32%|███▏      | 2078/6547 [01:25<02:55, 25.50it/s]Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
Training:  51%|█████     | 3320/6547 [02:16<02:11, 24.57it/s]Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
Training:  65%|██████▍   | 4232/6547 [02:54<01:31, 25.43it/s]Be aware, overflowing tokens are not returned for the setting you have chosen, 

📉 Avg Training Loss: 0.3719


Evaluating: 100%|██████████| 5463/5463 [00:35<00:00, 152.94it/s]


📊 Validation Score: 87.6077
✅ Best model saved with score: 87.6077 → /mnt/aix7101/jeong/ee/qnli_cak.pt

📘 Epoch 2/10


Training:  10%|▉         | 651/6547 [00:26<03:58, 24.71it/s]Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
Training:  77%|███████▋  | 5042/6547 [03:28<01:01, 24.67it/s]Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
Training:  81%|████████  | 5315/6547 [03:39<00:49, 24.74it/s]Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
Training:  82%|████████▏ | 5354/6547 [03:41<00:48, 24.56it/s]Be aware, overflowing tokens are not returned for the setting you have chosen, 

📉 Avg Training Loss: 0.2530


Evaluating: 100%|██████████| 5463/5463 [00:35<00:00, 152.77it/s]


📊 Validation Score: 87.7796
✅ Best model saved with score: 87.7796 → /mnt/aix7101/jeong/ee/qnli_cak.pt

📘 Epoch 3/10


Training:  23%|██▎       | 1533/6547 [01:02<03:24, 24.54it/s]Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
Training:  47%|████▋     | 3081/6547 [02:06<02:21, 24.50it/s]Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
Training:  63%|██████▎   | 4103/6547 [02:49<01:38, 24.83it/s]Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
Training:  76%|███████▌  | 4946/6547 [03:24<01:04, 24.94it/s]Be aware, overflowing tokens are not returned for the setting you have chosen,

📉 Avg Training Loss: 0.1691


Evaluating: 100%|██████████| 5463/5463 [00:35<00:00, 152.24it/s]


📊 Validation Score: 87.5264

📘 Epoch 4/10


Training:   4%|▍         | 288/6547 [00:11<04:12, 24.79it/s]Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
Training:  28%|██▊       | 1809/6547 [01:14<03:09, 25.05it/s]Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
Training:  52%|█████▏    | 3390/6547 [02:20<02:30, 20.99it/s]Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
Training:  56%|█████▌    | 3644/6547 [02:31<02:00, 24.19it/s]Be aware, overflowing tokens are not returned for the setting you have chosen, 

📉 Avg Training Loss: 0.1156


Evaluating: 100%|██████████| 5463/5463 [00:35<00:00, 153.10it/s]


📊 Validation Score: 86.1633

📘 Epoch 5/10


Training:  27%|██▋       | 1735/6547 [01:11<03:34, 22.42it/s]Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
Training:  56%|█████▌    | 3634/6547 [02:30<02:00, 24.17it/s]Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
Training:  68%|██████▊   | 4426/6547 [03:04<01:26, 24.52it/s]Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
Training:  73%|███████▎  | 4780/6547 [03:18<01:10, 25.08it/s]Be aware, overflowing tokens are not returned for the setting you have chosen,

📉 Avg Training Loss: 0.0887


Evaluating: 100%|██████████| 5463/5463 [00:35<00:00, 152.80it/s]


📊 Validation Score: 87.1776

📘 Epoch 6/10


Training:  23%|██▎       | 1478/6547 [01:01<03:26, 24.55it/s]Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
Training:  37%|███▋      | 2424/6547 [01:40<02:46, 24.72it/s]Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
Training:  40%|████      | 2640/6547 [01:49<02:39, 24.47it/s]Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
Training:  45%|████▌     | 2969/6547 [02:03<02:26, 24.37it/s]Be aware, overflowing tokens are not returned for the setting you have chosen,

📉 Avg Training Loss: 0.0729


Evaluating: 100%|██████████| 5463/5463 [00:36<00:00, 151.60it/s]


📊 Validation Score: 87.1403

📘 Epoch 7/10


Training:  19%|█▉        | 1250/6547 [00:52<03:28, 25.43it/s]Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
Training:  68%|██████▊   | 4474/6547 [03:08<01:23, 24.70it/s]Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
Training:  75%|███████▍  | 4879/6547 [03:24<01:06, 25.10it/s]Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
Training:  87%|████████▋ | 5668/6547 [03:58<00:34, 25.19it/s]Be aware, overflowing tokens are not returned for the setting you have chosen,

📉 Avg Training Loss: 0.0633


Evaluating: 100%|██████████| 5463/5463 [00:36<00:00, 150.45it/s]


📊 Validation Score: 87.0171

📘 Epoch 8/10


Training:  25%|██▍       | 1628/6547 [01:08<03:52, 21.12it/s]Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
Training:  25%|██▌       | 1653/6547 [01:09<03:41, 22.14it/s]Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
Training:  66%|██████▌   | 4323/6547 [03:00<01:32, 24.12it/s]Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
Training:  81%|████████  | 5307/6547 [03:42<00:49, 25.25it/s]Be aware, overflowing tokens are not returned for the setting you have chosen,

📉 Avg Training Loss: 0.0549


Evaluating: 100%|██████████| 5463/5463 [00:36<00:00, 151.11it/s]


📊 Validation Score: 86.7216

📘 Epoch 9/10


Training:  10%|█         | 675/6547 [00:27<03:59, 24.49it/s]Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
Training:  29%|██▉       | 1886/6547 [01:18<03:18, 23.45it/s]Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
Training:  81%|████████  | 5281/6547 [03:38<00:50, 25.14it/s]Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
Training:  85%|████████▌ | 5584/6547 [03:51<00:39, 24.40it/s]Be aware, overflowing tokens are not returned for the setting you have chosen, 

📉 Avg Training Loss: 0.0524


Evaluating: 100%|██████████| 5463/5463 [00:35<00:00, 153.56it/s]


📊 Validation Score: 86.8777

📘 Epoch 10/10


Training:  25%|██▌       | 1647/6547 [01:07<03:39, 22.37it/s]Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
Training:  37%|███▋      | 2397/6547 [01:38<02:44, 25.16it/s]Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
Training:  48%|████▊     | 3165/6547 [02:10<02:19, 24.26it/s]Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
Training:  60%|██████    | 3929/6547 [02:42<01:58, 22.05it/s]Be aware, overflowing tokens are not returned for the setting you have chosen,

📉 Avg Training Loss: 0.0462


Evaluating: 100%|██████████| 5463/5463 [00:35<00:00, 152.48it/s]


📊 Validation Score: 87.3821


In [None]:

small_model.load_state_dict(torch.load(model_save_path, map_location=device))
small_model = small_model.eval().to(device)


In [None]:
def evaluate_teacher_student(
    teacher_model,
    student_model,
    val_dataset,
    tokenizer,
    device,
    task_config,
    task_name=None  # 👈 task 이름 추가로 받음 (cola 확인용)
):
    inputs_key = task_config["inputs"]
    label_key = task_config["label"]
    task_type = task_config["type"]

    teacher_model.eval()
    student_model.eval()

    preds_teacher = []
    preds_student = []
    labels = []

    for item in tqdm(val_dataset, desc="Evaluating Teacher vs Student"):
        if len(inputs_key) == 2:
            input_text1 = item[inputs_key[0]]
            input_text2 = item[inputs_key[1]]
            tokenized = tokenizer(input_text1, input_text2, return_tensors="pt", padding=True, truncation=True).to(device)
        else:
            input_text = item[inputs_key[0]]
            tokenized = tokenizer(input_text, return_tensors="pt", padding=True, truncation=True).to(device)

        label = item[label_key]
        if isinstance(label, torch.Tensor):
            label = label.item()
        labels.append(label)

        with torch.no_grad():
            t_logits = teacher_model(**tokenized)
            t_pred = t_logits.logits.squeeze().item() if task_type == "regression" else torch.argmax(t_logits.logits, dim=-1).item()

            s_logits = student_model(**tokenized)
            s_pred = s_logits.squeeze().item() if task_type == "regression" else torch.argmax(s_logits, dim=-1).item()

        preds_teacher.append(t_pred)
        preds_student.append(s_pred)

    # 🎯 점수 계산
    if task_type == "regression":
        pearson_t = pearsonr(preds_teacher, labels)[0] * 100
        pearson_s = pearsonr(preds_student, labels)[0] * 100
        print(f"\n✅ Pearson of Teacher: {pearson_t:.2f}%")
        print(f"✅ Pearson of Student: {pearson_s:.2f}%")

    elif task_name == "cola":
        mcc_t = matthews_corrcoef(labels, preds_teacher) * 100
        mcc_s = matthews_corrcoef(labels, preds_student) * 100
        print(f"\n✅ MCC of Teacher: {mcc_t:.2f}%")
        print(f"✅ MCC of Student: {mcc_s:.2f}%")

    else:
        acc_t = accuracy_score(labels, preds_teacher) * 100
        acc_s = accuracy_score(labels, preds_student) * 100
        print(f"\n✅ Accuracy of Teacher: {acc_t:.2f}%")
        print(f"✅ Accuracy of Student: {acc_s:.2f}%")
    return acc_t, acc_s

In [None]:
evaluate_teacher_student(
    teacher_model=model,
    student_model=small_model,
    val_dataset=val_dataset,
    tokenizer=tokenizer,
    device=device,
    task_config=task_config,
    task_name=task_name
)

Evaluating Teacher vs Student: 100%|██████████| 5463/5463 [01:26<00:00, 63.14it/s]


✅ Accuracy of Teacher: 91.54%
✅ Accuracy of Student: 87.61%





(91.54310818231741, 87.60754164378547)