In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
%cd "/content/drive/MyDrive/661project"
!pwd

/content/drive/MyDrive/661project
/content/drive/MyDrive/661project


In [None]:
import torch, torchvision
print(torch.__version__, torchvision.__version__)
from torchvision import datasets   # should now succeed
full_train = datasets.CIFAR10(root='data/', train=True, download=True)



2.6.0+cu124 0.21.0+cu124


In [None]:
# Suppose K = 5 and 10 classes => 2 classes per teacher
K = 5
classes_per_teacher = 10 // K  # = 2
teacher_splits = []
targets = full_train.targets  # list of integer labels

for t in range(K):
    cls_start = t * classes_per_teacher
    cls_end   = cls_start + classes_per_teacher
    # find indices whose label ∈ [cls_start, cls_end)
    idxs = [i for i, lab in enumerate(targets)
            if cls_start <= lab < cls_end]
    teacher_splits.append(idxs)


In [None]:
import torch
num_samples = len(full_train)
perm = torch.randperm(num_samples).tolist()
teacher_splits = []
split_size = num_samples // K

for t in range(K):
    start = t * split_size
    end   = start + split_size if t < K-1 else num_samples
    teacher_splits.append(perm[start:end])


In [None]:
from torch.utils.data import Subset

teacher_datasets = [
    Subset(full_train, idxs)
    for idxs in teacher_splits
]
# Now teacher_datasets[i] is the CIFAR-10 subset for teacher #i


In [None]:
from torchvision import transforms

# CIFAR-10 normalization constants
CIFAR10_MEAN = (0.4914, 0.4822, 0.4465)
CIFAR10_STD  = (0.2470, 0.2435, 0.2616)

train_transform = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(CIFAR10_MEAN, CIFAR10_STD),
])

test_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(CIFAR10_MEAN, CIFAR10_STD),
])


In [None]:
# We need to override the transform of the underlying dataset
# for each Subset. One simple approach is to wrap with a lambda:
class TransformSubset(Subset):
    def __init__(self, subset, transform):
        super().__init__(subset.dataset, subset.indices)
        self.transform = transform
    def __getitem__(self, idx):
        img, label = super().__getitem__(idx)
        return self.transform(img), label

teacher_load_datasets = [
    TransformSubset(ds, train_transform)
    for ds in teacher_datasets
]


In [None]:
from torch.utils.data import DataLoader

batch_size  = 128
num_workers = 4

teacher_loaders = [
    DataLoader(
        ds,
        batch_size=batch_size,
        shuffle=True,
        num_workers=num_workers,
        pin_memory=True
    )
    for ds in teacher_load_datasets
]

# And for your student’s full training set (hard‑label baseline):
full_train.transform = train_transform
student_train_loader = DataLoader(
    full_train,
    batch_size=batch_size,
    shuffle=True,
    num_workers=num_workers,
    pin_memory=True
)

# CIFAR-10 test set:
test_set = datasets.CIFAR10(
    root='data/',
    train=False,
    download=False,
    transform=test_transform
)
test_loader = DataLoader(
    test_set,
    batch_size=batch_size,
    shuffle=False,
    num_workers=num_workers,
    pin_memory=True
)


In [None]:
pip install torchinfo

Collecting torchinfo
  Downloading torchinfo-1.8.0-py3-none-any.whl.metadata (21 kB)
Downloading torchinfo-1.8.0-py3-none-any.whl (23 kB)
Installing collected packages: torchinfo
Successfully installed torchinfo-1.8.0


In [None]:
import torch
from torchvision.models import resnet18, resnet34, resnet50
from torchinfo import summary

# Device setup
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# 2.1 Instantiate K teacher models (here: ResNet18, ResNet34, ResNet50)
teacher_archs = ['resnet18', 'resnet34', 'resnet50']
teachers = []

for arch in teacher_archs:
    if arch == 'resnet18':
        model = resnet18(num_classes=10)
    elif arch == 'resnet34':
        model = resnet34(num_classes=10)
    elif arch == 'resnet50':
        model = resnet50(num_classes=10)
    else:
        raise ValueError(f"Unsupported architecture: {arch}")
    model.to(device)
    teachers.append((arch, model))

# 2.2 Instantiate the student model (smaller-capacity): ResNet18
student = resnet18(num_classes=10)
student.to(device)

# 2.3 Parameter & FLOPs statistics using torchinfo.summary
# Assume CIFAR-10 input size: (batch_size=1, channels=3, height=32, width=32)
input_size = (1, 3, 32, 32)

print("\n=== Teacher Models ===")
for name, model in teachers:
    print(f"\n-- {name.upper()} --")
    summary(model, input_size=input_size, col_names=("output_size", "num_params", "mult_adds"))

print("\n=== Student Model ===")
print("-- RESNET18 STUDENT --")
summary(student, input_size=input_size, col_names=("output_size", "num_params", "mult_adds"))



=== Teacher Models ===

-- RESNET18 --

-- RESNET34 --

-- RESNET50 --

=== Student Model ===
-- RESNET18 STUDENT --


Layer (type:depth-idx)                   Output Shape              Param #                   Mult-Adds
ResNet                                   [1, 10]                   --                        --
├─Conv2d: 1-1                            [1, 64, 16, 16]           9,408                     2,408,448
├─BatchNorm2d: 1-2                       [1, 64, 16, 16]           128                       128
├─ReLU: 1-3                              [1, 64, 16, 16]           --                        --
├─MaxPool2d: 1-4                         [1, 64, 8, 8]             --                        --
├─Sequential: 1-5                        [1, 64, 8, 8]             --                        --
│    └─BasicBlock: 2-1                   [1, 64, 8, 8]             --                        --
│    │    └─Conv2d: 3-1                  [1, 64, 8, 8]             36,864                    2,359,296
│    │    └─BatchNorm2d: 3-2             [1, 64, 8, 8]             128                       128
│    │    └─ReLU:

In [None]:
import copy, os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms, models
from torch.utils.data import DataLoader

# Create a directory for checkpoints
os.makedirs('checkpoints', exist_ok=True)

# 1) Setup device & data
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914,0.4822,0.4465),
                         (0.247,0.243,0.261)),
])
transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914,0.4822,0.4465),
                         (0.247,0.243,0.261)),
])

train_ds = datasets.CIFAR10('data', train=True,  download=True, transform=transform_train)
test_ds  = datasets.CIFAR10('data', train=False, download=True, transform=transform_test)

train_loader = DataLoader(train_ds, batch_size=128, shuffle=True,  num_workers=4, pin_memory=True)
test_loader  = DataLoader(test_ds,  batch_size=256, shuffle=False, num_workers=4, pin_memory=True)
opt_s = optim.SGD(student.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4)
sch_s = optim.lr_scheduler.MultiStepLR(opt_s, milestones=[100,150], gamma=0.1)

# 2) Helpers
def train_epoch(m, loader, opt, crit):
    m.train()
    for x,y in loader:
        x,y = x.to(device), y.to(device)
        opt.zero_grad()
        loss = crit(m(x), y)
        loss.backward()
        opt.step()

def train_epoch_kd(student, teachers, loader, opt, alpha, T):
    student.train()
    for x,y in loader:
        x,y = x.to(device), y.to(device)
        with torch.no_grad():
            avg_soft = None
            for t in teachers:
                s = F.softmax(t(x).div(T), dim=1)
                avg_soft = s if avg_soft is None else avg_soft + s
            avg_soft /= len(teachers)
        logits = student(x)
        loss_h = F.cross_entropy(logits, y)
        loss_s = F.kl_div(
            F.log_softmax(logits.div(T), dim=1),
            avg_soft,
            reduction='batchmean'
        ) * (T*T)
        (1-alpha)*loss_h + alpha*loss_s
        opt.zero_grad()
        ( (1-alpha)*loss_h + alpha*loss_s ).backward()
        opt.step()

@torch.no_grad()
def evaluate(m, loader):
    m.eval()
    correct = total = 0
    for x,y in loader:
        x,y = x.to(device), y.to(device)
        correct += (m(x).argmax(1)==y).sum().item()
        total   += y.size(0)
    return correct/total

# 3) Model factories
def get_resnet34():
    m = models.resnet34(weights=None, num_classes=10)
    m.conv1 = nn.Conv2d(3,64,3,1,1,bias=False)
    m.maxpool = nn.Identity()
    return m.to(device)

def get_resnet50():
    m = models.resnet50(weights=None, num_classes=10)
    m.conv1 = nn.Conv2d(3,64,3,1,1,bias=False)
    m.maxpool = nn.Identity()
    return m.to(device)

# 4) Train & save student baseline
student = get_resnet34()
opt_s = optim.SGD(student.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4)
sch_s = optim.lr_scheduler.MultiStepLR(opt_s, milestones=[100,150], gamma=0.1)
crit = nn.CrossEntropyLoss()

for _ in range(200):
    train_epoch(student, train_loader, opt_s, crit)
    sch_s.step()

acc_s = evaluate(student, test_loader)
print(f"Student baseline accuracy: {acc_s:.4f}")
torch.save(student.state_dict(), 'checkpoints/student_baseline.pth')

# 5) Train & save snapshot-ensemble teachers
teacher = get_resnet50()
opt_t = optim.SGD(teacher.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4)
sch_t = optim.lr_scheduler.CosineAnnealingWarmRestarts(opt_t, T_0=40, T_mult=1)

snapshots = []
for epoch in range(200):
    train_epoch(teacher, train_loader, opt_t, crit)
    sch_t.step()
    if (epoch+1)%40==0:
        path = f'checkpoints/teacher_snapshot_{len(snapshots)+1}.pth'
        torch.save(teacher.state_dict(), path)
        snapshots.append(path)

teachers = []
for i,p in enumerate(snapshots,1):
    m = get_resnet50()
    m.load_state_dict(torch.load(p))
    m.eval()
    teachers.append(m)
    acc_t = evaluate(m, test_loader)
    print(f"Teacher #{i} baseline accuracy: {acc_t:.4f}")

# 6) Multi-teacher KD
alpha, T = 0.7, 5
for K in range(1,6):
    kd_stud = get_resnet34()
    opt_k = optim.SGD(kd_stud.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4)
    sch_k = optim.lr_scheduler.MultiStepLR(opt_k, [100,150], gamma=0.1)
    for _ in range(200):
        train_epoch_kd(kd_stud, teachers[:K], train_loader, opt_k, alpha, T)
        sch_k.step()
    acc_k = evaluate(kd_stud, test_loader)
    print(f"KD with K={K} teachers → accuracy: {acc_k:.4f}")
    torch.save(kd_stud.state_dict(), f'checkpoints/student_k{K}.pth')


Student baseline accuracy: 0.9542
Teacher #1 baseline accuracy: 0.9252
Teacher #2 baseline accuracy: 0.9419
Teacher #3 baseline accuracy: 0.9448
Teacher #4 baseline accuracy: 0.9485
Teacher #5 baseline accuracy: 0.9478
KD with K=1 teachers → accuracy: 0.9380
KD with K=2 teachers → accuracy: 0.9492
KD with K=3 teachers → accuracy: 0.9510


# Beacuse of runtime disconnect了，so we continue on runing the below code for
k=4,5


In [None]:
# Setup CIFAR-10 data loaders, KD helper, evaluation, and ResNet-34 factory

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

from torchvision import datasets, transforms, models
from torch.utils.data import DataLoader

# Device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# 1) Data loaders
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914,0.4822,0.4465),
                         (0.247,0.243,0.261)),
])
transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914,0.4822,0.4465),
                         (0.247,0.243,0.261)),
])

train_loader = DataLoader(
    datasets.CIFAR10('data', train=True, download=True, transform=transform_train),
    batch_size=128, shuffle=True, num_workers=4, pin_memory=True
)
test_loader = DataLoader(
    datasets.CIFAR10('data', train=False, download=True, transform=transform_test),
    batch_size=256, shuffle=False, num_workers=4, pin_memory=True
)

# 2) Knowledge-distillation training epoch
def train_epoch_kd(student, teachers, loader, optimizer, alpha, T):
    student.train()
    for x, y in loader:
        x, y = x.to(device), y.to(device)
        with torch.no_grad():
            avg_soft = None
            for t in teachers:
                t = t.to(device)
                p = F.softmax(t(x) / T, dim=1)
                avg_soft = p if avg_soft is None else avg_soft + p
            avg_soft /= len(teachers)
        logits = student(x)
        loss_h = F.cross_entropy(logits, y)
        loss_s = F.kl_div(
            F.log_softmax(logits / T, dim=1),
            avg_soft, reduction='batchmean'
        ) * (T * T)
        loss = (1 - alpha) * loss_h + alpha * loss_s
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

# 3) Evaluation function
@torch.no_grad()
def evaluate(model, loader):
    model.eval()
    correct = total = 0
    for x, y in loader:
        x, y = x.to(device), y.to(device)
        preds = model(x).argmax(dim=1)
        correct += (preds == y).sum().item()
        total   += y.size(0)
    return correct / total

# 4) ResNet-34 factory
def get_resnet34():
    m = models.resnet34(weights=None, num_classes=10)
    # Adapt first conv for CIFAR-10 (32×32)
    m.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
    m.maxpool = nn.Identity()
    return m.to(device)



In [None]:

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Re-use your train_loader & test_loader, and train_epoch_kd + evaluate + get_resnet34

# Load teachers:
teachers = []
for i in range(1, 6):
    # instantiate fresh ResNet-50 and move to GPU
    t = models.resnet50(weights=None, num_classes=10).to(device)
    t.conv1 = nn.Conv2d(3,64,3,1,1,bias=False).to(device)
    t.maxpool = nn.Identity().to(device)
    # load snapshot
    ckpt = torch.load(f'checkpoints/teacher_snapshot_{i}.pth', map_location=device)
    t.load_state_dict(ckpt)
    t.eval()
    teachers.append(t)

# Distill K=4 and K=5
alpha, T = 0.7, 5
for K in [4,5]:
    # fresh student
    student = get_resnet34().to(device)
    optimizer = optim.SGD(student.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4)
    scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[100,150], gamma=0.1)
    for epoch in range(200):
        train_epoch_kd(student, teachers[:K], train_loader, optimizer, alpha, T)
        scheduler.step()
    acc = evaluate(student, test_loader)
    print(f"KD with K={K} teachers → accuracy: {acc:.4f}")
    torch.save(student.state_dict(), f'checkpoints/student_k{K}.pth')


KD with K=4 teachers → accuracy: 0.9529


In [None]:
# --- PASTE THIS IN A NEW CELL AFTER SETUP ---

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Re-use your train_loader & test_loader, and train_epoch_kd + evaluate + get_resnet34

# Load teachers:
teachers = []
for i in range(1, 6):
    # instantiate fresh ResNet-50 and move to GPU
    t = models.resnet50(weights=None, num_classes=10).to(device)
    t.conv1 = nn.Conv2d(3,64,3,1,1,bias=False).to(device)
    t.maxpool = nn.Identity().to(device)
    # load snapshot
    ckpt = torch.load(f'checkpoints/teacher_snapshot_{i}.pth', map_location=device)
    t.load_state_dict(ckpt)
    t.eval()
    teachers.append(t)

# Distill K=4 and K=5
alpha, T = 0.7, 5
for K in [5]:
    # fresh student
    student = get_resnet34().to(device)
    optimizer = optim.SGD(student.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4)
    scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[100,150], gamma=0.1)
    for epoch in range(200):
        train_epoch_kd(student, teachers[:K], train_loader, optimizer, alpha, T)
        scheduler.step()
    acc = evaluate(student, test_loader)
    print(f"KD with K={K} teachers → accuracy: {acc:.4f}")
    torch.save(student.state_dict(), f'checkpoints/student_k{K}.pth')


KD with K=5 teachers → accuracy: 0.9544
