In [None]:
# =========================================================
# DERM7PT DATA LOADER FOR TRAINING NOTEBOOKS
# =========================================================
# Run this code in your training notebooks to load the preprocessed Derm7pt dataset
# Make sure the preprocessing pipeline above has been executed first!

import os
import json
import joblib
import pandas as pd
import numpy as np
from torch.utils.data import Dataset, DataLoader
from PIL import Image
from torchvision import transforms

# =========================================================
# PATHS TO PREPROCESSED DATA
# =========================================================
PREPROCESSED_DIR = r"augmented"

TRAIN_CSV = os.path.join(PREPROCESSED_DIR, "train_metadata_final.csv")
VAL_CSV   = os.path.join(PREPROCESSED_DIR, "val_metadata_final.csv")
TEST_CSV  = os.path.join(PREPROCESSED_DIR, "test_metadata_final.csv")

INFO_PATH = os.path.join(PREPROCESSED_DIR, "preprocessing_info.json")

# =========================================================
# LOAD PREPROCESSED DATA
# =========================================================
print("Loading preprocessed Derm7pt data...")

# Load CSVs
train_df = pd.read_csv(TRAIN_CSV)
val_df   = pd.read_csv(VAL_CSV)
test_df  = pd.read_csv(TEST_CSV)

# Load preprocessing info
with open(INFO_PATH, "r") as f:
    preprocessing_info = json.load(f)

categorical_cols = preprocessing_info["categorical_cols"]
label_mapping = preprocessing_info["label_mapping"]

# Extract class names from label_mapping (sorted by label index)
class_names = [k for k, v in sorted(label_mapping.items(), key=lambda x: x[1])]

print(f"\n‚úÖ Training samples:   {len(train_df)}")
print(f"‚úÖ Validation samples: {len(val_df)}")
print(f"‚úÖ Test samples:       {len(test_df)}")
print(f"\nLabel mapping: {label_mapping}")
print(f"Class names: {class_names}")

# =========================================================
# EXTRACT FEATURES AND LABELS
# =========================================================
def extract_features(df):
    """Extract image paths, metadata features, and labels from dataframe"""
    img_paths = df["ImagePath"].values
    labels = df["label"].values
    
    # Metadata features (all columns except ImagePath and label)
    metadata_cols = [col for col in df.columns if col not in ["ImagePath", "label"]]
    metadata = df[metadata_cols].values
    
    return img_paths, metadata, labels

X_train_img, X_train_meta, y_train = extract_features(train_df)
X_val_img, X_val_meta, y_val       = extract_features(val_df)
X_test_img, X_test_meta, y_test    = extract_features(test_df)

num_classes = len(label_mapping)
print(f"\nNumber of classes: {num_classes}")

# =========================================================
# PYTORCH DATASET CLASS
# =========================================================
class Derm7ptDataset(Dataset):
    """
    Custom Dataset for Derm7pt with images + metadata
    """
    def __init__(self, img_paths, metadata, labels, transform=None):
        self.img_paths = img_paths
        self.metadata = metadata
        self.labels = labels
        self.transform = transform
    
    def __len__(self):
        return len(self.labels)
    
    def __getitem__(self, idx):
        # Load image
        img_path = self.img_paths[idx]
        try:
            image = Image.open(img_path).convert("RGB")
        except Exception as e:
            # Fallback to black image if loading fails
            print(f"Warning: Failed to load {img_path}, using placeholder")
            image = Image.new("RGB", (224, 224), color="black")
        
        if self.transform:
            image = self.transform(image)
        
        # Get metadata and label
        metadata = self.metadata[idx].astype(np.float32)
        label = int(self.labels[idx])
        
        return image, metadata, label

# =========================================================
# DATA TRANSFORMS
# =========================================================
# Training transforms (with augmentation)
train_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomVerticalFlip(p=0.5),
    transforms.RandomRotation(30),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Validation/Test transforms (no augmentation)
val_test_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# =========================================================
# CREATE DATASETS
# =========================================================
train_dataset = Derm7ptDataset(X_train_img, X_train_meta, y_train, transform=train_transform)
val_dataset   = Derm7ptDataset(X_val_img, X_val_meta, y_val, transform=val_test_transform)
test_dataset  = Derm7ptDataset(X_test_img, X_test_meta, y_test, transform=val_test_transform)

print(f"\n‚úÖ Created PyTorch Datasets")
print(f"   - Train: {len(train_dataset)} samples")
print(f"   - Val:   {len(val_dataset)} samples")
print(f"   - Test:  {len(test_dataset)} samples")

# =========================================================
# CREATE DATALOADERS (EXAMPLE - ADJUST BATCH SIZE AS NEEDED)
# =========================================================
BATCH_SIZE = 32

train_loader = DataLoader(
    train_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=0,  # Set to 0 for Windows, increase for Linux/Mac
    pin_memory=True)

val_loader = DataLoader(
    val_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=0,
    pin_memory=True
)

test_loader = DataLoader(
    test_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=0,
    pin_memory=True
)

print(f"\n‚úÖ Created DataLoaders (batch_size={BATCH_SIZE})")
print(f"   - Train batches: {len(train_loader)}")
print(f"   - Val batches:   {len(val_loader)}")
print(f"   - Test batches:  {len(test_loader)}")

# =========================================================
# EXAMPLE: TEST LOADING A BATCH
# =========================================================
print("\nüîç Testing batch loading...")
for images, metadata, labels in train_loader:
    print(f"   - Image batch shape:    {images.shape}")
    print(f"   - Metadata batch shape: {metadata.shape}")
    print(f"   - Labels batch shape:   {labels.shape}")
    break

print("\n‚úÖ Derm7pt data loading complete! Ready for training.")
print("\nüí° Usage in your model:")
print("   for images, metadata, labels in train_loader:")
print("       # images: torch.Tensor of shape (batch_size, 3, 224, 224)")
print("       # metadata: torch.Tensor of shape (batch_size, num_metadata_features)")
print("       # labels: torch.Tensor of shape (batch_size,)")
print("       # Your training code here...")

In [None]:
num_classes = train_df['label'].nunique()
print("Number of classes:", num_classes)
non_feature_cols = ["label", "ImagePath"]
X_train_meta = train_df.drop(columns=non_feature_cols)
input_dim_meta = X_train_meta.shape[1]
print(f"Metadata input dimension: {input_dim_meta}")

In [None]:
from tqdm.auto import tqdm
from sklearn.metrics import classification_report, confusion_matrix, accuracy_score

def test(model, loader, device, desc="Testing"):
    model.eval()
    all_preds = []
    all_labels = []
    running_correct = 0
    running_total = 0

    with torch.no_grad():
        for imgs, metas, labels in tqdm(loader, total=len(loader), desc=desc, unit="batch"):
            imgs, metas, labels = imgs.to(device), metas.to(device), labels.to(device)
            outputs = model(imgs, metas)
            _, predicted = torch.max(outputs, 1)

            all_preds.extend(predicted.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

            running_correct += (predicted == labels).sum().item()
            running_total += labels.size(0)
            tqdm.write(f"Batch acc: {running_correct / running_total:.4f}")

    return all_labels, all_preds


<h1>MobileViT</h>

In [None]:
import torch
import torch.nn as nn
import timm  
import torch.nn.functional as F

class EarlyFusionModel(nn.Module):
    def __init__(self, input_dim_meta, num_classes):
        super().__init__()
        
        # Embed metadata to smaller spatial dimensions first
        self.meta_embed = nn.Sequential(
            nn.Linear(input_dim_meta, 64 * 64),  # Updated for mobilevit's smaller receptive field
            nn.ReLU(),
            nn.BatchNorm1d(64 * 64),
            nn.Dropout(0.3)
        )
        
        # Load MobileViT model
        self.mobilevit = timm.create_model("mobilevit_s.cvnets_in1k", pretrained=True, num_classes=num_classes)
        
        # Inspect the model to identify the first conv layer
        # Modify the first conv layer to accept additional channel
        first_conv = self.mobilevit.stem.conv  # `stem.conv` is the correct initial layer
        self.mobilevit.stem.conv = nn.Conv2d(4, first_conv.out_channels, 
                                             kernel_size=first_conv.kernel_size, 
                                             stride=first_conv.stride, 
                                             padding=first_conv.padding, 
                                             bias=first_conv.bias)
        
        # Initialize new channel weights
        with torch.no_grad():
            self.mobilevit.stem.conv.weight.data[:, :3] = first_conv.weight.data
            # Initialize the new channel with smaller weights to prevent dominating
            self.mobilevit.stem.conv.weight.data[:, 3:] = first_conv.weight.data.mean(dim=1, keepdim=True) * 0.1

    def forward(self, img, meta):
        # Reshape metadata to image-like format
        batch_size = img.shape[0]
        meta_reshaped = self.meta_embed(meta).view(batch_size, 1, 64, 64)
        
        # Upsample to match image dimensions
        meta_upsampled = F.interpolate(meta_reshaped, 
                                       size=(224, 224),  # MobileViT expects 256x256
                                       mode='bilinear', 
                                       align_corners=False)
        
        # Early fusion
        combined_input = torch.cat([img, meta_upsampled], dim=1)
        
        # Process through modified MobileViT
        out = self.mobilevit(combined_input)
        return out

# Assuming X_train_meta and other variables are defined
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = EarlyFusionModel(input_dim_meta, num_classes).to(device)

from torchinfo import summary
summary(model=model, 
        input_size=[(16, 3, 224, 224), (16, input_dim_meta)],  # Updated for MobileViT input size
        col_names=["input_size", "output_size", "num_params", "trainable"],
        col_width=20,
        row_settings=["var_names"]
)

mobilevit_model = EarlyFusionModel(input_dim_meta=input_dim_meta, num_classes=num_classes).to(device)
mobilevit_model.load_state_dict(torch.load('D:\\Dermp7\\best_early_fusion_mobilevitsmoteDA.pth'))

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
mobilevit_model.to(device)

true_labels, pred_labels = test(mobilevit_model, test_loader, device)

report = classification_report(true_labels, pred_labels, digits=4,target_names=class_names)
print("Classification Report:")
print(report)
cm = confusion_matrix(true_labels, pred_labels)

<h1>PvtV2</h>

In [None]:
import torch
import torch.nn as nn
import timm
import torch.nn.functional as F

class EarlyFusionModel(nn.Module):
    def __init__(self, input_dim_meta, num_classes):
        super().__init__()
        
        # Embed metadata to smaller spatial dimensions first
        self.meta_embed = nn.Sequential(
            nn.Linear(input_dim_meta, 56 * 56),  # Smaller initial dimension
            nn.ReLU(),
            nn.BatchNorm1d(56 * 56),
            nn.Dropout(0.3)
        )
        
        # Load PVT v2 model
        self.pvt = timm.create_model("pvt_v2_b1", pretrained=True, num_classes=num_classes)
        
        # Modify the first convolution layer to accept additional channel (4 instead of 3)
        first_conv = self.pvt.patch_embed.proj
        self.pvt.patch_embed.proj = nn.Conv2d(4, first_conv.out_channels, 
                                              kernel_size=first_conv.kernel_size,
                                              stride=first_conv.stride,
                                              padding=first_conv.padding,
                                              bias=first_conv.bias is not None)
        
        # Initialize new channel weights
        with torch.no_grad():
            self.pvt.patch_embed.proj.weight.data[:, :3] = first_conv.weight.data
            self.pvt.patch_embed.proj.weight.data[:, 3:] = first_conv.weight.data.mean(dim=1, keepdim=True) * 0.1

    def forward(self, img, meta):
        # Reshape metadata to image-like format
        batch_size = img.shape[0]
        meta_reshaped = self.meta_embed(meta).view(batch_size, 1, 56, 56)
        
        # Upsample to match image dimensions
        meta_upsampled = F.interpolate(meta_reshaped, 
                                       size=(224, 224), 
                                       mode='bilinear', 
                                       align_corners=False)
        
        # Early fusion
        combined_input = torch.cat([img, meta_upsampled], dim=1)
        
        # Process through modified PVT
        out = self.pvt(combined_input)
        return out

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = EarlyFusionModel(input_dim_meta, num_classes).to(device)

from torchinfo import summary
summary(model=model, 
        input_size=[(16, 3, 224, 224), (16, input_dim_meta)],  
        col_names=["input_size", "output_size", "num_params", "trainable"],
        col_width=20,
        row_settings=["var_names"]
)

pv2_model = EarlyFusionModel(input_dim_meta=input_dim_meta, num_classes=num_classes).to(device)
pv2_model.load_state_dict(torch.load('D:\\Dermp7\\best_early_fusion_pvtv2smoteDA.pth'))

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
pv2_model.to(device)

true_labels, pred_labels = test(pv2_model, test_loader, device)

report = classification_report(true_labels, pred_labels, digits=4,target_names=class_names)
print("Classification Report:")
print(report)
cm = confusion_matrix(true_labels, pred_labels)

<h1>Teacher Model (Mean Averaging)</h1>

In [None]:
import torch
import torch.nn as nn
import numpy as np
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score

class TeacherModel(nn.Module):
    def __init__(self, models, ensemble_method="mean"):
        """
        Teacher Model using Ensemble Learning.

        Args:
            models (list): List of trained models to use for ensembling.
            ensemble_method (str): "mean" for averaging logits, "vote" for majority voting.
        """
        super(TeacherModel, self).__init__()
        self.models = models
        self.ensemble_method = ensemble_method

        # Ensure all models are in eval mode and no gradients are computed
        for model in self.models:
            model.eval()
            for param in model.parameters():
                param.requires_grad = False

    def forward(self, img, meta):
        """
        Forward pass through the ensemble teacher model.

        Args:
            img (torch.Tensor): Batch of images.
            meta (torch.Tensor): Batch of metadata.

        Returns:
            torch.Tensor: The ensembled output (soft probabilities).
        """
        model_outputs = []

        with torch.no_grad():  # Disable gradient computation for teacher
            for model in self.models:
                outputs = model(img, meta)
                model_outputs.append(outputs)

        # Convert list to tensor shape [num_models, batch_size, num_classes]
        model_outputs = torch.stack(model_outputs, dim=0)

        if self.ensemble_method == "mean":
            # Soft-label generation: Averaging logits
            avg_outputs = model_outputs.mean(dim=0)  
        elif self.ensemble_method == "vote":
            # Majority voting: Get the most common prediction
            _, predictions = torch.max(model_outputs, dim=2) 
            avg_outputs = predictions.mode(dim=0).values  

        return avg_outputs 

teacher_model = TeacherModel(models=[mobilevit_model, pv2_model], ensemble_method="mean")

# Move to the correct device (CPU/GPU)
teacher_model = teacher_model.to(device)

true_labels, pred_labels = test(teacher_model, test_loader, device)

report = classification_report(true_labels, pred_labels, digits=4,target_names=class_names)
print("Classification Report:")
print(report)
cm = confusion_matrix(true_labels, pred_labels)

<h1>Knowledge Distillation on Student Model</h1>

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
from torch_cluster import knn_graph
import timm
from tqdm import tqdm

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


# =========================================================
# IMPROVED MODEL (A FIXED TO MATCH B'S BEHAVIOR)
# =========================================================
class EarlyFusionWithGCN(nn.Module):
    def __init__(self, input_dim_meta, num_classes, k=8):
        super().__init__()
        self.k = k

        # --- GCN Layers ---
        self.gcn1 = GCNConv(input_dim_meta, 64)
        self.gcn2 = GCNConv(64, 32)
        self.res_proj = nn.Linear(64, 32)

        # --- metadata ‚Üí pseudo image ---
        self.meta_to_image = nn.Sequential(
            nn.Linear(32, 56 * 56),
            nn.ReLU(),
            nn.BatchNorm1d(56 * 56),
            nn.Dropout(0.3)
        )

        # --- MobileViT backbone ---
        self.mobilevit = timm.create_model(
            "mobilevit_s.cvnets_in1k",
            pretrained=True,
            num_classes=0
        )

        # --- Modify first conv to accept 4 channels ---
        stem_conv = self.mobilevit.stem.conv
        new_conv = nn.Conv2d(
            4, stem_conv.out_channels,
            kernel_size=stem_conv.kernel_size,
            stride=stem_conv.stride,
            padding=stem_conv.padding,
            bias=stem_conv.bias is not None
        )

        with torch.no_grad():
            # copy RGB weights
            new_conv.weight[:, :3] = stem_conv.weight
            # tiny weight for metadata channel
            new_conv.weight[:, 3:] = stem_conv.weight.mean(dim=1, keepdim=True) * 0.1
            # copy bias if exists
            if stem_conv.bias is not None:
                new_conv.bias = stem_conv.bias.clone()

        self.mobilevit.stem.conv = new_conv

        # keep only first 4 stages
        self.mobilevit.stages = nn.Sequential(
            *list(self.mobilevit.stages.children())[:4]
        )
        self.mobilevit.final_conv = nn.Identity()
        self.mobilevit.head = nn.Identity()

        # --- Post Conv ---
        self.post_conv = nn.Sequential(
            nn.Conv2d(128, 160, kernel_size=1, bias=False),
            nn.BatchNorm2d(160),
            nn.ReLU(inplace=True)
        )

        # --- Classifier ---
        self.pool = nn.AdaptiveAvgPool2d((1, 1))
        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(160, 256),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(256, num_classes)
        )

    def forward(self, img, meta, batch_idx):
        B = meta.size(0)

        # CORRECT: dynamic kNN graph WITHOUT self-loops
        edge_index = knn_graph(meta, k=self.k, batch=batch_idx)

        # GCN + residual
        x1 = F.relu(self.gcn1(meta, edge_index))
        x2 = F.relu(self.gcn2(x1, edge_index) + self.res_proj(x1))

        # Metadata ‚Üí pseudo-image
        meta_img = self.meta_to_image(x2).view(B, 1, 56, 56)
        meta_img = F.interpolate(meta_img, size=(224, 224), mode="bilinear", align_corners=False)

        # Early fusion (4 channels)
        x = torch.cat([img, meta_img], dim=1)

        # CNN forward
        feats = self.mobilevit.stem(x)
        feats = self.mobilevit.stages(feats)
        feats = self.post_conv(feats)
        feats = self.pool(feats).view(B, -1)

        return self.classifier(feats)
    

from torchinfo import summary


batch_size = 16
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = EarlyFusionWithGCN(input_dim_meta, num_classes).to(device)

dummy_img = torch.randn(batch_size, 3, 224, 224).to(device)
dummy_meta = torch.randn(batch_size, input_dim_meta).to(device)
dummy_batch_idx = torch.arange(batch_size).to(device)

summary(
    model,
    input_data=[dummy_img, dummy_meta, dummy_batch_idx],
    col_names=["input_size", "output_size", "num_params", "trainable"],
    col_width=20,
    depth=3
)


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim.lr_scheduler import ReduceLROnPlateau
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score


# =========================================================
# Teacher Model (Assumed Already Defined and Loaded)
# =========================================================
teacher_model = TeacherModel(models=[mobilevit_model, pv2_model], ensemble_method="mean").to(device)


# =========================================================
# Utility Functions
# =========================================================
def set_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)


def evaluate_student(model, test_loader, device):
    model.eval()
    all_labels, all_preds = [], []

    with torch.no_grad():
        for images, metas, labels in test_loader:
            images, metas, labels = images.to(device), metas.to(device), labels.to(device)
            batch_indices = torch.arange(metas.size(0)).to(device).long()
            outputs = model(images, metas, batch_indices)
            preds = outputs.argmax(dim=1)
            all_labels.extend(labels.cpu().numpy())
            all_preds.extend(preds.cpu().numpy())

    all_labels = np.array(all_labels)
    all_preds = np.array(all_preds)

    accuracy = accuracy_score(all_labels, all_preds)
    f1 = f1_score(all_labels, all_preds, average='macro')
    precision = precision_score(all_labels, all_preds, average='macro', zero_division=0)
    recall = recall_score(all_labels, all_preds, average='macro', zero_division=0)

    print(f"Test Accuracy: {accuracy:.4f}")
    print(f"Test F1 Score: {f1:.4f}")
    print(f"Test Precision: {precision:.4f}")
    print(f"Test Recall: {recall:.4f}")

    return accuracy, f1, precision, recall


# =========================================================
# NEW: Averaged Training Curve Plotting (FIXED)
# =========================================================
def plot_average_training_curves(all_histories, save_path="averaged_training_curves.png"):
    import numpy as np
    import matplotlib.pyplot as plt

    # ---- Set global bold font ----
    plt.rcParams['font.weight'] = 'bold'
    plt.rcParams['axes.labelweight'] = 'bold'
    plt.rcParams['axes.titleweight'] = 'bold'
    plt.rcParams['xtick.labelsize'] = 12
    plt.rcParams['ytick.labelsize'] = 12

    # ---- FIXED: Pad histories to the same length ----
    max_epochs = max(len(h) for h in all_histories["train_loss"])
    
    def pad_history(history_list, max_len):
        """Pad each history to max_len using the last value"""
        padded = []
        for hist in history_list:
            if len(hist) < max_len:
                # Pad with the last value
                padded.append(hist + [hist[-1]] * (max_len - len(hist)))
            else:
                padded.append(hist)
        return np.array(padded)
    
    train_loss = pad_history(all_histories["train_loss"], max_epochs)
    val_loss = pad_history(all_histories["val_loss"], max_epochs)
    train_acc = pad_history(all_histories["train_acc"], max_epochs)
    val_acc = pad_history(all_histories["val_acc"], max_epochs)

    epochs = np.arange(1, max_epochs + 1)

    plt.figure(figsize=(12, 5))

    # -------- Loss subplot --------
    plt.subplot(1, 2, 1)

    plt.plot(epochs, train_loss.mean(axis=0), label="Train Loss")
    plt.fill_between(
        epochs,
        train_loss.mean(axis=0) - train_loss.std(axis=0),
        train_loss.mean(axis=0) + train_loss.std(axis=0),
        alpha=0.25
    )

    plt.plot(epochs, val_loss.mean(axis=0), label="Validation Loss")
    plt.fill_between(
        epochs,
        val_loss.mean(axis=0) - val_loss.std(axis=0),
        val_loss.mean(axis=0) + val_loss.std(axis=0),
        alpha=0.25
    )

    plt.xlabel("Epochs", fontweight="bold")
    plt.ylabel("Loss", fontweight="bold")
    plt.title("Training and Validation Loss (Averaged Across Runs)", fontweight="bold")
    plt.legend()

    # -------- Accuracy subplot --------
    plt.subplot(1, 2, 2)

    plt.plot(epochs, train_acc.mean(axis=0), label="Train Accuracy")
    plt.fill_between(
        epochs,
        train_acc.mean(axis=0) - train_acc.std(axis=0),
        train_acc.mean(axis=0) + train_acc.std(axis=0),
        alpha=0.25
    )

    plt.plot(epochs, val_acc.mean(axis=0), label="Validation Accuracy")
    plt.fill_between(
        epochs,
        val_acc.mean(axis=0) - val_acc.std(axis=0),
        val_acc.mean(axis=0) + val_acc.std(axis=0),
        alpha=0.25
    )

    plt.xlabel("Epochs", fontweight="bold")
    plt.ylabel("Accuracy", fontweight="bold")
    plt.title("Training and Validation Accuracy (Averaged Across Runs)", fontweight="bold")
    plt.legend()

    plt.tight_layout()

    # ---- SAVE AT 650 DPI ----
    plt.savefig(save_path, dpi=650, bbox_inches="tight")

    plt.show()


# =========================================================
# Training with Knowledge Distillation
# =========================================================
def train_student_model_kd(student_model, teacher_model, train_loader, val_loader, test_loader,
                           device, alpha=0.5, temperature=3.0, epochs=100, patience=10):

    student_model.to(device)
    teacher_model.eval()
    criterion = nn.CrossEntropyLoss()
    kl_div_loss = nn.KLDivLoss(reduction='batchmean')
    optimizer = torch.optim.Adam(student_model.parameters(), lr=0.001)
    scheduler = ReduceLROnPlateau(optimizer, mode='min', patience=5, verbose=True)

    best_val_accuracy = 0.0
    best_val_model_state = None
    patience_counter = 0

    history = {"train_loss": [], "val_loss": [], "train_acc": [], "val_acc": []}

    for epoch in range(epochs):
        student_model.train()
        train_loss, correct, total = 0.0, 0, 0
        pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs}")

        for images, metas, labels in pbar:
            images, metas, labels = images.to(device), metas.to(device), labels.to(device)
            batch_indices = torch.arange(metas.size(0)).to(device)

            optimizer.zero_grad()

            student_outputs = student_model(images, metas, batch_indices)
            with torch.no_grad():
                teacher_outputs = teacher_model(images, metas)

            # --- Hard loss ---
            loss_hard = criterion(student_outputs, labels)

            # --- Soft loss (KD) ---
            loss_soft = kl_div_loss(
                F.log_softmax(student_outputs / temperature, dim=1),
                F.softmax(teacher_outputs / temperature, dim=1)
            )

            loss = (1 - alpha) * loss_hard + alpha * (temperature ** 2) * loss_soft
            loss.backward()
            optimizer.step()

            train_loss += loss.item()
            _, predicted = student_outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()

        train_accuracy = correct / total
        history["train_loss"].append(train_loss)
        history["train_acc"].append(train_accuracy)

        # ---- Validation ----
        student_model.eval()
        val_loss, correct, total = 0.0, 0, 0
        with torch.no_grad():
            for images, metas, labels in val_loader:
                images, metas, labels = images.to(device), metas.to(device), labels.to(device)
                batch_indices = torch.arange(metas.size(0)).to(device)
                outputs = student_model(images, metas, batch_indices)
                loss = criterion(outputs, labels)
                val_loss += loss.item()
                _, predicted = outputs.max(1)
                correct += predicted.eq(labels).sum().item()
                total += labels.size(0)

        val_accuracy = correct / total
        history["val_loss"].append(val_loss)
        history["val_acc"].append(val_accuracy)

        print(f"Epoch {epoch+1}: Train Loss={train_loss:.4f}, Train Acc={train_accuracy:.4f}")
        print(f"Epoch {epoch+1}: Val Loss={val_loss:.4f}, Val Acc={val_accuracy:.4f}")

        # ---- Early Stopping ----
        if val_accuracy > best_val_accuracy:
            best_val_accuracy = val_accuracy
            best_val_model_state = student_model.state_dict()
            patience_counter = 0
        else:
            patience_counter += 1

        if patience_counter >= patience:
            print("Early stopping triggered.")
            break

        scheduler.step(val_loss)

    return best_val_model_state, history, best_val_accuracy


# =========================================================
# Main Experiment Loop (MULTI-SEED RUNS)
# =========================================================

seeds = [42, 123, 569]
best_overall_model = None
best_overall_accuracy = 0.0

results = {"accuracy": [], "f1": [], "precision": [], "recall": []}

# NEW: store histories for averaging
all_histories = {"train_loss": [], "val_loss": [], "train_acc": [], "val_acc": []}

for seed in seeds:
    print(f"\n--- Training with Seed {seed} ---")
    set_seed(seed)

    student_model = EarlyFusionWithGCN(input_dim_meta, num_classes).to(device)

    best_model_state, history, val_acc = train_student_model_kd(
        student_model=student_model,
        teacher_model=teacher_model,
        train_loader=train_loader,
        val_loader=val_loader,
        test_loader=test_loader,
        device=device,
        alpha=0.2,
        temperature=9.0,
        epochs=100,
        patience=10
    )

    # Store history for averaged curves
    all_histories["train_loss"].append(history["train_loss"])
    all_histories["val_loss"].append(history["val_loss"])
    all_histories["train_acc"].append(history["train_acc"])
    all_histories["val_acc"].append(history["val_acc"])

    # ---- Final Test ----
    student_model.load_state_dict(best_model_state)
    acc, f1, prec, recall = evaluate_student(student_model, test_loader, device)

    results["accuracy"].append(acc)
    results["f1"].append(f1)
    results["precision"].append(prec)
    results["recall"].append(recall)

    if val_acc > best_overall_accuracy:
        best_overall_accuracy = val_acc
        best_overall_model = student_model

# ---- Save Best Model ----
torch.save(best_overall_model.state_dict(), "dermpGCN.pth")
print(f"\nBest Val Accuracy Model Saved (Acc={best_overall_accuracy:.4f})")

# ---- Summary ----
print("\n--- Final Evaluation Across Seeds ---")
for metric in results:
    print(f"{metric.capitalize()}: {np.mean(results[metric]):.4f} ¬± {np.std(results[metric]):.4f}")

# ---- PLOT AVERAGED TRAINING CURVES ----
plot_average_training_curves(all_histories)


In [None]:
def plot_average_training_curves(all_histories, save_path="averaged_training_curves.png"):
    import numpy as np
    import matplotlib.pyplot as plt

    # ---- Set global bold font ----
    plt.rcParams['font.weight'] = 'bold'
    plt.rcParams['axes.labelweight'] = 'bold'
    plt.rcParams['axes.titleweight'] = 'bold'
    plt.rcParams['xtick.labelsize'] = 12
    plt.rcParams['ytick.labelsize'] = 12

    # ---- FIXED: Pad histories to the same length ----
    max_epochs = max(len(h) for h in all_histories["train_loss"])
    
    def pad_history(history_list, max_len):
        """Pad each history to max_len using the last value"""
        padded = []
        for hist in history_list:
            if len(hist) < max_len:
                # Pad with the last value
                padded.append(hist + [hist[-1]] * (max_len - len(hist)))
            else:
                padded.append(hist)
        return np.array(padded)
    
    train_loss = pad_history(all_histories["train_loss"], max_epochs)
    val_loss = pad_history(all_histories["val_loss"], max_epochs)
    train_acc = pad_history(all_histories["train_acc"], max_epochs)
    val_acc = pad_history(all_histories["val_acc"], max_epochs)

    epochs = np.arange(1, max_epochs + 1)

    plt.figure(figsize=(12, 5))

    # -------- Loss subplot --------
    plt.subplot(1, 2, 1)

    plt.plot(epochs, train_loss.mean(axis=0), label="Train Loss")
    plt.fill_between(
        epochs,
        train_loss.mean(axis=0) - train_loss.std(axis=0),
        train_loss.mean(axis=0) + train_loss.std(axis=0),
        alpha=0.25
    )

    plt.plot(epochs, val_loss.mean(axis=0), label="Validation Loss")
    plt.fill_between(
        epochs,
        val_loss.mean(axis=0) - val_loss.std(axis=0),
        val_loss.mean(axis=0) + val_loss.std(axis=0),
        alpha=0.25
    )

    plt.xlabel("Epochs", fontweight="bold")
    plt.ylabel("Loss", fontweight="bold")
    plt.title("Training and Validation Loss (Averaged Across Runs)", fontweight="bold")
    plt.legend()

    # -------- Accuracy subplot --------
    plt.subplot(1, 2, 2)

    plt.plot(epochs, train_acc.mean(axis=0), label="Train Accuracy")
    plt.fill_between(
        epochs,
        train_acc.mean(axis=0) - train_acc.std(axis=0),
        train_acc.mean(axis=0) + train_acc.std(axis=0),
        alpha=0.25
    )

    plt.plot(epochs, val_acc.mean(axis=0), label="Validation Accuracy")
    plt.fill_between(
        epochs,
        val_acc.mean(axis=0) - val_acc.std(axis=0),
        val_acc.mean(axis=0) + val_acc.std(axis=0),
        alpha=0.25
    )

    plt.xlabel("Epochs", fontweight="bold")
    plt.ylabel("Accuracy", fontweight="bold")
    plt.title("Training and Validation Accuracy (Averaged Across Runs)", fontweight="bold")
    plt.legend()

    plt.tight_layout()

    # ---- SAVE AT 650 DPI ----
    plt.savefig(save_path, dpi=650, bbox_inches="tight")

    plt.show()

plot_average_training_curves(all_histories)

In [None]:
import re
import numpy as np
import matplotlib.pyplot as plt

# ------------------------------------------------------------
# 1. FUNCTION TO PARSE ONE LOG FILE
# ------------------------------------------------------------
def parse_log(path):
    metrics = {
        "epoch": [],
        "train_loss": [],
        "train_acc": [],
        "val_loss": [],
        "val_acc": [],
    }

    with open(path, "r") as f:
        for line in f:
            line = line.strip()

            # Train loss + acc
            m_train = re.search(
                r"Epoch\s+(\d+):\s+Train Loss:\s+([0-9.]+),\s+Train Acc:\s+([0-9.]+)",
                line
            )
            if m_train:
                metrics["epoch"].append(int(m_train.group(1)))
                metrics["train_loss"].append(float(m_train.group(2)))
                metrics["train_acc"].append(float(m_train.group(3)))
                continue

            # Val loss + acc
            m_val = re.search(
                r"Epoch\s+(\d+):\s+Val Loss:\s+([0-9.]+),\s+Val Acc:\s+([0-9.]+)",
                line
            )
            if m_val:
                metrics["val_loss"].append(float(m_val.group(2)))
                metrics["val_acc"].append(float(m_val.group(3)))
                continue

    return metrics


# ------------------------------------------------------------
# 2. LOAD BOTH SEEDS
# ------------------------------------------------------------
log_files = {
    "42": r"C:\Users\User\MDY Research\With Augmentation\Concatenation\Adasyn\seed_42.txt",
    "569": r"C:\Users\User\MDY Research\With Augmentation\Concatenation\Adasyn\seed_569.txt",
}

histories = {}

for seed, path in log_files.items():
    histories[seed] = parse_log(path)
    print(f"Seed {seed}: {len(histories[seed]['epoch'])} epochs loaded")


# ------------------------------------------------------------
# 3. ALIGN EPOCHS ‚Äî truncate to minimum run length
# ------------------------------------------------------------
min_epochs = min(len(h["epoch"]) for h in histories.values())
print(f"\nUsing {min_epochs} epochs (minimum across seeds)")

seed_ids = sorted(histories.keys())
num_seeds = len(seed_ids)

train_loss = np.zeros((num_seeds, min_epochs))
val_loss   = np.zeros((num_seeds, min_epochs))
train_acc  = np.zeros((num_seeds, min_epochs))
val_acc    = np.zeros((num_seeds, min_epochs))

for i, sid in enumerate(seed_ids):
    h = histories[sid]
    train_loss[i] = np.array(h["train_loss"][:min_epochs])
    val_loss[i]   = np.array(h["val_loss"][:min_epochs])
    train_acc[i]  = np.array(h["train_acc"][:min_epochs])
    val_acc[i]    = np.array(h["val_acc"][:min_epochs])


# ------------------------------------------------------------
# 4. COMPUTE MEANS & STANDARD DEVIATIONS
# ------------------------------------------------------------
tl_mean, tl_std = train_loss.mean(0), train_loss.std(0)
vl_mean, vl_std = val_loss.mean(0),   val_loss.std(0)

ta_mean, ta_std = train_acc.mean(0),  train_acc.std(0)
va_mean, va_std = val_acc.mean(0),    val_acc.std(0)

epochs = np.arange(1, min_epochs + 1)


# ------------------------------------------------------------
# 5. PLOT COMBINED FIGURE (LOSS + ACCURACY)
# ------------------------------------------------------------
plt.rcParams['font.weight'] = 'bold'
plt.rcParams['axes.labelweight'] = 'bold'
plt.rcParams['axes.titleweight'] = 'bold'
plt.rcParams['xtick.labelsize'] = 12
plt.rcParams['ytick.labelsize'] = 12

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))

# Loss subplot
ax1.plot(epochs, tl_mean, label="Train Loss", linewidth=2)
ax1.fill_between(epochs, tl_mean - tl_std, tl_mean + tl_std, alpha=0.25)

ax1.plot(epochs, vl_mean, label="Validation Loss", linewidth=2)
ax1.fill_between(epochs, vl_mean - vl_std, vl_mean + vl_std, alpha=0.25)

ax1.set_xlabel("Epochs", fontweight="bold")
ax1.set_ylabel("Loss", fontweight="bold")
ax1.set_title("Training and Validation Loss (Averaged Across Seeds)", fontweight="bold")
ax1.legend()
ax1.grid(True, alpha=0.3)

# Accuracy subplot
ax2.plot(epochs, ta_mean, label="Train Accuracy", linewidth=2)
ax2.fill_between(epochs, ta_mean - ta_std, ta_mean + ta_std, alpha=0.25)

ax2.plot(epochs, va_mean, label="Validation Accuracy", linewidth=2)
ax2.fill_between(epochs, va_mean - va_std, va_mean + va_std, alpha=0.25)

ax2.set_xlabel("Epochs", fontweight="bold")
ax2.set_ylabel("Accuracy", fontweight="bold")
ax2.set_title("Training and Validation Accuracy (Averaged Across Seeds)", fontweight="bold")
ax2.legend()
ax2.grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig("training_curves_from_logs.png", dpi=650, bbox_inches="tight")
plt.show()

print("\n‚úÖ Saved: training_curves_from_logs.png")
print(f"üìä Plotted {min_epochs} epochs from seeds: {', '.join(seed_ids)}")

# Parse Log Files and Create Training Curves

In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import classification_report, confusion_matrix, roc_curve, auc, precision_recall_curve
import torch.nn.functional as F

# Ensure the teacher model is in eval mode
teacher_model.eval()

# Collect true labels and predictions
all_labels = []
all_preds = []
all_probs = []

with torch.no_grad():
    for images, metas, labels in test_loader:
        images, metas, labels = images.to(device), metas.to(device), labels.to(device)

        outputs = teacher_model(images, metas)
        probs = F.softmax(outputs, dim=1) if outputs.dim() == 2 else outputs
        preds = probs.argmax(dim=1)

        all_labels.extend(labels.cpu().numpy())
        all_preds.extend(preds.cpu().numpy())
        all_probs.extend(probs.cpu().numpy())

all_labels = np.array(all_labels)
all_preds = np.array(all_preds)
all_probs = np.array(all_probs)

# Compute classification report
class_report = classification_report(all_labels, all_preds, target_names=class_names, digits=4)

# Compute normalized confusion matrix
conf_matrix = confusion_matrix(all_labels, all_preds, normalize="true")

# Display classification report
print("\nClassification Report:\n")
print(class_report)

# Display confusion matrix (Blues colormap)
plt.figure(figsize=(8, 6))
sns.heatmap(conf_matrix, annot=True, cmap="Blues", fmt=".2f", xticklabels=class_names, yticklabels=class_names, cbar=True)
plt.xlabel("Predicted Label")
plt.ylabel("True Label")
plt.title("Normalized Confusion Matrix")
plt.show()

# Compute and plot ROC-AUC curve for each class
plt.figure(figsize=(10, 6))

for i, class_name in enumerate(class_names):
    fpr, tpr, _ = roc_curve((all_labels == i).astype(int), all_probs[:, i])
    roc_auc = auc(fpr, tpr)
    plt.plot(fpr, tpr, label=f"{class_name} (AUC = {roc_auc:.2f})")

plt.plot([0, 1], [0, 1], "k--")  # Diagonal line for reference
plt.xlabel("False Positive Rate")
plt.ylabel("True Positive Rate")
plt.title("ROC-AUC Curve")
plt.legend(loc="lower right")
plt.show()

# Compute and plot Precision-Recall Curve
plt.figure(figsize=(10, 6))

for i, class_name in enumerate(class_names):
    precision, recall, _ = precision_recall_curve((all_labels == i).astype(int), all_probs[:, i])
    plt.plot(recall, precision, label=f"{class_name}")

plt.xlabel("Recall")
plt.ylabel("Precision")
plt.title("Precision-Recall Curve")
plt.legend()
plt.show()


In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import classification_report, confusion_matrix, roc_curve, auc, precision_recall_curve
import torch.nn.functional as F

# Load the best model
best_model = EarlyFusionWithGCN(input_dim_meta, num_classes).to(device)
best_model.load_state_dict(torch.load("dermpGCN.pth"))
best_model.eval()

all_labels = []
all_preds = []
all_probs = []

with torch.no_grad():
    for images, metas, labels in test_loader:
        images, metas, labels = images.to(device), metas.to(device), labels.to(device)
        batch_indices = torch.arange(metas.size(0)).to(device).long()
        outputs = student_model(images, metas, batch_indices)        
        probs = F.softmax(outputs, dim=1)
        preds = probs.argmax(dim=1)

        all_labels.extend(labels.cpu().numpy())
        all_preds.extend(preds.cpu().numpy())
        all_probs.extend(probs.cpu().numpy())

all_labels = np.array(all_labels)
all_preds = np.array(all_preds)
all_probs = np.array(all_probs)

# Compute classification report
class_report = classification_report(all_labels, all_preds, target_names=class_names, digits=4)

# Compute normalized confusion matrix
conf_matrix = confusion_matrix(all_labels, all_preds, normalize="true")

# Display classification report
print("\nClassification Report:\n")
print(class_report)

# Display confusion matrix (black and white)
plt.figure(figsize=(8, 6))
sns.heatmap(conf_matrix, annot=True, cmap="gray", fmt=".2f", xticklabels=class_names, yticklabels=class_names, cbar=True)
plt.xlabel("Predicted Label")
plt.ylabel("True Label")
plt.title("Normalized Confusion Matrix")
plt.show()

# Compute and plot ROC-AUC curve for each class
plt.figure(figsize=(10, 6))

for i, class_name in enumerate(class_names):
    fpr, tpr, _ = roc_curve((all_labels == i).astype(int), all_probs[:, i])
    roc_auc = auc(fpr, tpr)
    plt.plot(fpr, tpr, label=f"{class_name} (AUC = {roc_auc:.2f})")

plt.plot([0, 1], [0, 1], "k--")  # Diagonal line for reference
plt.xlabel("False Positive Rate")
plt.ylabel("True Positive Rate")
plt.title("ROC-AUC Curve")
plt.legend(loc="lower right")
plt.show()

# Optional: Compute and plot Precision-Recall Curve
plt.figure(figsize=(10, 6))

for i, class_name in enumerate(class_names):
    precision, recall, _ = precision_recall_curve((all_labels == i).astype(int), all_probs[:, i])
    plt.plot(recall, precision, label=f"{class_name}")

plt.xlabel("Recall")
plt.ylabel("Precision")
plt.title("Precision-Recall Curve")
plt.legend()
plt.show()