In [2]:
import torch
import torch.nn as nn
from torch.optim import AdamW  # ← 여기로 수정
from transformers import BertForSequenceClassification, AutoTokenizer
from torch.utils.data import DataLoader
from tqdm import tqdm
import numpy as np

device = "cuda" if torch.cuda.is_available() else "cpu"

In [4]:
from datasets import load_dataset
from transformers import AutoTokenizer

# Load SST-2
train_dataset = load_dataset("glue", "sst2", split="train")
val_dataset = load_dataset("glue", "sst2", split="validation")

# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")

# Preprocessing function
def preprocess_function(example):
    return tokenizer(example["sentence"], padding="max_length", truncation=True, max_length=128)

# Tokenize
train_dataset = train_dataset.map(preprocess_function, batched=True)
val_dataset = val_dataset.map(preprocess_function, batched=True)

# Set format for PyTorch
train_dataset.set_format(type="torch", columns=["input_ids", "attention_mask", "label"])
val_dataset.set_format(type="torch", columns=["input_ids", "attention_mask", "label"])

Map: 100%|██████████| 67349/67349 [00:05<00:00, 12162.41 examples/s]
Map: 100%|██████████| 872/872 [00:00<00:00, 10236.08 examples/s]


In [5]:
from torch.utils.data import DataLoader

train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=16)

In [6]:
# 모델과 토크나이저 불러오기

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_name = "JeremiahZ/bert-base-uncased-sst2"
tokenizer = AutoTokenizer.from_pretrained(model_name)
origin_model = BertForSequenceClassification.from_pretrained(model_name, output_hidden_states=True).to(device)

# origin_model 복사 (updated_model은 학습 대상)
import copy
updated_model = copy.deepcopy(origin_model).to(device)

dropout = nn.Dropout(p=0.1).to(device) # in BERT default 0.1

In [7]:
# MSELoss 또는 CosineEmbeddingLoss 등 가능
representation_loss_fn = nn.MSELoss()
classification_loss_fn = nn.CrossEntropyLoss()

optimizer = AdamW(updated_model.parameters(), lr=2e-5)

In [8]:
def train_with_representation_alignment(
    updated_model, origin_model, dataloader, device,
    rep_loss_weight=0.5, epochs=3
):
    updated_model.train()
    origin_model.eval()

    classification_loss_fn = nn.CrossEntropyLoss()
    rep_loss_fn = nn.MSELoss()
    optimizer = AdamW(updated_model.parameters(), lr=2e-5)

    for epoch in range(epochs):
        total_loss = 0
        print(f"\nEpoch {epoch+1}/{epochs}")
        
        for batch in tqdm(dataloader, desc="Training"):
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            labels = batch["label"].to(device)

            inputs = {"input_ids": input_ids, "attention_mask": attention_mask}

            # ---- Origin Model ----
            with torch.no_grad():
                origin_outputs = origin_model(**inputs, output_hidden_states=True)
                origin_cls = origin_outputs.hidden_states[-1][:, 0]  # [CLS]

            # ---- Updated Model ----
            updated_outputs = updated_model(**inputs, output_hidden_states=True)
            logits = updated_outputs.logits
            updated_cls = updated_outputs.hidden_states[-3][:, 0]  # 앞쪽 layer와 representation 결과와 matching

            # Loss 계산
            ce_loss = classification_loss_fn(logits, labels)
            rep_loss = rep_loss_fn(updated_cls, origin_cls)
            loss = ce_loss + rep_loss_weight * rep_loss

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            total_loss += loss.item()

        avg_loss = total_loss / len(dataloader)
        print(f"Avg Loss: {avg_loss:.4f}")

In [9]:
train_with_representation_alignment(
    updated_model=updated_model,
    origin_model=origin_model,
    dataloader=train_loader,
    device=device,
    rep_loss_weight=0.5,
    epochs=3
)


Epoch 1/3


Training: 100%|██████████| 4210/4210 [05:54<00:00, 11.88it/s]


Avg Loss: 0.1513

Epoch 2/3


Training: 100%|██████████| 4210/4210 [06:00<00:00, 11.67it/s]


Avg Loss: 0.1083

Epoch 3/3


Training: 100%|██████████| 4210/4210 [06:00<00:00, 11.67it/s]

Avg Loss: 0.0860





In [13]:
# for accuracy with small model
correct_base = 0
correct_small = 0

origin_model.eval()
updated_model.eval()

for batch in tqdm(val_loader, desc="Evaluating Updated Model"):
    input_ids = batch["input_ids"].to(device)
    attention_mask = batch["attention_mask"].to(device)
    labels = batch["label"].to(device)

    inputs = {"input_ids": input_ids, "attention_mask": attention_mask}

    with torch.no_grad():
        output = origin_model(**inputs)
        logits = output.logits  # SequenceClassifierOutput에서 logits 추출
        pred = torch.argmax(logits, dim=-1).item()

    with torch.no_grad():
        updated_output = updated_model(**inputs)
        updated_logits = updated_output.logits  # SequenceClassifierOutput에서 logits 추출
        updated_pred = torch.argmax(updated_logits, dim=-1).item()
    
    # 정확도 계산
    correct_base += (preds_base == labels).sum().item()
    correct_updated += (preds_updated == labels).sum().item()
    total += labels.size(0)


# 정확도 출력
total = len(val_dataset)
print(f"\n✅ Accuracy of Bertbase: {correct_base / total * 100:.2f}%")
print(f"\n✅ Accuracy of updated_model: {correct_updated / total * 100:.2f}%")

Evaluating Updated Model:   0%|          | 0/55 [00:00<?, ?it/s]


RuntimeError: a Tensor with 16 elements cannot be converted to Scalar

In [None]:
# 정확도 저장 리스트
acc_original = []
acc_updated = []

for layer in range(1, 13):  # BERT-base: layer 1~12
    correct_orig = 0
    correct_small_by_layer = 0
    total = 0

    for item in tqdm(val_dataset, desc=f"Layer {layer}"):
        text = item["sentence"]
        label = item["label"]
        inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True).to(device)

        with torch.no_grad():
            # BERT-base layer별 정확도
            outputs = origin_model(**inputs)
            hidden = outputs.hidden_states[layer]
            pooled = origin_model.bert.pooler(hidden)
            pooled_dropped = dropout(pooled)
            logits_orig = origin_model.classifier(pooled_dropped)
            pred_orig = torch.argmax(logits_orig).item()
        
        with torch.no_grad():
            # updated-bert layer별 정확도
            updated_outputs = updated_model(**inputs)
            updated_hidden = updated_outputs.hidden_states[layer]
            updated_pooled = updated_model.bert.pooler(updated_hidden)
            updated_pooled_dropped = dropout(updated_pooled)
            updated_logits_orig = updated_model.classifier(updated_pooled_dropped)
            updated_pred_orig = torch.argmax(updated_logits_orig).item()

            correct_orig += int(pred_orig == label)
            total += 1

    # 마지막 layer 기준에서 정확도 누적 (1 epoch 끝났을 때만 기록됨)
    if layer == 12:
        acc_updated = [c / total for c in correct_small_by_layer]

In [None]:
for i, (acc_o, acc_u) in enumerate(zip(acc_original, acc_updated), start=1):
    print(f"Layer {i:>2}: Origin Acc = {acc_o:.2%} | Updated Acc = {acc_u:.2%}")

In [None]:
plt.figure(figsize=(7, 4))
plt.plot(range(1, 13), acc_original, label="BERT-base", marker='o', linewidth=2)
plt.plot(range(1, 13), acc_updated, label="Updated-model", marker='x', linewidth=2)
plt.xlabel("Layer Number")
plt.ylabel("Accuracy")
plt.title("Layer-wise Accuracy: BERT-base vs Updated-model")
plt.legend()
plt.grid(True)
plt.xticks(range(1, 13))
plt.ylim(0, 1.0)
plt.tight_layout()
# plt.savefig(f"img/bert_vs_small_layers_{epoch_num}.png", dpi=300)
plt.show()