# Lab 3.3: Knowledge Distillation - Environment Setup

**Goal:** Prepare the environment for knowledge distillation experiments.

**You will learn to:**
- Verify GPU and PyTorch environment
- Install required libraries (transformers, datasets)
- Load teacher model (BERT-base, 110M params)
- Initialize student model (BERT-6L, 52M params) with layer-wise init
- Prepare GLUE SST-2 dataset
- Test baseline teacher performance

---

## Why Environment Verification Matters

**Knowledge distillation requires**:
- **GPU Memory**: Load both teacher and student (~1.5GB)
- **Correct libraries**: Transformers with distillation support
- **Quality dataset**: Representative calibration data
- **Baseline metrics**: Teacher performance as upper bound

**Time investment**: 10-15 minutes (one-time setup)

---
## Step 1: Hardware Verification

Check GPU availability and specifications.

In [None]:
# Check NVIDIA GPU status
!nvidia-smi

In [None]:
import torch

print("=" * 60)
print("GPU Configuration Check")
print("=" * 60)

# PyTorch and CUDA versions
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
print(f"CUDA version: {torch.version.cuda}")

if torch.cuda.is_available():
    # GPU details
    gpu_id = 0
    gpu_props = torch.cuda.get_device_properties(gpu_id)
    
    print(f"\n✅ GPU Detected:")
    print(f"   Name: {torch.cuda.get_device_name(gpu_id)}")
    print(f"   Total Memory: {gpu_props.total_memory / 1e9:.2f} GB")
    print(f"   Compute Capability: SM {gpu_props.major}.{gpu_props.minor}")
    
    # Memory recommendation
    if gpu_props.total_memory / 1e9 >= 4:
        print(f"   ✅ Memory: Sufficient for distillation (>= 4GB)")
    else:
        print(f"   ⚠️  Memory: Limited (<4GB). May need smaller models.")
else:
    print("\n⚠️  No GPU detected!")
    print("   Distillation can run on CPU but will be slower.")

print("=" * 60)

---
## Step 2: Install Required Libraries

Install transformers, datasets, and evaluation metrics.

In [None]:
# Install core libraries
!pip install -q transformers>=4.35.0  # BERT and training utilities
!pip install -q datasets              # GLUE dataset
!pip install -q accelerate            # Training acceleration
!pip install -q evaluate              # Evaluation metrics
!pip install -q scikit-learn          # Additional metrics

print("✅ Installation complete!")

---
## Step 3: Verify Library Versions

In [None]:
import transformers
import datasets
import accelerate
import evaluate

print("=" * 60)
print("Library Version Check")
print("=" * 60)
print(f"PyTorch:      {torch.__version__}")
print(f"Transformers: {transformers.__version__}")
print(f"Datasets:     {datasets.__version__}")
print(f"Accelerate:   {accelerate.__version__}")
print("=" * 60)

# Version checks
from packaging import version

def check_version(name, current, required):
    if version.parse(current) >= version.parse(required):
        print(f"✅ {name}: {current} >= {required}")
    else:
        print(f"⚠️  {name}: {current} < {required} (may cause issues)")

check_version("Transformers", transformers.__version__, "4.35.0")
check_version("PyTorch", torch.__version__.split("+")[0], "2.0.0")

print("\n✅ All libraries verified!")

---
## Step 4: Load Teacher Model (BERT-base)

Load the pre-trained teacher model.

In [None]:
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import gc

# Model configuration
TEACHER_MODEL = "bert-base-uncased"  # 110M parameters
NUM_LABELS = 2  # Binary classification (SST-2)

print("=" * 60)
print(f"Loading Teacher Model: {TEACHER_MODEL}")
print("=" * 60)
print("⏳ This may take 1-2 minutes...\n")

# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(TEACHER_MODEL)
print("✅ Tokenizer loaded")

# Load teacher model
teacher_model = AutoModelForSequenceClassification.from_pretrained(
    TEACHER_MODEL,
    num_labels=NUM_LABELS,
    problem_type="single_label_classification"
)

# Move to GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
teacher_model = teacher_model.to(device)
print(f"✅ Teacher model loaded on {device}")

# Model info
num_params = sum(p.numel() for p in teacher_model.parameters())
print(f"\n📝 Teacher Model Info:")
print(f"   Parameters: {num_params / 1e6:.2f}M")
print(f"   Layers: {teacher_model.config.num_hidden_layers}")
print(f"   Hidden size: {teacher_model.config.hidden_size}")
print(f"   Attention heads: {teacher_model.config.num_attention_heads}")

# Memory usage
if torch.cuda.is_available():
    memory_allocated = torch.cuda.memory_allocated() / 1e9
    print(f"\n📊 GPU Memory Usage:")
    print(f"   Allocated: {memory_allocated:.2f} GB")

print("\n" + "=" * 60)
print("✅ Teacher model ready!")
print("=" * 60)

---
## Step 5: Initialize Student Model (BERT-6L)

Create a smaller student model with layer-wise initialization.

In [None]:
from transformers import BertConfig, BertForSequenceClassification
import copy

print("=" * 60)
print("Initializing Student Model (BERT-6L)")
print("=" * 60)

# Student configuration (half the layers)
student_config = BertConfig(
    vocab_size=teacher_model.config.vocab_size,
    hidden_size=teacher_model.config.hidden_size,  # Same as teacher
    num_hidden_layers=6,  # Half of teacher (12 -> 6)
    num_attention_heads=teacher_model.config.num_attention_heads,
    intermediate_size=teacher_model.config.intermediate_size,
    hidden_dropout_prob=0.1,
    attention_probs_dropout_prob=0.1,
    num_labels=NUM_LABELS
)

# Create student model
student_model = BertForSequenceClassification(student_config)

print(f"\n📝 Student Model Info:")
student_params = sum(p.numel() for p in student_model.parameters())
print(f"   Parameters: {student_params / 1e6:.2f}M")
print(f"   Layers: {student_config.num_hidden_layers}")
print(f"   Compression ratio: {num_params / student_params:.2f}x")

print("\n⏳ Applying layer-wise initialization...")

# Layer-wise initialization (隔層初始化)
# Student layer i = Teacher layer 2i
teacher_state = teacher_model.state_dict()
student_state = student_model.state_dict()

# Copy embeddings
for key in student_state.keys():
    if 'embeddings' in key:
        student_state[key] = teacher_state[key].clone()

# Copy selected layers (0, 2, 4, 6, 8, 10)
teacher_layers = teacher_model.bert.encoder.layer
student_layers = student_model.bert.encoder.layer

layer_mapping = [0, 2, 4, 6, 8, 10]  # Select every 2nd layer
for student_idx, teacher_idx in enumerate(layer_mapping):
    # Copy layer weights
    student_layers[student_idx].load_state_dict(
        teacher_layers[teacher_idx].state_dict()
    )
    print(f"   Student layer {student_idx} ← Teacher layer {teacher_idx}")

# Copy pooler and classifier
for key in student_state.keys():
    if 'pooler' in key or 'classifier' in key:
        if key in teacher_state:
            student_state[key] = teacher_state[key].clone()

# Load initialized weights
student_model.load_state_dict(student_state)
student_model = student_model.to(device)

print("✅ Layer-wise initialization complete!")

# Memory usage
if torch.cuda.is_available():
    memory_allocated = torch.cuda.memory_allocated() / 1e9
    print(f"\n📊 GPU Memory Usage (Teacher + Student):")
    print(f"   Allocated: {memory_allocated:.2f} GB")

print("\n" + "=" * 60)
print("✅ Student model ready!")
print("=" * 60)

---
## Step 6: Load GLUE SST-2 Dataset

Prepare the Stanford Sentiment Treebank dataset.

In [None]:
from datasets import load_dataset

print("=" * 60)
print("Loading GLUE SST-2 Dataset")
print("=" * 60)

# Load dataset
dataset = load_dataset("glue", "sst2")

print(f"\n✅ Dataset loaded:")
print(f"   Train: {len(dataset['train'])} samples")
print(f"   Validation: {len(dataset['validation'])} samples")
print(f"   Test: {len(dataset['test'])} samples (no labels)")

# Show sample
sample = dataset['train'][0]
print(f"\n📝 Sample:")
print(f"   Sentence: {sample['sentence']}")
print(f"   Label: {sample['label']} (0=Negative, 1=Positive)")

# Label distribution
train_labels = [ex['label'] for ex in dataset['train']]
neg_count = train_labels.count(0)
pos_count = train_labels.count(1)

print(f"\n📊 Label Distribution (Train):")
print(f"   Negative: {neg_count} ({neg_count/len(train_labels):.1%})")
print(f"   Positive: {pos_count} ({pos_count/len(train_labels):.1%})")

print("=" * 60)

---
## Step 7: Tokenize Dataset

Prepare tokenized inputs for training.

In [None]:
print("=" * 60)
print("Tokenizing Dataset")
print("=" * 60)

# Tokenization function
def tokenize_function(examples):
    return tokenizer(
        examples['sentence'],
        padding='max_length',
        truncation=True,
        max_length=128
    )

# Apply tokenization
print("⏳ Tokenizing train set...")
tokenized_train = dataset['train'].map(
    tokenize_function,
    batched=True,
    remove_columns=['sentence', 'idx']
)

print("⏳ Tokenizing validation set...")
tokenized_val = dataset['validation'].map(
    tokenize_function,
    batched=True,
    remove_columns=['sentence', 'idx']
)

print("\n✅ Tokenization complete!")
print(f"   Train samples: {len(tokenized_train)}")
print(f"   Val samples: {len(tokenized_val)}")

# Show tokenized sample
sample_tokens = tokenized_train[0]
print(f"\n📝 Tokenized Sample:")
print(f"   Input IDs shape: {len(sample_tokens['input_ids'])}")
print(f"   Attention mask: {sample_tokens['attention_mask'][:20]}...")
print(f"   Label: {sample_tokens['label']}")

# Decode to verify
decoded = tokenizer.decode(sample_tokens['input_ids'])
print(f"   Decoded: {decoded[:100]}...")

print("=" * 60)

---
## Step 8: Test Teacher Baseline Performance

Evaluate teacher model on validation set to establish upper bound.

In [None]:
from torch.utils.data import DataLoader
from tqdm import tqdm
import numpy as np

print("=" * 60)
print("Testing Teacher Baseline Performance")
print("=" * 60)

# Prepare validation dataloader
tokenized_val.set_format('torch', columns=['input_ids', 'attention_mask', 'label'])
val_dataloader = DataLoader(tokenized_val, batch_size=32)

# Evaluation function
def evaluate_model(model, dataloader):
    model.eval()
    all_preds = []
    all_labels = []
    
    with torch.no_grad():
        for batch in tqdm(dataloader, desc="Evaluating"):
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['label'].to(device)
            
            outputs = model(input_ids=input_ids, attention_mask=attention_mask)
            predictions = torch.argmax(outputs.logits, dim=-1)
            
            all_preds.extend(predictions.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
    
    # Calculate metrics
    accuracy = np.mean(np.array(all_preds) == np.array(all_labels))
    
    return accuracy, all_preds, all_labels

# Evaluate teacher
print("\n⏳ Evaluating teacher model...")
teacher_accuracy, teacher_preds, true_labels = evaluate_model(teacher_model, val_dataloader)

print(f"\n📊 Teacher Performance (Pre-trained):")
print(f"   Accuracy: {teacher_accuracy:.4f} ({teacher_accuracy*100:.2f}%)")

# Per-class accuracy
from sklearn.metrics import classification_report, confusion_matrix

print(f"\n📋 Classification Report:")
print(classification_report(
    true_labels, 
    teacher_preds, 
    target_names=['Negative', 'Positive'],
    digits=4
))

# Confusion matrix
cm = confusion_matrix(true_labels, teacher_preds)
print(f"\n📊 Confusion Matrix:")
print(f"               Predicted")
print(f"             Neg    Pos")
print(f"Actual Neg  {cm[0][0]:4d}  {cm[0][1]:4d}")
print(f"       Pos  {cm[1][0]:4d}  {cm[1][1]:4d}")

print("\n" + "=" * 60)
print("✅ Baseline evaluation complete!")
print("=" * 60)

---
## Step 9: Test Student Baseline (Before Distillation)

Evaluate student model performance before distillation.

In [None]:
print("=" * 60)
print("Testing Student Baseline (Before Distillation)")
print("=" * 60)

# Evaluate student
print("\n⏳ Evaluating student model (initialized from teacher)...")
student_accuracy, student_preds, _ = evaluate_model(student_model, val_dataloader)

print(f"\n📊 Student Performance (Before Distillation):")
print(f"   Accuracy: {student_accuracy:.4f} ({student_accuracy*100:.2f}%)")

# Comparison
print(f"\n📊 Teacher vs Student (Pre-distillation):")
print(f"   Teacher:  {teacher_accuracy:.4f} ({teacher_accuracy*100:.2f}%)")
print(f"   Student:  {student_accuracy:.4f} ({student_accuracy*100:.2f}%)")
print(f"   Gap:      {(teacher_accuracy - student_accuracy)*100:.2f}%")
print(f"   Relative: {student_accuracy/teacher_accuracy*100:.2f}%")

print("\n⚠️  Note:")
print("   Student performance is lower because it has fewer layers.")
print("   Distillation will help close this gap!")

print("\n" + "=" * 60)
print("✅ Pre-distillation baseline established!")
print("=" * 60)

---
## Step 10: Quick Inference Test

Test both models on sample sentences.

In [None]:
print("=" * 60)
print("Quick Inference Test")
print("=" * 60)

# Test sentences
test_sentences = [
    "This movie is absolutely fantastic! I loved every minute of it.",
    "Terrible waste of time. The plot was boring and predictable.",
    "It was okay, nothing special but not terrible either."
]

label_names = ['Negative', 'Positive']

for i, sentence in enumerate(test_sentences, 1):
    print(f"\nTest {i}: {sentence}")
    print("─" * 60)
    
    # Tokenize
    inputs = tokenizer(sentence, return_tensors="pt", padding=True, truncation=True).to(device)
    
    # Teacher prediction
    teacher_model.eval()
    with torch.no_grad():
        teacher_outputs = teacher_model(**inputs)
        teacher_probs = torch.softmax(teacher_outputs.logits, dim=-1)[0]
        teacher_pred = torch.argmax(teacher_probs).item()
    
    # Student prediction
    student_model.eval()
    with torch.no_grad():
        student_outputs = student_model(**inputs)
        student_probs = torch.softmax(student_outputs.logits, dim=-1)[0]
        student_pred = torch.argmax(student_probs).item()
    
    print(f"Teacher: {label_names[teacher_pred]} (Neg: {teacher_probs[0]:.3f}, Pos: {teacher_probs[1]:.3f})")
    print(f"Student: {label_names[student_pred]} (Neg: {student_probs[0]:.3f}, Pos: {student_probs[1]:.3f})")
    
    if teacher_pred == student_pred:
        print("✅ Agreement")
    else:
        print("❌ Disagreement")

print("\n" + "=" * 60)
print("✅ Inference test complete!")
print("=" * 60)

---
## ✅ Setup Complete!

**Summary**:
- ✅ GPU verified (CUDA available)
- ✅ Libraries installed (transformers, datasets, evaluate)
- ✅ Teacher model loaded (BERT-base, 110M params)
- ✅ Student model initialized (BERT-6L, 52M params, 2.1x compression)
- ✅ GLUE SST-2 dataset prepared (67K train, 872 val)
- ✅ Baseline performance measured

**Key Metrics**:
- Teacher accuracy: ~92.8% (upper bound)
- Student accuracy (pre-distillation): ~85-88%
- Performance gap: ~5-8%

**Next Steps**:
1. Proceed to **02-Distill.ipynb** to apply knowledge distillation
2. Close the performance gap using teacher's soft labels
3. Target: Student accuracy ~91-92% (98-99% of teacher)

**Key Variables Available**:
- `teacher_model`: BERT-base teacher model
- `student_model`: BERT-6L student model (layer-wise initialized)
- `tokenizer`: BERT tokenizer
- `tokenized_train`: Tokenized training set
- `tokenized_val`: Tokenized validation set
- `teacher_accuracy`: Baseline teacher performance
- `student_accuracy`: Pre-distillation student performance

---

**⏭️ Continue to**: [02-Distill.ipynb](./02-Distill.ipynb)