In [2]:
# 1. Mount Google Drive to access your Day 1 files
from google.colab import drive
drive.mount('/content/drive')

# 2. Add your project folder to the Python path so we can import model_archs.py
import sys
import os
sys.path.append('/content/drive/MyDrive/KD_Project')

# 3. Import necessary libraries
import torch
import torch.nn as nn
import torch.optim as optim
import pickle
from torch.utils.data import DataLoader, TensorDataset

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [4]:
from datasets import load_dataset
from transformers import AutoTokenizer

# 1. Re-load the dataset
raw_dataset = load_dataset("glue", "sst2")

# 2. Re-load the tokenizer
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")

# 3. Re-tokenize the data (This creates 'tokenized_ds' again)
def tokenize_fn(batch):
    return tokenizer(batch["sentence"], truncation=True, padding="max_length", max_length=128)

tokenized_ds = raw_dataset.map(tokenize_fn, batched=True)

tokenizer_config.json:   0%|          | 0.00/48.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

Map:   0%|          | 0/67349 [00:00<?, ? examples/s]

Map:   0%|          | 0/872 [00:00<?, ? examples/s]

Map:   0%|          | 0/1821 [00:00<?, ? examples/s]

In [5]:
# Load the logits you extracted yesterday
with open('/content/drive/MyDrive/KD_Project/teacher_logits.pkl', 'rb') as f:
    teacher_logits = pickle.load(f)

# Load the SST-2 training data (same as Day 1)
dataset = load_dataset("glue", "sst2")

# Wrap them into a PyTorch Dataset
# teacher_logits: [67349, 2], labels: [67349], input_ids: [67349, 128]
train_ids = torch.tensor(tokenized_ds["train"]["input_ids"])
train_labels = torch.tensor(tokenized_ds["train"]["label"])

# Combined dataset: Student sees text, real answer, and professor's notes
distill_dataset = TensorDataset(train_ids, train_labels, teacher_logits)
train_loader = DataLoader(distill_dataset, batch_size=64, shuffle=True)

In [8]:
import torch.nn as nn
import torch.nn.functional as F
from model_arch import BiLSTMStudent # Ensure this file is in your path

# 1. Distillation Loss Function
def distillation_loss(student_logits, teacher_logits, labels, T=4.0, alpha=0.5):
    # Soft loss (KL Divergence between softened distributions)
    soft_targets = F.softmax(teacher_logits / T, dim=1)
    student_soft = F.log_softmax(student_logits / T, dim=1)
    loss_soft = F.kl_div(student_soft, soft_targets, reduction='batchmean') * (T**2)

    # Hard loss (Standard comparison to the 0/1 label)
    loss_hard = F.cross_entropy(student_logits, labels)

    return alpha * loss_soft + (1 - alpha) * loss_hard

# 2. Setup Model, Optimizer on GPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
student = BiLSTMStudent(vocab_size=30522, embed_dim=128, hidden_dim=256, output_dim=2).to(device)
optimizer = torch.optim.Adam(student.parameters(), lr=1e-3)

In [9]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Initialize Student (Bi-LSTM)
student = BiLSTMStudent(vocab_size=30522, embed_dim=128, hidden_dim=256, output_dim=2).to(device)
optimizer = optim.Adam(student.parameters(), lr=1e-3)

T = 4.0      # Temperature
alpha = 0.5  # 50% Teacher, 50% Ground Truth

print("Starting Distillation...")
for epoch in range(5):
    student.train()
    epoch_loss = 0

    for batch in train_loader:
        # Move everything to GPU
        ids, labels, t_logits = [t.to(device) for t in batch]

        optimizer.zero_grad()

        # 1. Forward pass
        s_logits = student(ids)

        # 2. Compute the Dual Loss (KD Loss)
        # We call the function from your Part 2 explanation
        loss = distillation_loss(s_logits, t_logits, labels, T, alpha)

        # 3. Backward pass
        loss.backward()
        optimizer.step()

        epoch_loss += loss.item()

    print(f"Epoch {epoch+1} Complete | Average Loss: {epoch_loss/len(train_loader):.4f}")

# Save the student's "Brain"
torch.save(student.state_dict(), "/content/drive/MyDrive/KD_Project/student_lstm.pth")

Starting Distillation...
Epoch 1 Complete | Average Loss: 1.3320
Epoch 2 Complete | Average Loss: 0.6729
Epoch 3 Complete | Average Loss: 0.4594
Epoch 4 Complete | Average Loss: 0.3525
Epoch 5 Complete | Average Loss: 0.4521
