<a href="https://colab.research.google.com/github/PETEROA/Knowledge_Distillation/blob/main/Geo_Aware_KD.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

GEOMETRY-AWARE KNOWLEDGE DISTILLATION WITH OPTIMAL TRANSPORT:
Mathematical Framework:
- Optimal Transport (Sinkhorn Divergence)
- Riemannian Geometry (Fisher-Rao Metric)
- Information Theory (Adaptive Temperature)
- Multi-Scale Attention Transfer


In [None]:
gpu_info = !nvidia-smi
gpu_info = '\n'.join(gpu_info)
if gpu_info.find('failed') >= 0:
  print('Not connected to a GPU')
else:
  print(gpu_info)

Fri Dec 19 11:42:06 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.54.15              Driver Version: 550.54.15      CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  Tesla T4                       Off |   00000000:00:04.0 Off |                    0 |
| N/A   49C    P0             27W /   70W |     908MiB /  15360MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
                                                

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
import torchvision
import torchvision.transforms as transforms
import torchvision.models as models

import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm
import os
from typing import Tuple, List, Dict
from collections import OrderedDict

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)
if torch.cuda.is_available():
    torch.cuda.manual_seed(42)



# SECTION 2: OPTIMAL TRANSPORT LOSS (SINKHORN DIVERGENCE)


class SinkhornDistance(nn.Module):


    def __init__(self, epsilon=0.1, max_iter=100, reduction='mean'):
        super().__init__()
        self.epsilon = epsilon
        self.max_iter = max_iter
        self.reduction = reduction

    def _cost_matrix(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
        """Compute pairwise squared Euclidean distance matrix."""
        x_norm = (x ** 2).sum(1).view(-1, 1)
        y_norm = (y ** 2).sum(1).view(1, -1)
        dist = x_norm + y_norm - 2.0 * torch.mm(x, y.transpose(0, 1))
        return dist

    def forward(self, source: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
        """Compute Sinkhorn divergence between source and target."""
        # Flatten spatial dimensions if needed
        if source.dim() > 2:
            source = source.view(source.size(0), -1)
            target = target.view(target.size(0), -1)

        # Compute cost matrix
        C = self._cost_matrix(source, target)
        batch_size = source.size(0)

        # Initialize dual variables
        mu = torch.ones(batch_size, device=source.device) / batch_size
        nu = torch.ones(batch_size, device=target.device) / batch_size

        # Sinkhorn iterations
        K = torch.exp(-C / self.epsilon)
        u = torch.ones_like(mu)

        for _ in range(self.max_iter):
            v = nu / (K.t() @ u + 1e-8)
            u = mu / (K @ v + 1e-8)

        # Compute transport plan and cost
        pi = u.unsqueeze(1) * K * v.unsqueeze(0)
        cost = torch.sum(pi * C)

        if self.reduction == 'mean':
            cost = cost / batch_size

        return cost



# SECTION 3: RIEMANNIAN MANIFOLD REGULARIZATION (FISHER-RAO METRIC)


class RiemannianManifoldRegularizer(nn.Module):


    def __init__(self, epsilon=1e-8):
        super().__init__()
        self.epsilon = epsilon

    def _to_probability(self, x: torch.Tensor) -> torch.Tensor:
        """Convert features to probability distribution via softmax."""
        if x.dim() > 2:
            x = x.view(x.size(0), x.size(1), -1)
            x = F.softmax(x, dim=-1)
        else:
            x = F.softmax(x, dim=-1)
        return x

    def forward(self, teacher_feat: torch.Tensor, student_feat: torch.Tensor) -> torch.Tensor:
        """Compute Fisher-Rao distance between teacher and student features."""
        # Convert to probability distributions
        p = self._to_probability(teacher_feat)
        q = self._to_probability(student_feat)

        # Flatten for computation
        p = p.view(p.size(0), -1)
        q = q.view(q.size(0), -1)

        # Bhattacharyya coefficient: BC = ∑√(p_i·q_i)
        bc = torch.sum(torch.sqrt(p * q + self.epsilon), dim=1)

        # Fisher-Rao distance: d_FR = 2·arccos(BC)
        bc = torch.clamp(bc, -1.0 + self.epsilon, 1.0 - self.epsilon)
        fisher_rao = 2.0 * torch.acos(bc)

        return fisher_rao.mean()



# SECTION 4: MULTI-SCALE ATTENTION TRANSFER WITH ORTHOGONALITY

class MultiScaleAttentionTransfer(nn.Module):


    def __init__(self, ortho_weight=0.1):
        super().__init__()
        self.ortho_weight = ortho_weight

    def _compute_attention(self, features: torch.Tensor) -> torch.Tensor:
        """Compute spatial attention map from features."""
        # Channel-wise L2 norm
        attention = torch.sum(features ** 2, dim=1, keepdim=True)

        # Normalize to attention map
        b, _, h, w = attention.shape
        attention = attention.view(b, -1)
        attention = F.softmax(attention, dim=1)
        attention = attention.view(b, 1, h, w)

        return attention

    def _orthogonality_loss(self, attention: torch.Tensor) -> torch.Tensor:
        """Enforce orthogonality: ||A^T·A - I||_F²"""
        b, _, h, w = attention.shape
        attention_flat = attention.view(b, h * w)

        # Compute Gram matrix A^T·A
        gram = torch.matmul(attention_flat.t(), attention_flat)

        # Identity matrix
        identity = torch.eye(h * w, device=attention.device)

        # Frobenius norm of difference
        ortho_loss = torch.norm(gram - identity, p='fro') ** 2

        return ortho_loss / (h * w)

    def forward(self, teacher_feat: torch.Tensor, student_feat: torch.Tensor) -> torch.Tensor:
        """Compute attention transfer loss with orthogonality constraint."""
        # Compute attention maps
        teacher_att = self._compute_attention(teacher_feat)
        student_att = self._compute_attention(student_feat)

        # Attention transfer loss
        att_loss = torch.norm(teacher_att - student_att, p='fro') ** 2

        # Orthogonality regularization
        ortho_loss = self._orthogonality_loss(student_att)

        total_loss = att_loss + self.ortho_weight * ortho_loss

        return total_loss



# SECTION 5: ADAPTIVE TEMPERATURE SCHEDULING


class AdaptiveTemperatureScheduler(nn.Module):


    def __init__(self, T_init=4.0, alpha=0.5, T_min=1.0, T_max=10.0):
        super().__init__()
        self.T_init = T_init
        self.alpha = alpha
        self.T_min = T_min
        self.T_max = T_max

    def _compute_entropy(self, logits: torch.Tensor) -> torch.Tensor:
        """Compute Shannon entropy of prediction distribution."""
        probs = F.softmax(logits, dim=1)
        log_probs = F.log_softmax(logits, dim=1)
        entropy = -torch.sum(probs * log_probs, dim=1)
        return entropy.mean()

    def forward(self, teacher_logits: torch.Tensor) -> float:
        """Compute adaptive temperature based on teacher entropy."""
        entropy = self._compute_entropy(teacher_logits)

        # Temperature: T = T_0 · exp(-α · H)
        temperature = self.T_init * torch.exp(-self.alpha * entropy)

        # Clamp to valid range
        temperature = torch.clamp(temperature, self.T_min, self.T_max)

        return temperature.item()



# SECTION 6: FEATURE ADAPTATION LAYERS


class FeatureAdaptation(nn.Module):


    def __init__(self, student_channels: List[int], teacher_channels: List[int]):
        super().__init__()

        self.adaptations = nn.ModuleList([
            nn.Sequential(
                nn.Conv2d(s_ch, t_ch, kernel_size=1, bias=False),
                nn.BatchNorm2d(t_ch)
            )
            for s_ch, t_ch in zip(student_channels, teacher_channels)
        ])

    def forward(self, student_features: List[torch.Tensor]) -> List[torch.Tensor]:
        """Project student features to teacher dimensions."""
        adapted_features = [
            adapt(feat) for adapt, feat in zip(self.adaptations, student_features)
        ]
        return adapted_features



# SECTION 7: COMPLETE GEOMETRY-AWARE DISTILLATION LOSS


class GeometryAwareDistillationLoss(nn.Module):


    def __init__(
        self,
        lambda_ot=1.0,
        lambda_geo=0.5,
        lambda_att=2.0,
        sinkhorn_eps=0.1,
        sinkhorn_iter=100,
        ortho_weight=0.1
    ):
        super().__init__()

        self.lambda_ot = lambda_ot
        self.lambda_geo = lambda_geo
        self.lambda_att = lambda_att

        # Initialize components
        self.sinkhorn = SinkhornDistance(epsilon=sinkhorn_eps, max_iter=sinkhorn_iter)
        self.riemannian = RiemannianManifoldRegularizer()
        self.attention = MultiScaleAttentionTransfer(ortho_weight=ortho_weight)
        self.temp_scheduler = AdaptiveTemperatureScheduler()

    def _kl_divergence_loss(
        self,
        student_logits: torch.Tensor,
        teacher_logits: torch.Tensor,
        temperature: float
    ) -> torch.Tensor:
        """Standard KL divergence loss with temperature scaling."""
        student_log_probs = F.log_softmax(student_logits / temperature, dim=1)
        teacher_probs = F.softmax(teacher_logits / temperature, dim=1)

        kl_loss = F.kl_div(student_log_probs, teacher_probs, reduction='batchmean')

        # Scale by T² to maintain gradient magnitude
        return kl_loss * (temperature ** 2)

    def forward(
        self,
        student_logits: torch.Tensor,
        teacher_logits: torch.Tensor,
        student_features: List[torch.Tensor],
        teacher_features: List[torch.Tensor],
        labels: torch.Tensor = None,
        alpha: float = 0.7
    ) -> Tuple[torch.Tensor, Dict[str, float]]:

        # Adaptive temperature
        temperature = self.temp_scheduler(teacher_logits)

        # 1. Standard KL divergence
        loss_kd = self._kl_divergence_loss(student_logits, teacher_logits, temperature)

        # 2. Optimal transport (deepest features)
        loss_ot = self.sinkhorn(teacher_features[-1], student_features[-1])

        # 3. Riemannian regularization (intermediate features)
        loss_geo = sum([
            self.riemannian(t_feat, s_feat)
            for t_feat, s_feat in zip(teacher_features[:-1], student_features[:-1])
        ]) / len(teacher_features[:-1])

        # 4. Multi-scale attention transfer
        loss_att = sum([
            self.attention(t_feat, s_feat)
            for t_feat, s_feat in zip(teacher_features, student_features)
        ]) / len(teacher_features)

        # Combined distillation loss
        distillation_loss = (
            loss_kd +
            self.lambda_ot * loss_ot +
            self.lambda_geo * loss_geo +
            self.lambda_att * loss_att
        )

        # Add supervised loss if labels provided
        total_loss = distillation_loss
        if labels is not None:
            loss_ce = F.cross_entropy(student_logits, labels)
            total_loss = alpha * distillation_loss + (1 - alpha) * loss_ce

        # Loss breakdown
        loss_dict = {
            'total': total_loss.item(),
            'kd': loss_kd.item(),
            'ot': loss_ot.item(),
            'geo': loss_geo.item(),
            'att': loss_att.item(),
            'temperature': temperature
        }

        if labels is not None:
            loss_dict['ce'] = loss_ce.item()

        return total_loss, loss_dict



# SECTION 8: MODEL ARCHITECTURES


class FeatureExtractor(nn.Module):
    """Wrapper for extracting intermediate features using forward hooks."""

    def __init__(self, model: nn.Module, layer_names: List[str]):
        super().__init__()
        self.model = model
        self.layer_names = layer_names
        self.features = {}
        self._register_hooks()

    def _register_hooks(self):
        """Register forward hooks to capture intermediate activations."""
        def get_activation(name):
            def hook(module, input, output):
                self.features[name] = output
            return hook

        for name, module in self.model.named_modules():
            if name in self.layer_names:
                module.register_forward_hook(get_activation(name))

    def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, List[torch.Tensor]]:
        """Forward pass with feature extraction."""
        self.features.clear()
        logits = self.model(x)
        features = [self.features[name] for name in self.layer_names]
        return logits, features


def create_teacher_model(num_classes=10, pretrained=False):
    """Create ResNet-50 teacher model with feature extraction."""
    base_model = models.resnet50(pretrained=pretrained)

    # Modify final layer
    in_features = base_model.fc.in_features
    base_model.fc = nn.Linear(in_features, num_classes)

    # Feature extraction points
    layer_names = ['layer1', 'layer2', 'layer3', 'layer4']

    # Channel dimensions for ResNet-50
    teacher_channels = [256, 512, 1024, 2048]

    model = FeatureExtractor(base_model, layer_names)
    return model, layer_names, teacher_channels


def create_student_model(num_classes=10, width_mult=0.5):
    """Create compact ResNet-18 student model."""
    base_model = models.resnet18(pretrained=False)

    # Modify final layer
    in_features = base_model.fc.in_features
    base_model.fc = nn.Linear(in_features, num_classes)

    # Same layer names for alignment
    layer_names = ['layer1', 'layer2', 'layer3', 'layer4']

    # Channel dimensions for ResNet-18
    student_channels = [64, 128, 256, 512]

    model = FeatureExtractor(base_model, layer_names)
    return model, layer_names, student_channels


def count_parameters(model: nn.Module) -> int:
    """Count trainable parameters."""
    return sum(p.numel() for p in model.parameters() if p.requires_grad)



# SECTION 9: DATA LOADING


def get_cifar10_dataloaders(batch_size=128, num_workers=2, data_dir='./data'):
    """Create CIFAR-10 train and test dataloaders."""

    # Training transforms with augmentation
    train_transform = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.RandomRotation(15),
        transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
    ])

    # Test transforms
    test_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
    ])

    # Load datasets
    train_dataset = torchvision.datasets.CIFAR10(
        root=data_dir, train=True, download=True, transform=train_transform
    )

    test_dataset = torchvision.datasets.CIFAR10(
        root=data_dir, train=False, download=True, transform=test_transform
    )

    # Create dataloaders
    train_loader = DataLoader(
        train_dataset, batch_size=batch_size, shuffle=True,
        num_workers=num_workers, pin_memory=True
    )

    test_loader = DataLoader(
        test_dataset, batch_size=batch_size, shuffle=False,
        num_workers=num_workers, pin_memory=True
    )

    return train_loader, test_loader



# SECTION 10: TRAINING FUNCTIONS


def train_teacher(
    model, train_loader, test_loader,
    num_epochs=100, learning_rate=0.1,
    device='cuda', save_path='teacher_model.pth'
):
    """Train teacher model from scratch."""

    model = model.to(device)

    optimizer = optim.SGD(
        model.parameters(), lr=learning_rate,
        momentum=0.9, weight_decay=5e-4
    )

    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs)
    criterion = nn.CrossEntropyLoss()

    history = {'train_loss': [], 'train_acc': [], 'test_loss': [], 'test_acc': []}
    best_acc = 0.0

    print(f"\n{'='*60}")
    print(f"Training Teacher Model ({num_epochs} epochs)")
    print(f"{'='*60}")

    for epoch in range(num_epochs):
        # Training
        model.train()
        train_loss = 0.0
        correct = 0
        total = 0

        pbar = tqdm(train_loader, desc=f'Epoch {epoch+1}/{num_epochs}')
        for inputs, labels in pbar:
            inputs, labels = inputs.to(device), labels.to(device)

            optimizer.zero_grad()
            outputs, _ = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            train_loss += loss.item()
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()

            pbar.set_postfix({
                'loss': train_loss / (pbar.n + 1),
                'acc': 100. * correct / total
            })

        train_acc = 100. * correct / total
        train_loss = train_loss / len(train_loader)

        # Evaluation
        test_loss, test_acc = evaluate_model(model, test_loader, criterion, device)

        history['train_loss'].append(train_loss)
        history['train_acc'].append(train_acc)
        history['test_loss'].append(test_loss)
        history['test_acc'].append(test_acc)

        scheduler.step()

        if test_acc > best_acc:
            best_acc = test_acc
            torch.save(model.state_dict(), save_path)
            print(f'\n Best: {test_acc:.2f}%')

        print(f'Epoch {epoch+1}: Train={train_acc:.2f}%, Test={test_acc:.2f}%')

    print(f'\nTeacher training complete. Best: {best_acc:.2f}%')
    return history


def distill_student(
    teacher, student, train_loader, test_loader,
    student_channels, teacher_channels,
    num_epochs=200, learning_rate=0.1, alpha=0.7,
    device='cuda', save_path='student_model.pth'
):
    """Distill student using geometry-aware loss."""

    teacher = teacher.to(device)
    student = student.to(device)
    teacher.eval()

    # Create feature adaptation layers
    feature_adapt = FeatureAdaptation(student_channels, teacher_channels).to(device)

    # Optimizer for both student and adaptation layers
    optimizer = optim.SGD(
        list(student.parameters()) + list(feature_adapt.parameters()),
        lr=learning_rate, momentum=0.9, weight_decay=5e-4
    )

    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs)

    distill_criterion = GeometryAwareDistillationLoss(
        lambda_ot=1.0, lambda_geo=0.5, lambda_att=2.0
    ).to(device)

    history = {
        'train_loss': [], 'train_acc': [], 'test_loss': [], 'test_acc': [],
        'loss_components': {'kd': [], 'ot': [], 'geo': [], 'att': [], 'temperature': []}
    }

    best_acc = 0.0

    print(f"\n{'='*60}")
    print(f"Distilling Student Model ({num_epochs} epochs)")
    print(f"Distillation weight α={alpha:.2f}")
    print(f"{'='*60}")

    for epoch in range(num_epochs):
        student.train()
        feature_adapt.train()
        train_loss = 0.0
        correct = 0
        total = 0

        epoch_components = {k: [] for k in ['kd', 'ot', 'geo', 'att', 'temperature']}

        pbar = tqdm(train_loader, desc=f'Epoch {epoch+1}/{num_epochs}')
        for inputs, labels in pbar:
            inputs, labels = inputs.to(device), labels.to(device)

            optimizer.zero_grad()

            with torch.no_grad():
                teacher_logits, teacher_features = teacher(inputs)

            student_logits, student_features = student(inputs)

            # Adapt student features to match teacher dimensions
            adapted_features = feature_adapt(student_features)

            loss, loss_dict = distill_criterion(
                student_logits, teacher_logits,
                adapted_features, teacher_features,
                labels, alpha=alpha
            )

            loss.backward()
            optimizer.step()

            train_loss += loss.item()
            _, predicted = student_logits.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()

            for key in epoch_components:
                if key in loss_dict:
                    epoch_components[key].append(loss_dict[key])

            pbar.set_postfix({
                'loss': train_loss / (pbar.n + 1),
                'acc': 100. * correct / total,
                'T': loss_dict['temperature']
            })

        train_acc = 100. * correct / total
        train_loss = train_loss / len(train_loader)

        for key in epoch_components:
            history['loss_components'][key].append(np.mean(epoch_components[key]))

        test_loss, test_acc = evaluate_model(
            student, test_loader, nn.CrossEntropyLoss(), device
        )

        history['train_loss'].append(train_loss)
        history['train_acc'].append(train_acc)
        history['test_loss'].append(test_loss)
        history['test_acc'].append(test_acc)

        scheduler.step()

        if test_acc > best_acc:
            best_acc = test_acc
            torch.save(student.state_dict(), save_path)
            print(f'\nBest: {test_acc:.2f}%')

        print(f'Epoch {epoch+1}: Train={train_acc:.2f}%, Test={test_acc:.2f}%')

    print(f'\nStudent distillation complete. Best: {best_acc:.2f}%')
    return history


def evaluate_model(model, test_loader, criterion, device='cuda'):
    """Evaluate model on test set."""
    model.eval()
    test_loss = 0.0
    correct = 0
    total = 0

    with torch.no_grad():
        for inputs, labels in test_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs, _ = model(inputs)
            loss = criterion(outputs, labels)

            test_loss += loss.item()
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()

    test_loss = test_loss / len(test_loader)
    test_acc = 100. * correct / total

    return test_loss, test_acc



# SECTION 11: VISUALIZATION


def plot_training_history(history, save_path='training_history.png'):
    """Visualize training history."""
    fig, axes = plt.subplots(2, 2, figsize=(15, 10))

    # Loss
    ax = axes[0, 0]
    ax.plot(history['train_loss'], label='Train', linewidth=2)
    ax.plot(history['test_loss'], label='Test', linewidth=2)
    ax.set_xlabel('Epoch', fontsize=12)
    ax.set_ylabel('Loss', fontsize=12)
    ax.set_title('Loss vs. Epoch', fontsize=14, fontweight='bold')
    ax.legend()
    ax.grid(True, alpha=0.3)

    # Accuracy
    ax = axes[0, 1]
    ax.plot(history['train_acc'], label='Train', linewidth=2)
    ax.plot(history['test_acc'], label='Test', linewidth=2)
    ax.set_xlabel('Epoch', fontsize=12)
    ax.set_ylabel('Accuracy (%)', fontsize=12)
    ax.set_title('Accuracy vs. Epoch', fontsize=14, fontweight='bold')
    ax.legend()
    ax.grid(True, alpha=0.3)

    # Loss components
    if 'loss_components' in history:
        ax = axes[1, 0]
        components = history['loss_components']
        for key in ['kd', 'ot', 'geo', 'att']:
            if key in components and components[key]:
                ax.plot(components[key], label=key.upper(), linewidth=2)
        ax.set_xlabel('Epoch', fontsize=12)
        ax.set_ylabel('Loss Value', fontsize=12)
        ax.set_title('Loss Components', fontsize=14, fontweight='bold')
        ax.legend()
        ax.grid(True, alpha=0.3)

        # Temperature
        ax = axes[1, 1]
        if 'temperature' in components and components['temperature']:
            ax.plot(components['temperature'], linewidth=2, color='red')
            ax.set_xlabel('Epoch', fontsize=12)
            ax.set_ylabel('Temperature', fontsize=12)
            ax.set_title('Adaptive Temperature', fontsize=14, fontweight='bold')
            ax.grid(True, alpha=0.3)

    plt.tight_layout()
    plt.savefig(save_path, dpi=300, bbox_inches='tight')
    print(f"\n✓ Plot saved to {save_path}")
    plt.show()



# SECTION 12: QUICK DEMO & ANALYSIS


def quick_demo():
    """Run quick demonstration of all components."""

    print("\n" + "="*70)
    print("GEOMETRY-AWARE KNOWLEDGE DISTILLATION - QUICK DEMO")
    print("="*70)

    # Test data
    batch_size = 16
    num_classes = 10

    teacher_logits = torch.randn(batch_size, num_classes)
    student_logits = torch.randn(batch_size, num_classes)

    teacher_features = [
        torch.randn(batch_size, 64, 56, 56),
        torch.randn(batch_size, 128, 28, 28),
        torch.randn(batch_size, 256, 14, 14),
        torch.randn(batch_size, 512, 7, 7)
    ]

    student_features = [
        torch.randn(batch_size, 64, 56, 56),
        torch.randn(batch_size, 128, 28, 28),
        torch.randn(batch_size, 256, 14, 14),
        torch.randn(batch_size, 512, 7, 7)
    ]

    labels = torch.randint(0, num_classes, (batch_size,))

    # Initialize loss
    distill_loss = GeometryAwareDistillationLoss()

    # Compute loss
    total_loss, loss_dict = distill_loss(
        student_logits, teacher_logits,
        student_features, teacher_features,
        labels, alpha=0.7
    )

    print("\n Loss Components:")
    print(f"{'='*50}")
    for key, value in loss_dict.items():
        print(f"{key:12s}: {value:.6f}")

    print(f"\n All components working correctly!")
    print(f"{'='*70}\n")


def analyze_components():
    """Detailed analysis of geometric properties."""

    print("\n" + "="*70)
    print("MATHEMATICAL ANALYSIS OF COMPONENTS")
    print("="*70)

    # 1. Optimal Transport Analysis
    print("\n1️ OPTIMAL TRANSPORT (Sinkhorn Divergence)")
    print("-" * 50)

    dist1 = torch.randn(100, 2) * 0.5 + torch.tensor([1.0, 1.0])
    dist2 = torch.randn(100, 2) * 0.5 + torch.tensor([1.5, 1.5])

    sinkhorn = SinkhornDistance(epsilon=0.1, max_iter=100)
    ot_dist = sinkhorn(dist1, dist2).item()
    l2_dist = torch.mean((dist1 - dist2) ** 2).item()

    print(f"  OT Distance:  {ot_dist:.4f}")
    print(f"  L2 Distance:  {l2_dist:.4f}")
    print(f"  Ratio OT/L2:  {ot_dist/l2_dist:.4f}")
    print("   OT captures distribution geometry better than L2")

    # 2. Fisher-Rao Analysis
    print("\n2️  RIEMANNIAN MANIFOLD (Fisher-Rao Distance)")
    print("-" * 50)

    p = torch.softmax(torch.randn(1, 1, 10), dim=-1)
    q = torch.softmax(torch.randn(1, 1, 10), dim=-1)

    riemannian = RiemannianManifoldRegularizer()
    fr_dist = riemannian(p, q).item()
    eucl_dist = torch.norm(p - q).item()

    print(f"  Fisher-Rao:   {fr_dist:.4f}")
    print(f"  Euclidean:    {eucl_dist:.4f}")
    print("   Geodesic distance respects probability manifold")

    # 3. Attention Analysis
    print("\n3️  ATTENTION TRANSFER (Orthogonality Constraint)")
    print("-" * 50)

    features = torch.randn(4, 64, 14, 14)
    attention_module = MultiScaleAttentionTransfer(ortho_weight=0.1)
    att_map = attention_module._compute_attention(features)

    print(f"  Attention shape: {att_map.shape}")
    print(f"  Sum per sample:  {att_map.sum(dim=[1,2,3]).tolist()}")
    print("   Attention maps are valid probability distributions")

    # 4. Temperature Analysis
    print("\n4️ ADAPTIVE TEMPERATURE (Information Theory)")
    print("-" * 50)

    confident = torch.randn(10, 10)
    confident[0, 0] += 5.0  # Make first class very confident

    uncertain = torch.randn(10, 10) * 2.0  # All classes similar

    temp_scheduler = AdaptiveTemperatureScheduler()
    T_conf = temp_scheduler(confident)
    T_uncer = temp_scheduler(uncertain)

    print(f"  T (confident):  {T_conf:.4f}")
    print(f"  T (uncertain):  {T_uncer:.4f}")
    print("  Temperature adapts to teacher confidence")

    print("\n" + "="*70)
    print("Analysis complete!")
    print("="*70 + "\n")



# SECTION 13: MAIN EXECUTION

def main():
    """Main execution function."""

    print("\n" + "="*70)
    print("GEOMETRY-AWARE KNOWLEDGE DISTILLATION")
    print("Complete Implementation in Single Colab Notebook")
    print("="*70)

    # Quick demo
    quick_demo()

    # Mathematical analysis
    analyze_components()

    # Print model info
    print("\nMODEL ARCHITECTURE")
    print("="*70)

    teacher, _, teacher_channels = create_teacher_model(num_classes=10, pretrained=False)
    student, _, student_channels = create_student_model(num_classes=10, width_mult=0.5)

    teacher_params = count_parameters(teacher.model)
    student_params = count_parameters(student.model)

    print(f"\nTeacher (ResNet-50):")
    print(f"  Parameters: {teacher_params:,}")
    print(f"  Size: {teacher_params * 4 / (1024**2):.2f} MB")
    print(f"  Channels: {teacher_channels}")

    print(f"\nStudent (ResNet-18, 0.5× width):")
    print(f"  Parameters: {student_params:,}")
    print(f"  Size: {student_params * 4 / (1024**2):.2f} MB")
    print(f"  Channels: {student_channels}")

    print(f"\nCompression:")
    print(f"  Ratio: {teacher_params/student_params:.2f}×")
    print(f"  Reduction: {(1 - student_params/teacher_params) * 100:.1f}%")

    print("\n Feature Adaptation:")
    print(f"  Student features are projected to teacher dimensions")
    print(f"  Enables direct geometric comparison across architectures")

    print("\n" + "="*70)
    print("Setup complete! Ready to train.")
    print("="*70)


    print("\n")



# RUN DEMO


if __name__ == "__main__":
    main()

#  FULL TRAINING

# Load data
print("Loading CIFAR-10...")
train_loader, test_loader = get_cifar10_dataloaders(batch_size=128, num_workers=2)

# Create models
teacher, _, teacher_channels = create_teacher_model(num_classes=10)
student, _, student_channels = create_student_model(num_classes=10)

# Train teacher (shorter for demo)
print("\nTraining teacher...")
teacher_history = train_teacher(
    teacher, train_loader, test_loader,
    num_epochs=100,
    learning_rate=0.1,
    device=device
)

plot_training_history(teacher_history, 'teacher_history.png')

# Distill student
print("\nDistilling student...")
student_history = distill_student(
    teacher, student, train_loader, test_loader,
    student_channels, teacher_channels,  # Pass channel dimensions
    num_epochs=200,
    learning_rate=0.1,
    alpha=0.7,
    device=device
)

plot_training_history(student_history, 'student_history.png')

print("\nTraining complete!")


Using device: cuda

GEOMETRY-AWARE KNOWLEDGE DISTILLATION
Complete Implementation in Single Colab Notebook

GEOMETRY-AWARE KNOWLEDGE DISTILLATION - QUICK DEMO

 Loss Components:
total       : 39.285347
kd          : 0.738198
ot          : 0.000000
geo         : 0.000000
att         : 27.121498
temperature : 1.497221
ce          : 2.661700

 All components working correctly!


MATHEMATICAL ANALYSIS OF COMPONENTS

1️ OPTIMAL TRANSPORT (Sinkhorn Divergence)
--------------------------------------------------
  OT Distance:  0.0060
  L2 Distance:  0.7975
  Ratio OT/L2:  0.0075
   OT captures distribution geometry better than L2

2️  RIEMANNIAN MANIFOLD (Fisher-Rao Distance)
--------------------------------------------------
  Fisher-Rao:   0.1619
  Euclidean:    0.4864
   Geodesic distance respects probability manifold

3️  ATTENTION TRANSFER (Orthogonality Constraint)
--------------------------------------------------
  Attention shape: torch.Size([4, 1, 14, 14])
  Sum per sample:  [1.0000

Epoch 1/100: 100%|██████████| 391/391 [00:36<00:00, 10.86it/s, loss=4.62, acc=14.2]



 Best: 17.95%
Epoch 1: Train=14.16%, Test=17.95%


Epoch 2/100: 100%|██████████| 391/391 [00:37<00:00, 10.51it/s, loss=2.08, acc=22.1]



 Best: 26.94%
Epoch 2: Train=22.06%, Test=26.94%


Epoch 3/100:  24%|██▍       | 95/391 [00:09<00:29, 10.06it/s, loss=1.96, acc=27.1]