In [None]:
# =========================================================
# DERM7PT DATA LOADER FOR TRAINING NOTEBOOKS
# =========================================================
# Run this code in your training notebooks to load the preprocessed Derm7pt dataset
# Make sure the preprocessing pipeline above has been executed first!

import os
import json
import joblib
import pandas as pd
import numpy as np
from torch.utils.data import Dataset, DataLoader
from PIL import Image
from torchvision import transforms

# =========================================================
# PATHS TO PREPROCESSED DATA
# =========================================================
PREPROCESSED_DIR = r"augmented"

TRAIN_CSV = os.path.join(PREPROCESSED_DIR, "train_metadata_final.csv")
VAL_CSV   = os.path.join(PREPROCESSED_DIR, "val_metadata_final.csv")
TEST_CSV  = os.path.join(PREPROCESSED_DIR, "test_metadata_final.csv")

INFO_PATH = os.path.join(PREPROCESSED_DIR, "preprocessing_info.json")

# =========================================================
# LOAD PREPROCESSED DATA
# =========================================================
print("Loading preprocessed Derm7pt data...")

# Load CSVs
train_df = pd.read_csv(TRAIN_CSV)
val_df   = pd.read_csv(VAL_CSV)
test_df  = pd.read_csv(TEST_CSV)

# Load preprocessing info
with open(INFO_PATH, "r") as f:
    preprocessing_info = json.load(f)

categorical_cols = preprocessing_info["categorical_cols"]
label_mapping = preprocessing_info["label_mapping"]

# Extract class names from label_mapping (sorted by label index)
class_names = [k for k, v in sorted(label_mapping.items(), key=lambda x: x[1])]

print(f"\n‚úÖ Training samples:   {len(train_df)}")
print(f"‚úÖ Validation samples: {len(val_df)}")
print(f"‚úÖ Test samples:       {len(test_df)}")
print(f"\nLabel mapping: {label_mapping}")
print(f"Class names: {class_names}")

# =========================================================
# EXTRACT FEATURES AND LABELS
# =========================================================
def extract_features(df):
    """Extract image paths, metadata features, and labels from dataframe"""
    img_paths = df["ImagePath"].values
    labels = df["label"].values
    
    # Metadata features (all columns except ImagePath and label)
    metadata_cols = [col for col in df.columns if col not in ["ImagePath", "label"]]
    metadata = df[metadata_cols].values
    
    return img_paths, metadata, labels

X_train_img, X_train_meta, y_train = extract_features(train_df)
X_val_img, X_val_meta, y_val       = extract_features(val_df)
X_test_img, X_test_meta, y_test    = extract_features(test_df)

num_classes = len(label_mapping)
print(f"\nNumber of classes: {num_classes}")

# =========================================================
# PYTORCH DATASET CLASS
# =========================================================
class Derm7ptDataset(Dataset):
    """
    Custom Dataset for Derm7pt with images + metadata
    """
    def __init__(self, img_paths, metadata, labels, transform=None):
        self.img_paths = img_paths
        self.metadata = metadata
        self.labels = labels
        self.transform = transform
    
    def __len__(self):
        return len(self.labels)
    
    def __getitem__(self, idx):
        # Load image
        img_path = self.img_paths[idx]
        try:
            image = Image.open(img_path).convert("RGB")
        except Exception as e:
            # Fallback to black image if loading fails
            print(f"Warning: Failed to load {img_path}, using placeholder")
            image = Image.new("RGB", (224, 224), color="black")
        
        if self.transform:
            image = self.transform(image)
        
        # Get metadata and label
        metadata = self.metadata[idx].astype(np.float32)
        label = int(self.labels[idx])
        
        return image, metadata, label

# =========================================================
# DATA TRANSFORMS
# =========================================================
# Training transforms (with augmentation)
train_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomVerticalFlip(p=0.5),
    transforms.RandomRotation(30),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Validation/Test transforms (no augmentation)
val_test_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# =========================================================
# CREATE DATASETS
# =========================================================
train_dataset = Derm7ptDataset(X_train_img, X_train_meta, y_train, transform=train_transform)
val_dataset   = Derm7ptDataset(X_val_img, X_val_meta, y_val, transform=val_test_transform)
test_dataset  = Derm7ptDataset(X_test_img, X_test_meta, y_test, transform=val_test_transform)

print(f"\n‚úÖ Created PyTorch Datasets")
print(f"   - Train: {len(train_dataset)} samples")
print(f"   - Val:   {len(val_dataset)} samples")
print(f"   - Test:  {len(test_dataset)} samples")

# =========================================================
# CREATE DATALOADERS (EXAMPLE - ADJUST BATCH SIZE AS NEEDED)
# =========================================================
BATCH_SIZE = 32

train_loader = DataLoader(
    train_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=0,  # Set to 0 for Windows, increase for Linux/Mac
    pin_memory=True)

val_loader = DataLoader(
    val_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=0,
    pin_memory=True
)

test_loader = DataLoader(
    test_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=0,
    pin_memory=True
)

print(f"\n‚úÖ Created DataLoaders (batch_size={BATCH_SIZE})")
print(f"   - Train batches: {len(train_loader)}")
print(f"   - Val batches:   {len(val_loader)}")
print(f"   - Test batches:  {len(test_loader)}")

# =========================================================
# EXAMPLE: TEST LOADING A BATCH
# =========================================================
print("\nüîç Testing batch loading...")
for images, metadata, labels in train_loader:
    print(f"   - Image batch shape:    {images.shape}")
    print(f"   - Metadata batch shape: {metadata.shape}")
    print(f"   - Labels batch shape:   {labels.shape}")
    break

print("\n‚úÖ Derm7pt data loading complete! Ready for training.")
print("\nüí° Usage in your model:")
print("   for images, metadata, labels in train_loader:")
print("       # images: torch.Tensor of shape (batch_size, 3, 224, 224)")
print("       # metadata: torch.Tensor of shape (batch_size, num_metadata_features)")
print("       # labels: torch.Tensor of shape (batch_size,)")
print("       # Your training code here...")

In [None]:
input_dim_meta = X_train_meta.shape[1]


In [None]:
import torch
from torch.optim.lr_scheduler import ReduceLROnPlateau
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm

class EarlyStopping:
    def __init__(self, patience=7, verbose=False, delta=0, path='checkpoint.pt'):
        self.patience = patience
        self.verbose = verbose
        self.delta = delta
        self.path = path
        self.counter = 0
        self.best_score = None
        self.early_stop = False
        self.val_accuracy_max = -np.Inf
        
    def __call__(self, val_acc, model):
        score = val_acc
        if self.best_score is None:
            self.best_score = score
            self.save_checkpoint(val_acc, model)
        elif score <= self.best_score + self.delta:
            self.counter += 1
            if self.verbose:
                print(f'EarlyStopping counter: {self.counter} out of {self.patience}')
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_score = score
            self.save_checkpoint(val_acc, model)
            self.counter = 0

    def save_checkpoint(self, val_acc, model):
        if self.verbose:
            print(f'Validation accuracy increased ({self.val_accuracy_max:.6f} --> {val_acc:.6f}). Saving model...')
        torch.save(model.state_dict(), self.path)
        self.val_accuracy_max = val_acc

from PIL import Image
from sklearn.metrics import classification_report, confusion_matrix, accuracy_score

def test(model, loader, device):
    model.eval()
    all_preds = []
    all_labels = []

    with torch.no_grad():
        for imgs, metas, labels in loader:
            imgs, metas, labels = imgs.to(device), metas.to(device), labels.to(device)
            outputs = model(imgs, metas)
            _, predicted = torch.max(outputs.data, 1)
            all_preds.extend(predicted.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

    return all_labels, all_preds

def train(model, train_loader, optimizer, criterion, device):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    
    pbar = tqdm(train_loader, desc='Training')
    for images, meta, labels in pbar:
        images, meta, labels = images.to(device), meta.to(device), labels.to(device)
        
        optimizer.zero_grad()
        outputs = model(images, meta)
        loss = criterion(outputs, labels)
        loss.backward()
        
        # Optional: Gradient clipping
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        
        optimizer.step()
        
        running_loss += loss.item()
        _, predicted = outputs.max(1)
        total += labels.size(0)
        correct += predicted.eq(labels).sum().item()
        
        # Update progress bar
        pbar.set_postfix({
            'loss': f'{running_loss/total:.4f}',
            'acc': f'{100.*correct/total:.2f}%'
        })
    
    return running_loss/len(train_loader), correct/total

def validate(model, val_loader, criterion, device):
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0
    
    with torch.no_grad():
        for images, meta, labels in val_loader:
            images, meta, labels = images.to(device), meta.to(device), labels.to(device)
            
            outputs = model(images, meta)
            loss = criterion(outputs, labels)
            
            running_loss += loss.item()
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()
    
    return running_loss/len(val_loader), correct/total

def train_model_with_scheduler_and_checkpoint(
    model, train_loader, val_loader, optimizer, criterion, device, 
    epochs=20, patience=5, scheduler_patience=5, checkpoint_dir='checkpoints'):
    
    # Create checkpoint directory if it doesn't exist
    os.makedirs(checkpoint_dir, exist_ok=True)
    checkpoint_path = os.path.join(checkpoint_dir, 'best_model.pt')
    
    early_stopping = EarlyStopping(
        patience=patience, 
        verbose=True, 
        path=checkpoint_path
    )
    scheduler = ReduceLROnPlateau(
        optimizer, 
        mode='max',  # Changed to max since we're monitoring accuracy
        patience=scheduler_patience, 
        verbose=True,
        factor=0.1,
        min_lr=1e-6
    )
    
    history = {
        'train_loss': [], 'val_loss': [],
        'train_acc': [], 'val_acc': [],
        'lr': []
    }
    
    best_model_epoch = None
    
    for epoch in range(epochs):
        print(f'\nEpoch {epoch+1}/{epochs}')
        
        # Training phase
        train_loss, train_acc = train(model, train_loader, optimizer, criterion, device)
        
        # Validation phase
        val_loss, val_acc = validate(model, val_loader, criterion, device)
        
        # Update history
        history['train_loss'].append(train_loss)
        history['val_loss'].append(val_loss)
        history['train_acc'].append(train_acc)
        history['val_acc'].append(val_acc)
        history['lr'].append(optimizer.param_groups[0]['lr'])
        
        print(f'Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}')
        print(f'Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}')
        
        # Update scheduler based on validation accuracy
        scheduler.step(val_acc)
        
        # Early stopping check
        early_stopping(val_acc, model)
        if val_acc > early_stopping.val_accuracy_max:
            best_model_epoch = epoch + 1
            
        if early_stopping.early_stop:
            print("Early stopping triggered")
            break
    
    # Load best model
    model.load_state_dict(torch.load(checkpoint_path))
    
    # Plot training curves
    plot_training_curves_with_checkpoint(history, best_model_epoch)
    
    return model, history

def plot_training_curves_with_checkpoint(history, best_model_epoch):
    epochs_range = range(1, len(history['train_loss']) + 1)
    
    fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(18, 5))
    
    # Loss curves
    ax1.plot(epochs_range, history['train_loss'], label='Training Loss')
    ax1.plot(epochs_range, history['val_loss'], label='Validation Loss')
    if best_model_epoch:
        ax1.axvline(best_model_epoch, color='r', linestyle='--', label='Best Model')
    ax1.set_xlabel('Epoch')
    ax1.set_ylabel('Loss')
    ax1.set_title('Training and Validation Loss')
    ax1.legend()
    
    # Accuracy curves
    ax2.plot(epochs_range, history['train_acc'], label='Training Accuracy')
    ax2.plot(epochs_range, history['val_acc'], label='Validation Accuracy')
    if best_model_epoch:
        ax2.axvline(best_model_epoch, color='r', linestyle='--', label='Best Model')
    ax2.set_xlabel('Epoch')
    ax2.set_ylabel('Accuracy')
    ax2.set_title('Training and Validation Accuracy')
    ax2.legend()
    
    # Learning rate curve
    ax3.plot(epochs_range, history['lr'], label='Learning Rate')
    if best_model_epoch:
        ax3.axvline(best_model_epoch, color='r', linestyle='--', label='Best Model')
    ax3.set_xlabel('Epoch')
    ax3.set_ylabel('Learning Rate')
    ax3.set_title('Learning Rate Schedule')
    ax3.set_yscale('log')
    ax3.legend()
    
    plt.tight_layout()
    plt.show()

<h1>MobileViT</h>

In [None]:
import torch
import torch.nn as nn
import timm  
import torch.nn.functional as F

class EarlyFusionModel(nn.Module):
    def __init__(self, input_dim_meta, num_classes):
        super().__init__()
        
        # Embed metadata to smaller spatial dimensions first
        self.meta_embed = nn.Sequential(
            nn.Linear(input_dim_meta, 64 * 64),  # Updated for mobilevit's smaller receptive field
            nn.ReLU(),
            nn.BatchNorm1d(64 * 64),
            nn.Dropout(0.3)
        )
        
        # Load MobileViT model
        self.mobilevit = timm.create_model("mobilevit_s.cvnets_in1k", pretrained=True, num_classes=num_classes)
        
        # Inspect the model to identify the first conv layer
        # Modify the first conv layer to accept additional channel
        first_conv = self.mobilevit.stem.conv  # `stem.conv` is the correct initial layer
        self.mobilevit.stem.conv = nn.Conv2d(4, first_conv.out_channels, 
                                             kernel_size=first_conv.kernel_size, 
                                             stride=first_conv.stride, 
                                             padding=first_conv.padding, 
                                             bias=first_conv.bias)
        
        # Initialize new channel weights
        with torch.no_grad():
            self.mobilevit.stem.conv.weight.data[:, :3] = first_conv.weight.data
            # Initialize the new channel with smaller weights to prevent dominating
            self.mobilevit.stem.conv.weight.data[:, 3:] = first_conv.weight.data.mean(dim=1, keepdim=True) * 0.1

    def forward(self, img, meta):
        # Reshape metadata to image-like format
        batch_size = img.shape[0]
        meta_reshaped = self.meta_embed(meta).view(batch_size, 1, 64, 64)
        
        # Upsample to match image dimensions
        meta_upsampled = F.interpolate(meta_reshaped, 
                                       size=(224, 224),  # MobileViT expects 256x256
                                       mode='bilinear', 
                                       align_corners=False)
        
        # Early fusion
        combined_input = torch.cat([img, meta_upsampled], dim=1)
        
        # Process through modified MobileViT
        out = self.mobilevit(combined_input)
        return out

# Assuming X_train_meta and other variables are defined
input_dim_meta = X_train_meta.shape[1]
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = EarlyFusionModel(input_dim_meta, num_classes).to(device)

from torchinfo import summary
summary(model=model, 
        input_size=[(16, 3, 224, 224), (16, input_dim_meta)],  # Updated for MobileViT input size
        col_names=["input_size", "output_size", "num_params", "trainable"],
        col_width=20,
        row_settings=["var_names"]
)

mobilevit_model = EarlyFusionModel(input_dim_meta=input_dim_meta, num_classes=num_classes).to(device)
mobilevit_model.load_state_dict(torch.load('D:\\Dermp7\\best_early_fusion_mobilevitsmoteDA.pth'))


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
mobilevit_model.to(device)

true_labels, pred_labels = test(mobilevit_model, test_loader, device)

report = classification_report(true_labels, pred_labels, digits=4,target_names=class_names)
print("Classification Report:")
print(report)
cm = confusion_matrix(true_labels, pred_labels)

<h1>PvtV2</h>

In [None]:
import torch
import torch.nn as nn
import timm
import torch.nn.functional as F

class EarlyFusionModel(nn.Module):
    def __init__(self, input_dim_meta, num_classes):
        super().__init__()
        
        # Embed metadata to smaller spatial dimensions first
        self.meta_embed = nn.Sequential(
            nn.Linear(input_dim_meta, 56 * 56),  # Smaller initial dimension
            nn.ReLU(),
            nn.BatchNorm1d(56 * 56),
            nn.Dropout(0.3)
        )
        
        # Load PVT v2 model
        self.pvt = timm.create_model("pvt_v2_b1", pretrained=True, num_classes=num_classes)
        
        # Modify the first convolution layer to accept additional channel (4 instead of 3)
        first_conv = self.pvt.patch_embed.proj
        self.pvt.patch_embed.proj = nn.Conv2d(4, first_conv.out_channels, 
                                              kernel_size=first_conv.kernel_size,
                                              stride=first_conv.stride,
                                              padding=first_conv.padding,
                                              bias=first_conv.bias is not None)
        
        # Initialize new channel weights
        with torch.no_grad():
            self.pvt.patch_embed.proj.weight.data[:, :3] = first_conv.weight.data
            self.pvt.patch_embed.proj.weight.data[:, 3:] = first_conv.weight.data.mean(dim=1, keepdim=True) * 0.1

    def forward(self, img, meta):
        # Reshape metadata to image-like format
        batch_size = img.shape[0]
        meta_reshaped = self.meta_embed(meta).view(batch_size, 1, 56, 56)
        
        # Upsample to match image dimensions
        meta_upsampled = F.interpolate(meta_reshaped, 
                                       size=(224, 224), 
                                       mode='bilinear', 
                                       align_corners=False)
        
        # Early fusion
        combined_input = torch.cat([img, meta_upsampled], dim=1)
        
        # Process through modified PVT
        out = self.pvt(combined_input)
        return out

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = EarlyFusionModel(input_dim_meta, num_classes).to(device)

from torchinfo import summary
summary(model=model, 
        input_size=[(16, 3, 224, 224), (16, input_dim_meta)],  
        col_names=["input_size", "output_size", "num_params", "trainable"],
        col_width=20,
        row_settings=["var_names"]
)

pv2_model = EarlyFusionModel(input_dim_meta=input_dim_meta, num_classes=num_classes).to(device)
pv2_model.load_state_dict(torch.load('D:\\Dermp7\\best_early_fusion_pvtv2smoteDA.pth'))

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
pv2_model.to(device)

true_labels, pred_labels = test(pv2_model, test_loader, device)

report = classification_report(true_labels, pred_labels, digits=4,target_names=class_names)
print("Classification Report:")
print(report)
cm = confusion_matrix(true_labels, pred_labels)

<h1>Teacher Model (Mean Averaging)</h1>

In [None]:
import torch
import torch.nn as nn
import numpy as np
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score

class TeacherModel(nn.Module):
    def __init__(self, models, ensemble_method="mean"):
        """
        Teacher Model using Ensemble Learning.

        Args:
            models (list): List of trained models to use for ensembling.
            ensemble_method (str): "mean" for averaging logits, "vote" for majority voting.
        """
        super(TeacherModel, self).__init__()
        self.models = models
        self.ensemble_method = ensemble_method

        # Ensure all models are in eval mode and no gradients are computed
        for model in self.models:
            model.eval()
            for param in model.parameters():
                param.requires_grad = False

    def forward(self, img, meta):
        """
        Forward pass through the ensemble teacher model.

        Args:
            img (torch.Tensor): Batch of images.
            meta (torch.Tensor): Batch of metadata.

        Returns:
            torch.Tensor: The ensembled output (soft probabilities).
        """
        model_outputs = []

        with torch.no_grad():  # Disable gradient computation for teacher
            for model in self.models:
                outputs = model(img, meta)
                model_outputs.append(outputs)

        # Convert list to tensor shape [num_models, batch_size, num_classes]
        model_outputs = torch.stack(model_outputs, dim=0)

        if self.ensemble_method == "mean":
            # Soft-label generation: Averaging logits
            avg_outputs = model_outputs.mean(dim=0)  
        elif self.ensemble_method == "vote":
            # Majority voting: Get the most common prediction
            _, predictions = torch.max(model_outputs, dim=2) 
            avg_outputs = predictions.mode(dim=0).values  

        return avg_outputs 

teacher_model = TeacherModel(models=[mobilevit_model, pv2_model], ensemble_method="mean")

# Move to the correct device (CPU/GPU)
teacher_model = teacher_model.to(device)

true_labels, pred_labels = test(teacher_model, test_loader, device)

report = classification_report(true_labels, pred_labels, digits=4,target_names=class_names)
print("Classification Report:")
print(report)
cm = confusion_matrix(true_labels, pred_labels)

<h1>Knowledge Distillation on Student Model</h1>

In [None]:
import torch 
import torch.nn as nn
import torch.nn.functional as F
import timm
from torch_geometric.nn import GCNConv
from torch_cluster import knn_graph
from torchinfo import summary


# -----------------------------
# Early Fusion with Dynamic GCN
# -----------------------------
class EarlyFusionWithDynamicGCN(nn.Module):
    def __init__(self, input_dim_meta, num_classes, backbone="xxs", k=8):
        super().__init__()
        self.k = k

        # --------- GCN branch ----------
        self.gcn1 = GCNConv(input_dim_meta, 64)
        self.gcn2 = GCNConv(64, 32)
        self.res_proj = nn.Linear(64, 32)

        # Metadata ‚Üí pseudo-image
        self.meta_to_image = nn.Sequential(
            nn.Linear(32, 56*56),
            nn.ReLU(),
            nn.BatchNorm1d(56*56),
            nn.Dropout(0.3)
        )

        # --------- MobileViT backbone ----------
        if backbone == "xxs":
            model_name = "mobilevit_xxs.cvnets_in1k"
            self.out_channels = 320
        elif backbone == "s":
            model_name = "mobilevit_s.cvnets_in1k"
            self.out_channels = 640
        else:
            raise ValueError("backbone must be 'xxs' or 's'")

        self.mobilevit = timm.create_model(
            model_name,
            pretrained=True,
            num_classes=0,
            global_pool=''   # keep features, no built-in pooling
        )

        # Modify stem for 4-channel input
        stem_conv = self.mobilevit.stem.conv
        new_conv = nn.Conv2d(
            4, stem_conv.out_channels,
            kernel_size=stem_conv.kernel_size,
            stride=stem_conv.stride,
            padding=stem_conv.padding,
            bias=stem_conv.bias is not None
        )

        with torch.no_grad():
            # copy RGB weights
            new_conv.weight[:, :3] = stem_conv.weight
            # tiny weight for metadata channel
            new_conv.weight[:, 3:] = stem_conv.weight.mean(dim=1, keepdim=True) * 0.1
            # copy bias if exists
            if stem_conv.bias is not None:
                new_conv.bias = stem_conv.bias.clone()

        self.mobilevit.stem.conv = new_conv

        # --------- Classifier ----------
        self.pool = nn.AdaptiveAvgPool2d((1, 1))
        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(self.out_channels, 256),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(256, num_classes)
        )

    def forward(self, img, meta, batch_idx):
        B = meta.size(0)

        # --------- Graph construction (NOW IDENTICAL TO MODEL B) ----------
        # dynamic kNN graph WITHOUT self-loops, batch-aware
        edge_index = knn_graph(meta, k=self.k, batch=batch_idx)
        # (torch_cluster.knn_graph by default excludes self-loops)

        # ----- GCN with residual -----
        x1 = F.relu(self.gcn1(meta, edge_index))
        x2 = F.relu(self.gcn2(x1, edge_index) + self.res_proj(x1))
        x_meta = x2  # [B, 32]

        # ----- metadata ‚Üí pseudo-image -----
        meta_img = self.meta_to_image(x_meta).view(B, 1, 56, 56)
        meta_img = F.interpolate(meta_img, size=(224, 224), mode='bilinear', align_corners=False)

        # ----- early fusion -----
        x = torch.cat([img, meta_img], dim=1)  # [B, 4, 224, 224]

        # ----- MobileViT forward -----
        features = self.mobilevit(x)          # [B, C, H, W] (no global_pool)
        features = self.pool(features).view(B, -1)

        return self.classifier(features)


input_dim_meta = X_train_meta.shape[1]

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
batch_size=16
model = EarlyFusionWithDynamicGCN(input_dim_meta, num_classes, backbone="xxs").to(device)

dummy_img = torch.randn(batch_size, 3, 224, 224).to(device)
dummy_meta = torch.randn(batch_size, input_dim_meta).to(device)
dummy_batch_idx = torch.arange(batch_size, device=device)  # one node per graph

summary(
        model,
        input_data=[dummy_img, dummy_meta, dummy_batch_idx],
        col_names=["input_size", "output_size", "num_params", "trainable"],
        col_width=20,
        depth=3
    )


In [None]:
model_xxs = EarlyFusionMobileViTXXS(input_dim_meta, num_classes).to(device)
model_s = EarlyFusionWithDynamicGCN(input_dim_meta, num_classes, backbone="s").to(device)

# model_proposed = ...   # Only include if defined


import torch.nn as nn

def count_mobilevit_blocks_and_layers(backbone):
    print("\n===== Counting MobileViT Blocks and Transformer Layers (cvnets MobileViT) =====")

    # 1) Find all MobileVitBlock modules
    blocks = []
    for name, module in backbone.named_modules():
        if "mobilevitblock" in module.__class__.__name__.lower():
            blocks.append((name, module))

    print(f"\nFound {len(blocks)} MobileViT Blocks:")
    for i, (name, module) in enumerate(blocks):
        print(f"  Block #{i}: {name} --> {module.__class__.__name__}")

    # 2) Inspect transformer depth inside MobileVitBlocks
    total_layers = 0
    print("\nPer-block transformer depths:")
    for name, block in blocks:
        depth = 0

        if hasattr(block, "transformer"):
            tr = block.transformer

            # CVNets MobileViT uses `tr.layers` (ModuleList)
            if hasattr(tr, "layers") and isinstance(tr.layers, nn.ModuleList):
                depth = len(tr.layers)

            # Another version uses `tr.blocks`
            elif hasattr(tr, "blocks") and isinstance(tr.blocks, nn.ModuleList):
                depth = len(tr.blocks)

            # Fallback: count submodules inside transformer manually
            else:
                depth = sum(1 for _ in tr.children())

        total_layers += depth
        print(f"  - {name}: {depth} transformer layers")

    print(f"\nTotal transformer layers across all MobileViT blocks: {total_layers}")
    return blocks, total_layers


# -------------------------
# RUN FOR XXS
# -------------------------
blocks_xxs, total_layers_xxs = count_mobilevit_blocks_and_layers(model_xxs.backbone)
print("\nXXS Total MobileViT Blocks:", len(blocks_xxs))
print("XXS Total Transformer Layers:", total_layers_xxs)


# -------------------------
# RUN FOR S
# -------------------------
blocks_s, total_layers_s = count_mobilevit_blocks_and_layers(model_s.mobilevit)
print("\nS Total MobileViT Blocks:", len(blocks_s))
print("S Total Transformer Layers:", total_layers_s)


# -------------------------
# RUN FOR PROPOSED (optional)
# -------------------------
blocks_p, total_layers_p = count_mobilevit_blocks_and_layers(model_proposed.mobilevit)
print("\nProposed Total MobileViT Blocks:", len(blocks_p))
print("Proposed Total Transformer Layers:", total_layers_p)

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim.lr_scheduler import ReduceLROnPlateau
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score


# =========================================================
# Teacher Model (Assumed Already Defined and Loaded)
# =========================================================
teacher_model = TeacherModel(models=[mobilevit_model, pv2_model], ensemble_method="mean").to(device)


# =========================================================
# Utility Functions
# =========================================================
def set_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)


def evaluate_student(model, test_loader, device):
    model.eval()
    all_labels, all_preds = [], []

    with torch.no_grad():
        for images, metas, labels in test_loader:
            images, metas, labels = images.to(device), metas.to(device), labels.to(device)
            batch_indices = torch.arange(metas.size(0), device=device, dtype=torch.long)

            outputs = model(images, metas, batch_indices)
            preds = outputs.argmax(dim=1)

            all_labels.extend(labels.cpu().numpy())
            all_preds.extend(preds.cpu().numpy())

    all_labels = np.array(all_labels)
    all_preds = np.array(all_preds)

    accuracy = accuracy_score(all_labels, all_preds)
    f1 = f1_score(all_labels, all_preds, average='macro')
    precision = precision_score(all_labels, all_preds, average='macro', zero_division=0)
    recall = recall_score(all_labels, all_preds, average='macro', zero_division=0)

    print(f"Test Accuracy:  {accuracy:.4f}")
    print(f"Test F1 Score:  {f1:.4f}")
    print(f"Test Precision: {precision:.4f}")
    print(f"Test Recall:    {recall:.4f}")

    return accuracy, f1, precision, recall


# =========================================================
# Averaged Training Curve Plotting (handles early stopping)
# =========================================================
def plot_average_training_curves(all_histories, save_path="averaged_training_curves.png"):
    # Convert list-of-lists (one per seed) ‚Üí stacked arrays with same length
    train_loss_runs = [np.array(run) for run in all_histories["train_loss"]]
    val_loss_runs   = [np.array(run) for run in all_histories["val_loss"]]
    train_acc_runs  = [np.array(run) for run in all_histories["train_acc"]]
    val_acc_runs    = [np.array(run) for run in all_histories["val_acc"]]

    # Different seeds may stop earlier due to early stopping ‚Üí align to min length
    min_len = min(len(x) for x in train_loss_runs)

    train_loss = np.stack([x[:min_len] for x in train_loss_runs])
    val_loss   = np.stack([x[:min_len] for x in val_loss_runs])
    train_acc  = np.stack([x[:min_len] for x in train_acc_runs])
    val_acc    = np.stack([x[:min_len] for x in val_acc_runs])

    epochs = np.arange(1, min_len + 1)

    # ---- Set global bold font ----
    plt.rcParams['font.weight'] = 'bold'
    plt.rcParams['axes.labelweight'] = 'bold'
    plt.rcParams['axes.titleweight'] = 'bold'
    plt.rcParams['xtick.labelsize'] = 12
    plt.rcParams['ytick.labelsize'] = 12

    plt.figure(figsize=(12, 5))

    # -------- Loss subplot --------
    plt.subplot(1, 2, 1)
    tl_mean, tl_std = train_loss.mean(axis=0), train_loss.std(axis=0)
    vl_mean, vl_std = val_loss.mean(axis=0),   val_loss.std(axis=0)

    plt.plot(epochs, tl_mean, label="Train Loss")
    plt.fill_between(epochs, tl_mean - tl_std, tl_mean + tl_std, alpha=0.25)

    plt.plot(epochs, vl_mean, label="Validation Loss")
    plt.fill_between(epochs, vl_mean - vl_std, vl_mean + vl_std, alpha=0.25)

    plt.xlabel("Epochs", fontweight="bold")
    plt.ylabel("Loss", fontweight="bold")
    plt.title("Training and Validation Loss (Averaged Across Runs)", fontweight="bold")
    plt.legend()

    # -------- Accuracy subplot --------
    plt.subplot(1, 2, 2)
    ta_mean, ta_std = train_acc.mean(axis=0), train_acc.std(axis=0)
    va_mean, va_std = val_acc.mean(axis=0),   val_acc.std(axis=0)

    plt.plot(epochs, ta_mean, label="Train Accuracy")
    plt.fill_between(epochs, ta_mean - ta_std, ta_mean + ta_std, alpha=0.25)

    plt.plot(epochs, va_mean, label="Validation Accuracy")
    plt.fill_between(epochs, va_mean - va_std, va_mean + va_std, alpha=0.25)

    plt.xlabel("Epochs", fontweight="bold")
    plt.ylabel("Accuracy", fontweight="bold")
    plt.title("Training and Validation Accuracy (Averaged Across Runs)", fontweight="bold")
    plt.legend()

    plt.tight_layout()
    plt.savefig(save_path, dpi=650, bbox_inches="tight")
    plt.show()


# =========================================================
# Training with Knowledge Distillation
# =========================================================
def train_student_model_kd(student_model, teacher_model,
                           train_loader, val_loader, test_loader,
                           device, alpha=0.5, temperature=3.0,
                           epochs=100, patience=10):

    student_model.to(device)
    teacher_model.eval()

    criterion = nn.CrossEntropyLoss()
    kl_div_loss = nn.KLDivLoss(reduction='batchmean')
    optimizer = torch.optim.Adam(student_model.parameters(), lr=0.001)

    # You early-stop on val_accuracy ‚Üí use mode='max' & step on val_accuracy
    scheduler = ReduceLROnPlateau(optimizer, mode='max', patience=5, verbose=True)

    best_val_accuracy = 0.0
    best_val_model_state = None
    patience_counter = 0

    history = {"train_loss": [], "val_loss": [], "train_acc": [], "val_acc": []}

    for epoch in range(epochs):
        student_model.train()
        train_loss_sum, correct, total = 0.0, 0, 0
        pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs}")

        for images, metas, labels in pbar:
            images, metas, labels = images.to(device), metas.to(device), labels.to(device)
            batch_indices = torch.arange(metas.size(0), device=device, dtype=torch.long)

            optimizer.zero_grad()

            student_outputs = student_model(images, metas, batch_indices)

            with torch.no_grad():
                teacher_outputs = teacher_model(images, metas)

            # --- Hard loss ---
            loss_hard = criterion(student_outputs, labels)

            # --- Soft loss (KD) ---
            loss_soft = kl_div_loss(
                F.log_softmax(student_outputs / temperature, dim=1),
                F.softmax(teacher_outputs / temperature, dim=1)
            )

            loss = (1 - alpha) * loss_hard + alpha * (temperature ** 2) * loss_soft

            loss.backward()
            optimizer.step()

            train_loss_sum += loss.item()
            _, predicted = student_outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()

        train_loss = train_loss_sum / len(train_loader)
        train_accuracy = correct / total
        history["train_loss"].append(train_loss)
        history["train_acc"].append(train_accuracy)

        # ---- Validation ----
        student_model.eval()
        val_loss_sum, correct, total = 0.0, 0, 0
        with torch.no_grad():
            for images, metas, labels in val_loader:
                images, metas, labels = images.to(device), metas.to(device), labels.to(device)
                batch_indices = torch.arange(metas.size(0), device=device, dtype=torch.long)

                outputs = student_model(images, metas, batch_indices)
                loss = criterion(outputs, labels)
                val_loss_sum += loss.item()

                _, predicted = outputs.max(1)
                correct += predicted.eq(labels).sum().item()
                total += labels.size(0)

        val_loss = val_loss_sum / len(val_loader)
        val_accuracy = correct / total
        history["val_loss"].append(val_loss)
        history["val_acc"].append(val_accuracy)

        print(f"Epoch {epoch+1}: "
              f"Train Loss={train_loss:.4f}, Train Acc={train_accuracy:.4f} | "
              f"Val Loss={val_loss:.4f}, Val Acc={val_accuracy:.4f}")

        # ---- Early Stopping on Val Accuracy ----
        if val_accuracy > best_val_accuracy:
            best_val_accuracy = val_accuracy
            best_val_model_state = student_model.state_dict()
            patience_counter = 0
        else:
            patience_counter += 1

        if patience_counter >= patience:
            print("Early stopping triggered.")
            break

        # LR schedule on validation accuracy (consistent with early stopping)
        scheduler.step(val_accuracy)

    return best_val_model_state, history, best_val_accuracy


# =========================================================
# Main Experiment Loop (MULTI-SEED RUNS)
# =========================================================
seeds = [42, 123, 569]
best_overall_model = None
best_overall_accuracy = 0.0

results = {"accuracy": [], "f1": [], "precision": [], "recall": []}
all_histories = {"train_loss": [], "val_loss": [], "train_acc": [], "val_acc": []}

for seed in seeds:
    print(f"\n--- Training with Seed {seed} ---")
    set_seed(seed)

    student_model = EarlyFusionWithDynamicGCN(input_dim_meta, num_classes, backbone="xxs").to(device)

    best_model_state, history, val_acc = train_student_model_kd(
        student_model=student_model,
        teacher_model=teacher_model,
        train_loader=train_loader,
        val_loader=val_loader,
        test_loader=test_loader,
        device=device,
        alpha=0.2,
        temperature=9.0,
        epochs=100,
        patience=10
    )

    # Store history for averaged curves
    all_histories["train_loss"].append(history["train_loss"])
    all_histories["val_loss"].append(history["val_loss"])
    all_histories["train_acc"].append(history["train_acc"])
    all_histories["val_acc"].append(history["val_acc"])

    # ---- Final Test ----
    student_model.load_state_dict(best_model_state)
    acc, f1, prec, recall = evaluate_student(student_model, test_loader, device)

    results["accuracy"].append(acc)
    results["f1"].append(f1)
    results["precision"].append(prec)
    results["recall"].append(recall)

    if val_acc > best_overall_accuracy:
        best_overall_accuracy = val_acc
        best_overall_model = student_model

# ---- Save Best Model ----
torch.save(best_overall_model.state_dict(), "Lightstudent.pth")
print(f"\nBest Val Accuracy Model Saved (Acc={best_overall_accuracy:.4f})")

# ---- Summary ----
print("\n--- Final Evaluation Across Seeds ---")
for metric in results:
    print(f"{metric.capitalize()}: {np.mean(results[metric]):.4f} ¬± {np.std(results[metric]):.4f}")

# ---- PLOT AVERAGED TRAINING CURVES ----
plot_average_training_curves(all_histories)


In [None]:
def plot_average_training_curves(all_histories, save_path="averaged_training_curves.png"):
    import numpy as np
    import matplotlib.pyplot as plt

    # Convert lists of lists ‚Üí numpy arrays
    train_loss = np.array([np.array(run) for run in all_histories["train_loss"]], dtype=object)
    val_loss   = np.array([np.array(run) for run in all_histories["val_loss"]], dtype=object)
    train_acc  = np.array([np.array(run) for run in all_histories["train_acc"]], dtype=object)
    val_acc    = np.array([np.array(run) for run in all_histories["val_acc"]], dtype=object)

    # ---- Ensure equal length by truncating to minimum epoch count ----
    min_len = min([len(x) for x in train_loss])

    train_loss = np.stack([run[:min_len] for run in train_loss])
    val_loss   = np.stack([run[:min_len] for run in val_loss])
    train_acc  = np.stack([run[:min_len] for run in train_acc])
    val_acc    = np.stack([run[:min_len] for run in val_acc])

    epochs = np.arange(1, min_len + 1)

    # ---- Styling ----
    plt.figure(figsize=(12, 5))
    plt.rcParams.update({
        "font.weight": "bold",
        "axes.labelweight": "bold",
        "axes.titleweight": "bold",
        "xtick.labelsize": 12,
        "ytick.labelsize": 12
    })

    # ---------------- LOSS PLOT ----------------
    plt.subplot(1, 2, 1)
    plt.plot(epochs, train_loss.mean(0), label="Train Loss", color="blue")
    plt.fill_between(
        epochs,
        train_loss.mean(0) - train_loss.std(0),
        train_loss.mean(0) + train_loss.std(0),
        alpha=0.25,
        color="blue"
    )

    plt.plot(epochs, val_loss.mean(0), label="Val Loss", color="red")
    plt.fill_between(
        epochs,
        val_loss.mean(0) - val_loss.std(0),
        val_loss.mean(0) + val_loss.std(0),
        alpha=0.25,
        color="red"
    )

    plt.xlabel("Epochs")
    plt.ylabel("Loss")
    plt.title("Training & Validation Loss")
    plt.legend()

    # ---------------- ACCURACY PLOT ----------------
    plt.subplot(1, 2, 2)
    plt.plot(epochs, train_acc.mean(0), label="Train Accuracy", color="green")
    plt.fill_between(
        epochs,
        train_acc.mean(0) - train_acc.std(0),
        train_acc.mean(0) + train_acc.std(0),
        alpha=0.25,
        color="green"
    )

    plt.plot(epochs, val_acc.mean(0), label="Val Accuracy", color="orange")
    plt.fill_between(
        epochs,
        val_acc.mean(0) - val_acc.std(0),
        val_acc.mean(0) + val_acc.std(0),
        alpha=0.25,
        color="orange"
    )

    plt.xlabel("Epochs")
    plt.ylabel("Accuracy")
    plt.title("Training & Validation Accuracy")
    plt.legend()

    plt.tight_layout()
    plt.show()

plot_average_training_curves(all_histories)


In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import classification_report, confusion_matrix, roc_curve, auc, precision_recall_curve
import torch.nn.functional as F

# Load the best model
best_model = EarlyFusionWithDynamicGCN(input_dim_meta, num_classes).to(device)
best_model.load_state_dict(torch.load("Lightstudent.pth"))
best_model.eval()

all_labels = []
all_preds = []
all_probs = []

with torch.no_grad():
    for images, metas, labels in test_loader:
        images, metas, labels = images.to(device), metas.to(device), labels.to(device)
        batch_indices = torch.arange(metas.size(0)).to(device).long()
        outputs = student_model(images, metas, batch_indices)        
        probs = F.softmax(outputs, dim=1)
        preds = probs.argmax(dim=1)

        all_labels.extend(labels.cpu().numpy())
        all_preds.extend(preds.cpu().numpy())
        all_probs.extend(probs.cpu().numpy())

all_labels = np.array(all_labels)
all_preds = np.array(all_preds)
all_probs = np.array(all_probs)

# Compute classification report
class_report = classification_report(all_labels, all_preds, target_names=class_names, digits=4)

# Compute normalized confusion matrix
conf_matrix = confusion_matrix(all_labels, all_preds, normalize="true")

# Display classification report
print("\nClassification Report:\n")
print(class_report)

# Display confusion matrix (black and white)
plt.figure(figsize=(8, 6))
sns.heatmap(conf_matrix, annot=True, cmap="gray", fmt=".2f", xticklabels=class_names, yticklabels=class_names, cbar=True)
plt.xlabel("Predicted Label")
plt.ylabel("True Label")
plt.title("Normalized Confusion Matrix")
plt.show()

# Compute and plot ROC-AUC curve for each class
plt.figure(figsize=(10, 6))

for i, class_name in enumerate(class_names):
    fpr, tpr, _ = roc_curve((all_labels == i).astype(int), all_probs[:, i])
    roc_auc = auc(fpr, tpr)
    plt.plot(fpr, tpr, label=f"{class_name} (AUC = {roc_auc:.2f})")

plt.plot([0, 1], [0, 1], "k--")  # Diagonal line for reference
plt.xlabel("False Positive Rate")
plt.ylabel("True Positive Rate")
plt.title("ROC-AUC Curve")
plt.legend(loc="lower right")
plt.show()

# Optional: Compute and plot Precision-Recall Curve
plt.figure(figsize=(10, 6))

for i, class_name in enumerate(class_names):
    precision, recall, _ = precision_recall_curve((all_labels == i).astype(int), all_probs[:, i])
    plt.plot(recall, precision, label=f"{class_name}")

plt.xlabel("Recall")
plt.ylabel("Precision")
plt.title("Precision-Recall Curve")
plt.legend()
plt.show()

In [None]:
import torch
import torch.nn as nn
import time
import numpy as np
from ptflops import get_model_complexity_info

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
image_size = (3, 224, 224)

results = []

def count_parameters(model, trainable_only=False):
    if trainable_only:
        return sum(p.numel() for p in model.parameters() if p.requires_grad)
    else:
        return sum(p.numel() for p in model.parameters())



class FusionWrapper(nn.Module):
    """
    Wraps 2-input (img, meta) models so ptflops sees a single image input.
    """
    def __init__(self, model, meta_dim):
        super().__init__()
        self.model = model
        self.meta_dim = meta_dim

    def forward(self, x):
        dummy_meta = torch.randn(x.size(0), self.meta_dim).to(x.device)
        return self.model(x, dummy_meta)


class GCN2InputWrapper(nn.Module):
    """
    Wraps EarlyFusionWithDynamicGCN so it behaves like forward(img, meta).
    Batch indices are synthesized internally.
    """
    def __init__(self, gcn_model, meta_dim):
        super().__init__()
        self.gcn_model = gcn_model
        self.meta_dim = meta_dim

    def forward(self, img, meta):
        B = meta.size(0)
        batch_idx = torch.arange(B, device=meta.device)
        return self.gcn_model(img, meta, batch_idx)


def compute_flops(model, meta_dim):
    wrapper = FusionWrapper(model, meta_dim).to(device)
    with torch.no_grad():
        flops, _ = get_model_complexity_info(
            wrapper,
            image_size,
            as_strings=False,
            print_per_layer_stat=False,
            verbose=False
        )
    return float(flops / 1e9)  # GFLOPs


def measure_gpu_latency(model, meta_dim, runs=200, warmup=30):
    if not torch.cuda.is_available():
        return None, None, None

    torch.cuda.empty_cache()
    torch.cuda.reset_peak_memory_stats()
    model.eval().to(device)

    dummy_img = torch.randn(1, *image_size, device=device)
    dummy_meta = torch.randn(1, meta_dim, device=device)

    # warmup
    for _ in range(warmup):
        _ = model(dummy_img, dummy_meta)
    torch.cuda.synchronize()

    start_evt = torch.cuda.Event(enable_timing=True)
    end_evt = torch.cuda.Event(enable_timing=True)
    times = []

    for _ in range(runs):
        start_evt.record()
        _ = model(dummy_img, dummy_meta)
        end_evt.record()
        torch.cuda.synchronize()
        times.append(start_evt.elapsed_time(end_evt))  # ms

    times = np.array(times)
    mean = float(times.mean())
    std = float(times.std())
    fps = float(1000.0 / mean)
    return mean, std, fps


def measure_cpu_latency(model, meta_dim, runs=100, warmup=20):
    model_cpu = model.cpu()
    model_cpu.eval()

    dummy_img = torch.randn(1, *image_size)
    dummy_meta = torch.randn(1, meta_dim)

    for _ in range(warmup):
        _ = model_cpu(dummy_img, dummy_meta)

    times = []
    for _ in range(runs):
        start = time.perf_counter()
        _ = model_cpu(dummy_img, dummy_meta)
        end = time.perf_counter()
        times.append((end - start) * 1000.0)

    times = np.array(times)
    mean = float(times.mean())
    std = float(times.std())
    fps = float(1000.0 / mean)
    return mean, std, fps


def load_model(model_class, ckpt_path, name, wrap_gcn=False):
    """
    Creates model_class(input_dim_meta=59, num_classes=6), loads checkpoint,
    optionally wraps GCN so it behaves like (img, meta).
    """
    print(f"\n[LOAD] {name} from {ckpt_path}")
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        torch.cuda.reset_peak_memory_stats()

    model = model_class(meta_dim, num_classes)
    state = torch.load(ckpt_path, map_location="cpu")
    model.load_state_dict(state)
    if wrap_gcn:
        model = GCN2InputWrapper(model, meta_dim)
    model.to(device)
    model.eval()
    print("[OK] Model loaded.")
    return model


def benchmark_model(model, name, meta_dim, count_all_params=False):
    """
    Given a ready-to-use 2-input (img, meta) model, compute all stats.
    If count_all_params=True, counts ALL parameters (ignoring requires_grad),
    which is what we want for the teacher ensemble.
    """
    print(f"\n======= Benchmarking: {name} =======")

    # trainable_only=False for teacher, True for others via this flag
    trainable_only = not count_all_params
    params_m = count_parameters(model, trainable_only=trainable_only) / 1e6
    print(f"Params: {params_m:.3f} M")

    flops_g = compute_flops(model, meta_dim)
    print(f"FLOPs: {flops_g:.3f} G")

    gpu_mean, gpu_std, gpu_fps = measure_gpu_latency(model, meta_dim)
    if gpu_mean is not None:
        print(f"GPU Latency: {gpu_mean:.3f} ¬± {gpu_std:.3f} ms  |  FPS: {gpu_fps:.1f}")
    else:
        print("GPU Latency: N/A")

    cpu_mean, cpu_std, cpu_fps = measure_cpu_latency(model, meta_dim)
    print(f"CPU Latency: {cpu_mean:.3f} ¬± {cpu_std:.3f} ms  |  FPS: {cpu_fps:.1f}")

    return {
        "model": name,
        "params_M": params_m,
        "flops_G": flops_g,
        "gpu_latency_mean_ms": gpu_mean,
        "gpu_latency_std_ms": gpu_std,
        "gpu_fps": gpu_fps,
        "cpu_latency_mean_ms": cpu_mean,
        "cpu_latency_std_ms": cpu_std,
        "cpu_fps": cpu_fps,
    }



student_raw = EarlyFusionWithDynamicGCN(input_dim_meta, num_classes=num_classes)
student_state = torch.load("Lightstudent.pth", map_location="cpu")
student_raw.load_state_dict(student_state)
student_model = GCN2InputWrapper(student_raw, input_dim_meta).to(device).eval()
results.append(benchmark_model(student_model, "TabFusion (GCN Student)", input_dim_meta))

<h1>Knowledge Distillation on Student Model with "s" Backbone</h1>

In [None]:
# =========================================================
# Main Experiment Loop (MULTI-SEED RUNS) - "s" BACKBONE
# =========================================================
seeds = [42, 123, 569]
best_overall_model_s = None
best_overall_accuracy_s = 0.0

results_s = {"accuracy": [], "f1": [], "precision": [], "recall": []}
all_histories_s = {"train_loss": [], "val_loss": [], "train_acc": [], "val_acc": []}

for seed in seeds:
    print(f"\n--- Training with Seed {seed} (s backbone) ---")
    set_seed(seed)

    student_model_s = EarlyFusionWithDynamicGCN(input_dim_meta, num_classes, backbone="s").to(device)

    best_model_state, history, val_acc = train_student_model_kd(
        student_model=student_model_s,
        teacher_model=teacher_model,
        train_loader=train_loader,
        val_loader=val_loader,
        test_loader=test_loader,
        device=device,
        alpha=0.2,
        temperature=9.0,
        epochs=100,
        patience=10
    )

    # Store history for averaged curves
    all_histories_s["train_loss"].append(history["train_loss"])
    all_histories_s["val_loss"].append(history["val_loss"])
    all_histories_s["train_acc"].append(history["train_acc"])
    all_histories_s["val_acc"].append(history["val_acc"])

    # ---- Final Test ----
    student_model_s.load_state_dict(best_model_state)
    acc, f1, prec, recall = evaluate_student(student_model_s, test_loader, device)

    results_s["accuracy"].append(acc)
    results_s["f1"].append(f1)
    results_s["precision"].append(prec)
    results_s["recall"].append(recall)

    if val_acc > best_overall_accuracy_s:
        best_overall_accuracy_s = val_acc
        best_overall_model_s = student_model_s

# ---- Save Best Model ----
torch.save(best_overall_model_s.state_dict(), "Lightstudent_s.pth")
print(f"\nBest Val Accuracy Model Saved (Acc={best_overall_accuracy_s:.4f})")

# ---- Summary ----
print("\n--- Final Evaluation Across Seeds (s backbone) ---")
for metric in results_s:
    print(f"{metric.capitalize()}: {np.mean(results_s[metric]):.4f} ¬± {np.std(results_s[metric]):.4f}")

# ---- PLOT AVERAGED TRAINING CURVES ----
plot_average_training_curves(all_histories_s)

In [None]:
# Plot averaged training curves for s backbone
def plot_average_training_curves_s(all_histories, save_path="averaged_training_curves_s.png"):
    import numpy as np
    import matplotlib.pyplot as plt

    # Convert lists of lists ‚Üí numpy arrays
    train_loss = np.array([np.array(run) for run in all_histories["train_loss"]], dtype=object)
    val_loss   = np.array([np.array(run) for run in all_histories["val_loss"]], dtype=object)
    train_acc  = np.array([np.array(run) for run in all_histories["train_acc"]], dtype=object)
    val_acc    = np.array([np.array(run) for run in all_histories["val_acc"]], dtype=object)

    # ---- Ensure equal length by truncating to minimum epoch count ----
    min_len = min([len(x) for x in train_loss])

    train_loss = np.stack([run[:min_len] for run in train_loss])
    val_loss   = np.stack([run[:min_len] for run in val_loss])
    train_acc  = np.stack([run[:min_len] for run in train_acc])
    val_acc    = np.stack([run[:min_len] for run in val_acc])

    epochs = np.arange(1, min_len + 1)

    # ---- Styling ----
    plt.figure(figsize=(12, 5))
    plt.rcParams.update({
        "font.weight": "bold",
        "axes.labelweight": "bold",
        "axes.titleweight": "bold",
        "xtick.labelsize": 12,
        "ytick.labelsize": 12
    })

    # ---------------- LOSS PLOT ----------------
    plt.subplot(1, 2, 1)
    plt.plot(epochs, train_loss.mean(0), label="Train Loss", color="blue")
    plt.fill_between(
        epochs,
        train_loss.mean(0) - train_loss.std(0),
        train_loss.mean(0) + train_loss.std(0),
        alpha=0.25,
        color="blue"
    )

    plt.plot(epochs, val_loss.mean(0), label="Val Loss", color="red")
    plt.fill_between(
        epochs,
        val_loss.mean(0) - val_loss.std(0),
        val_loss.mean(0) + val_loss.std(0),
        alpha=0.25,
        color="red"
    )

    plt.xlabel("Epochs")
    plt.ylabel("Loss")
    plt.title("Training & Validation Loss (s backbone)")
    plt.legend()

    # ---------------- ACCURACY PLOT ----------------
    plt.subplot(1, 2, 2)
    plt.plot(epochs, train_acc.mean(0), label="Train Accuracy", color="green")
    plt.fill_between(
        epochs,
        train_acc.mean(0) - train_acc.std(0),
        train_acc.mean(0) + train_acc.std(0),
        alpha=0.25,
        color="green"
    )

    plt.plot(epochs, val_acc.mean(0), label="Val Accuracy", color="orange")
    plt.fill_between(
        epochs,
        val_acc.mean(0) - val_acc.std(0),
        val_acc.mean(0) + val_acc.std(0),
        alpha=0.25,
        color="orange"
    )

    plt.xlabel("Epochs")
    plt.ylabel("Accuracy")
    plt.title("Training & Validation Accuracy (s backbone)")
    plt.legend()

    plt.tight_layout()
    plt.savefig(save_path, dpi=650, bbox_inches="tight")
    plt.show()

plot_average_training_curves_s(all_histories_s)

In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import classification_report, confusion_matrix, roc_curve, auc, precision_recall_curve
import torch.nn.functional as F

# Load the best model (s backbone)
best_model_s = EarlyFusionWithDynamicGCN(input_dim_meta, num_classes, backbone="s").to(device)
best_model_s.load_state_dict(torch.load("Lightstudent_s.pth"))
best_model_s.eval()

all_labels = []
all_preds = []
all_probs = []

with torch.no_grad():
    for images, metas, labels in test_loader:
        images, metas, labels = images.to(device), metas.to(device), labels.to(device)
        batch_indices = torch.arange(metas.size(0)).to(device).long()
        outputs = best_model_s(images, metas, batch_indices)        
        probs = F.softmax(outputs, dim=1)
        preds = probs.argmax(dim=1)

        all_labels.extend(labels.cpu().numpy())
        all_preds.extend(preds.cpu().numpy())
        all_probs.extend(probs.cpu().numpy())

all_labels = np.array(all_labels)
all_preds = np.array(all_preds)
all_probs = np.array(all_probs)

# Compute classification report
class_report = classification_report(all_labels, all_preds, target_names=class_names, digits=4)

# Compute normalized confusion matrix
conf_matrix = confusion_matrix(all_labels, all_preds, normalize="true")

# Display classification report
print("\nClassification Report (s backbone):\n")
print(class_report)

# Display confusion matrix (black and white)
plt.figure(figsize=(8, 6))
sns.heatmap(conf_matrix, annot=True, cmap="gray", fmt=".2f", xticklabels=class_names, yticklabels=class_names, cbar=True)
plt.xlabel("Predicted Label")
plt.ylabel("True Label")
plt.title("Normalized Confusion Matrix (s backbone)")
plt.show()

# Compute and plot ROC-AUC curve for each class
plt.figure(figsize=(10, 6))

for i, class_name in enumerate(class_names):
    fpr, tpr, _ = roc_curve((all_labels == i).astype(int), all_probs[:, i])
    roc_auc = auc(fpr, tpr)
    plt.plot(fpr, tpr, label=f"{class_name} (AUC = {roc_auc:.2f})")

plt.plot([0, 1], [0, 1], "k--")  # Diagonal line for reference
plt.xlabel("False Positive Rate")
plt.ylabel("True Positive Rate")
plt.title("ROC-AUC Curve (s backbone)")
plt.legend(loc="lower right")
plt.show()

# Optional: Compute and plot Precision-Recall Curve
plt.figure(figsize=(10, 6))

for i, class_name in enumerate(class_names):
    precision, recall, _ = precision_recall_curve((all_labels == i).astype(int), all_probs[:, i])
    plt.plot(recall, precision, label=f"{class_name}")

plt.xlabel("Recall")
plt.ylabel("Precision")
plt.title("Precision-Recall Curve (s backbone)")
plt.legend()
plt.show()

In [None]:
import torch 
import torch.nn as nn
import torch.nn.functional as F
import timm
from torch_geometric.nn import GCNConv
from torch_cluster import knn_graph
from torchinfo import summary


# -----------------------------
# Early Fusion with Dynamic GCN
# -----------------------------
class EarlyFusionWithDynamicGCN(nn.Module):
    def __init__(self, input_dim_meta, num_classes, backbone="xxs", k=8):
        super().__init__()
        self.k = k

        # --------- GCN branch ----------
        self.gcn1 = GCNConv(input_dim_meta, 64)
        self.gcn2 = GCNConv(64, 32)
        self.res_proj = nn.Linear(64, 32)

        # Metadata ‚Üí pseudo-image
        self.meta_to_image = nn.Sequential(
            nn.Linear(32, 56*56),
            nn.ReLU(),
            nn.BatchNorm1d(56*56),
            nn.Dropout(0.3)
        )

        # --------- MobileViT backbone ----------
        if backbone == "xxs":
            model_name = "mobilevit_xxs.cvnets_in1k"
            self.out_channels = 320
        elif backbone == "s":
            model_name = "mobilevit_s.cvnets_in1k"
            self.out_channels = 640
        else:
            raise ValueError("backbone must be 'xxs' or 's'")

        self.mobilevit = timm.create_model(
            model_name,
            pretrained=True,
            num_classes=0,
            global_pool=''   # keep features, no built-in pooling
        )

        # Modify stem for 4-channel input
        stem_conv = self.mobilevit.stem.conv
        new_conv = nn.Conv2d(
            4, stem_conv.out_channels,
            kernel_size=stem_conv.kernel_size,
            stride=stem_conv.stride,
            padding=stem_conv.padding,
            bias=stem_conv.bias is not None
        )

        with torch.no_grad():
            # copy RGB weights
            new_conv.weight[:, :3] = stem_conv.weight
            # tiny weight for metadata channel
            new_conv.weight[:, 3:] = stem_conv.weight.mean(dim=1, keepdim=True) * 0.1
            # copy bias if exists
            if stem_conv.bias is not None:
                new_conv.bias = stem_conv.bias.clone()

        self.mobilevit.stem.conv = new_conv

        # --------- Classifier ----------
        self.pool = nn.AdaptiveAvgPool2d((1, 1))
        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(self.out_channels, 256),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(256, num_classes)
        )

    def forward(self, img, meta, batch_idx):
        B = meta.size(0)

        # --------- Graph construction (NOW IDENTICAL TO MODEL B) ----------
        # dynamic kNN graph WITHOUT self-loops, batch-aware
        edge_index = knn_graph(meta, k=self.k, batch=batch_idx)
        # (torch_cluster.knn_graph by default excludes self-loops)

        # ----- GCN with residual -----
        x1 = F.relu(self.gcn1(meta, edge_index))
        x2 = F.relu(self.gcn2(x1, edge_index) + self.res_proj(x1))
        x_meta = x2  # [B, 32]

        # ----- metadata ‚Üí pseudo-image -----
        meta_img = self.meta_to_image(x_meta).view(B, 1, 56, 56)
        meta_img = F.interpolate(meta_img, size=(224, 224), mode='bilinear', align_corners=False)

        # ----- early fusion -----
        x = torch.cat([img, meta_img], dim=1)  # [B, 4, 224, 224]

        # ----- MobileViT forward -----
        features = self.mobilevit(x)          # [B, C, H, W] (no global_pool)
        features = self.pool(features).view(B, -1)

        return self.classifier(features)


input_dim_meta = 59
num_classes = 6

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
batch_size=16
model_xxs = EarlyFusionWithDynamicGCN(input_dim_meta, num_classes, backbone="xxs").to(device)
model_s = EarlyFusionWithDynamicGCN(input_dim_meta, num_classes, backbone="s").to(device)
dummy_img = torch.randn(batch_size, 3, 224, 224).to(device)
dummy_meta = torch.randn(batch_size, input_dim_meta).to(device)
dummy_batch_idx = torch.arange(batch_size, device=device)  # one node per graph

In [None]:
import torch
import torch.nn as nn
import time
import numpy as np
from ptflops import get_model_complexity_info

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
meta_dim = input_dim_meta
image_size = (3, 224, 224)

results_s_benchmark = []

def count_parameters(model, trainable_only=False):
    if trainable_only:
        return sum(p.numel() for p in model.parameters() if p.requires_grad)
    else:
        return sum(p.numel() for p in model.parameters())

class FusionWrapper(nn.Module):
    """
    Wraps 2-input (img, meta) models so ptflops sees a single image input.
    """
    def __init__(self, model, meta_dim):
        super().__init__()
        self.model = model
        self.meta_dim = meta_dim

    def forward(self, x):
        dummy_meta = torch.randn(x.size(0), self.meta_dim).to(x.device)
        return self.model(x, dummy_meta)

class GCN2InputWrapper(nn.Module):
    """
    Wraps EarlyFusionWithDynamicGCN so it behaves like forward(img, meta).
    Batch indices are synthesized internally.
    """
    def __init__(self, gcn_model, meta_dim):
        super().__init__()
        self.gcn_model = gcn_model
        self.meta_dim = meta_dim

    def forward(self, img, meta):
        B = meta.size(0)
        batch_idx = torch.arange(B, device=meta.device)
        return self.gcn_model(img, meta, batch_idx)

def compute_flops(model, meta_dim):
    wrapper = FusionWrapper(model, meta_dim).to(device)
    with torch.no_grad():
        flops, _ = get_model_complexity_info(
            wrapper,
            image_size,
            as_strings=False,
            print_per_layer_stat=False,
            verbose=False
        )
    return float(flops / 1e9)  # GFLOPs

def measure_gpu_latency(model, meta_dim, runs=200, warmup=30):
    if not torch.cuda.is_available():
        return None, None, None

    torch.cuda.empty_cache()
    torch.cuda.reset_peak_memory_stats()
    model.eval().to(device)

    dummy_img = torch.randn(1, *image_size, device=device)
    dummy_meta = torch.randn(1, meta_dim, device=device)

    # warmup
    for _ in range(warmup):
        _ = model(dummy_img, dummy_meta)
    torch.cuda.synchronize()

    start_evt = torch.cuda.Event(enable_timing=True)
    end_evt = torch.cuda.Event(enable_timing=True)
    times = []

    for _ in range(runs):
        start_evt.record()
        _ = model(dummy_img, dummy_meta)
        end_evt.record()
        torch.cuda.synchronize()
        times.append(start_evt.elapsed_time(end_evt))  # ms

    times = np.array(times)
    mean = float(times.mean())
    std = float(times.std())
    fps = float(1000.0 / mean)
    return mean, std, fps

def measure_cpu_latency(model, meta_dim, runs=100, warmup=20):
    model_cpu = model.cpu()
    model_cpu.eval()

    dummy_img = torch.randn(1, *image_size)
    dummy_meta = torch.randn(1, meta_dim)

    for _ in range(warmup):
        _ = model_cpu(dummy_img, dummy_meta)

    times = []
    for _ in range(runs):
        start = time.perf_counter()
        _ = model_cpu(dummy_img, dummy_meta)
        end = time.perf_counter()
        times.append((end - start) * 1000.0)

    times = np.array(times)
    mean = float(times.mean())
    std = float(times.std())
    fps = float(1000.0 / mean)
    return mean, std, fps

def benchmark_model(model, name, meta_dim, count_all_params=False):
    """
    Given a ready-to-use 2-input (img, meta) model, compute all stats.
    """
    print(f"\n======= Benchmarking: {name} =======")

    trainable_only = not count_all_params
    params_m = count_parameters(model, trainable_only=trainable_only) / 1e6
    print(f"Params: {params_m:.3f} M")

    flops_g = compute_flops(model, meta_dim)
    print(f"FLOPs: {flops_g:.3f} G")

    gpu_mean, gpu_std, gpu_fps = measure_gpu_latency(model, meta_dim)
    if gpu_mean is not None:
        print(f"GPU Latency: {gpu_mean:.3f} ¬± {gpu_std:.3f} ms  |  FPS: {gpu_fps:.1f}")
    else:
        print("GPU Latency: N/A")

    cpu_mean, cpu_std, cpu_fps = measure_cpu_latency(model, meta_dim)
    print(f"CPU Latency: {cpu_mean:.3f} ¬± {cpu_std:.3f} ms  |  FPS: {cpu_fps:.1f}")

    return {
        "model": name,
        "params_M": params_m,
        "flops_G": flops_g,
        "gpu_latency_mean_ms": gpu_mean,
        "gpu_latency_std_ms": gpu_std,
        "gpu_fps": gpu_fps,
        "cpu_latency_mean_ms": cpu_mean,
        "cpu_latency_std_ms": cpu_std,
        "cpu_fps": cpu_fps,
    }

# Benchmark s backbone student
student_raw_s = EarlyFusionWithDynamicGCN(input_dim_meta, num_classes, backbone="s")
student_state_s = torch.load("Lightstudent_s.pth", map_location="cpu")
student_raw_s.load_state_dict(student_state_s)
student_model_s_wrapped = GCN2InputWrapper(student_raw_s, input_dim_meta).to(device).eval()
results_s_benchmark.append(benchmark_model(student_model_s_wrapped, "TabFusion (GCN Student - s backbone)", input_dim_meta))

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
from torch_cluster import knn_graph
import timm
from tqdm import tqdm

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


# =========================================================
# IMPROVED MODEL (A FIXED TO MATCH B'S BEHAVIOR)
# =========================================================
class EarlyFusionWithGCN(nn.Module):
    def __init__(self, input_dim_meta, num_classes, k=8):
        super().__init__()
        self.k = k

        # --- GCN Layers ---
        self.gcn1 = GCNConv(input_dim_meta, 64)
        self.gcn2 = GCNConv(64, 32)
        self.res_proj = nn.Linear(64, 32)

        # --- metadata ‚Üí pseudo image ---
        self.meta_to_image = nn.Sequential(
            nn.Linear(32, 56 * 56),
            nn.ReLU(),
            nn.BatchNorm1d(56 * 56),
            nn.Dropout(0.3)
        )

        # --- MobileViT backbone ---
        self.mobilevit = timm.create_model(
            "mobilevit_s.cvnets_in1k",
            pretrained=True,
            num_classes=0
        )

        # --- Modify first conv to accept 4 channels ---
        stem_conv = self.mobilevit.stem.conv
        new_conv = nn.Conv2d(
            4, stem_conv.out_channels,
            kernel_size=stem_conv.kernel_size,
            stride=stem_conv.stride,
            padding=stem_conv.padding,
            bias=stem_conv.bias is not None
        )

        with torch.no_grad():
            # copy RGB weights
            new_conv.weight[:, :3] = stem_conv.weight
            # tiny weight for metadata channel
            new_conv.weight[:, 3:] = stem_conv.weight.mean(dim=1, keepdim=True) * 0.1
            # copy bias if exists
            if stem_conv.bias is not None:
                new_conv.bias = stem_conv.bias.clone()

        self.mobilevit.stem.conv = new_conv

        # keep only first 4 stages
        self.mobilevit.stages = nn.Sequential(
            *list(self.mobilevit.stages.children())[:4]
        )
        self.mobilevit.final_conv = nn.Identity()
        self.mobilevit.head = nn.Identity()

        # --- Post Conv ---
        self.post_conv = nn.Sequential(
            nn.Conv2d(128, 160, kernel_size=1, bias=False),
            nn.BatchNorm2d(160),
            nn.ReLU(inplace=True)
        )

        # --- Classifier ---
        self.pool = nn.AdaptiveAvgPool2d((1, 1))
        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(160, 256),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(256, num_classes)
        )

    def forward(self, img, meta, batch_idx):
        B = meta.size(0)

        # CORRECT: dynamic kNN graph WITHOUT self-loops
        edge_index = knn_graph(meta, k=self.k, batch=batch_idx)

        # GCN + residual
        x1 = F.relu(self.gcn1(meta, edge_index))
        x2 = F.relu(self.gcn2(x1, edge_index) + self.res_proj(x1))

        # Metadata ‚Üí pseudo-image
        meta_img = self.meta_to_image(x2).view(B, 1, 56, 56)
        meta_img = F.interpolate(meta_img, size=(224, 224), mode="bilinear", align_corners=False)

        # Early fusion (4 channels)
        x = torch.cat([img, meta_img], dim=1)

        # CNN forward
        feats = self.mobilevit.stem(x)
        feats = self.mobilevit.stages(feats)
        feats = self.post_conv(feats)
        feats = self.pool(feats).view(B, -1)

        return self.classifier(feats)
    

from torchinfo import summary


batch_size = 16
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = EarlyFusionWithGCN(input_dim_meta, num_classes).to(device)

dummy_img = torch.randn(batch_size, 3, 224, 224).to(device)
dummy_meta = torch.randn(batch_size, input_dim_meta).to(device)
dummy_batch_idx = torch.arange(batch_size).to(device)

summary(
    model,
    input_data=[dummy_img, dummy_meta, dummy_batch_idx],
    col_names=["input_size", "output_size", "num_params", "trainable"],
    col_width=20,
    depth=3
)


In [None]:
import torch
import torch.nn as nn
import time
import numpy as np
from ptflops import get_model_complexity_info

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
meta_dim = input_dim_meta
image_size = (3, 224, 224)

results_s_benchmark = []

def count_parameters(model, trainable_only=False):
    if trainable_only:
        return sum(p.numel() for p in model.parameters() if p.requires_grad)
    else:
        return sum(p.numel() for p in model.parameters())

class FusionWrapper(nn.Module):
    """
    Wraps 2-input (img, meta) models so ptflops sees a single image input.
    """
    def __init__(self, model, meta_dim):
        super().__init__()
        self.model = model
        self.meta_dim = meta_dim

    def forward(self, x):
        dummy_meta = torch.randn(x.size(0), self.meta_dim).to(x.device)
        return self.model(x, dummy_meta)

class GCN2InputWrapper(nn.Module):
    """
    Wraps EarlyFusionWithDynamicGCN so it behaves like forward(img, meta).
    Batch indices are synthesized internally.
    """
    def __init__(self, gcn_model, meta_dim):
        super().__init__()
        self.gcn_model = gcn_model
        self.meta_dim = meta_dim

    def forward(self, img, meta):
        B = meta.size(0)
        batch_idx = torch.arange(B, device=meta.device)
        return self.gcn_model(img, meta, batch_idx)

def compute_flops(model, meta_dim):
    wrapper = FusionWrapper(model, meta_dim).to(device)
    with torch.no_grad():
        flops, _ = get_model_complexity_info(
            wrapper,
            image_size,
            as_strings=False,
            print_per_layer_stat=False,
            verbose=False
        )
    return float(flops / 1e9)  # GFLOPs

def measure_gpu_latency(model, meta_dim, runs=200, warmup=30):
    if not torch.cuda.is_available():
        return None, None, None

    torch.cuda.empty_cache()
    torch.cuda.reset_peak_memory_stats()
    model.eval().to(device)

    dummy_img = torch.randn(1, *image_size, device=device)
    dummy_meta = torch.randn(1, meta_dim, device=device)

    # warmup
    for _ in range(warmup):
        _ = model(dummy_img, dummy_meta)
    torch.cuda.synchronize()

    start_evt = torch.cuda.Event(enable_timing=True)
    end_evt = torch.cuda.Event(enable_timing=True)
    times = []

    for _ in range(runs):
        start_evt.record()
        _ = model(dummy_img, dummy_meta)
        end_evt.record()
        torch.cuda.synchronize()
        times.append(start_evt.elapsed_time(end_evt))  # ms

    times = np.array(times)
    mean = float(times.mean())
    std = float(times.std())
    fps = float(1000.0 / mean)
    return mean, std, fps

def measure_cpu_latency(model, meta_dim, runs=100, warmup=20):
    model_cpu = model.cpu()
    model_cpu.eval()

    dummy_img = torch.randn(1, *image_size)
    dummy_meta = torch.randn(1, meta_dim)

    for _ in range(warmup):
        _ = model_cpu(dummy_img, dummy_meta)

    times = []
    for _ in range(runs):
        start = time.perf_counter()
        _ = model_cpu(dummy_img, dummy_meta)
        end = time.perf_counter()
        times.append((end - start) * 1000.0)

    times = np.array(times)
    mean = float(times.mean())
    std = float(times.std())
    fps = float(1000.0 / mean)
    return mean, std, fps

def benchmark_model(model, name, meta_dim, count_all_params=False):
    """
    Given a ready-to-use 2-input (img, meta) model, compute all stats.
    """
    print(f"\n======= Benchmarking: {name} =======")

    trainable_only = not count_all_params
    params_m = count_parameters(model, trainable_only=trainable_only) / 1e6
    print(f"Params: {params_m:.3f} M")

    flops_g = compute_flops(model, meta_dim)
    print(f"FLOPs: {flops_g:.3f} G")

    gpu_mean, gpu_std, gpu_fps = measure_gpu_latency(model, meta_dim)
    if gpu_mean is not None:
        print(f"GPU Latency: {gpu_mean:.3f} ¬± {gpu_std:.3f} ms  |  FPS: {gpu_fps:.1f}")
    else:
        print("GPU Latency: N/A")

    cpu_mean, cpu_std, cpu_fps = measure_cpu_latency(model, meta_dim)
    print(f"CPU Latency: {cpu_mean:.3f} ¬± {cpu_std:.3f} ms  |  FPS: {cpu_fps:.1f}")

    return {
        "model": name,
        "params_M": params_m,
        "flops_G": flops_g,
        "gpu_latency_mean_ms": gpu_mean,
        "gpu_latency_std_ms": gpu_std,
        "gpu_fps": gpu_fps,
        "cpu_latency_mean_ms": cpu_mean,
        "cpu_latency_std_ms": cpu_std,
        "cpu_fps": cpu_fps,
    }

# Benchmark s backbone student
student_raw_s = EarlyFusionWithDynamicGCN(input_dim_meta, num_classes, backbone="s")
student_model_s_wrapped = GCN2InputWrapper(student_raw_s, input_dim_meta).to(device).eval()
results_s_benchmark.append(benchmark_model(student_model_s_wrapped, "TabFusion (GCN Student - s backbone)", input_dim_meta))


student_raw_xxs = EarlyFusionWithDynamicGCN(input_dim_meta, num_classes, backbone="xxs")
student_model_xxs_wrapped = GCN2InputWrapper(student_raw_xxs, input_dim_meta).to(device).eval()
results_s_benchmark.append(benchmark_model(student_model_xxs_wrapped, "TabFusion (GCN Student - xxs backbone)", input_dim_meta))

In [None]:

student_raw_s = EarlyFusionWithGCN(input_dim_meta, 5)
student_state_s = torch.load("dermpGCN.pth", map_location="cpu")
student_raw_s.load_state_dict(student_state_s)
student_model_s_wrapped = GCN2InputWrapper(student_raw_s, input_dim_meta).to(device).eval()
results_s_benchmark.append(benchmark_model(student_model_s_wrapped, "TabFusion", input_dim_meta))