In [None]:
import sys
sys.path.append("..")
import torchvision.transforms as transforms
from Utils.TinyImageNet_loader import get_tinyimagenet_dataloaders

image_size =224
tiny_transform_train = transforms.Compose([
        transforms.RandomHorizontalFlip(),
        transforms.Resize((image_size, image_size)), 
        transforms.RandomCrop(image_size, padding=5),
        transforms.RandomRotation(10),
        transforms.ToTensor(),
        transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
    ])
tiny_transform_val = transforms.Compose([
        transforms.Resize((image_size, image_size)), 
        transforms.ToTensor(),
        transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
    ])
tiny_transform_test = transforms.Compose([
        transforms.Resize((image_size, image_size)), 
        transforms.ToTensor(),
        transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
    ])

train_loader, val_loader, test_loader = get_tinyimagenet_dataloaders(
                                                    data_dir = '../datasets',
                                                    transform_train=tiny_transform_train,
                                                    transform_val=tiny_transform_val,
                                                    transform_test=tiny_transform_test,
                                                    batch_size=64,
                                                    image_size=image_size)


In [5]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader
import torch.optim as optim
import torchvision.models as models

#############################################
#            Data Loader Setup              #
#############################################


#############################################
#       Model Components & Architecture     #
#############################################

# 1. Shared Convolutional Stem for Patch Embedding
class SharedConvStem(nn.Module):
    def __init__(self, in_channels=3, embed_dim=48):
        super(SharedConvStem, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, embed_dim // 2, kernel_size=3, stride=2, padding=1)
        self.bn1 = nn.BatchNorm2d(embed_dim // 2)
        self.conv2 = nn.Conv2d(embed_dim // 2, embed_dim, kernel_size=3, stride=2, padding=1)
        self.bn2 = nn.BatchNorm2d(embed_dim)

    def forward(self, x):
        # x: [B, 3, 64, 64] --> [B, embed_dim, 16, 16]
        x = F.relu(self.bn1(self.conv1(x)))
        x = F.relu(self.bn2(self.conv2(x)))
        return x

# 2. Multi-Head Self-Attention (MSA) Module
class MultiHeadSelfAttention(nn.Module):
    def __init__(self, dim, num_heads=4):
        super(MultiHeadSelfAttention, self).__init__()
        self.num_heads = num_heads
        self.head_dim = dim // num_heads
        self.scale = self.head_dim ** -0.5
        
        self.qkv = nn.Linear(dim, dim * 3, bias=True)
        self.proj = nn.Linear(dim, dim)

    def forward(self, x):
        """
        x: [B, N, dim]
        """
        B, N, C = x.shape
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim)
        q, k, v = qkv[:, :, 0], qkv[:, :, 1], qkv[:, :, 2]
        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = attn.softmax(dim=-1)
        out = (attn @ v).transpose(1, 2).reshape(B, N, C)
        out = self.proj(out)
        return out

# 3. Swin Block with Partial Parameter Sharing (sharing the MSA)
class SharedMSABlock(nn.Module):
    def __init__(self, dim, num_heads, shared_msa: MultiHeadSelfAttention = None):
        super(SharedMSABlock, self).__init__()
        self.msa = shared_msa  # Shared MSA module
        self.norm1 = nn.LayerNorm(dim)
        self.norm2 = nn.LayerNorm(dim)
        hidden_dim = dim * 4
        self.ffn = nn.Sequential(
            nn.Linear(dim, hidden_dim),
            nn.GELU(),
            nn.Linear(hidden_dim, dim),
        )

    def forward(self, x):
        shortcut = x
        x = self.norm1(x)
        x = self.msa(x)
        x = x + shortcut

        shortcut = x
        x = self.norm2(x)
        x = self.ffn(x)
        x = x + shortcut
        return x

# 4. Progressive Token Merging (naively merges tokens by averaging 2x2 patches)
def progressive_token_merge(x, merge_ratio=2):
    """
    x: [B, N, dim] where N is assumed to be a perfect square.
    """
    B, N, dim = x.shape
    H = int(N**0.5)
    W = H
    x_2d = x.reshape(B, H, W, dim)
    new_H = H // merge_ratio
    new_W = W // merge_ratio

    merged_tokens = []
    for i in range(new_H):
        row_tokens = []
        for j in range(new_W):
            patch = x_2d[:, i*merge_ratio:(i+1)*merge_ratio, j*merge_ratio:(j+1)*merge_ratio, :]
            patch_mean = patch.mean(dim=(1, 2))  # Average pooling in spatial dimensions
            row_tokens.append(patch_mean.unsqueeze(1))
        row_tokens = torch.cat(row_tokens, dim=1)
        merged_tokens.append(row_tokens)
    merged_tokens = torch.cat(merged_tokens, dim=1)
    return merged_tokens

# 5. Complete Lightweight Swin Transformer
class LightweightSwin(nn.Module):
    def __init__(self, num_classes=200, embed_dim=48, num_heads=4, stages=[2, 2, 2],
                 share_stem=True, share_msa=True):
        super(LightweightSwin, self).__init__()
        self.shared_stem = SharedConvStem(in_channels=3, embed_dim=embed_dim)
        self.shared_msa = MultiHeadSelfAttention(embed_dim, num_heads=num_heads)
        self.norm = nn.LayerNorm(embed_dim)
        self.head = nn.Linear(embed_dim, num_classes)

        # Create stages of blocks
        self.stages = nn.ModuleList()
        for stage_idx, num_blocks in enumerate(stages):
            blocks = []
            for block_idx in range(num_blocks):
                block = SharedMSABlock(
                    dim=embed_dim,
                    num_heads=num_heads,
                    shared_msa=self.shared_msa if share_msa else MultiHeadSelfAttention(embed_dim, num_heads)
                )
                blocks.append(block)
            self.stages.append(nn.ModuleList(blocks))

    def forward_features(self, x):
        # Shared conv stem
        x = self.shared_stem(x)  # [B, embed_dim, 16, 16]
        B, C, H, W = x.shape
        x = x.flatten(2).transpose(1, 2)  # [B, N, C] with N = H*W
        
        # Process stages with progressive token merging
        for stage_idx, blocks in enumerate(self.stages):
            for block in blocks:
                x = block(x)
            if stage_idx < len(self.stages) - 1:  # Merge tokens between stages (except after last)
                x = progressive_token_merge(x, merge_ratio=2)
        return x

    def forward(self, x):
        x = self.forward_features(x)
        x = x.mean(dim=1)  # Global average pooling
        x = self.norm(x)
        x = self.head(x)
        return x

#############################################
#       Multi-View Distillation Setup       #
#############################################

class MultiViewDistiller:
    def __init__(self, teacher_model, student_model, temperature=4.0, alpha=0.5):
        """
        teacher_model: A pretrained teacher in eval mode.
        student_model: The lightweight student model.
        temperature: Temperature for distillation.
        alpha: Weight factor between CE loss and KD loss.
        """
        self.teacher = teacher_model
        self.teacher.eval()
        self.student = student_model
        self.temperature = temperature
        self.alpha = alpha
        self.ce_loss = nn.CrossEntropyLoss()

    def label_smoothing_loss(self, student_logits, targets, smoothing=0.1):
        num_classes = student_logits.size(-1)
        log_probs = F.log_softmax(student_logits, dim=-1)
        with torch.no_grad():
            true_dist = torch.zeros_like(log_probs)
            true_dist.fill_(smoothing / (num_classes - 1))
            true_dist.scatter_(1, targets.unsqueeze(1), 1 - smoothing)
        return torch.mean(torch.sum(-true_dist * log_probs, dim=-1))

    def kd_loss(self, student_logits, teacher_logits):
        T = self.temperature
        student_probs = F.log_softmax(student_logits / T, dim=-1)
        teacher_probs = F.softmax(teacher_logits / T, dim=-1)
        loss_kd = F.kl_div(student_probs, teacher_probs, reduction='batchmean') * (T * T)
        return loss_kd

    def forward_loss(self, images, targets, K=2):
        # Obtain teacher predictions from K augmented views (here we simply reuse the same image)
        with torch.no_grad():
            teacher_logits_all = []
            for _ in range(K):
                teacher_logits = self.teacher(images)
                teacher_logits_all.append(teacher_logits)
            teacher_logits_avg = torch.mean(torch.stack(teacher_logits_all, dim=0), dim=0)
        # Student prediction
        student_logits = self.student(images)
        loss_kd = self.kd_loss(student_logits, teacher_logits_avg)
        loss_ce = self.label_smoothing_loss(student_logits, targets, smoothing=0.1)
        loss = self.alpha * loss_ce + (1 - self.alpha) * loss_kd
        return loss, loss_ce, loss_kd

#############################################
#            Training Loop                  #
#############################################

def train_distillation_epoch(distiller, train_loader, optimizer, device='cuda'):
    distiller.student.train()
    total_loss, total_ce, total_kd, total_samples = 0, 0, 0, 0
    for images, targets in train_loader:
        images, targets = images.to(device), targets.to(device)
        optimizer.zero_grad()
        loss, loss_ce, loss_kd = distiller.forward_loss(images, targets, K=2)
        loss.backward()
        optimizer.step()

        batch_size = images.size(0)
        total_loss += loss.item() * batch_size
        total_ce += loss_ce.item() * batch_size
        total_kd += loss_kd.item() * batch_size
        total_samples += batch_size
    return (total_loss / total_samples, total_ce / total_samples, total_kd / total_samples)

@torch.no_grad()
def evaluate(model, val_loader, device='cuda'):
    model.eval()
    correct, total = 0, 0
    for images, targets in val_loader:
        images, targets = images.to(device), targets.to(device)
        logits = model(images)
        preds = logits.argmax(dim=1)
        correct += (preds == targets).sum().item()
        total += targets.size(0)
    return correct / total

def train_distillation(teacher_model, student_model, train_loader, val_loader,
                       epochs=10, lr=1e-4, device='cuda'):
    teacher_model.to(device)
    teacher_model.eval()
    student_model.to(device)
    distiller = MultiViewDistiller(teacher_model, student_model, temperature=4.0, alpha=0.5)
    optimizer = optim.AdamW(student_model.parameters(), lr=lr)
    best_acc = 0.0

    for epoch in range(1, epochs + 1):
        loss, ce, kd = train_distillation_epoch(distiller, train_loader, optimizer, device)
        acc = evaluate(student_model, val_loader, device)
        if acc > best_acc:
            best_acc = acc
            torch.save(student_model.state_dict(), "best_lightweight_swin.pt")
        print(f"[Epoch {epoch}] Loss: {loss:.4f} (CE: {ce:.4f}, KD: {kd:.4f}), Val Acc: {acc:.4f}, Best: {best_acc:.4f}")

    print("Training complete. Best Accuracy:", best_acc)

#############################################
#                  Main                     #
#############################################

def main():

    # Set up the teacher model (using ResNet50 as an example)
    teacher_model = models.resnet50(pretrained=True)
    teacher_model.fc = nn.Linear(2048, 200)
    # Optionally load your trained teacher weights here:
    # teacher_model.load_state_dict(torch.load("teacher_resnet50_tiny_imagenet.pt"))
    teacher_model.eval()

    # Create the student model (lightweight Swin Transformer)
    student_model = LightweightSwin(num_classes=200, embed_dim=48, num_heads=4, stages=[2, 2, 2],
                                    share_stem=True, share_msa=True)
    
    
    # Print the number of trainable parameters.
    num_params = sum(p.numel() for p in student_model.parameters() if p.requires_grad)
    print(f"Number of trainable parameters: {num_params:,}")
    

    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    print(f"Training on device: {device}")

    # Train the student model with multi-view distillation
    train_distillation(teacher_model, student_model, train_loader, val_loader,
                       epochs=20, lr=1e-4, device=device)

if __name__ == "__main__":
    main()


Downloading: "https://download.pytorch.org/models/resnet50-0676ba61.pth" to C:\Users\m.badzohreh/.cache\torch\hub\checkpoints\resnet50-0676ba61.pth


KeyboardInterrupt: 