In [None]:
# Libreria di sistema
import kagglehub
from pathlib import Path
import os.path

# Installa dipendenze base
!pip install torch torchvision
!pip install pandas numpy matplotlib seaborn
!pip install scikit-learn pillow tqdm
!pip install jupyter notebook
!pip install tensorboard
!pip install timm

In [None]:
# Download dataset
path = kagglehub.dataset_download("rtlmhjbn/ip02-dataset")

print("Path to dataset files:", path)

# Implementazione del Modello, Teacher e Trainer di Distillazione

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.cuda.amp import GradScaler, autocast
import timm
from tqdm import tqdm
import numpy as np
from pathlib import Path
from sklearn.metrics import accuracy_score, balanced_accuracy_score, f1_score
import matplotlib.pyplot as plt
import json
from sklearn.metrics import confusion_matrix
import seaborn as sns

import torchvision.transforms as T
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
from torch.utils.data import WeightedRandomSampler

class IP102Classifier(nn.Module):
    
    def __init__(
        self,
        model_name: str = 'convnext_base.fb_in22k_ft_in1k',
        num_classes: int = 102,
        pretrained: bool = False,
        dropout: float = 0.2,
        drop_path_rate: float = 0.2
    ):
        super().__init__()
        
        self.model_name = model_name
        
        try:
            self.backbone = timm.create_model(
                model_name,
                pretrained=pretrained,
                num_classes=0,
                global_pool='avg',
                drop_path_rate=drop_path_rate
            )
        except:
            self.backbone = timm.create_model(
                model_name,
                pretrained=pretrained,
                num_classes=0,
                global_pool='avg'
            )
        
        if hasattr(self.backbone, "classifier") and hasattr(self.backbone.classifier, "in_features"):
            num_features = self.backbone.classifier.in_features
        elif hasattr(self.backbone, "feature_info"):
            num_features = self.backbone.feature_info[-1]["num_chs"]
        else:
            num_features = getattr(self.backbone, "num_features")
            
        if num_features is None:
            with torch.no_grad():
                if 'swin' in model_name.lower():
                    dummy = torch.randn(1, 3, 224, 224)
                elif 'efficientnet' in model_name.lower():
                    dummy = torch.randn(1, 3, 384, 384)
                else:
                    dummy = torch.randn(1, 3, 224, 224)
                
                self.backbone.eval()
                out = self.backbone(dummy)
                num_features = out.shape[1]
                self.backbone.train()
        
        # Classifier in base all'architettura
        if 'convnext' in model_name.lower():
            self.classifier = nn.Sequential(
                nn.LayerNorm(num_features),
                nn.Linear(num_features, num_classes)
            )
        elif 'swin' in model_name.lower():
            self.classifier = nn.Sequential(
                nn.LayerNorm(num_features),
                nn.Dropout(dropout),
                nn.Linear(num_features, num_classes)
            )
        elif 'efficientnet' in model_name.lower():
            self.classifier = nn.Sequential(
                nn.Dropout(0.3),
                nn.Linear(num_features, 512),
                nn.ReLU(inplace=True),
                nn.Dropout(0.2),
                nn.Linear(512, num_classes)
            )
        elif 'mobilenet' in model_name.lower():
            self.classifier = nn.Sequential(
                nn.Dropout(dropout),
                nn.Linear(num_features, num_classes)
            )
        else:
            self.classifier = nn.Sequential(
                nn.LayerNorm(num_features),
                nn.Linear(num_features, num_classes)
            )
        
        self.num_features = num_features
        
        # Verifica 
        self.eval()
        with torch.no_grad():
            if 'swin' in model_name.lower():
                test_input = torch.randn(2, 3, 224, 224)
            elif 'efficientnet' in model_name.lower():
                test_input = torch.randn(2, 3, 384, 384)
            else:
                test_input = torch.randn(2, 3, 224, 224)
            
            test_features = self.backbone.forward_features(test_input)
            print(f"   Raw feature shape from backbone: {test_features.shape}")
            
            test_output = self(test_input)
            
            print(f"   Output shape:  {test_output.shape}")
            
            # Verifica che l'output sia corretto
            assert test_output.shape == (2, num_classes), \
                f"Output shape {test_output.shape} != expected (2, {num_classes})"
        
        self.train()
        print(f" {model_name}: {num_features} features - Verified OK!")
    
    def forward(self, x, return_features=False):
        features = self.backbone.forward_features(x)

        # SWIN: NHWC â†’ NCHW 
        if features.dim() == 4:
            # Swin: [B, H, W, C]
            if features.shape[-1] > features.shape[1]:
                features = features.permute(0, 3, 1, 2).contiguous()
            # altrimenti Ã¨ giÃ  [B, C, H, W]

        # Global Average Pooling
        features_pooled = F.adaptive_avg_pool2d(features, (1, 1))
        features_pooled = features_pooled.flatten(1)

        logits = self.classifier(features_pooled)

        if return_features:
            return logits, features
        return logits


# ENSEMBLE TEACHER (per soft targets)

class EnsembleTeacher:
    # Ensemble di modelli per generare soft targets
    
    def __init__(self, models_config, device='cuda'):
        self.device = device
        self.models = []
        self.weights = []
        
        print("\n Caricamento Ensemble Teacher per soft targets")
        
        for config in models_config:
            print(f" {config['name']}")
            
            model = IP102Classifier(
                model_name=config['model_name'],
                num_classes=102,
                pretrained=False
            )
            
            checkpoint = torch.load(config['checkpoint_path'], map_location=device, weights_only=False)
            model.load_state_dict(checkpoint['model_state_dict'])
            model = model.to(device)
            model.eval()
            
            self.models.append(model)
            self.weights.append(config.get('weight', 1.0))
        
        # Normalizza pesi
        total_weight = sum(self.weights)
        self.weights = [w / total_weight for w in self.weights]
        
        print(f"Ensemble pronto: {len(self.models)} modelli")
        print(f"Pesi: {[f'{w:.3f}' for w in self.weights]}")
    
    @torch.no_grad()
    def get_logits(self, images):
        logits_ens = None
        for w, model in zip(self.weights, self.models):
            logits = model(images)
            logits_ens = logits * w if logits_ens is None else logits_ens + w * logits
        return logits_ens


# FEATURE TEACHER (per feature distillation)
class FeatureTeacher:
    # Singolo modello per estrarre feature intermedie 
    
    def __init__(self, model_config, device='cuda'):
        self.device = device
        
        print(f"\n Caricamento Feature Teacher: {model_config['name']}")
        
        self.model = IP102Classifier(
            model_name=model_config['model_name'],
            num_classes=102,
            pretrained=False
        )
        
        checkpoint = torch.load(model_config['checkpoint_path'], map_location=device, weights_only=False)
        self.model.load_state_dict(checkpoint['model_state_dict'])
        self.model = self.model.to(device)
        self.model.eval()
        
        self.num_features = self.model.num_features
        
        print(f"Feature teacher pronto")
        print(f"Features: {self.num_features}")
    
    @torch.no_grad()
    def get_features(self, images):
        # Estrae feature intermedie dal teacher.
        # Restituisce sempre feature in formato [B, C, H, W] per la distillation.

        features = self.model.backbone.forward_features(images)
        
        # Converti sempre a formato 4D [B, C, H, W]
        if len(features.shape) == 4:  
            return features
        elif len(features.shape) == 3:
            B, N, C = features.shape
            H = W = int(N ** 0.5)
            if H * W == N:
                # Reshape a 2D feature map
                features = features.transpose(1, 2).reshape(B, C, H, W)
            else:
                # Fallback: tratta come 1x1 feature map
                features = features.mean(dim=1).unsqueeze(-1).unsqueeze(-1)
        else:  # [B, C] 
            features = features.unsqueeze(-1).unsqueeze(-1)  # [B, C, 1, 1]
        
        return features

# IMPROVED GROUP CONVOLUTION MAPPING LAYER


class GroupConvMappingLayer(nn.Module):
    def __init__(
        self,
        student_channels,
        teacher_channels,
        inner_channels=None,
        conv1_groups=16,
        conv2_groups=4,
        kernel_size=3,
        use_residual=True,
        dropout=0.1
    ):
        super().__init__()
        
        if inner_channels is None:
            inner_channels = student_channels // 2
            inner_channels = max(256, min(768, inner_channels))
        
        self.student_channels = student_channels
        self.teacher_channels = teacher_channels
        self.use_residual = use_residual and (student_channels == teacher_channels)
        
        padding = kernel_size // 2
        
        self.conv1 = nn.Conv2d(
            student_channels,
            inner_channels,
            kernel_size=kernel_size,
            padding=padding,
            groups=conv1_groups,
            bias=False
        )
        self.bn1 = nn.BatchNorm2d(inner_channels)
        self.relu = nn.ReLU(inplace=True)
        self.dropout1 = nn.Dropout2d(dropout)
        
        self.conv2 = nn.Conv2d(
            inner_channels,
            teacher_channels,
            kernel_size=kernel_size,
            padding=padding,
            groups=conv2_groups,
            bias=False
        )
        self.bn2 = nn.BatchNorm2d(teacher_channels)
        
        if self.use_residual:
            self.skip = nn.Identity()
        elif student_channels != teacher_channels:
            self.skip = nn.Conv2d(student_channels, teacher_channels, 1, bias=False)
        
        self._initialize_weights()
        
        self.num_params = sum(p.numel() for p in self.parameters())
        
        print(f" Improved Mapping Layer:")
        print(f" {student_channels} â†’ {inner_channels} â†’ {teacher_channels}")
        print(f" Compression ratio: {inner_channels/student_channels:.2%}")
        print(f" Groups: {conv1_groups}, {conv2_groups}")
        print(f" Residual: {self.use_residual}")
        print(f" Dropout: {dropout}")
        print(f" Params: {self.num_params:,}")
    
    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.xavier_uniform_(m.weight)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
    
    def forward(self, x):
        identity = x
        
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        out = self.dropout1(out)
        
        out = self.conv2(out)
        out = self.bn2(out)
        
        if self.use_residual:
            out = out + identity
        elif hasattr(self, 'skip') and not isinstance(self.skip, nn.Identity):
            out = out + self.skip(identity)
        
        return out



# DEBUG
def tensor_stats(name, x):
    return {
        'name': name,
        'shape': list(x.shape),
        'min': x.min().item(),
        'max': x.max().item(),
        'mean': x.mean().item(),
        'std': x.std().item(),
        'l2_norm': x.norm(p=2).item()
    }



# HYBRID KNOWLEDGE DISTILLATION TRAINER


class HybridDistillationTrainer:
    
    def __init__(
        self,
        student_model,
        ensemble_teacher,
        feature_teacher,
        train_loader,
        val_loader,
        test_loader,
        device='cuda',
        save_dir='./hybrid_distillation',
        temperature=10.0,
        gamma_feature_start=0.1,
        gamma_feature_end=0.5,
        gamma_soft=0.6,
        feature_loss_type='cosine',
        use_mixup=True,
        mixup_alpha=0.2,
        use_ema=True,
        ema_decay=0.999,
        inner_channels=None,
        conv1_groups=8,
        conv2_groups=4,
        kernel_size=3,
        mapping_dropout=0.1,
        learning_rate=1e-3,
        weight_decay=0.01,
        num_epochs=50,
        warmup_epochs=10,
        use_amp=True,
        gradient_clip=1.0
    ):
        self.student = student_model.to(device)
        self.ensemble_teacher = ensemble_teacher
        self.feature_teacher = feature_teacher
        
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.test_loader = test_loader
        self.device = device
        
        self.temperature = temperature
        self.gamma_feature_start = gamma_feature_start
        self.gamma_feature_end = gamma_feature_end
        self.gamma_soft = gamma_soft
        self.current_gamma_feature = gamma_feature_start
        self.feature_loss_type = feature_loss_type
        
        self.use_mixup = use_mixup
        self.mixup_alpha = mixup_alpha
    
        self.use_ema = use_ema
        if use_ema:
            self.ema_model = self._create_ema_model()
            self.ema_decay = ema_decay
        
        student_channels = self.student.num_features
        teacher_channels = self.feature_teacher.num_features
        
        print(f"\n Creazione improved mapping layer")
        self.mapping_layer = GroupConvMappingLayer(
            student_channels=student_channels,
            teacher_channels=teacher_channels,
            inner_channels=inner_channels,
            conv1_groups=conv1_groups,
            conv2_groups=conv2_groups,
            kernel_size=kernel_size,
            dropout=mapping_dropout
        ).to(device)
        
        self.optimizer = optim.AdamW(
            list(self.student.parameters()) + list(self.mapping_layer.parameters()),
            lr=learning_rate,
            weight_decay=weight_decay
        )
        
        self.warmup_epochs = warmup_epochs
        self.num_epochs = num_epochs
        
        self.warmup_scheduler = optim.lr_scheduler.LinearLR(
            self.optimizer,
            start_factor=0.01,
            total_iters=self.warmup_epochs
        )
        
        self.main_scheduler = optim.lr_scheduler.CosineAnnealingLR(
            self.optimizer,
            T_max=num_epochs - warmup_epochs,
            eta_min=1e-6
        )
        
        self.use_amp = use_amp
        self.scaler = GradScaler(enabled=use_amp)
        self.gradient_clip = gradient_clip
        
        self.save_dir = Path(save_dir)
        self.save_dir.mkdir(parents=True, exist_ok=True)
        
        self.history = {
            'train_loss': [], 'train_loss_feature': [], 
            'train_loss_soft': [], 'train_loss_hard': [],
            'val_loss': [], 'train_acc': [], 'val_acc': [], 
            'val_balanced_acc': [], 'val_f1': [],
            'gamma_feature': [],
            'test_acc': None, 'test_balanced_acc': None, 'test_f1': None
        }
        
        self.best_val_balanced_acc = 0.0
        
        print(f"\n Improved Hybrid Distillation Setup:")
        print(f"Temperature: {temperature}")
        print(f"Î³_feature: {gamma_feature_start} â†’ {gamma_feature_end} (adaptive)")
        print(f"Î³_soft:    {gamma_soft}")
        print(f"Feature loss: {feature_loss_type}")
        print(f"Mixup: {use_mixup} (Î±={mixup_alpha})")
        print(f"EMA: {use_ema} (decay={ema_decay})")
        print(f"Gradient clip: {gradient_clip}")
        print(f"LR: {learning_rate}, Epochs: {num_epochs}, Warmup: {warmup_epochs}")
        print(f"Student params: {sum(p.numel() for p in student_model.parameters()) / 1e6:.2f}M")
        print(f"Mapping params: {self.mapping_layer.num_params / 1e6:.2f}M")
    
    def _create_ema_model(self):
        ema_student = IP102Classifier(
            model_name=self.student.model_name,
            num_classes=102,
            pretrained=False
        ).to(self.device)
        ema_student.load_state_dict(self.student.state_dict())
        ema_student.eval()
        return ema_student
    
    def _update_ema(self):
        if not self.use_ema:
            return
        with torch.no_grad():
            for ema_param, param in zip(self.ema_model.parameters(), 
                                        self.student.parameters()):
                ema_param.data.mul_(self.ema_decay).add_(
                    param.data, alpha=1 - self.ema_decay
                )
    
    def _mixup_data(self, x, y):
        if not self.use_mixup:
            return x, y, None, None
        
        lam = np.random.beta(self.mixup_alpha, self.mixup_alpha)
        batch_size = x.size(0)
        index = torch.randperm(batch_size, device=x.device)
        
        mixed_x = lam * x + (1 - lam) * x[index]
        y_a, y_b = y, y[index]
        
        return mixed_x, y_a, y_b, lam
    
    def _update_gamma_feature(self, epoch):
        if epoch <= self.warmup_epochs:
            progress = epoch / self.warmup_epochs
            self.current_gamma_feature = (
                self.gamma_feature_start + 
                (self.gamma_feature_end - self.gamma_feature_start) * progress
            )
        else:
            self.current_gamma_feature = self.gamma_feature_end
        return self.current_gamma_feature
    
    def compute_distillation_loss(self, images, labels, y_a=None, y_b=None, lam=None):
        student_logits, student_features = self.student(images, return_features=True)
        
        # 1. FEATURE DISTILLATION
        with torch.no_grad():
            teacher_features = self.feature_teacher.get_features(images)
        
        if student_features.shape[2:] != teacher_features.shape[2:]:
            student_features = F.interpolate(
                student_features,
                size=teacher_features.shape[2:],
                mode='bilinear',
                align_corners=False
            )
        
        mapped_student_features = self.mapping_layer(student_features)
        
        if self.feature_loss_type == 'cosine':
            teacher_norm = F.normalize(teacher_features, p=2, dim=1)
            student_norm = F.normalize(mapped_student_features, p=2, dim=1)
            loss_feature = 1 - F.cosine_similarity(
                student_norm.flatten(1), 
                teacher_norm.flatten(1), 
                dim=1
            ).mean()
        else:
            B, C, H, W = teacher_features.shape
            loss_feature = F.mse_loss(
                mapped_student_features, 
                teacher_features, 
                reduction='sum'
            ) / (H * W * B)
        
        # 2. SOFT TARGETS DISTILLATION
        with torch.no_grad():
            teacher_logits = self.ensemble_teacher.get_logits(images)
        
        teacher_probs = F.softmax(teacher_logits / self.temperature, dim=1)
        student_log_probs = F.log_softmax(student_logits / self.temperature, dim=1)
        
        loss_soft = F.kl_div(
            student_log_probs,
            teacher_probs,
            reduction='batchmean'
        ) * (self.temperature ** 2)
        
        # 3. HARD TARGETS
        if lam is not None:
            loss_hard = (
                lam * F.cross_entropy(student_logits, y_a) +
                (1 - lam) * F.cross_entropy(student_logits, y_b)
            )
        else:
            loss_hard = F.cross_entropy(student_logits, labels)
        
        gamma_hard = 1.0 - (self.current_gamma_feature + self.gamma_soft) / 2.0
        
        total_loss = (
            self.current_gamma_feature * loss_feature +
            self.gamma_soft * loss_soft +
            gamma_hard * loss_hard
        )
        
        return (total_loss, loss_feature.item(), loss_soft.item(), 
                loss_hard.item(), student_features, mapped_student_features, 
                teacher_features)
    
    def train_epoch(self, epoch):
        self.student.train()
        self.mapping_layer.train()
        
        gamma_feat = self._update_gamma_feature(epoch)
        
        running_loss = 0.0
        running_loss_feature = 0.0
        running_loss_soft = 0.0
        running_loss_hard = 0.0
        correct = 0
        total = 0
        
        pbar = tqdm(self.train_loader, desc=f'Epoch {epoch}/{self.num_epochs}')
        
        for batch_idx, (images, labels) in enumerate(pbar):
            images = images.to(self.device, non_blocking=True)
            labels = labels.to(self.device, non_blocking=True)
            
            mixed_images, y_a, y_b, lam = self._mixup_data(images, labels)
            
            self.optimizer.zero_grad()
            
            with autocast(enabled=self.use_amp):
                if lam is not None:
                    loss, loss_feat, loss_soft, loss_hard, sf, msf, tf = \
                        self.compute_distillation_loss(mixed_images, labels, y_a, y_b, lam)
                else:
                    loss, loss_feat, loss_soft, loss_hard, sf, msf, tf = \
                        self.compute_distillation_loss(images, labels)
            
            # DEBUG: primo batch
            if epoch == 1 and batch_idx == 0:
                with torch.no_grad():
                    stats = [
                        tensor_stats("Student features", sf),
                        tensor_stats("Mapped student features", msf),
                        tensor_stats("Teacher features", tf)
                    ]
                    print(" FEATURE SCALE DEBUG (First Batch)")
                    for s in stats:
                        print(
                            f"{s['name']:30s} | "
                            f"shape={s['shape']} | "
                            f"min={s['min']:8.3f} max={s['max']:8.3f} | "
                            f"mean={s['mean']:8.3f} std={s['std']:8.3f} | "
                            f"L2={s['l2_norm']:10.2f}"
                        )
            
            if self.use_amp:
                self.scaler.scale(loss).backward()
                if self.gradient_clip > 0:
                    self.scaler.unscale_(self.optimizer)
                    torch.nn.utils.clip_grad_norm_(
                        list(self.student.parameters()) + 
                        list(self.mapping_layer.parameters()),
                        self.gradient_clip
                    )
                self.scaler.step(self.optimizer)
                self.scaler.update()
            else:
                loss.backward()
                if self.gradient_clip > 0:
                    torch.nn.utils.clip_grad_norm_(
                        list(self.student.parameters()) + 
                        list(self.mapping_layer.parameters()),
                        self.gradient_clip
                    )
                self.optimizer.step()
            
            self._update_ema()
            
            running_loss += loss.item() * images.size(0)
            running_loss_feature += loss_feat * images.size(0)
            running_loss_soft += loss_soft * images.size(0)
            running_loss_hard += loss_hard * images.size(0)
            
            with torch.no_grad():
                student_logits = self.student(images)
                preds = torch.argmax(student_logits, dim=1)
                correct += (preds == labels).sum().item()
                total += labels.size(0)
            
            pbar.set_postfix({
                'loss': f'{loss.item():.4f}',
                'Î³_feat': f'{gamma_feat:.3f}',
                'acc': f'{correct/total:.4f}'
            })
        
        epoch_loss = running_loss / len(self.train_loader.dataset)
        epoch_loss_feature = running_loss_feature / len(self.train_loader.dataset)
        epoch_loss_soft = running_loss_soft / len(self.train_loader.dataset)
        epoch_loss_hard = running_loss_hard / len(self.train_loader.dataset)
        epoch_acc = correct / total
        
        if epoch <= self.warmup_epochs:
            self.warmup_scheduler.step()
        else:
            self.main_scheduler.step()
        
        return epoch_loss, epoch_loss_feature, epoch_loss_soft, epoch_loss_hard, epoch_acc
    
    @torch.no_grad()
    def validate(self, epoch, use_ema=False):
        model = self.ema_model if (use_ema and self.use_ema) else self.student
        model.eval()
        
        running_loss = 0.0
        all_preds = []
        all_labels = []
        
        for images, labels in tqdm(self.val_loader, 
                                   desc=f'Validation {epoch} {"(EMA)" if use_ema else ""}'):
            images = images.to(self.device, non_blocking=True)
            labels = labels.to(self.device, non_blocking=True)
            
            logits = model(images)
            loss = F.cross_entropy(logits, labels)
            
            running_loss += loss.item() * images.size(0)
            preds = torch.argmax(logits, dim=1)
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
        
        val_loss = running_loss / len(self.val_loader.dataset)
        val_acc = accuracy_score(all_labels, all_preds)
        val_balanced_acc = balanced_accuracy_score(all_labels, all_preds)
        val_f1 = f1_score(all_labels, all_preds, average='macro')
        
        return val_loss, val_acc, val_balanced_acc, val_f1
    
    @torch.no_grad()
    def test(self, use_ema=False):
        print(f"TEST SET EVALUATION {'(EMA)' if use_ema else ''}")
        
        model = self.ema_model if (use_ema and self.use_ema) else self.student
        model.eval()
        
        all_preds = []
        all_labels = []
        running_loss = 0.0
        
        for images, labels in tqdm(self.test_loader, desc='Testing'):
            images = images.to(self.device, non_blocking=True)
            labels = labels.to(self.device, non_blocking=True)
            
            logits = model(images)
            loss = F.cross_entropy(logits, labels)
            
            running_loss += loss.item() * images.size(0)
            preds = torch.argmax(logits, dim=1)
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
        
        test_loss = running_loss / len(self.test_loader.dataset)
        test_acc = accuracy_score(all_labels, all_preds)
        test_balanced_acc = balanced_accuracy_score(all_labels, all_preds)
        test_f1 = f1_score(all_labels, all_preds, average='macro')
        
        print(f"\nTest Loss:        {test_loss:.4f}")
        print(f"Test Accuracy:    {test_acc:.4f}")
        print(f"Balanced Acc:     {test_balanced_acc:.4f}")
        print(f"F1 Macro:         {test_f1:.4f}")
        
        return test_loss, test_acc, test_balanced_acc, test_f1
    
    def train(self):
        print(" STARTING HYBRID DISTILLATION TRAINING")

        
        for epoch in range(1, self.num_epochs + 1):
            train_loss, loss_feat, loss_soft, loss_hard, train_acc = self.train_epoch(epoch)
            val_loss, val_acc, val_balanced_acc, val_f1 = self.validate(epoch, use_ema=False)
            
            if self.use_ema:
                ema_val_loss, ema_val_acc, ema_val_balanced_acc, ema_val_f1 = \
                    self.validate(epoch, use_ema=True)
            
            current_lr = self.optimizer.param_groups[0]['lr']
            
            self.history['train_loss'].append(train_loss)
            self.history['train_loss_feature'].append(loss_feat)
            self.history['train_loss_soft'].append(loss_soft)
            self.history['train_loss_hard'].append(loss_hard)
            self.history['val_loss'].append(val_loss)
            self.history['val_acc'].append(val_acc)
            self.history['train_acc'].append(train_acc)
            self.history['val_balanced_acc'].append(val_balanced_acc)
            self.history['val_f1'].append(val_f1)
            self.history['gamma_feature'].append(self.current_gamma_feature)
            
            print(f"\nEpoch {epoch}/{self.num_epochs}")
            print(f"Train Loss: {train_loss:.4f} "
                  f"(feat:{loss_feat:.4f} soft:{loss_soft:.4f} hard:{loss_hard:.4f})")
            print(f"Train Acc:  {train_acc:.4f}")
            print(f"Val Loss:   {val_loss:.4f}")
            print(f"Val Acc:    {val_acc:.4f} | Balanced: {val_balanced_acc:.4f} | F1: {val_f1:.4f}")
            
            if self.use_ema:
                print(f"EMA Val:    {ema_val_acc:.4f} | Balanced: {ema_val_balanced_acc:.4f} | F1: {ema_val_f1:.4f}")
            
            print(f"Î³_feature:  {self.current_gamma_feature:.3f}")
            print(f"LR:         {current_lr:.6e}")
            
            best_metric = ema_val_balanced_acc if self.use_ema else val_balanced_acc
            is_best = best_metric > self.best_val_balanced_acc
            
            if is_best:
                self.best_val_balanced_acc = best_metric
                self.save_checkpoint(epoch, is_best=True)
                print(f" Best model! Balanced Acc: {best_metric:.4f}")
            
            if epoch % 5 == 0:
                self.save_checkpoint(epoch, is_best=False)
        
        test_loss, test_acc, test_balanced_acc, test_f1 = self.test(use_ema=self.use_ema)
        self.history['test_acc'] = test_acc
        self.history['test_balanced_acc'] = test_balanced_acc
        self.history['test_f1'] = test_f1
        
        self.save_history()
        self.plot_training_curves()
        
        print("\n TRAINING COMPLETED")
        print(f"Best Val Balanced Acc: {self.best_val_balanced_acc:.4f}")
        print(f"Test Balanced Acc:     {test_balanced_acc:.4f}")
    
    def save_checkpoint(self, epoch, is_best=False):
        checkpoint = {
            'epoch': epoch,
            'student_state_dict': self.student.state_dict(),
            'mapping_layer_state_dict': self.mapping_layer.state_dict(),
            'optimizer_state_dict': self.optimizer.state_dict(),
            'best_val_balanced_acc': self.best_val_balanced_acc,
            'history': self.history
        }
        
        if self.use_ema:
            checkpoint['ema_state_dict'] = self.ema_model.state_dict()
        
        if is_best:
            path = self.save_dir / 'best_model.pth'
            torch.save(checkpoint, path)
            print(f" Best model saved: {path}")
        else:
            path = self.save_dir / f'checkpoint_epoch_{epoch}.pth'
            torch.save(checkpoint, path)
    
    def save_history(self):
        with open(self.save_dir / 'history.json', 'w') as f:
            json.dump(self.history, f, indent=2)
    
    def plot_training_curves(self):
        fig, axes = plt.subplots(2, 3, figsize=(18, 10))
        epochs = range(1, len(self.history['train_loss']) + 1)
        
        axes[0, 0].plot(epochs, self.history['train_loss_feature'], label='Feature Loss')
        axes[0, 0].plot(epochs, self.history['train_loss_soft'], label='Soft Loss')
        axes[0, 0].plot(epochs, self.history['train_loss_hard'], label='Hard Loss')
        axes[0, 0].set_title('Loss Components')
        axes[0, 0].legend()
        axes[0, 0].grid(alpha=0.3)
        
        axes[0, 1].plot(epochs, self.history['train_loss'], label='Train')
        axes[0, 1].plot(epochs, self.history['val_loss'], label='Val')
        axes[0, 1].set_title('Total Loss')
        axes[0, 1].legend()
        axes[0, 1].grid(alpha=0.3)
        
        axes[0, 2].plot(epochs, self.history['train_acc'], label='Train', linestyle='--')
        axes[0, 2].plot(epochs, self.history['val_acc'], label='Val')
        axes[0, 2].set_title('Accuracy')
        axes[0, 2].legend()
        axes[0, 2].grid(alpha=0.3)
        
        axes[1, 0].plot(epochs, self.history['val_balanced_acc'], label='Balanced Acc')
        axes[1, 0].plot(epochs, self.history['val_f1'], label='F1')
        axes[1, 0].set_title('Validation Metrics')
        axes[1, 0].legend()
        axes[1, 0].grid(alpha=0.3)
        
        axes[1, 1].plot(epochs, self.history['gamma_feature'], color='purple')
        axes[1, 1].set_title('Feature Distillation Weight (Î³_feature)')
        axes[1, 1].set_ylabel('Î³_feature')
        axes[1, 1].grid(alpha=0.3)
        
        gap = np.array(self.history['train_acc']) - np.array(self.history['val_acc'])
        axes[1, 2].plot(epochs, gap, color='red')
        axes[1, 2].set_title('Train-Val Gap (Overfitting Monitor)')
        axes[1, 2].set_ylabel('Train Acc - Val Acc')
        axes[1, 2].axhline(y=0, color='k', linestyle='--', alpha=0.3)
        axes[1, 2].grid(alpha=0.3)
        
        plt.tight_layout()
        plt.savefig(self.save_dir / 'training_curves.png', dpi=300)
        plt.close()
        print(f"Training curves saved")



# CONFUSION MATRIX E VISUALIZZAZIONE

def plot_confusion_matrix(model, test_loader, device, save_path, class_names=None):
    model.eval()
    all_preds = []
    all_labels = []

    print(" GENERATING CONFUSION MATRIX")
    
    with torch.no_grad():
        for images, labels in tqdm(test_loader, desc='Computing predictions'):
            images = images.to(device, non_blocking=True)
            labels = labels.to(device, non_blocking=True)
            
            logits = model(images)
            preds = torch.argmax(logits, dim=1)
            
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
    
    cm = confusion_matrix(all_labels, all_preds)
    cm_normalized = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
    
    fig, axes = plt.subplots(1, 2, figsize=(24, 10))
    
    sns.heatmap(cm, annot=False, fmt='d', cmap='Blues', 
                xticklabels=class_names if class_names else 'auto',
                yticklabels=class_names if class_names else 'auto',
                ax=axes[0], cbar_kws={'label': 'Count'})
    axes[0].set_title('Confusion Matrix (Counts)', fontsize=14, fontweight='bold')
    axes[0].set_xlabel('Predicted Label', fontsize=12)
    axes[0].set_ylabel('True Label', fontsize=12)
    
    sns.heatmap(cm_normalized, annot=False, fmt='.2f', cmap='Blues',
                xticklabels=class_names if class_names else 'auto',
                yticklabels=class_names if class_names else 'auto',
                ax=axes[1], cbar_kws={'label': 'Proportion'}, vmin=0, vmax=1)
    axes[1].set_title('Confusion Matrix (Normalized)', fontsize=14, fontweight='bold')
    axes[1].set_xlabel('Predicted Label', fontsize=12)
    axes[1].set_ylabel('True Label', fontsize=12)
    
    plt.tight_layout()
    plt.savefig(save_path, dpi=300, bbox_inches='tight')
    plt.close()
    
    print(f" Confusion matrix saved: {save_path}")
    
    per_class_acc = cm_normalized.diagonal()
    print(f"\nPer-class accuracy statistics:")
    print(f"  Mean:   {per_class_acc.mean():.4f}")
    print(f"  Median: {np.median(per_class_acc):.4f}")
    print(f"  Min:    {per_class_acc.min():.4f}")
    print(f"  Max:    {per_class_acc.max():.4f}")
    
    return cm, cm_normalized


def plot_confusion_matrix_20(model, test_loader, device, save_path, class_names=None):
    model.eval()
    all_preds = []
    all_labels = []
    
    print("\n" + "=" * 60)
    print("ðŸ“Š GENERATING CONFUSION MATRIX")
    print("=" * 60)
    
    with torch.no_grad():
        for images, labels in tqdm(test_loader, desc='Computing predictions'):
            images = images.to(device, non_blocking=True)
            labels = labels.to(device, non_blocking=True)
            
            logits = model(images)
            preds = torch.argmax(logits, dim=1)
            
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

    all_preds = np.array(all_preds)
    all_labels = np.array(all_labels)

    # Filtra solo le prime 20 classi
    mask_20 = all_labels < 20
    all_preds_20 = all_preds[mask_20]
    all_labels_20 = all_labels[mask_20]

    print(f" Filtering to first 20 classes:")
    print(f" Total samples: {len(all_labels)} â†’ {len(all_labels_20)} (first 20 classes)")
    
    cm = confusion_matrix(all_labels_20, all_preds_20, labels=np.arange(20))
    cm_normalized = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]

    # Nomi delle classi per le prime 20
    if class_names is not None:
        class_names_20 = class_names[:20]
    else:
        class_names_20 = [f'Class {i}' for i in range(20)]

    fig, axes = plt.subplots(1, 2, figsize=(20, 8))
    
    sns.heatmap(cm, annot=False, fmt='d', cmap='Blues', 
                xticklabels=class_names_20,
                yticklabels=class_names_20,
                ax=axes[0], cbar_kws={'label': 'Count'},
                annot_kws={'fontsize': 8})
    axes[0].set_title('Confusion Matrix (Counts) - First 20 Classes', fontsize=14, fontweight='bold')
    axes[0].set_xlabel('Predicted Label', fontsize=12)
    axes[0].set_ylabel('True Label', fontsize=12)
    axes[0].tick_params(axis='both', labelsize=9)
    
    sns.heatmap(cm_normalized, annot=False, fmt='.2f', cmap='Blues',
                xticklabels=class_names_20,
                yticklabels=class_names_20,
                ax=axes[1], cbar_kws={'label': 'Proportion'}, vmin=0, vmax=1,
                annot_kws={'fontsize': 8})
    axes[1].set_title('Confusion Matrix (Normalized)', fontsize=14, fontweight='bold')
    axes[1].set_xlabel('Predicted Label', fontsize=12)
    axes[1].set_ylabel('True Label', fontsize=12)
    axes[1].tick_params(axis='both', labelsize=9)
    
    plt.tight_layout()
    plt.savefig(save_path, dpi=300, bbox_inches='tight')
    plt.close()
    
    print(f" Confusion matrix (20 classes) saved: {save_path}")
    
    per_class_acc_20 = cm_normalized.diagonal()
    print(f"\nPer-class accuracy statistics (first 20 classes):")
    print(f"  Mean:   {per_class_acc_20.mean():.4f}")
    print(f"  Median: {np.median(per_class_acc_20):.4f}")
    print(f"  Min:    {per_class_acc_20.min():.4f}")
    print(f"  Max:    {per_class_acc_20.max():.4f}")

    worst_5_idx = np.argsort(per_class_acc_20)[:5]
    print(f"\n 5 Worst performing classes:")
    for idx in worst_5_idx:
        class_name = class_names_20[idx] if class_names_20 else f"Class {idx}"
        print(f"   {class_name}: {per_class_acc_20[idx]:.4f}")
    
    return cm, cm_normalized

# CONFRONTO MODELLI CON ISTOGRAMMI PER CLASSE

def compare_models_per_class_histogram(
    ensemble_teacher,
    distilled_model,
    test_loader,
    device,
    save_dir,
    class_names=None,
    num_classes_to_show=11
):

    print(" PER-CLASS MODEL COMPARISON")
    
    save_dir = Path(save_dir)
    save_dir.mkdir(parents=True, exist_ok=True)
    
    def compute_per_class_accuracy(model, is_ensemble=False):
        all_preds = []
        all_labels = []
        
        with torch.no_grad():
            for images, labels in tqdm(test_loader, desc='Evaluating'):
                images = images.to(device, non_blocking=True)
                labels = labels.to(device, non_blocking=True)
                
                if is_ensemble:
                    logits = model.get_logits(images)
                else:
                    logits = model(images)
                
                preds = torch.argmax(logits, dim=1)
                all_preds.extend(preds.cpu().numpy())
                all_labels.extend(labels.cpu().numpy())
        
        all_preds = np.array(all_preds)
        all_labels = np.array(all_labels)
        
        per_class_acc = []
        per_class_support = []
        
        for class_id in range(102):
            mask = all_labels == class_id
            support = mask.sum()
            per_class_support.append(support)
            
            if support > 0:
                class_acc = (all_preds[mask] == class_id).mean()
                per_class_acc.append(class_acc)
            else:
                per_class_acc.append(0.0)
        
        return np.array(per_class_acc), np.array(per_class_support)
    
    print("\n Evaluating Teacher (Ensemble)")
    teacher_acc, teacher_support = compute_per_class_accuracy(ensemble_teacher, is_ensemble=True)
    
    print(" Evaluating Distilled Student")
    distilled_acc, distilled_support = compute_per_class_accuracy(distilled_model, is_ensemble=False)
    

    print(" GLOBAL METRICS")
    print(f"\n{'Model':<25} {'Mean Acc':<12} {'Median Acc':<12} {'Min Acc':<12} {'Max Acc':<12}")
    print(f"{'Teacher (Ensemble)':<25} {teacher_acc.mean():<12.4f} {np.median(teacher_acc):<12.4f} {teacher_acc.min():<12.4f} {teacher_acc.max():<12.4f}")
    print(f"{'KD (Distilled)':<25} {distilled_acc.mean():<12.4f} {np.median(distilled_acc):<12.4f} {distilled_acc.min():<12.4f} {distilled_acc.max():<12.4f}")

    valid_classes = np.arange(len(distilled_acc))  # tutte le classi
    sorted_indices = np.argsort(distilled_acc[valid_classes])
    step = max(1, len(sorted_indices) // (num_classes_to_show - 1))
    selected_indices = sorted_indices[::step][:num_classes_to_show]
    selected_indices = selected_indices[np.argsort(distilled_acc[valid_classes[selected_indices]])]
    classes_to_show = valid_classes[selected_indices]
    
    if class_names is None:
        display_names = [f"Class {i}" for i in classes_to_show]
    else:
        display_names = [class_names[i] if i < len(class_names) else f"Class {i}" 
                        for i in classes_to_show]
    
    print(f"\n Selected {num_classes_to_show} classes (sorted by Non-distilled accuracy):")
    print(f"Class IDs: {classes_to_show.tolist()}")
    
    fig, ax = plt.subplots(figsize=(14, 7))
    
    x = np.arange(len(classes_to_show))
    width = 0.35
    
    teacher_values = teacher_acc[classes_to_show]
    distilled_values = distilled_acc[classes_to_show]
    
    color_teacher = '#1f77b4'
    color_distilled = '#2ca02c'
    
    bars1 = ax.bar(x - width/2, teacher_values, width, 
                   label='Teacher', color=color_teacher, edgecolor='black', linewidth=0.5)
    bars2 = ax.bar(x + width/2, distilled_values, width, 
                   label='KD (Improved)', color=color_distilled, edgecolor='black', linewidth=0.5)
    
    ax.set_xlabel('Class', fontsize=14)
    ax.set_ylabel('Accuracy', fontsize=14)
    ax.set_title('Model Comparison', fontsize=14)
    ax.set_xticks(x)
    ax.set_xticklabels(display_names, rotation=45, ha='right', fontsize=11)
    ax.set_ylim([0, 1.0])
    ax.set_yticks(np.arange(0, 1.1, 0.2))
    ax.grid(axis='y', alpha=0.3, linestyle='-', linewidth=0.5)
    ax.set_axisbelow(True)
    ax.legend(fontsize=11, loc='upper left')
    
    plt.tight_layout()
    plt.savefig(save_dir / 'per_class_comparison.png', dpi=300, bbox_inches='tight', facecolor='white')
    plt.close()
    print(f"Plot saved: {save_dir / 'per_class_comparison.png'}")
    
    fig, ax = plt.subplots(figsize=(10, 6))
    
    def get_global_metrics(model, is_ensemble=False):
        all_preds = []
        all_labels = []
        
        with torch.no_grad():
            for images, labels in test_loader:
                images = images.to(device, non_blocking=True)
                labels = labels.to(device, non_blocking=True)
                
                if is_ensemble:
                    logits = model.get_logits(images)
                else:
                    logits = model(images)
                
                preds = torch.argmax(logits, dim=1)
                all_preds.extend(preds.cpu().numpy())
                all_labels.extend(labels.cpu().numpy())
        
        all_preds = np.array(all_preds)
        all_labels = np.array(all_labels)
        
        acc = accuracy_score(all_labels, all_preds)
        bal_acc = balanced_accuracy_score(all_labels, all_preds)
        f1 = f1_score(all_labels, all_preds, average='macro')
        
        return [acc, bal_acc, f1]
    
    print("\n Computing global metrics")
    teacher_metrics = get_global_metrics(ensemble_teacher, is_ensemble=True)
    distilled_metrics = get_global_metrics(distilled_model, is_ensemble=False)
    
    metrics_names = ['Accuracy', 'Balanced Accuracy', 'F1 Macro']
    x_metrics = np.arange(len(metrics_names))
    
    bars1 = ax.bar(x_metrics - width/2, teacher_metrics, width, 
                   label='Teacher', color=color_teacher, edgecolor='black', linewidth=1.5)
    bars2 = ax.bar(x_metrics + width/2, distilled_metrics, width, 
                   label='KD (Improved)', color=color_distilled, edgecolor='black', linewidth=1.5)
    
    def add_values(bars):
        for bar in bars:
            height = bar.get_height()
            ax.text(bar.get_x() + bar.get_width()/2., height + 0.01,
                    f'{height:.3f}', ha='center', va='bottom', fontsize=10, fontweight='bold')
    
    add_values(bars1)
    add_values(bars2)
    
    ax.set_xlabel('Metrics', fontsize=12, fontweight='bold')
    ax.set_ylabel('Score', fontsize=12, fontweight='bold')
    ax.set_title('Global Metrics Comparison', fontsize=14, fontweight='bold')
    ax.set_xticks(x_metrics)
    ax.set_xticklabels(metrics_names, fontsize=11)
    ax.legend(fontsize=10, loc='lower right')
    ax.set_ylim([0, 1.1])
    ax.grid(axis='y', alpha=0.3, linestyle='--')
    
    plt.tight_layout()
    plt.savefig(save_dir / 'global_metrics_comparison.png', dpi=300, bbox_inches='tight')
    plt.close()
    print(f" Global metrics plot saved")
    
    results_dict = {
        'global_metrics': {
            'teacher': {
                'accuracy': float(teacher_metrics[0]),
                'balanced_accuracy': float(teacher_metrics[1]),
                'f1_macro': float(teacher_metrics[2]),
            },
            'distilled': {
                'accuracy': float(distilled_metrics[0]),
                'balanced_accuracy': float(distilled_metrics[1]),
                'f1_macro': float(distilled_metrics[2]),
            }
        }
    }
    
    with open(save_dir / 'detailed_comparison.json', 'w') as f:
        json.dump(results_dict, f, indent=2)
    
    print(f" Detailed results saved")
    print(" PER-CLASS COMPARISON COMPLETED")
    
    return results_dict

# MAIN

def main():
    CONFIG = {
        'train_dir': '/kaggle/input/ip02-dataset/classification/train',
        'val_dir': '/kaggle/input/ip02-dataset/classification/val',
        'test_dir': '/kaggle/input/ip02-dataset/classification/test',
        'save_dir': '/kaggle/working/hybrid_distillation_improved',
        
        'batch_size': 128,
        'img_size': 224,
        'num_workers': 4,
        'device': 'cuda' if torch.cuda.is_available() else 'cpu',
        
        'num_epochs': 30,
        'warmup_epochs': 10,
        'learning_rate': 5e-4,
        'weight_decay': 0.05,
        
        'temperature': 10.0,
        'gamma_feature_start': 0.1,
        'gamma_feature_end': 0.5,
        'gamma_soft': 0.6,
        'feature_loss_type': 'cosine',
        'use_mixup': True,
        'mixup_alpha': 0.2,
        'use_ema': False,
        'ema_decay': 0.999,
        'gradient_clip': 1.0,
        
        'inner_channels': None,
        'conv1_groups': 8,
        'conv2_groups': 4,
        'kernel_size': 3,
        'mapping_dropout': 0.1,
    }
    
    ensemble_config = [
        {
            'name': 'ConvNext-Base',
            'model_name': 'convnext_base.fb_in22k_ft_in1k',
            'checkpoint_path': '/kaggle/input/convnext-15-epoch/checkpoints/best_model.pth',
            'weight': 1.2
        },
        {
            'name': 'Swin-Base',
            'model_name': 'swin_base_patch4_window7_224.ms_in22k_ft_in1k',
            'checkpoint_path': '/kaggle/input/swin-base-25-epochs/best_model.pth',
            'weight': 1.3
        },
    ]
    
    feature_teacher_config = {
        'name': 'ConvNext-Base',
        'model_name': 'convnext_base.fb_in22k_ft_in1k',
        'checkpoint_path': '/kaggle/input/convnext-15-epoch/checkpoints/best_model.pth',
    }
    
    print(" HYBRID KNOWLEDGE DISTILLATION")
    
    mean = [0.485, 0.456, 0.406]
    std = [0.229, 0.224, 0.225]
    
    train_tf = T.Compose([
        T.RandomResizedCrop(CONFIG['img_size'], scale=(0.7, 1.0)),
        T.RandomHorizontalFlip(),
        T.RandomApply([T.ColorJitter(0.4, 0.4, 0.4, 0.1)], p=0.8),
        T.ToTensor(),
        T.Normalize(mean=mean, std=std),
        T.RandomErasing(p=0.15)
    ])
    
    val_tf = T.Compose([
        T.Resize(int(CONFIG['img_size'] * 1.1)),
        T.CenterCrop(CONFIG['img_size']),
        T.ToTensor(),
        T.Normalize(mean=mean, std=std)
    ])
    
    train_ds = datasets.ImageFolder(CONFIG['train_dir'], transform=train_tf)

    val_ds = datasets.ImageFolder(CONFIG['val_dir'], transform=val_tf)
    test_ds = datasets.ImageFolder(CONFIG['test_dir'], transform=val_tf)
    
    train_loader = DataLoader(train_ds, batch_size=CONFIG['batch_size'], 
                              shuffle=True, num_workers=CONFIG['num_workers'], 
                              pin_memory=True, drop_last=True)
    val_loader = DataLoader(val_ds, batch_size=CONFIG['batch_size'], 
                            shuffle=False, num_workers=CONFIG['num_workers'], 
                            pin_memory=True)
    test_loader = DataLoader(test_ds, batch_size=CONFIG['batch_size'], 
                             shuffle=False, num_workers=CONFIG['num_workers'], 
                             pin_memory=True)
    
    print(f"\n Dataset:")
    print(f"   Train: {len(train_ds)}, Val: {len(val_ds)}, Test: {len(test_ds)}")
    
    print(f"\n Creating Student (MobileNetV3-Large)...")
    student = IP102Classifier(
        model_name='mobilenetv3_large_100.miil_in21k_ft_in1k',
        num_classes=102,
        pretrained=True
    )
    
    ensemble_teacher = EnsembleTeacher(ensemble_config, device=CONFIG['device'])
    feature_teacher = FeatureTeacher(feature_teacher_config, device=CONFIG['device'])
    
    trainer = HybridDistillationTrainer(
        student_model=student,
        ensemble_teacher=ensemble_teacher,
        feature_teacher=feature_teacher,
        train_loader=train_loader,
        val_loader=val_loader,
        test_loader=test_loader,
        device=CONFIG['device'],
        save_dir=CONFIG['save_dir'],
        **{k: v for k, v in CONFIG.items() if k not in ['train_dir', 'val_dir', 
                                                          'test_dir', 'save_dir',
                                                          'batch_size', 'img_size',
                                                          'num_workers', 'device']}
    )
    
    trainer.train()

    # POST-TRAINING ANALYSIS


    print(" STARTING POST-TRAINING ANALYSIS")
    
    print("\n Loading best distilled model for analysis")
    best_checkpoint_path = Path(CONFIG['save_dir']) / 'best_model.pth'
    best_checkpoint = torch.load(
        best_checkpoint_path,
        map_location=CONFIG['device'],
        weights_only=False
    )
    
    if trainer.use_ema and 'ema_state_dict' in best_checkpoint:
        student.load_state_dict(best_checkpoint['ema_state_dict'])
        print(" Best EMA model loaded for analysis")
    else:
        student.load_state_dict(best_checkpoint['student_state_dict'])
        print(" Best student model loaded for analysis")
    
    student.eval()
    
    class_names = None
    
    plot_confusion_matrix(
        model=student,
        test_loader=test_loader,
        device=CONFIG['device'],
        save_path=Path(CONFIG['save_dir']) / 'confusion_matrix.png',
        class_names=class_names
    )

    plot_confusion_matrix_20(
        model=student,
        test_loader=test_loader,
        device=CONFIG['device'],
        save_path=Path(CONFIG['save_dir']) / 'confusion_matrix_20classes.png',
        class_names=class_names
    )
    
    baseline_checkpoint = '/kaggle/input/mobilenetv3-large-notdistilled/best_model (1).pth'
    
    comparison_results = compare_models_per_class_histogram(
        ensemble_teacher=ensemble_teacher,
        distilled_model=student,
        test_loader=test_loader,
        device=CONFIG['device'],
        save_dir=CONFIG['save_dir'],
        class_names=class_names,
        num_classes_to_show=11
    )
    
    print("\n DISTILLATION COMPLETED ")


if __name__ == "__main__":
    main()