In [None]:
import numpy as np
import pandas as pd
import os
import cv2
from sklearn.utils import shuffle
from sklearn.model_selection import train_test_split
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Tuple

In [40]:

class MultiScaleAttention(nn.Module):
    def __init__(self, in_channels: int, out_channels: int, num_heads: int):
        super(MultiScaleAttention, self).__init__()
        self.num_heads = num_heads
        self.head_dim = out_channels // num_heads
        self.query = nn.Linear(in_channels, out_channels)
        self.key = nn.Linear(in_channels, out_channels)
        self.value = nn.Linear(in_channels, out_channels)
        self.attention = nn.MultiheadAttention(out_channels, num_heads)
        self.layer_norm = nn.LayerNorm(out_channels)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        batch_size, in_channels, height, width = x.shape
        x = x.view(batch_size, in_channels, -1) 
        x = x.permute(0, 2, 1) 
        query = self.query(x)
        key = self.key(x)
        value = self.value(x)
        attention_output, _ = self.attention(query, key, value)
        attention_output = self.layer_norm(attention_output)  
        attention_output = attention_output.permute(0, 2, 1).view(batch_size, -1, height, width)
        return attention_output

class HybridCNNViT(nn.Module):
    def __init__(self, in_channels: int, num_classes: int):
        super(HybridCNNViT, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, 64, kernel_size=7, stride=2, padding=3, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool1 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)  
        self.conv2 = nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(128)
        self.conv3 = nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1, bias=False)
        self.bn3 = nn.BatchNorm2d(256)
        self.maxpool2 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 
        self.conv4 = nn.Conv2d(256, 512, kernel_size=3, stride=2, padding=1, bias=False)
        self.bn4 = nn.BatchNorm2d(512)
        self.multi_scale_attention = MultiScaleAttention(512, 512, 8)
        self.residual = nn.Conv2d(512, 512, kernel_size=1, bias=False)
        self.classifier_conv = nn.Conv2d(512, num_classes, kernel_size=1, stride=1, padding=0, bias=False)
        self.classifier = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Flatten(),
            nn.Dropout(0.5)
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.relu(self.bn1(self.conv1(x)))
        x = self.maxpool1(x)
        x = self.relu(self.bn2(self.conv2(x)))
        x = self.relu(self.bn3(self.conv3(x)))
        x = self.maxpool2(x) 
        x = self.relu(self.bn4(self.conv4(x)))
        attention_output = self.multi_scale_attention(x)
        x = x + self.residual(attention_output) 
        x = self.classifier_conv(x)
        x = self.classifier(x)
        return x


In [41]:
def load_and_pad_images(folder_path, img_size=(128, 128), max_images=None):

    images = []

    for filename in os.listdir(folder_path):

        img_path = os.path.join(folder_path, filename)
        img = cv2.imread(img_path)

        if img is not None:
            img = cv2.resize(img, img_size)
            images.append(img)

    images = np.array(images)

    if max_images is not None and len(images) < max_images:
        num_to_pad = max_images - len(images)

        pad_indices = np.random.randint(0, len(images), num_to_pad)

        padding = images[pad_indices]

        images = np.concatenate([images, padding], axis=0)

    return images

In [None]:
axial_control_folder = "/kaggle/input/msdataset/Multiple Sclerosis/Control-Axial"
sagittal_control_folder = "/kaggle/input/msdataset/Multiple Sclerosis/Control-Sagittal"
axial_ms_folder = "/kaggle/input/msdataset/Multiple Sclerosis/MS-Axial"
sagittal_ms_folder = "/kaggle/input/msdataset/Multiple Sclerosis/MS-Sagittal"


max_samples = max(
    len(os.listdir(axial_control_folder)),
    len(os.listdir(sagittal_control_folder)),
    len(os.listdir(axial_ms_folder)),
    len(os.listdir(sagittal_ms_folder))
)

axial_control_images = load_and_pad_images(axial_control_folder, max_images=max_samples)
sagittal_control_images = load_and_pad_images(sagittal_control_folder, max_images=max_samples)
axial_ms_images = load_and_pad_images(axial_ms_folder, max_images=max_samples)
sagittal_ms_images = load_and_pad_images(sagittal_ms_folder, max_images=max_samples)

In [None]:

import numpy as np
from scipy.fftpack import dct, idct
from PIL import Image, ImageEnhance
import random

def apply_dct(image):
    dct_image = np.zeros_like(image, dtype=np.float32)
    for i in range(3): 
        dct_image[:, :, i] = dct(dct(image[:, :, i], axis=0, norm='ortho'), axis=1, norm='ortho')
    return dct_image

def augment_image(image):
    image = Image.fromarray(np.uint8(image))
    image = ImageEnhance.Brightness(image).enhance(random.uniform(0.8, 1.2))
    image = ImageEnhance.Contrast(image).enhance(random.uniform(0.8, 1.2))
    image = ImageEnhance.Sharpness(image).enhance(random.uniform(0.8, 1.2))
    image = np.array(image)
    image = apply_dct(image)
    return image

def augment_dataset(dataset):
    return np.array([augment_image(image) for image in dataset])

axial_control_images = augment_dataset(axial_control_images)
sagittal_control_images = augment_dataset(sagittal_control_images)
axial_ms_images = augment_dataset(axial_ms_images)
sagittal_ms_images = augment_dataset(sagittal_ms_images)

In [None]:

control_labels = np.zeros(max_samples * 2  , dtype=np.int64) 
ms_labels = np.ones(max_samples * 2      , dtype=np.int64)


In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class_weights = torch.tensor([1.69, 0.71], dtype=torch.float32).to(device)
data = data = np.concatenate([
    axial_control_images,sagittal_control_images, 
    axial_ms_images, sagittal_ms_images
], axis=0)

labels = np.concatenate([
    control_labels, ms_labels
], axis=0)

data, labels = shuffle(data, labels, random_state=42)
train_data, test_data, train_labels, test_labels = train_test_split(
    data, labels, test_size=0.2, random_state=42
)

train_data = torch.tensor(train_data, dtype=torch.float32).permute(0, 3, 1, 2) 
test_data = torch.tensor(test_data, dtype=torch.float32).permute(0, 3, 1, 2)
train_labels = torch.tensor(train_labels, dtype=torch.long)
test_labels = torch.tensor(test_labels, dtype=torch.long)


In [None]:
def calculate_metrics(y_true, y_pred, n_classes):
    
    cm = confusion_matrix(y_true, y_pred, labels=np.arange(n_classes))
    
    sensitivity = {}
    specificity = {}
    f1_scores = {}
    accuracy = 0
    
    for i in range(n_classes):
        TP = cm[i, i]  
        FP = cm[:, i].sum() - TP
        FN = cm[i, :].sum() - TP 
        TN = cm.sum() - (TP + FP + FN) 

        sensitivity[i] = TP / (TP + FN) if TP + FN > 0 else 0

        specificity[i] = TN / (TN + FP) if TN + FP > 0 else 0

        # Precision
        precision = TP / (TP + FP) if TP + FP > 0 else 0

        # F1 Score
        f1_scores[i] = 2 * (precision * sensitivity[i]) / (precision + sensitivity[i]) if precision + sensitivity[i] > 0 else 0

    # Calculate overall accuracy
    accuracy = np.trace(cm) / cm.sum()

    # Collect metrics in a dictionary
    metrics_dict = {
        'Sensitivity': sensitivity,
        'Specificity': specificity,
        'F1 Score': f1_scores,
        'Accuracy': accuracy
    }

    return metrics_dict

In [None]:
import torch
import torch.optim as optim
def train_model(train_loader):
    model = HybridCNNViT(in_channels=3, num_classes=2)
    criterion = nn.CrossEntropyLoss(weight=class_weights)
    optimizer = torch.optim.AdamW(model.parameters(), lr=0.001, weight_decay=1e-4)
    model = model.to(device)
    model = nn.DataParallel(model)
     
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=100)
    train_losses = []
    val_losses = []
    train_accuracies = []
    val_accuracies = []
    num_epochs = 100
    for epoch in range(num_epochs):
        model.train()
        total_loss = 0
        correct = 0
        total = 0
        for i, (inputs, labels) in enumerate(train_loader):
            inputs, labels = inputs.to(device), labels.to(device)
    
            outputs = model(inputs)
            loss = criterion(outputs, labels)
    
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
    
            total_loss += loss.item()
    
            _, predicted = torch.max(outputs, 1)
            correct += (predicted == labels).sum().item()
            total += labels.size(0)
    
        # Calculate accuracy for this epoch
        train_accuracy = 100 * correct / total
        train_losses.append(total_loss / len(train_loader))
        train_accuracies.append(train_accuracy)
        scheduler.step()
    
        print(f"Epoch {epoch+1}, Loss: {total_loss / len(train_loader)}, Accuracy: {train_accuracy}%")
    return model,optimizer,scheduler

In [None]:
import torch
import numpy as np
from sklearn.metrics import confusion_matrix
from sklearn.metrics import ConfusionMatrixDisplay


def test_eval(model, test_loader, n_classes=4):
    model.eval()
    correct = 0
    total = 0
    test_loss = 0
    predictions = []
    true_labels = []
    criterion = nn.CrossEntropyLoss(weight=class_weights)
    with torch.no_grad():
        for inputs, labels in test_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            test_loss += loss.item()
    
            _, predicted = torch.max(outputs, 1)
            predictions.extend(predicted.cpu().numpy())
            true_labels.extend(labels.cpu().numpy())
    
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    
    metrics = calculate_metrics(true_labels, predictions, n_classes)
    
    accuracy = 100 * correct / total
    average_loss = test_loss / len(test_loader)
    
    print(f"Test Loss: {average_loss:.4f}, Test Accuracy: {accuracy:.2f}%")
    print(f"Sensitivity: {metrics['Sensitivity']}")
    print(f"Specificity: {metrics['Specificity']}")
    print(f"F1 Score: {metrics['F1 Score']}")
    print(f"Accuracy: {metrics['Accuracy']:.2f}")
    
    plot_confusion_matrix(predictions, true_labels, class_names=["Control Axial", "Control Sagittal", "MS Axial", "MS Sagittal"])



In [None]:
import torch
import matplotlib.pyplot as plt
from sklearn.metrics import roc_curve, auc
from sklearn.preprocessing import label_binarize
import numpy as np


def plot_roc_curve(model, test_loader, device):
    
    model.eval()  

    correct = 0
    total = 0
    test_loss = 0
    predictions = []
    true_labels = []

    with torch.no_grad():
        for data, labels in test_loader:
            data, labels = data.to(device), labels.to(device)

            outputs = model(data)
            
            probs = torch.softmax(outputs, dim=1).cpu().numpy()
            
            true_labels.extend(labels.cpu().numpy())
            
            _, predicted = torch.max(outputs.data, 1)
            predictions.extend(predicted.cpu().numpy())

            # Count correct predictions
            correct += (predicted == labels).sum().item()
            total += labels.size(0)

    true_labels = label_binarize(true_labels, classes=[0, 1])  # For 2 classes

    fpr, tpr, _ = roc_curve(true_labels[:, 0], np.array(predictions) == 0)  # Compare class-wise
    roc_auc = auc(fpr, tpr)

    plt.figure(figsize=(10, 8))
    plt.plot(fpr, tpr, lw=2, label=f'Class 0 (AUC = {roc_auc:.2f})')

    plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')

    plt.xlim([0.0, 1.0])
    plt.ylim([0.0, 1.05])
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    plt.title('ROC Curve for Binary Classification')
    plt.legend(loc='lower right')
    plt.grid(True)
    plt.show()

    accuracy = 100 * correct / total
    print(f'Accuracy: {accuracy:.2f}%')
    
    return accuracy  

In [None]:
import torch
import matplotlib.pyplot as plt
from sklearn.metrics import roc_curve, auc, precision_recall_curve, average_precision_score
from sklearn.preprocessing import label_binarize
import numpy as np

def plot_roc_pr_curve(model, test_loader, device):
    model.eval() 
    correct = 0
    total = 0
    test_loss = 0
    predictions = []
    true_labels = []
    with torch.no_grad():
        for data, labels in test_loader:
            data, labels = data.to(device), labels.to(device)

            outputs = model(data)

            probs = torch.softmax(outputs, dim=1).cpu().numpy()

            true_labels.extend(labels.cpu().numpy())
            
            _, predicted = torch.max(outputs.data, 1)
            predictions.extend(predicted.cpu().numpy())

            correct += (predicted == labels).sum().item()
            total += labels.size(0)

    true_labels_bin = label_binarize(true_labels, classes=[0, 1, 2, 3])  # For 4 classes
    n_classes = 4 
    fpr = dict()
    tpr = dict()
    roc_auc = dict()

    precision = dict()
    recall = dict()
    pr_auc = dict()

    for i in range(n_classes):
        fpr[i], tpr[i], _ = roc_curve(true_labels_bin[:, i], probs[:, i])
        roc_auc[i] = auc(fpr[i], tpr[i])

        precision[i], recall[i], _ = precision_recall_curve(true_labels_bin[:, i], probs[:, i])
        pr_auc[i] = average_precision_score(true_labels_bin[:, i], probs[:, i])

    plt.figure(figsize=(14, 10))
    plt.subplot(1, 2, 1)
    for i in range(n_classes):
        plt.plot(fpr[i], tpr[i], lw=2, label=f'Class {i} (AUC = {roc_auc[i]:.2f})')
    plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')
    plt.xlim([0.0, 1.0])
    plt.ylim([0.0, 1.05])
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    plt.title('ROC Curve for 4-Class Classification')
    plt.legend(loc='lower right')
    plt.grid(True)

    plt.subplot(1, 2, 2)
    for i in range(n_classes):
        plt.plot(recall[i], precision[i], lw=2, label=f'Class {i} (AP = {pr_auc[i]:.2f})')
    plt.xlabel('Recall')
    plt.ylabel('Precision')
    plt.title('Precision-Recall Curve for 4-Class Classification')
    plt.legend(loc='lower left')
    plt.grid(True)

    plt.tight_layout()
    plt.show()

    accuracy = 100 * correct / total
    print(f'Accuracy: {accuracy:.2f}%')
    
    return accuracy 


In [None]:
def measure_inference_time(model, test_loader, device):
    model.eval() 
    
    start_time = time.time()

    with torch.no_grad():
        for inputs, _ in test_loader:
            inputs = inputs.to(device)
            outputs = model(inputs)
            break  
    
    end_time = time.time()
    inference_time = end_time - start_time

    return inference_time

In [None]:
from torch.utils.data import Dataset, DataLoader

class MultipleSclerosisDataset(Dataset):
    def __init__(self, data, labels):
        self.data = data
        self.labels = labels

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx].to(device), self.labels[idx].to(device)

train_dataset = MultipleSclerosisDataset(train_data, train_labels)
test_dataset = MultipleSclerosisDataset(test_data, test_labels)
models = []
batch_sizes = []
inference_times = []
for i in [8,16,32,64,40,52,100,128,256,512]:
    print(f"Batch Size : {i}")
    batch_size = i
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True)
    new_model,optimizer,scheduler = train_model(train_loader)
    models.append(new_model)
    test_eval(new_model,test_loader)
    accuracy = plot_roc_curve(new_model, test_loader, device)
    inference_time = measure_inference_time(new_model, test_loader, device)
    batch_sizes.append(i)
    inference_times.append(inference_time)
    torch.save({
        'batch_size': i, 
        'model_state_dict': new_model.state_dict(), 
        'optimizer_state_dict': optimizer.state_dict(), 
        'scheduler_state_dict': scheduler.state_dict(),
    }, f'checkpoint{i}.pth')

    print("Model saved successfully!")
plt.figure(figsize=(10, 6))
plt.plot(batch_sizes, inference_times, marker='o', linestyle='-', color='b')
plt.title('Model Inference Time vs. Batch Size')
plt.xlabel('Batch Size')
plt.ylabel('Inference Time (seconds)')
plt.grid(True)
plt.show()

In [None]:
models

In [None]:
main_model = models[0]
checkpoint = torch.load('/kaggle/working/checkpoint8.pth')
main_model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
main_model.eval()  
test_loader = DataLoader(test_dataset, batch_size=64,shuffle=True)
correct = 0
total = 0
test_loss = 0
predictions = []
true_labels = []
predicted_probs = []  
criterion = nn.CrossEntropyLoss(weight=class_weights)

with torch.no_grad():
    for inputs, labels in test_loader:
        inputs, labels = inputs.to(device), labels.to(device)
        outputs = main_model(inputs)
        loss = criterion(outputs, labels)
        test_loss += loss.item()

        probs = torch.softmax(outputs, dim=1)  
        predicted_probs.extend(probs.cpu().numpy()) 

        _, predicted = torch.max(outputs, 1)
        predictions.extend(predicted.cpu().numpy())
        true_labels.extend(labels.cpu().numpy())

        total += labels.size(0)
        correct += (predicted == labels).sum().item()

metrics = calculate_metrics(true_labels, predictions, 2)

accuracy = 100 * correct / total
average_loss = test_loss / len(test_loader)

print(f"Test Loss: {average_loss:.4f}, Test Accuracy: {accuracy:.2f}%")
print(f"Sensitivity: {metrics['Sensitivity']}")
print(f"Specificity: {metrics['Specificity']}")
print(f"F1 Score: {metrics['F1 Score']}")
print(f"Accuracy: {metrics['Accuracy']:.2f}")

plot_confusion_matrix(predictions, true_labels, class_names=["Control", "MS"])



In [None]:
from sklearn.metrics import roc_curve, auc

def plot_roc_curve(y_true, y_scores):
    fpr, tpr, _ = roc_curve(y_true, y_scores)
    roc_auc = auc(fpr, tpr)
    
    plt.figure()
    plt.plot(fpr, tpr, color='darkorange', lw=2, label='ROC curve (area = %0.2f)' % roc_auc)
    plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')
    plt.xlim([0.0, 1.0])
    plt.ylim([0.0, 1.05])
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    plt.title('Receiver Operating Characteristic')
    plt.legend(loc="lower right")
    plt.show()

In [None]:
predicted_probs = np.array(predicted_probs)
positive_class_probs = predicted_probs[:, 1] 
plot_roc_curve(true_labels, positive_class_probs)

In [None]:

def plot_feature_maps(model, input_tensor):
    model.eval()
    with torch.no_grad():
        activations = model.conv1(input_tensor).cpu().numpy() 
        num_filters = activations.shape[1]
        fig, axes = plt.subplots(1, num_filters, figsize=(15, 15))
        for i in range(num_filters):
            axes[i].imshow(activations[0, i, :, :], cmap='viridis')
            axes[i].axis('off')
        plt.show()


In [None]:
predicted_probs = [tensor for tensor in predicted_probs]
all_predicted_probs = np.concatenate(predicted_probs, axis=0)
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np

if isinstance(predicted_probs, list):
    all_predicted_probs = np.concatenate(predicted_probs, axis=0) 
else:
    all_predicted_probs = predicted_probs

def plot_histogram(data, title, xlabel, ylabel):
    """
    Helper function to plot a histogram.
    """
    plt.figure(figsize=(8, 6))
    sns.histplot(data, kde=True, bins=30, color='blue', edgecolor='black')
    plt.title(title, fontsize=16)
    plt.xlabel(xlabel, fontsize=14)
    plt.ylabel(ylabel, fontsize=14)
    plt.grid(axis='y', linestyle='--', alpha=0.7)
    plt.show()

# Histogram of predicted probabilities
plot_histogram(
    all_predicted_probs,
    title="Histogram of Predicted Probabilities",
    xlabel="Probability",
    ylabel="Frequency"
)
