In [None]:
import os
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models
from sklearn.preprocessing import StandardScaler, LabelEncoder
from sklearn.model_selection import train_test_split
from PIL import Image
from sklearn.metrics import classification_report

import random
import numpy as np
import torch

# Set random seeds for reproducibility
seed = 42
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)  # If you are using CUDA
torch.backends.cudnn.deterministic = True  # For deterministic results
torch.backends.cudnn.benchmark = False  # For consistency across different environments

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

IMAGE_DIR = 'D:\\PAD-UFES\\images'  
METADATA_PATH = 'D:\\PAD-UFES\\metadata.csv'

metadata = pd.read_csv(METADATA_PATH)

def preprocess_metadata(metadata):
    metadata = metadata.fillna('UNK')

    boolean_cols = [
        'smoke',
        'drink',
        'pesticide',
        'skin_cancer_history',
        'cancer_history',
        'has_piped_water',
        'has_sewage_system',
        'itch',
        'grew',
        'hurt',
        'changed',
        'bleed',
        'elevation',
        'biopsed',
    ]
    # Ensure columns are strings and lowercase
    for col in boolean_cols:
        metadata[col] = metadata[col].astype(str).str.lower()
    
    # Map boolean columns to 1/0/-1
    boolean_mapping = {'true': 1, 'false': 0, 'unk': -1}
    for col in boolean_cols:
        metadata[col] = metadata[col].map(boolean_mapping)
    
    # Handle categorical variables
    categorical_cols = [
        'background_father',
        'background_mother',
        'gender',
        'region',
        'diagnostic',
    ]
    # Convert categorical columns to string
    for col in categorical_cols:
        metadata[col] = metadata[col].astype(str)
    
    # One-hot encode categorical variables
    metadata_encoded = pd.get_dummies(metadata[categorical_cols])
    
    # Normalize numerical variables
    numerical_cols = ['age', 'fitspatrick', 'diameter_1', 'diameter_2']
    # Ensure numerical columns are numeric
    for col in numerical_cols:
        metadata[col] = pd.to_numeric(metadata[col], errors='coerce')
    # Fill NaNs in numerical columns with the mean
    metadata[numerical_cols] = metadata[numerical_cols].fillna(metadata[numerical_cols].mean())
    # Scale numerical columns
    scaler = StandardScaler()
    metadata_numeric = metadata[numerical_cols]
    metadata_numeric_scaled = pd.DataFrame(
        scaler.fit_transform(metadata_numeric), columns=numerical_cols
    )
    
    # Combine all metadata features
    metadata_processed = pd.concat(
        [metadata_numeric_scaled.reset_index(drop=True),
         metadata_encoded.reset_index(drop=True),
         metadata[boolean_cols].reset_index(drop=True)], axis=1
    )
    
    return metadata_processed

# Preprocess metadata
metadata_processed = preprocess_metadata(metadata)

def get_image_paths(metadata, image_dir):
    image_paths = []
    for idx, row in metadata.iterrows():
        filename = row['img_id']
        # Ensure filename is a string
        filename = str(filename)
        # Check if filename has an extension
        if not filename.endswith(('.jpg', '.jpeg', '.png')):
            # Try common extensions
            possible_extensions = ['.jpg', '.jpeg', '.png']
            found = False
            for ext in possible_extensions:
                filepath = os.path.join(image_dir, filename + ext)
                if os.path.isfile(filepath):
                    image_paths.append(filepath)
                    found = True
                    break
            if not found:
                print(f"Image file not found for ID: {filename}")
                image_paths.append(None)
        else:
            filepath = os.path.join(image_dir, filename)
            if os.path.isfile(filepath):
                image_paths.append(filepath)
            else:
                print(f"Image file not found: {filepath}")
                image_paths.append(None)
    metadata['ImagePath'] = image_paths
    return metadata

metadata = get_image_paths(metadata, IMAGE_DIR)

# Remove entries with missing images
metadata = metadata[metadata['ImagePath'].notnull()]
metadata_processed = metadata_processed.loc[metadata.index].reset_index(drop=True)
metadata = metadata.reset_index(drop=True)

# Drop diagnostic-related columns from features
diagnostic_cols = ['diagnostic_ACK', 'diagnostic_BCC', 'diagnostic_MEL', 'diagnostic_NEV', 'diagnostic_SCC', 'diagnostic_SEK']
metadata_processed = metadata_processed.drop(columns=diagnostic_cols)

label_encoder = LabelEncoder()
y_encoded = label_encoder.fit_transform(metadata['diagnostic'])
num_classes = len(label_encoder.classes_)

# Split data into features and labels
X_meta = metadata_processed.reset_index(drop=True)
X_img_paths = metadata['ImagePath'].reset_index(drop=True)
y = pd.Series(y_encoded)

X_train_meta, X_temp_meta, X_train_img_paths, X_temp_img_paths, y_train, y_temp = train_test_split(
    X_meta,
    X_img_paths,
    y,
    test_size=0.2,
    random_state=42,
    stratify=y
)

X_val_meta, X_test_meta, X_val_img_paths, X_test_img_paths, y_val, y_test = train_test_split(
    X_temp_meta,
    X_temp_img_paths,
    y_temp,
    test_size=0.5,
    random_state=42,
    stratify=y_temp
)

# Load augmented metadata + image paths
aug_meta_df   = pd.read_csv("D:/PAD-UFES/augmented_metadata.csv")
aug_labels_df = pd.read_csv("D:/PAD-UFES/augmented_labels.csv")

# Combine augmented samples with training set
X_train_meta_final = pd.concat([X_train_meta, aug_meta_df], ignore_index=True)
X_train_img_paths_final = pd.concat([X_train_img_paths.reset_index(drop=True),
                                     aug_labels_df['ImagePath']], ignore_index=True)
y_train_final = pd.concat([y_train.reset_index(drop=True),
                           aug_labels_df['Label']], ignore_index=True)

class PADUFESDataset(Dataset):
    def __init__(self, img_paths, meta_data, labels, transform=None):
        self.img_paths = img_paths.reset_index(drop=True)
        self.meta_data = meta_data.reset_index(drop=True)
        self.labels = pd.Series(labels).reset_index(drop=True)
        self.transform = transform

    def __len__(self):
        return len(self.img_paths)

    def __getitem__(self, idx):
        img_path = self.img_paths[idx]
        image = Image.open(img_path).convert('RGB')
        if self.transform:
            image = self.transform(image)
        meta = torch.tensor(self.meta_data.iloc[idx].values.astype(np.float32))
        label = torch.tensor(self.labels.iloc[idx], dtype=torch.long)
        return image, meta, label

train_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(p=0.5),  # Random horizontal flip
    transforms.RandomRotation(70),          
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),  # Color jitter
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

val_test_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

# Create datasets
train_dataset = PADUFESDataset(X_train_img_paths_final, X_train_meta_final, y_train_final, transform=train_transform)
val_dataset = PADUFESDataset(X_val_img_paths, X_val_meta, y_val, transform=val_test_transform)
test_dataset = PADUFESDataset(X_test_img_paths, X_test_meta, y_test, transform=val_test_transform)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

print(f"\n✅ Loaded Train: {len(train_dataset)}, Val: {len(val_dataset)}, Test: {len(test_dataset)}")

In [None]:
import torch
import torch.nn as nn
import timm  
import torch.nn.functional as F

class EarlyFusionModelMobileViT(nn.Module):
    def __init__(self, input_dim_meta, num_classes):
        super().__init__()
        
        # Embed metadata to smaller spatial dimensions first
        self.meta_embed = nn.Sequential(
            nn.Linear(input_dim_meta, 64 * 64),  # Updated for mobilevit's smaller receptive field
            nn.ReLU(),
            nn.BatchNorm1d(64 * 64),
            nn.Dropout(0.3)
        )
        
        # Load MobileViT model
        self.mobilevit = timm.create_model("mobilevit_s.cvnets_in1k", pretrained=True, num_classes=num_classes)
        
        # Inspect the model to identify the first conv layer
        # Modify the first conv layer to accept additional channel
        first_conv = self.mobilevit.stem.conv  # `stem.conv` is the correct initial layer
        self.mobilevit.stem.conv = nn.Conv2d(4, first_conv.out_channels, 
                                             kernel_size=first_conv.kernel_size, 
                                             stride=first_conv.stride, 
                                             padding=first_conv.padding, 
                                             bias=first_conv.bias)
        
        # Initialize new channel weights
        with torch.no_grad():
            self.mobilevit.stem.conv.weight.data[:, :3] = first_conv.weight.data
            # Initialize the new channel with smaller weights to prevent dominating
            self.mobilevit.stem.conv.weight.data[:, 3:] = first_conv.weight.data.mean(dim=1, keepdim=True) * 0.1

    def forward(self, img, meta):
        # Reshape metadata to image-like format
        batch_size = img.shape[0]
        meta_reshaped = self.meta_embed(meta).view(batch_size, 1, 64, 64)
        
        # Upsample to match image dimensions
        meta_upsampled = F.interpolate(meta_reshaped, 
                                       size=(224, 224),  # MobileViT expects 256x256
                                       mode='bilinear', 
                                       align_corners=False)
        
        # Early fusion
        combined_input = torch.cat([img, meta_upsampled], dim=1)
        
        # Process through modified MobileViT
        out = self.mobilevit(combined_input)
        return out



class EarlyFusionModelPvtV2(nn.Module):
    def __init__(self, input_dim_meta, num_classes):
        super().__init__()
        
        # Embed metadata to smaller spatial dimensions first
        self.meta_embed = nn.Sequential(
            nn.Linear(input_dim_meta, 56 * 56),  # Smaller initial dimension
            nn.ReLU(),
            nn.BatchNorm1d(56 * 56),
            nn.Dropout(0.3)
        )
        
        # Load PVT v2 model
        self.pvt = timm.create_model("pvt_v2_b1", pretrained=True, num_classes=num_classes)
        
        # Modify the first convolution layer to accept additional channel (4 instead of 3)
        first_conv = self.pvt.patch_embed.proj
        self.pvt.patch_embed.proj = nn.Conv2d(4, first_conv.out_channels, 
                                              kernel_size=first_conv.kernel_size,
                                              stride=first_conv.stride,
                                              padding=first_conv.padding,
                                              bias=first_conv.bias is not None)
        
        # Initialize new channel weights
        with torch.no_grad():
            self.pvt.patch_embed.proj.weight.data[:, :3] = first_conv.weight.data
            self.pvt.patch_embed.proj.weight.data[:, 3:] = first_conv.weight.data.mean(dim=1, keepdim=True) * 0.1

    def forward(self, img, meta):
        # Reshape metadata to image-like format
        batch_size = img.shape[0]
        meta_reshaped = self.meta_embed(meta).view(batch_size, 1, 56, 56)
        
        # Upsample to match image dimensions
        meta_upsampled = F.interpolate(meta_reshaped, 
                                       size=(224, 224), 
                                       mode='bilinear', 
                                       align_corners=False)
        
        # Early fusion
        combined_input = torch.cat([img, meta_upsampled], dim=1)
        
        # Process through modified PVT
        out = self.pvt(combined_input)
        return out

input_dim_meta = X_train_meta.shape[1]
num_classes = 6
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


class EarlyFusionMobileNetV3(nn.Module):
    def __init__(self, input_dim_meta, num_classes, model_size='large'):
        super().__init__()
        
        # Embed metadata to smaller spatial dimensions
        self.meta_embed = nn.Sequential(
            nn.Linear(input_dim_meta, 56 * 56),
            nn.ReLU(),
            nn.BatchNorm1d(56 * 56),
            nn.Dropout(0.3)
        )
        
        # Load MobileNetV3
        model_name = f'mobilenetv3_{model_size}'  # 'mobilenetv3_large' or 'mobilenetv3_small'
        self.mobilenet = timm.create_model("mobilenetv3_large_100", pretrained=True, num_classes=6)
        
        # Store the original first conv layer
        first_conv = self.mobilenet.conv_stem
        
        # Modify first conv layer to accept 4 channels (RGB + metadata)
        self.mobilenet.conv_stem = nn.Conv2d(
            4, 
            first_conv.out_channels,
            kernel_size=first_conv.kernel_size,
            stride=first_conv.stride,
            padding=first_conv.padding,
            bias=False
        )
        
        # Initialize new conv layer
        with torch.no_grad():
            # Copy weights for RGB channels
            self.mobilenet.conv_stem.weight[:, :3] = first_conv.weight
            # Initialize metadata channel with mean of RGB weights * small factor
            self.mobilenet.conv_stem.weight[:, 3:] = first_conv.weight.mean(dim=1, keepdim=True) * 0.1
        
        # Modify classifier
        in_features = self.mobilenet.classifier.in_features
        self.mobilenet.classifier = nn.Sequential(
            nn.Dropout(p=0.3),
            nn.Linear(in_features, num_classes)
        )
        
    def forward(self, img, meta):
        batch_size = img.shape[0]
        
        # Process metadata
        meta_features = self.meta_embed(meta)
        meta_reshaped = meta_features.view(batch_size, 1, 56, 56)
        
        meta_upsampled = F.interpolate(
            meta_reshaped,
            size=(224, 224),
            mode='bilinear',
            align_corners=False
        )
        
        # Concatenate along channel dimension
        combined_input = torch.cat([img, meta_upsampled], dim=1)
        
        # Forward pass through MobileNetV3
        output = self.mobilenet(combined_input)
        
        return output


class EarlyFusionModelCoatNet(nn.Module):
    def __init__(self, input_dim_meta, num_classes):
        super().__init__()
        
        # Embed metadata to smaller spatial dimensions
        self.meta_embed = nn.Sequential(
            nn.Linear(input_dim_meta, 56 * 56),  # Smaller initial dimension
            nn.ReLU(),
            nn.BatchNorm1d(56 * 56),
            nn.Dropout(0.3)
        )
        
        # Load the CoAtNet model
        self.coatnet = timm.create_model('coatnet_2_rw_224.sw_in12k_ft_in1k', pretrained=True, num_classes=num_classes)
        
        # Modify the first conv layer to accept 4 input channels
        first_conv = self.coatnet.stem.conv1
        self.coatnet.stem.conv1 = nn.Conv2d(4, 64, kernel_size=3, stride=2, padding=1, bias=False)
        
        # Initialize new channel weights
        with torch.no_grad():
            self.coatnet.stem.conv1.weight.data[:, :3] = first_conv.weight.data
            # Initialize the new channel with smaller weights
            self.coatnet.stem.conv1.weight.data[:, 3:] = first_conv.weight.data.mean(dim=1, keepdim=True) * 0.1

    def forward(self, img, meta):
        # Reshape metadata to image-like format
        batch_size = img.shape[0]
        meta_reshaped = self.meta_embed(meta).view(batch_size, 1, 56, 56)
        
        # Upsample metadata to match image dimensions
        meta_upsampled = F.interpolate(meta_reshaped, size=(224, 224), mode='bilinear', align_corners=False)
        
        # Early fusion: Combine image and metadata as additional channel
        combined_input = torch.cat([img, meta_upsampled], dim=1)
        
        out = self.coatnet(combined_input)
        return out
       

class EarlyFusionModelXception(nn.Module):
    def __init__(self, input_dim_meta, num_classes):
        super().__init__()
        
        # Embed metadata to smaller spatial dimensions first
        self.meta_embed = nn.Sequential(
            nn.Linear(input_dim_meta, 56 * 56),  # Smaller initial dimension
            nn.ReLU(),
            nn.BatchNorm1d(56 * 56),
            nn.Dropout(0.3)
        )
        
        # Modified Xception
        self.xception = timm.create_model('xception', pretrained=True)
        # Modify first conv layer to accept additional channel
        first_conv = self.xception.conv1
        self.xception.conv1 = nn.Conv2d(4, 32, kernel_size=3, stride=2, padding=1, bias=False)
        
        # Modify the final classification layer
        in_features = self.xception.fc.in_features  # Usually 2048 for Xception
        self.xception.fc = nn.Sequential(
            nn.Dropout(0.5),
            nn.Linear(in_features, num_classes)
        )
        
        # Initialize new channel weights
        with torch.no_grad():
            self.xception.conv1.weight.data[:, :3] = first_conv.weight.data
            # Initialize the new channel with smaller weights to prevent dominating
            self.xception.conv1.weight.data[:, 3:] = first_conv.weight.data.mean(dim=1, keepdim=True) * 0.1
        
    def forward(self, img, meta):
        # Reshape metadata to image-like format
        batch_size = img.shape[0]
        meta_reshaped = self.meta_embed(meta).view(batch_size, 1, 56, 56)
        
        # Upsample to match image dimensions
        meta_upsampled = F.interpolate(meta_reshaped, 
                                     size=(224, 224), 
                                     mode='bilinear', 
                                     align_corners=False)
        
        # Early fusion
        combined_input = torch.cat([img, meta_upsampled], dim=1)
        
        # Process through modified Xception
        out = self.xception(combined_input)
        return out
        

class EarlyFusionModelVgg16(nn.Module):
    def __init__(self, input_dim_meta, num_classes):
        super().__init__()
        
        # Embed metadata to smaller spatial dimensions first
        self.meta_embed = nn.Sequential(
            nn.Linear(input_dim_meta, 56 * 56),  # Smaller initial dimension
            nn.ReLU(),
            nn.BatchNorm1d(56 * 56),
            nn.Dropout(0.3)
        )
        
        # Modified VGG16
        self.vgg16 = timm.create_model('vgg16', pretrained=True, num_classes = 6)
        
        # Modify the first convolutional layer to accept 4 input channels
        first_conv = self.vgg16.features[0]
        self.vgg16.features[0] = nn.Conv2d(
            4, 64, kernel_size=3, stride=1, padding=1, bias=False
        )
        
        # Modify the classifier to match the number of classes
        in_features = self.vgg16.get_classifier().in_features  # Use `get_classifier()` to access the classifier
        self.vgg16.classifier = nn.Sequential(
            nn.Dropout(0.5),
            nn.Linear(in_features, 4096),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(4096, 4096),
            nn.ReLU(),
            nn.Linear(4096, num_classes)
        )

        # Initialize new channel weights
        with torch.no_grad():
            self.vgg16.features[0].weight.data[:, :3] = first_conv.weight.data
            # Initialize the new channel with smaller weights
            self.vgg16.features[0].weight.data[:, 3:] = first_conv.weight.data.mean(dim=1, keepdim=True) * 0.1
        
    def forward(self, img, meta):
        # Reshape metadata to image-like format
        batch_size = img.shape[0]
        meta_reshaped = self.meta_embed(meta).view(batch_size, 1, 56, 56)
        
        # Upsample to match image dimensions
        meta_upsampled = F.interpolate(
            meta_reshaped, size=(224, 224), mode='bilinear', align_corners=False
        )
        
        # Early fusion
        combined_input = torch.cat([img, meta_upsampled], dim=1)
        
        # Process through modified VGG16
        out = self.vgg16(combined_input)
        return out


from transformers import AutoModelForImageClassification, AutoConfig
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoModelForImageClassification, AutoImageProcessor

class EarlyFusionModelSwinTiny(nn.Module):
    def __init__(self, input_dim_meta, num_classes):
        super().__init__()
        
        # Embed metadata to smaller spatial dimensions
        self.meta_embed = nn.Sequential(
            nn.Linear(input_dim_meta, 56 * 56),
            nn.ReLU(),
            nn.BatchNorm1d(56 * 56),
            nn.Dropout(0.3)
        )
        
        # Load Swin Transformer from Hugging Face
        self.swin_transformer = AutoModelForImageClassification.from_pretrained(
            "microsoft/swin-tiny-patch4-window7-224",
            ignore_mismatched_sizes=True
        )
        
        # Replace the classifier to match the number of classes
        in_features = self.swin_transformer.classifier.in_features
        self.swin_transformer.classifier = nn.Linear(in_features, num_classes)
        
        # Access and modify the patch embedding layer
        patch_embed_layer = self.swin_transformer.swin.embeddings.patch_embeddings.projection
        self.swin_transformer.swin.embeddings.patch_embeddings.projection = nn.Conv2d(
            in_channels=4,  # Update to accept 4 input channels
            out_channels=patch_embed_layer.out_channels,
            kernel_size=patch_embed_layer.kernel_size,
            stride=patch_embed_layer.stride,
            padding=patch_embed_layer.padding,
            bias=patch_embed_layer.bias is not None,  # Ensure bias is a boolean
        )
        
        # Initialize weights for the new input channel
        with torch.no_grad():
            self.swin_transformer.swin.embeddings.patch_embeddings.projection.weight[:, :3] = patch_embed_layer.weight
            self.swin_transformer.swin.embeddings.patch_embeddings.projection.weight[:, 3:] = patch_embed_layer.weight.mean(dim=1, keepdim=True)

    def forward(self, img, meta):
        # Embed metadata
        batch_size = img.size(0)
        meta_reshaped = self.meta_embed(meta).view(batch_size, 1, 56, 56)
        
        # Upsample metadata to match image dimensions
        meta_upsampled = F.interpolate(meta_reshaped, size=(224, 224), mode="bilinear", align_corners=False)
        
        # Concatenate the image and metadata
        combined_input = torch.cat([img, meta_upsampled], dim=1)
        
        # Pass through Swin Transformer
        outputs = self.swin_transformer(pixel_values=combined_input)
        return outputs.logits


class EarlyFusionModelEfficientViT(nn.Module):
    def __init__(self, input_dim_meta, num_classes):
        super().__init__()
        
        # Embed metadata to smaller spatial dimensions first
        self.meta_embed = nn.Sequential(
            nn.Linear(input_dim_meta, 56 * 56),  # Smaller initial dimension
            nn.ReLU(),
            nn.BatchNorm1d(56 * 56),
            nn.Dropout(0.3)
        )
        
        self.efficientvit = timm.create_model("efficientvit_b0.r224_in1k", pretrained=True, num_classes=6)
        
        # Modify the first conv layer to accept additional channel
        first_conv = self.efficientvit.stem.in_conv.conv
        self.efficientvit.stem.in_conv.conv = nn.Conv2d(4, 8, kernel_size=3, stride=2, padding=1, bias=False)
        
        # Initialize new channel weights
        with torch.no_grad():
            self.efficientvit.stem.in_conv.conv.weight.data[:, :3] = first_conv.weight.data
            # Initialize the new channel with smaller weights to prevent dominating
            self.efficientvit.stem.in_conv.conv.weight.data[:, 3:] = first_conv.weight.data.mean(dim=1, keepdim=True) * 0.1

    def forward(self, img, meta):
        # Reshape metadata to image-like format
        batch_size = img.shape[0]
        meta_reshaped = self.meta_embed(meta).view(batch_size, 1, 56, 56)
        
        # Upsample to match image dimensions
        meta_upsampled = F.interpolate(meta_reshaped, 
                                     size=(224, 224), 
                                     mode='bilinear', 
                                     align_corners=False)
        
        # Early fusion
        combined_input = torch.cat([img, meta_upsampled], dim=1)
        
        # Process through modified EfficientViT
        out = self.efficientvit(combined_input)
        return out
        


class EarlyFusionModelDenseNet(nn.Module):
    def __init__(self, input_dim_meta, num_classes):
        super().__init__()
        
        self.meta_embed = nn.Sequential(
            nn.Linear(input_dim_meta, 56 * 56),  # Smaller initial dimension
            nn.ReLU(),
            nn.BatchNorm1d(56 * 56),
            nn.Dropout(0.3)
        )
        
        # Modified DenseNet121
        self.densenet = timm.create_model('densenet121', pretrained=True)
        
        first_conv = self.densenet.features.conv0
        self.densenet.features.conv0 = nn.Conv2d(
            4, 64, kernel_size=7, stride=2, padding=3, bias=False
        )
        
        # Modify the final classification layer
        in_features = self.densenet.classifier.in_features
        self.densenet.classifier = nn.Sequential(
            nn.Dropout(0.5),
            nn.Linear(in_features, num_classes)
        )
        
        # Initialize new channel weights
        with torch.no_grad():
            self.densenet.features.conv0.weight.data[:, :3] = first_conv.weight.data
            self.densenet.features.conv0.weight.data[:, 3:] = first_conv.weight.data.mean(dim=1, keepdim=True) * 0.1
        
    def forward(self, img, meta):
        # Reshape metadata to image-like format
        batch_size = img.shape[0]
        meta_reshaped = self.meta_embed(meta).view(batch_size, 1, 56, 56)
        
        # Upsample to match image dimensions
        meta_upsampled = F.interpolate(
            meta_reshaped, size=(224, 224), mode='bilinear', align_corners=False
        )
        
        # Early fusion
        combined_input = torch.cat([img, meta_upsampled], dim=1)
        
        # Process through modified DenseNet121
        out = self.densenet(combined_input)
        return out
        

class EarlyFusionModelInceptionResnetv2(nn.Module):
    def __init__(self, input_dim_meta, num_classes):
        super().__init__()
        
        # Embed metadata to smaller spatial dimensions
        self.meta_embed = nn.Sequential(
            nn.Linear(input_dim_meta, 56 * 56),  # Smaller initial dimension
            nn.ReLU(),
            nn.BatchNorm1d(56 * 56),
            nn.Dropout(0.3)
        )
        
        # Load the InceptionResNetV2 model from timm
        self.inception_resnet = timm.create_model('inception_resnet_v2', pretrained=True)
        
        # Modify the first convolutional block (conv2d_1a) to accept 4 channels
        first_conv = self.inception_resnet.conv2d_1a.conv
        self.inception_resnet.conv2d_1a.conv = nn.Conv2d(
            4, 32, kernel_size=(3, 3), stride=(2, 2), bias=False
        )
        
        # Copy weights for the first 3 channels and initialize the 4th channel
        with torch.no_grad():
            self.inception_resnet.conv2d_1a.conv.weight[:, :3] = first_conv.weight
            # Initialize the new channel with the mean of existing weights
            self.inception_resnet.conv2d_1a.conv.weight[:, 3:] = first_conv.weight.mean(dim=1, keepdim=True) * 0.1
        
        # Modify the final classification head to match the number of classes
        in_features = self.inception_resnet.classif.in_features
        self.inception_resnet.classif = nn.Sequential(
            nn.Dropout(0.5),
            nn.Linear(in_features, num_classes)
        )
        
    def forward(self, img, meta):
        # Reshape metadata to image-like format
        batch_size = img.shape[0]
        meta_reshaped = self.meta_embed(meta).view(batch_size, 1, 56, 56)
        
        # Upsample metadata to match image dimensions
        meta_upsampled = F.interpolate(meta_reshaped, 
                                       size=(224, 224), 
                                       mode='bilinear', 
                                       align_corners=False)
        
        # Concatenate image and metadata along the channel dimension
        combined_input = torch.cat([img, meta_upsampled], dim=1)
        
        # Process through the modified InceptionResNetV2
        out = self.inception_resnet(combined_input)
        return out
        

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
from torch_cluster import knn_graph
import timm


class EarlyFusionWithDynamicGCN(nn.Module):
    def __init__(self, input_dim_meta, num_classes, k=8):
        super().__init__()
        self.k = k

        # === GCN with residual ===
        self.gcn1 = GCNConv(input_dim_meta, 64)
        self.gcn2 = GCNConv(64, 32)
        self.res_proj = nn.Linear(64, 32)

        # === Metadata to image (56×56 = 3136) ===
        self.meta_to_image = nn.Sequential(
            nn.Linear(32, 56 * 56),
            nn.ReLU(),
            nn.BatchNorm1d(56 * 56),
            nn.Dropout(0.3)
        )

        # === MobileViT Backbone Modification ===
        self.mobilevit = timm.create_model("mobilevit_s.cvnets_in1k", pretrained=True, num_classes=0)

        stem_conv = self.mobilevit.stem.conv
        new_conv = nn.Conv2d(4, stem_conv.out_channels,
                             kernel_size=stem_conv.kernel_size,
                             stride=stem_conv.stride,
                             padding=stem_conv.padding,
                             bias=stem_conv.bias is not None)
        with torch.no_grad():
            new_conv.weight[:, :3] = stem_conv.weight
            new_conv.weight[:, 3:] = stem_conv.weight.mean(dim=1, keepdim=True) * 0.1
            if stem_conv.bias is not None:
                new_conv.bias = stem_conv.bias
        self.mobilevit.stem.conv = new_conv

        self.mobilevit.stages = nn.Sequential(*list(self.mobilevit.stages.children())[:4])
        self.mobilevit.final_conv = nn.Identity()
        self.mobilevit.head = nn.Identity()

        # === Post Conv ===
        self.post_conv = nn.Sequential(
            nn.Conv2d(128, 160, kernel_size=1, bias=False),
            nn.BatchNorm2d(160),
            nn.ReLU(inplace=True)
        )

        # === Classifier ===
        self.pool = nn.AdaptiveAvgPool2d((1, 1))
        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(160, 256),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(256, num_classes)
        )

    def forward(self, img, meta, batch_idx):
        B = meta.size(0)

        # Dynamic KNN graph
        edge_index = knn_graph(meta, k=self.k, batch=batch_idx)

        # GCN + residual
        x1 = F.relu(self.gcn1(meta, edge_index))
        x2 = F.relu(self.gcn2(x1, edge_index) + self.res_proj(x1))
        x_meta = x2  # [B, 32]

        # Metadata → image → reshape
        meta_img = self.meta_to_image(x_meta).view(B, 1, 56, 56)
        meta_img = F.interpolate(meta_img, size=(224, 224), mode='bilinear', align_corners=False)

        # Early fusion
        x = torch.cat([img, meta_img], dim=1)  # [B, 4, 224, 224]

        # CNN
        x_cnn = self.mobilevit.stem(x)
        x_cnn = self.mobilevit.stages(x_cnn)
        x_cnn = self.post_conv(x_cnn)
        x_cnn = self.pool(x_cnn).view(B, -1)  # [B, 160]                                            # [B, 160]

        return self.classifier(x_cnn)


In [None]:
import torch
import torch.nn as nn
import time
import numpy as np
from ptflops import get_model_complexity_info

# ============================================================
# SETTINGS
# ============================================================

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
meta_dim = 59
num_classes = 6
image_size = (3, 224, 224)

PAD_ROOT = r"D:\PAD-UFES"
STUDENT_CKPT = r"C:\Users\User\MDY Research\With Augmentation\Concatenation\Adasyn\best_student_model_final2.pth"

# ============================================================
# UTILS
# ============================================================

def count_parameters(model, trainable_only=False):
    if trainable_only:
        return sum(p.numel() for p in model.parameters() if p.requires_grad)
    else:
        return sum(p.numel() for p in model.parameters())



class FusionWrapper(nn.Module):
    """
    Wraps 2-input (img, meta) models so ptflops sees a single image input.
    """
    def __init__(self, model, meta_dim):
        super().__init__()
        self.model = model
        self.meta_dim = meta_dim

    def forward(self, x):
        dummy_meta = torch.randn(x.size(0), self.meta_dim).to(x.device)
        return self.model(x, dummy_meta)


class GCN2InputWrapper(nn.Module):
    """
    Wraps EarlyFusionWithDynamicGCN so it behaves like forward(img, meta).
    Batch indices are synthesized internally.
    """
    def __init__(self, gcn_model, meta_dim):
        super().__init__()
        self.gcn_model = gcn_model
        self.meta_dim = meta_dim

    def forward(self, img, meta):
        B = meta.size(0)
        batch_idx = torch.arange(B, device=meta.device)
        return self.gcn_model(img, meta, batch_idx)


def compute_flops(model, meta_dim):
    wrapper = FusionWrapper(model, meta_dim).to(device)
    with torch.no_grad():
        flops, _ = get_model_complexity_info(
            wrapper,
            image_size,
            as_strings=False,
            print_per_layer_stat=False,
            verbose=False
        )
    return float(flops / 1e9)  # GFLOPs


def measure_gpu_latency(model, meta_dim, runs=200, warmup=30):
    if not torch.cuda.is_available():
        return None, None, None

    torch.cuda.empty_cache()
    torch.cuda.reset_peak_memory_stats()
    model.eval().to(device)

    dummy_img = torch.randn(1, *image_size, device=device)
    dummy_meta = torch.randn(1, meta_dim, device=device)

    # warmup
    for _ in range(warmup):
        _ = model(dummy_img, dummy_meta)
    torch.cuda.synchronize()

    start_evt = torch.cuda.Event(enable_timing=True)
    end_evt = torch.cuda.Event(enable_timing=True)
    times = []

    for _ in range(runs):
        start_evt.record()
        _ = model(dummy_img, dummy_meta)
        end_evt.record()
        torch.cuda.synchronize()
        times.append(start_evt.elapsed_time(end_evt))  # ms

    times = np.array(times)
    mean = float(times.mean())
    std = float(times.std())
    fps = float(1000.0 / mean)
    return mean, std, fps


def measure_cpu_latency(model, meta_dim, runs=100, warmup=20):
    model_cpu = model.cpu()
    model_cpu.eval()

    dummy_img = torch.randn(1, *image_size)
    dummy_meta = torch.randn(1, meta_dim)

    for _ in range(warmup):
        _ = model_cpu(dummy_img, dummy_meta)

    times = []
    for _ in range(runs):
        start = time.perf_counter()
        _ = model_cpu(dummy_img, dummy_meta)
        end = time.perf_counter()
        times.append((end - start) * 1000.0)

    times = np.array(times)
    mean = float(times.mean())
    std = float(times.std())
    fps = float(1000.0 / mean)
    return mean, std, fps


def load_model(model_class, ckpt_path, name, wrap_gcn=False):
    """
    Creates model_class(input_dim_meta=59, num_classes=6), loads checkpoint,
    optionally wraps GCN so it behaves like (img, meta).
    """
    print(f"\n[LOAD] {name} from {ckpt_path}")
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        torch.cuda.reset_peak_memory_stats()

    model = model_class(meta_dim, num_classes)
    state = torch.load(ckpt_path, map_location="cpu")
    model.load_state_dict(state)
    if wrap_gcn:
        model = GCN2InputWrapper(model, meta_dim)
    model.to(device)
    model.eval()
    print("[OK] Model loaded.")
    return model


def benchmark_model(model, name, meta_dim, count_all_params=False):
    """
    Given a ready-to-use 2-input (img, meta) model, compute all stats.
    If count_all_params=True, counts ALL parameters (ignoring requires_grad),
    which is what we want for the teacher ensemble.
    """
    print(f"\n======= Benchmarking: {name} =======")

    # trainable_only=False for teacher, True for others via this flag
    trainable_only = not count_all_params
    params_m = count_parameters(model, trainable_only=trainable_only) / 1e6
    print(f"Params: {params_m:.3f} M")

    flops_g = compute_flops(model, meta_dim)
    print(f"FLOPs: {flops_g:.3f} G")

    gpu_mean, gpu_std, gpu_fps = measure_gpu_latency(model, meta_dim)
    if gpu_mean is not None:
        print(f"GPU Latency: {gpu_mean:.3f} ± {gpu_std:.3f} ms  |  FPS: {gpu_fps:.1f}")
    else:
        print("GPU Latency: N/A")

    cpu_mean, cpu_std, cpu_fps = measure_cpu_latency(model, meta_dim)
    print(f"CPU Latency: {cpu_mean:.3f} ± {cpu_std:.3f} ms  |  FPS: {cpu_fps:.1f}")

    return {
        "model": name,
        "params_M": params_m,
        "flops_G": flops_g,
        "gpu_latency_mean_ms": gpu_mean,
        "gpu_latency_std_ms": gpu_std,
        "gpu_fps": gpu_fps,
        "cpu_latency_mean_ms": cpu_mean,
        "cpu_latency_std_ms": cpu_std,
        "cpu_fps": cpu_fps,
    }


# ============================================================
# TEACHER ENSEMBLE MODEL
# ============================================================

class TeacherModel(nn.Module):
    def __init__(self, models, ensemble_method="mean"):
        super().__init__()
        self.models = nn.ModuleList(models)
        self.ensemble_method = ensemble_method
        for m in self.models:
            m.eval()
            for p in m.parameters():
                p.requires_grad = False

    def forward(self, img, meta):
        outputs = []
        with torch.no_grad():
            for m in self.models:
                outputs.append(m(img, meta))
        outputs = torch.stack(outputs, dim=0)  # [n_models, B, C]
        if self.ensemble_method == "mean":
            return outputs.mean(dim=0)
        else:
            # simple majority vote
            _, preds = torch.max(outputs, dim=2)
            return preds.mode(dim=0).values


# ============================================================
# BENCHMARK ALL MODELS
# (Assumes all EarlyFusion* classes and EarlyFusionWithDynamicGCN
#  are defined exactly as in your message.)
# ============================================================

results = []

# 1) MobileViT + MLP
mv_ckpt = rf"{PAD_ROOT}\best_early_fusion_mobilevitDA.pth"
mv_model = load_model(EarlyFusionModelMobileViT, mv_ckpt, "MobileViT + MLP")
results.append(benchmark_model(mv_model, "MobileViT + MLP", meta_dim))

# 2) PVTv2 + MLP
pvt_ckpt = rf"{PAD_ROOT}\best_early_fusion_pvtv2smoteDA.pth"
pvt_model = load_model(EarlyFusionModelPvtV2, pvt_ckpt, "PVTv2 + MLP")
results.append(benchmark_model(pvt_model, "PVTv2 + MLP", meta_dim))

# 3) MobileNetV3 + MLP
mn_ckpt = rf"{PAD_ROOT}\best_early_fusion_mobilenetv3smoteDA.pth"
mn_model = load_model(EarlyFusionMobileNetV3, mn_ckpt, "MobileNetV3 + MLP")
results.append(benchmark_model(mn_model, "MobileNetV3 + MLP", meta_dim))

# 4) DenseNet121 + MLP
dn_ckpt = rf"{PAD_ROOT}\best_early_fusion_densenetsmoteDA.pth"
dn_model = load_model(EarlyFusionModelDenseNet, dn_ckpt, "DenseNet121 + MLP")
results.append(benchmark_model(dn_model, "DenseNet121 + MLP", meta_dim))

# 5) EfficientViT + MLP
ev_ckpt = rf"{PAD_ROOT}\best_early_fusion_efficientvitDA.pth"
ev_model = load_model(EarlyFusionModelEfficientViT, ev_ckpt, "EfficientViT + MLP")
results.append(benchmark_model(ev_model, "EfficientViT + MLP", meta_dim))

# 6) InceptionResNetV2 + MLP
ir_ckpt = rf"{PAD_ROOT}\best_early_fusion_inceptionresnetv2smoteDA.pth"
ir_model = load_model(EarlyFusionModelInceptionResnetv2, ir_ckpt, "InceptionResNetV2 + MLP")
results.append(benchmark_model(ir_model, "InceptionResNetV2 + MLP", meta_dim))

# 7) VGG16 + MLP
vgg_ckpt = rf"{PAD_ROOT}\best_early_fusion_vgg16smoteDA.pth"
vgg_model = load_model(EarlyFusionModelVgg16, vgg_ckpt, "VGG16 + MLP")
results.append(benchmark_model(vgg_model, "VGG16 + MLP", meta_dim))

# 8) Xception + MLP
xc_ckpt = rf"{PAD_ROOT}\best_early_fusion_xceptionsmoteDA.pth"
xc_model = load_model(EarlyFusionModelXception, xc_ckpt, "Xception + MLP")
results.append(benchmark_model(xc_model, "Xception + MLP", meta_dim))

# 9) Swin-Tiny + MLP
swin_ckpt = rf"{PAD_ROOT}\best_early_fusion_swintinysmoteDA.pth"
swin_model = load_model(EarlyFusionModelSwinTiny, swin_ckpt, "Swin-Tiny + MLP")
results.append(benchmark_model(swin_model, "Swin-Tiny + MLP", meta_dim))

# 10) CoAtNet-2 + MLP
coat_ckpt = rf"{PAD_ROOT}\best_early_fusion_coatnetsmoteDA.pth"
coat_model = load_model(EarlyFusionModelCoatNet, coat_ckpt, "CoAtNet-2 + MLP")
results.append(benchmark_model(coat_model, "CoAtNet-2 + MLP", meta_dim))

# 11) TabFusion (GCN Student)
student_raw = EarlyFusionWithDynamicGCN(meta_dim, num_classes)
student_state = torch.load(STUDENT_CKPT, map_location="cpu")
student_raw.load_state_dict(student_state)
student_model = GCN2InputWrapper(student_raw, meta_dim).to(device).eval()
results.append(benchmark_model(student_model, "TabFusion (GCN Student)", meta_dim))

# 12) Teacher Ensemble (MobileViT + PVTv2)
teacher_models = [mv_model, pvt_model]  # reuse the ones we already loaded
teacher = TeacherModel(teacher_models, ensemble_method="mean").to(device)
results.append(
    benchmark_model(
        teacher,
        name="Teacher Ensemble (MobileViT + PVTv2)",
        meta_dim=59,
        count_all_params=True   # <---- IMPORTANT
    )
)

# ============================================================
# SUMMARY
# ============================================================

print("\n================ FINAL SUMMARY ================\n")
for r in results:
    print(r)
    print()
