## Import necessary libraries

In [None]:
import os
import numpy as np
import pandas as pd
import cv2
import glob
from collections import Counter
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torchvision.models import densenet121
from transformers import BertTokenizer, BertForSequenceClassification
from transformers import BertTokenizer, BertForSequenceClassification, DeiTForImageClassification
from torch.optim import Adam
from torch.optim.lr_scheduler import StepLR
from sklearn.utils.class_weight import compute_class_weight
from PIL import Image
from torchvision import models
from sklearn.model_selection import train_test_split
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score, confusion_matrix, classification_report
import seaborn as sns

## Seeding

In [None]:
def set_seed(seed):
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

seed = 42

## Links and device

In [None]:
# Constants
url = "/kaggle/input/mias-mammography/all-mias/"  # Base URL for data
csv_file_path = '/kaggle/input/mias-text-without-co-ordinates/MIAS text.csv'


In [None]:
# Define the device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

## Image data preprocessing

In [None]:
# Function to isolate breast region
def preprocess_image_2(img):
    """Preprocess a mammogram image to isolate the breast region."""
    _, binary_img = cv2.threshold(img, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
    kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (15, 15))
    morphed_img = cv2.morphologyEx(binary_img, cv2.MORPH_CLOSE, kernel)
    contours, _ = cv2.findContours(morphed_img, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
    largest_contour = max(contours, key=cv2.contourArea)
    mask = np.zeros_like(binary_img)
    cv2.drawContours(mask, [largest_contour], -1, 255, thickness=cv2.FILLED)
    isolated_breast = cv2.bitwise_and(img, img, mask=mask)
    x, y, w, h = cv2.boundingRect(largest_contour)
    cropped_breast = isolated_breast[y:y+h, x:x+w]
    return cropped_breast

In [None]:
# Function to apply bicubic interpolation
def apply_bicubic_interpolation(image, scale_factor=2):
    """Enhance an image using bicubic interpolation-based super-resolution."""
    height, width = image.shape[:2]
    new_height, new_width = int(height * scale_factor), int(width * scale_factor)
    high_res_image = cv2.resize(image, (new_width, new_height), interpolation=cv2.INTER_CUBIC)

    # Resize the image to ensure it has the same shape (224, 224)
    high_res_image = cv2.resize(high_res_image, (224, 224))
    high_res_image = np.expand_dims(high_res_image, axis=-1)

    return high_res_image

In [None]:
# Function to read and preprocess images
def read_image():
    """Read and preprocess images without augmentation."""
    print("Reading images")
    info = {}  # Dictionary to store image data

    for i in range(322):  # 322 images in total
        if i < 9:
            image_name = f'mdb00{i + 1}'
        elif i < 99:
            image_name = f'mdb0{i + 1}'
        else:
            image_name = f'mdb{i + 1}'

        image_address = os.path.join(url, f"{image_name}.pgm")
        img = cv2.imread(image_address, cv2.IMREAD_GRAYSCALE)

        # Check if the image exists
        if img is not None:
            info[image_name] = img  # Store image directly
        else:
            print(f"Warning: Image {image_name} not found.")

    print(f"Total images read: {len(info)}")  # Debugging the number of images read
    return info


In [None]:
# Function to read labels from file
def read_label():
    """Read labels from file."""
    print("Reading labels")
    filename = url + 'Info.txt'
    text_all = open(filename).read()
    lines = text_all.split('\n')
    info = {}  # Dictionary for label data

    for line in lines:
        words = line.split(' ')
        if len(words) > 3:
            if (words[3] == 'B'):  # Label 'B' for benign
                info[words[0]] = 0  # Assigning label 0 for benign
            if (words[3] == 'M'):  # Label 'M' for malignant
                info[words[0]] = 1  # Assigning label 1 for malignant

    return info

In [None]:
# Load label data
label_info = read_label()
image_info = read_image()

# Ensure that ids are properly aligned
ids = list(label_info.keys())

# Remove 'Truth-Data:' from label information if it exists
if 'Truth-Data:' in label_info:
    del label_info['Truth-Data:']

# Print the number of labels
print(f"Total number of labels: {len(label_info)}")

# Prepare X and Y arrays
X, Y = [], []

# Check for images without corresponding labels
missing_labels = []

# Loop through image names to handle missing labels and apply preprocessing and bicubic interpolation
for id in image_info.keys():  # Loop through image names
    if id in label_info:  # If label exists for the image
        # Apply preprocessing
        preprocessed_image = preprocess_image_2(image_info[id])

        # Apply bicubic interpolation
        high_res_image = apply_bicubic_interpolation(preprocessed_image, scale_factor=2)

        # Store the processed and high-res images
        X.append(high_res_image)
        Y.append(label_info[id])
    else:  # If no label for the image
        missing_labels.append(id)
        # Apply preprocessing
        preprocessed_image = preprocess_image_2(image_info[id])

        # Apply bicubic interpolation
        high_res_image = apply_bicubic_interpolation(preprocessed_image, scale_factor=2)

        # Assign default label 'N' for missing labels (Benign)
        X.append(high_res_image)
        Y.append(0)  # Default 'Benign' label

X = np.array(X)
Y = np.array(Y)

# Print dataset size to check if everything is correct
print(f"X shape: {X.shape}")
print(f"Y shape: {Y.shape}")

# Print missing labels (if any)
if missing_labels:
    print(f"Images without corresponding labels: {missing_labels}")


## Transformation

In [None]:
# Define updated transformations
train_transform = transforms.Compose([
    transforms.Resize((224, 224)),  # Resize to a fixed size
    transforms.RandomAffine(degrees=60,  # Rotation up to ±60 degrees
                            translate=(0.1, 0.1),  # Random translation up to 10% of the image size
                            scale=(0.8, 1.2),  # Random scaling between 80% and 120%
                            shear=20),  # Random shearing up to ±20 degrees
    transforms.ColorJitter(contrast=1.0),  # Enhance image contrast
    transforms.ToTensor(),  # Convert to Tensor
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])  # Normalize with ImageNet stats
])

val_test_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

In [None]:
# Prepare text data (CSV file)
df = pd.read_csv(csv_file_path)

texts = df['Generated Sentence'].values

# Ensure data lengths match
assert len(X) == len(Y) == len(texts), "Mismatch between X, Y, and texts lengths before augmentation!"

# Step 1: Find the maximum class count
class_counts = Counter(Y)
max_count = max(class_counts.values())

# Step 2: Convert grayscale images to RGB
X_rgb = np.stack([cv2.cvtColor(img.squeeze(), cv2.COLOR_GRAY2RGB) for img in X])
print(f"Converted X_rgb shape: {X_rgb.shape}")

# Step 3: Initialize augmented images, labels, and texts
augmented_images = []
augmented_labels = []
augmented_texts = []

# Step 4: Augment the minority classes until all are balanced
for label, count in class_counts.items():
    if count < max_count:
        # Get the indices of images and texts of the minority class from the original dataset
        minority_class_indices = np.where(Y == label)[0]
        images_to_augment = X_rgb[minority_class_indices]
        texts_to_augment = np.array(texts)[minority_class_indices]

        # Augment until the count matches max_count
        while count < max_count:
            for img, text in zip(images_to_augment, texts_to_augment):
                if count >= max_count:
                    break
                # Apply augmentation
                pil_img = Image.fromarray(img.astype(np.uint8))
                augmented_img = train_transform(pil_img)  # Apply transformations

                # Append augmented image, label, and corresponding text
                augmented_images.append(
                    np.array(augmented_img.permute(1, 2, 0))  # Convert back to numpy
                )
                augmented_labels.append(label)
                augmented_texts.append(text)
                count += 1

# Step 5: Convert augmented data to numpy arrays
if augmented_images:
    augmented_images = np.array(augmented_images)
    augmented_labels = np.array(augmented_labels)
    augmented_texts = np.array(augmented_texts)

    # Check augmentation results
    print(f"Number of augmented images: {len(augmented_images)}")
    print(f"Number of augmented labels: {len(augmented_labels)}")
    print(f"Number of augmented texts: {len(augmented_texts)}")

    # Step 6: Concatenate augmented data with original training data
    X_rgb = np.concatenate((X_rgb, augmented_images), axis=0)
    Y = np.concatenate((Y, augmented_labels), axis=0)
    A_texts = np.concatenate((texts, augmented_texts), axis=0)

# Final Check: Validate shapes
assert X_rgb.shape[0] == Y.shape[0] == len(A_texts), (
    f"Mismatch after augmentation: {X_rgb.shape[0]} vs {Y.shape[0]} vs {len(A_texts)}"
)
print(f"Shapes after augmentation:\nX_rgb shape: {X_rgb.shape}\nY shape: {Y.shape}\nTexts count: {len(A_texts)}")

# Step 7: Check the new class distribution
new_class_counts = Counter(Y)
print("Class distribution after augmentation:", new_class_counts)

# Step 8: Plot the class distribution
counts = [new_class_counts[i] for i in sorted(new_class_counts.keys())]
class_names = ['Benign', 'Malignant']  # Update with your actual class names
plt.figure(figsize=(8, 5))
plt.bar(class_names, counts, color=['blue', 'red'])
plt.title("Class Distribution After Augmentation")
plt.xlabel("Classes")
plt.ylabel("Number of Samples")
plt.show()


## Bidirectional Gated Cross-Attention

In [None]:
class GatedCrossAttention(nn.Module):
    def __init__(self, query_dim, context_dim, hidden_dim):
        super(GatedCrossAttention, self).__init__()
        self.query_proj = nn.Linear(query_dim, hidden_dim)
        self.key_proj = nn.Linear(context_dim, hidden_dim)
        self.value_proj = nn.Linear(context_dim, hidden_dim)

        # Gating mechanism
        self.gate_fc = nn.Linear(query_dim + hidden_dim, hidden_dim)
        self.sigmoid = nn.Sigmoid()

        self.softmax = nn.Softmax(dim=-1)

    def forward(self, query, context):
        Q = self.query_proj(query).unsqueeze(1)     # [B, 1, H]
        K = self.key_proj(context).unsqueeze(1)     # [B, 1, H]
        V = self.value_proj(context).unsqueeze(1)   # [B, 1, H]

        attn_scores = torch.bmm(Q, K.transpose(1, 2))  # [B, 1, 1]
        attn_weights = self.softmax(attn_scores)       # [B, 1, 1]
        attended = torch.bmm(attn_weights, V).squeeze(1)  # [B, H]

        # Project query into hidden space for fusion
        query_proj = self.query_proj(query)  # [B, H]

        # Gate computation
        gate_input = torch.cat([query, attended], dim=1)  # [B, Q+H]
        gate = self.sigmoid(self.gate_fc(gate_input))     # [B, H]

        # Gated fusion
        gated_output = gate * query_proj + (1 - gate) * attended  # [B, H]
        return gated_output

## Multimodal model

In [None]:
class MultiModalModel(nn.Module):
    def __init__(self, efficientvit_model, deit_model, text_model, fc_network):
        super(MultiModalModel, self).__init__()
        self.efficientvit_model = efficientvit_model
        self.deit_model = deit_model
        self.text_model = text_model
        self.fc_network = fc_network

        self.image_feature_dim = 1024
        self.deit_feature_dim = 768
        self.text_feature_dim = 768  # **Fixed: Correct text feature dimension**
        self.hidden_dim = 768

        self.text_to_vision = GatedCrossAttention(self.text_feature_dim, self.image_feature_dim + self.deit_feature_dim, self.hidden_dim)
        self.vision_to_text = GatedCrossAttention(self.image_feature_dim + self.deit_feature_dim, self.text_feature_dim, self.hidden_dim)


    def forward(self, image_input, input_ids, attention_mask):
        # Extract features from EfficientViT and DeiT
        image_features = self.efficientvit_model(image_input)  # [batch, 1024]
        deit_features = self.deit_model(image_input).logits   # [batch, 768]
        vision_features = torch.cat((image_features, deit_features), dim=1)  # **Shape: [batch, 1792]**

        text_outputs = self.text_model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            output_hidden_states=True,
            return_dict=True
        )
        text_features = text_outputs.hidden_states[-1][:, 0, :]  # **Extract hidden state ([batch, 2])**
        # print("Text Features Shape:", text_features.shape)
        # print("Vision Features Shape:", vision_features.shape)


        # Fix mismatch: Change text feature projection size
        enhanced_text = self.text_to_vision(text_features, vision_features)
        enhanced_vision = self.vision_to_text(vision_features, text_features)

        fused_features = torch.cat([enhanced_text, enhanced_vision, vision_features[:, :1024]], dim=1)

        output = self.fc_network(fused_features)  # Input size matches fc_network

        return output


## Fully connected network

In [None]:
class FullyConnectedNetwork(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super(FullyConnectedNetwork, self).__init__()

        # Define the layers for the fully connected network
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, hidden_dim // 2)
        self.fc3 = nn.Linear(hidden_dim // 2, output_dim)

        # BatchNorm and ReLU layers
        self.batch_norm = nn.BatchNorm1d(hidden_dim // 2)
        self.relu = nn.ReLU()

    def forward(self, x):
        # Apply layers sequentially
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        x = self.batch_norm(x)
        x = self.relu(x)
        x = self.fc3(x)
        return x

In [None]:
# Image feature dimension
image_feature_dim = 1024  # Adjust based on EfficientViT output
deit_feature_dim = 768  # DeiT feature output
text_feature_dim = 768   # Adjust based on MedBERT output (equal to `num_labels` in classification tasks)
hidden_dim = 512
output_dim = 2  # Number of classes

# Create the fully connected network
fc_network = FullyConnectedNetwork(
    input_dim=image_feature_dim +deit_feature_dim + text_feature_dim,
    hidden_dim=hidden_dim,
    output_dim=output_dim
)

## Multimodal dataset

In [None]:
class MultiModalDataset(torch.utils.data.Dataset):
    def __init__(self, images, texts, labels, tokenizer, transform=None, max_text_length=512):
        self.images = images
        self.texts = texts
        self.labels = labels
        self.tokenizer = tokenizer
        self.transform = transform
        self.max_text_length = max_text_length

    def __getitem__(self, idx):
        # Process image
        image = self.images[idx]

        # Convert numpy array to PIL image if it's in numpy array format
        if isinstance(image, np.ndarray):
            # Ensure the image is in the correct data type (uint8)
            if image.dtype != np.uint8:
                image = (image * 255).astype(np.uint8)
            # Convert to PIL Image
            image = Image.fromarray(image)

        if self.transform:
            image = self.transform(image)  # Apply image transformations

        # Process text
        text = self.texts[idx]
        encoding = self.tokenizer(
            text,
            truncation=True,
            padding='max_length',
            max_length=self.max_text_length,
            return_tensors='pt'
        )

        input_ids = encoding['input_ids'].squeeze(0)  # Remove batch dimension
        attention_mask = encoding['attention_mask'].squeeze(0)  # Remove batch dimension

        # Get label
        label = self.labels[idx]

        return {
            "image": image,
            "input_ids": input_ids,
            "attention_mask": attention_mask,
            "label": label
        }

    def __len__(self):
        return len(self.images)

## Class weights

In [None]:
class_weights = compute_class_weight(
    class_weight="balanced",
    classes=torch.unique(torch.tensor(labels)).numpy(),
    y=labels
)

# Convert to tensor for PyTorch
class_weights_tensor = torch.tensor(class_weights, dtype=torch.float32)

# Move class weights to the same device as the model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class_weights_tensor = class_weights_tensor.to(device)

print(f"Class Weights on Device: {class_weights_tensor}")


## Train loop

In [None]:
def train_model(model, train_loader, criterion, optimizer, scheduler, device):
    model.train()  # Set the model to training mode
    running_loss = 0.0
    correct_predictions = 0
    total_samples = 0

    for batch in tqdm(train_loader, desc="Training"):
        # Move data to the device
        images = batch["image"].to(device)
        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        labels = batch["label"].to(device)

        # Zero the parameter gradients
        optimizer.zero_grad()

        # Forward pass
        outputs = model(images, input_ids, attention_mask)
        loss = criterion(outputs, labels)

        # Backward pass and optimization
        loss.backward()
        optimizer.step()

        # Accumulate metrics
        running_loss += loss.item()
        _, preds = torch.max(outputs, dim=1)
        correct_predictions += (preds == labels).sum().item()
        total_samples += labels.size(0)

    epoch_loss = running_loss / len(train_loader)
    epoch_accuracy = correct_predictions / total_samples

    # Step the scheduler based on the validation loss (if using ReduceLROnPlateau)
    # scheduler.step(epoch_loss)  # Uncomment if using ReduceLROnPlateau

    # Step the scheduler every epoch (if using StepLR or similar)
    scheduler.step()

    return epoch_loss, epoch_accuracy


## Evaluation loop

In [None]:
def evaluate_model(model, val_loader, criterion, device):
    model.eval()  # Set the model to evaluation mode
    running_loss = 0.0
    correct_predictions = 0
    total_samples = 0
    predictions = []
    true_labels = []

    with torch.no_grad():  # Disable gradient computation
        for batch in tqdm(val_loader, desc="Validation"):
            # Move data to the device
            images = batch["image"].to(device)
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            labels = batch["label"].to(device)

            # Forward pass
            outputs = model(images, input_ids, attention_mask)
            loss = criterion(outputs, labels)

            # Accumulate metrics
            running_loss += loss.item()
            _, preds = torch.max(outputs, dim=1)
            correct_predictions += (preds == labels).sum().item()
            total_samples += labels.size(0)

            predictions.extend(preds)
            true_labels.extend(labels)

    predictions = torch.stack(predictions).cpu()
    true_labels = torch.stack(true_labels).cpu()

    epoch_loss = running_loss / len(val_loader)
    epoch_accuracy = correct_predictions / total_samples

    return epoch_loss, epoch_accuracy, classification_report(
        true_labels, predictions, target_names=df['Class'].unique(), output_dict=True
    )

## Confusion matrix

In [None]:
def generate_confusion_matrix(true_labels, predictions, save_cm_path, class_names=['Non-malignant', 'Malignant']):

    # Ensure the directory exists
    os.makedirs(os.path.dirname(save_cm_path), exist_ok=True)

    # Generate Confusion Matrix
    cm = confusion_matrix(true_labels, predictions)

    # Plot Confusion Matrix
    fig, ax = plt.subplots(figsize=(8, 6))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=class_names,
                yticklabels=class_names, linewidths=2, cbar=False, square=True, annot_kws={"size": 14})

    plt.xlabel('Predicted Labels', fontsize=12)
    plt.ylabel('True Labels', fontsize=12)
    plt.title('Confusion Matrix', fontsize=14)

    # Save confusion matrix as PDF
    plt.savefig(save_cm_path, bbox_inches="tight", format="pdf")
    plt.close()
    print(f"Confusion matrix saved as {save_cm_path}")

    # Return classification report
    return classification_report(true_labels, predictions, target_names=class_names, output_dict=True)


## Test loop

In [None]:
def test_model(model, val_loader, criterion, device, seed=None, report_save_path=None):
    model.eval()
    running_loss = 0.0
    correct_predictions = 0
    total_samples = 0
    predictions = []
    true_labels = []
    probabilities = []

    with torch.no_grad():
        for batch in tqdm(val_loader, desc="Evaluation"):
            images = batch["image"].to(device)
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            labels = batch["label"].to(device)

            outputs = model(images, input_ids, attention_mask)
            loss = criterion(outputs, labels)

            running_loss += loss.item()
            _, preds = torch.max(outputs, dim=1)
            correct_predictions += (preds == labels).sum().item()
            total_samples += labels.size(0)

            predictions.extend(preds.cpu())
            true_labels.extend(labels.cpu())

            probs = torch.softmax(outputs, dim=1)
            probabilities.extend(probs.cpu())

    predictions = torch.stack(predictions)
    true_labels = torch.stack(true_labels)
    probabilities = torch.stack(probabilities)

    # Metrics
    accuracy = accuracy_score(true_labels, predictions)
    precision = precision_score(true_labels, predictions, average='macro')
    recall = recall_score(true_labels, predictions, average='macro')
    f1 = f1_score(true_labels, predictions, average='macro')
    auc_roc = roc_auc_score(true_labels, probabilities[:, 1], multi_class="ovr") if probabilities.shape[1] > 1 else None

    print(f"Accuracy: {accuracy:.4f}")
    print(f"Precision: {precision:.4f}")
    print(f"Recall: {recall:.4f}")
    print(f"F1 Score: {f1:.4f}")
    if auc_roc is not None:
        print(f"AUC-ROC: {auc_roc:.4f}")

    # Confusion matrix
    save_cm_path = '/kaggle/working/confusion_matrix.pdf'
    cm = confusion_matrix(true_labels, predictions)
    plt.figure(figsize=(8, 6))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
                xticklabels=['Non-malignant', 'Malignant'],
                yticklabels=['Non-malignant', 'Malignant'])
    plt.xlabel('Predicted')
    plt.ylabel('True')
    plt.title('Confusion Matrix')
    plt.savefig(save_cm_path)
    plt.show()

    # Classification Report
    class_report = classification_report(
        true_labels,
        predictions,
        target_names=['Non-malignant', 'Malignant'],
        output_dict=True
    )

    # Return metrics summary and full classification report
    results_df = pd.DataFrame([{
        'Seed': seed if seed is not None else 'N/A',
        'Accuracy': accuracy,
        'Precision': precision,
        'Recall': recall,
        'F1-Score': f1,
        'AUC-ROC': auc_roc
    }])

    return results_df, class_report

## Multimodal learning (DenseNet121+Deit+BERT)

In [None]:
# Load DenseNet121
def load_densenet_model(weight_path):

    densenet_model = densenet121(pretrained=False)
    densenet_model.classifier = torch.nn.Identity()

    # Load the state_dict
    state_dict = torch.load(weight_path)

    # Remove "classifier.weight" and "classifier.bias" from the state_dict
    state_dict = {k: v for k, v in state_dict.items() if not k.startswith("classifier.")}

    # Load the pruned state_dict into the model
    densenet_model.load_state_dict(state_dict, strict=False)

    # Set model to evaluation mode
    densenet_model.eval()

    for param in densenet_model.parameters():
        param.requires_grad = False  # Freeze all parameters

    return densenet_model

# Load DeiT model
def load_deit_model(weight_path, num_classes=2):

    deit_model = DeiTForImageClassification.from_pretrained("facebook/deit-base-distilled-patch16-224")

    # Modify the classifier layer
    deit_model.classifier = nn.Identity()

    # Load trained weights
    state_dict = torch.load(weight_path, map_location=torch.device('cpu'))
    deit_model.load_state_dict(state_dict, strict=False)

    # Set model to evaluation mode
    deit_model.eval()

    # Freeze parameters
    for param in deit_model.parameters():
        param.requires_grad = False

    return deit_model

# Load the BERT model
def load_bert_model(weight_path, bert_model_name="Charangan/MedBERT"):

    tokenizer = BertTokenizer.from_pretrained(bert_model_name)

    # Initialize the model with the same configuration used during training
    bert_model = BertForSequenceClassification.from_pretrained(bert_model_name, num_labels=2)

    # Load the state_dict into the model
    state_dict = torch.load(weight_path, map_location=torch.device('cpu'))

    # Load the state dict into the model (ignore mismatched keys if any)
    bert_model.load_state_dict(state_dict, strict=False)

    # Set the model to evaluation mode
    bert_model.eval()

    # Freeze the parameters of the model (for feature extraction)
    for param in bert_model.parameters():
        param.requires_grad = False

    return bert_model, tokenizer

# Paths to Kaggle dataset files
densenet_weight_path = '/kaggle/input/best-models-with-seeds/pytorch/default/1/DenseNet_best_model.pt'
deit_weight_path = '/kaggle/input/best-models-with-seeds/pytorch/default/1/deit_best_model.bin'
bert_weight_path = '/kaggle/input/best-models-with-seeds/pytorch/default/1/best_BERT_model_state.bin'

# Load the pretrained models for feature extraction
densenet_model = load_densenet_model(densenet_weight_path)
deit_model = load_deit_model(deit_weight_path)
text_model, bert_tokenizer = load_bert_model(bert_weight_path)

print("Models loaded and ready for feature extraction.")


## Experiments across 10 different seeds

In [None]:
def run_experiments_over_seeds(seed_list):
    all_val_reports = []
    all_test_reports = []
    # Initialize a list to store AUC-ROC scores for each seed
    all_test_auc_roc_scores = []

    for seed in seed_list:
        print(f"\n====== Running for Seed: {seed} ======")

        # Set seed
        set_seed(seed)
        def seed_worker(worker_id): np.random.seed(seed); random.seed(seed)

        # Reload all models fresh for this seed
        densenet_model = load_densenet_model(densenet_weight_path)
        deit_model = load_deit_model(deit_weight_path)
        text_model, bert_tokenizer = load_bert_model(bert_weight_path)

        # === STEP 2: Data split ===
        texts_train_val, texts_test, images_train_val, images_test, labels_train_val, labels_test = train_test_split(
            texts, images, labels, test_size=0.15, stratify=labels, random_state=seed
        )
        texts_train, texts_val, images_train, images_val, labels_train, labels_val = train_test_split(
            texts_train_val, images_train_val, labels_train_val, test_size=0.176, stratify=labels_train_val, random_state=seed
        )

        # === STEP 3: Create datasets and dataloaders ===
        train_dataset = MultiModalDataset(images_train, texts_train, labels_train, bert_tokenizer, train_transform, 512)
        val_dataset = MultiModalDataset(images_val, texts_val, labels_val, bert_tokenizer, val_test_transform, 512)
        test_dataset = MultiModalDataset(images_test, texts_test, labels_test, bert_tokenizer, val_test_transform, 512)

        train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, worker_init_fn=seed_worker)
        val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, worker_init_fn=seed_worker)
        test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False, worker_init_fn=seed_worker)

        # === STEP 4: Initialize model and training tools ===
        # Initialize model, criterion, optimizer, and device
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        model = MultiModalModel(densenet_model, deit_model, text_model, fc_network).to(device)
        criterion = nn.CrossEntropyLoss(weight=class_weights_tensor, label_smoothing=0.1)

        # ✅ Optimizer: only update trainable parameters (attention + fc)
        optimizer = Adam(
            filter(lambda p: p.requires_grad, model.parameters()),
            lr=1e-5, weight_decay=1e-4
        )

        scheduler = StepLR(optimizer, step_size=6, gamma=0.1)

        best_val_f1 = 0
        patience_counter = 0

        # Store loss history
        mm_train_losses = []
        mm_val_losses = []

        for epoch in range(100):
            train_loss, train_acc = train_model(model, train_loader, criterion, optimizer, scheduler, device)
            val_loss, val_acc, val_report = evaluate_model(model, val_loader, criterion, device)
            val_f1 = val_report['macro avg']['f1-score']

            mm_train_losses.append(train_loss)
            mm_val_losses.append(val_loss)

            print(f"Epoch {epoch+1:03d} | Train Acc: {train_acc:.4f} | Val Acc: {val_acc:.4f} | Val F1: {val_f1:.4f}")

            if val_f1 > best_val_f1:
                best_val_f1 = val_f1
                torch.save(model.state_dict(), f"best_model_seed_{seed}.bin")
                patience_counter = 0
            else:
                patience_counter += 1
                if patience_counter >= 10:
                    print("Early stopping.")
                    break

        # Load best model and evaluate on test
        model.load_state_dict(torch.load(f"best_model_seed_{seed}.bin"))

        # Modify the call to test_model to also return the results_df
        # Assuming test_model now returns (results_df, classification_report)
        results_df_test, test_report = test_model(model, test_loader, criterion, device)

        all_val_reports.append(val_report)
        all_test_reports.append(test_report)

        # Extract AUC-ROC score from results_df_test and append to its dedicated list
        # Assuming AUC-ROC is a single value in the DataFrame, likely in the first row
        if not results_df_test.empty and 'AUC-ROC' in results_df_test.columns:
            auc_roc_score = results_df_test['AUC-ROC'].iloc[0]
            all_test_auc_roc_scores.append(auc_roc_score)
        else:
            print(f"Warning: AUC-ROC score not found in results_df_test for seed {seed}")
            all_test_auc_roc_scores.append(None) # Append None or handle as appropriate


        # Plot training and validation loss curves for this seed
        plt.figure(figsize=(10, 5))
        plt.plot(range(1, len(mm_train_losses) + 1), mm_train_losses, label='Training Loss')
        plt.plot(range(1, len(mm_val_losses) + 1), mm_val_losses, label='Validation Loss')
        plt.xlabel('Epoch')
        plt.ylabel('Loss')
        plt.title(f'Training and Validation Loss (Seed {seed})')
        plt.legend()
        plt.grid(True)
        plt.tight_layout()
        plt.show()

    # Return all collected reports and the new AUC-ROC scores list|
    return all_val_reports, all_test_reports, all_test_auc_roc_scores

In [None]:
seed_list = [42, 77, 7, 101, 314, 2024, 123, 88, 11, 999]
val_results, test_results, test_auc_roc = run_experiments_over_seeds(seed_list)

## Results

In [None]:
macro_f1_scores = []
accuracies = []
macro_precisions = []
macro_recalls = []
auc_roc_scores = []

for report in test_results:
    macro_f1_scores.append(report["macro avg"]["f1-score"])
    macro_precisions.append(report["macro avg"]["precision"])
    macro_recalls.append(report["macro avg"]["recall"])
    accuracies.append(report["accuracy"])

# Assuming test_auc_roc is a list of AUC-ROC scores, one for each seed run
auc_roc_scores = test_auc_roc

# Print extracted values, now including AUC-ROC
print("===== Individual Seed Run Results =====")
for i, (acc, f1, prec, rec, roc_auc) in enumerate(zip(accuracies, macro_f1_scores, macro_precisions, macro_recalls, auc_roc_scores), 1):
    print(f"Seed Run {i}:")
    print(f"  Accuracy        : {acc:.4f}")
    print(f"  Macro F1-score  : {f1:.4f}")
    print(f"  Macro Precision : {prec:.4f}")
    print(f"  Macro Recall    : {rec:.4f}")
    print(f"  AUC-ROC         : {roc_auc:.4f}")
    print()

# --- Bootstrap Confidence Interval Calculation ---

# Create a DataFrame from your collected metric lists
# This DataFrame will serve the same purpose as final_results_df_pt for the bootstrap function
metrics_df = pd.DataFrame({
    'Accuracy': accuracies,
    'Precision': macro_precisions, # Assuming you want macro precision for CI
    'Recall': macro_recalls,     # Assuming you want macro recall for CI
    'F1-Score': macro_f1_scores, # Assuming you want macro F1 for CI
    'AUC-ROC': auc_roc_scores
})

# Define metrics for bootstrap (ensure these match the DataFrame column names)
metrics_to_bootstrap = ['Accuracy', 'Precision', 'Recall', 'F1-Score', 'AUC-ROC']

# Bootstrapped CI function (re-defined here for clarity, or ensure it's globally accessible)
def bootstrap_ci(data, n_bootstrap=10000, ci=95):
    data = np.array(data)
    means = []
    n = len(data)
    if n == 0: # Handle empty data case
        return np.nan, np.nan, np.nan
    for _ in range(n_bootstrap):
        sample = np.random.choice(data, size=n, replace=True)
        means.append(np.mean(sample))
    lower = np.percentile(means, (100 - ci) / 2)
    upper = 100 - (100 - ci) / 2
    if ci == 95: # Handle cases where upper bound may be 95% if lower bound is 2.5%
        upper_percentile = 97.5
    elif ci == 90:
        upper_percentile = 95
    else:
        upper_percentile = 100 - (100 - ci) / 2
    upper = np.percentile(means, upper_percentile)
    return np.mean(means), lower, upper

# Prepare summary with Mean, Std, and Bootstrap CI
summary_rows = []
for metric in metrics_to_bootstrap:
    mean_val = metrics_df[metric].mean()
    std_val = metrics_df[metric].std()
    boot_mean, ci_lower, ci_upper = bootstrap_ci(metrics_df[metric].dropna().values) # Handle potential NaNs

    summary_rows.append({
        'Metric': metric,
        'Mean': mean_val,
        'Std Dev': std_val,
        'Boot Mean': boot_mean,
        '95% CI Lower': ci_lower,
        '95% CI Upper': ci_upper
    })

# Final summary table as a Pandas DataFrame
summary_df = pd.DataFrame(summary_rows)

print("---") # Horizontal line for separation
print("===== Aggregated Results (Mean ± Std & Bootstrap 95% CI) =====")

# Print the summary DataFrame formatted for readability
# You can customize this printing more, e.g., using to_string(index=False) or f-strings
for _, row in summary_df.iterrows():
    print(f"{row['Metric']}:")
    print(f"  Mean ± Std       : {row['Mean']:.4f} ± {row['Std Dev']:.4f}")
    print(f"  95% CI (Bootstrap): [{row['95% CI Lower']:.4f}, {row['95% CI Upper']:.4f}]")
    print()