In [1]:
import nibabel as nib
import os
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
import torch.nn.functional as F
import cv2
import random
import matplotlib.pyplot as plt
from torchvision import models
from torch.utils.data import Dataset, DataLoader
from scipy.ndimage import zoom
from glob import glob
from PIL import Image
from tqdm import tqdm
from torch.utils.data import random_split
from torch.utils.data import Subset, DataLoader
from torchsummary import summary
from sklearn.metrics import f1_score, precision_score, recall_score, accuracy_score, confusion_matrix

torch.cuda.manual_seed(42)
np.random.seed(42)
random.seed(42)

if torch.cuda.is_available():
    torch.cuda.manual_seed_all(42)
torch.backends.cudnn.deterministic = True

def fix_seed(seed):
    '''
    Args : 
        seed : fix the seed
    Function which allows to fix all the seed and get reproducible results
    '''
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.benchmark = False
    torch.use_deterministic_algorithms(True)
    os.environ['OMP_NUM_THREADS'] = '1'
    os.environ['MKL_NUM_THREADS'] = '1'
    torch.set_num_threads(1)

fix_seed(42)
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"

## Data Processing and Exploration

### Choosing the right data:

The full dataset has been processed from it's original 350GB+ MRI dataset. The original csv file is filtered to include only 1.5T or 3.0T scans as they contain the most comprehensive sets of MRI scans of differing types (T1-weighted, T2, Bold, etc). 

The clinical dementia rating (cdr) with values (0, 0.5, 1, 2) represent 0 = absent; 0.5 = questionable; 1= present, but mild; 2 = moderate (reference: https://www.sciencedirect.com/topics/neuroscience/clinical-dementia-rating). They have been remapped from (0, 0.5, 1, 2) to (0, 1, 2, 3) due to issues processing it with Torch dataloader, leading to a missing class for 0.5 cdr.

Based on each MRI session value (Label), if there are multiple values for the cdr then the scans are excluded.

In [2]:
diagnostic_file = glob('diagnosis.csv')[0]

diagnostic_df = pd.read_csv(diagnostic_file)

diagnostic_df = diagnostic_df[(diagnostic_df.Scanner == '3.0T') | (diagnostic_df.Scanner == '1.5T')]

multiple_values = diagnostic_df.groupby('Label').filter(lambda group: group['cdr'].nunique() > 1)['Label'].unique()

cdr_map = {0.0: 0, 0.5: 1, 1.0: 2, 2.0: 3}

# Filter out rows with these values in column A
filtered_diagnostic_df = diagnostic_df[~diagnostic_df['Label'].isin(multiple_values)]
filtered_diagnostic_df = filtered_diagnostic_df.drop_duplicates(subset='Label')
# filtered_diagnostic_df['cdr'] = filtered_diagnostic_df['cdr'].map(cdr_map)
filtered_diagnostic_df['file_tag'] = filtered_diagnostic_df.Label.apply(lambda file_name: file_name[-5:])
print(len(filtered_diagnostic_df))

valid_mr_scans_ls = filtered_diagnostic_df.file_tag.to_list()
print(valid_mr_scans_ls[:3])
filtered_diagnostic_df.file_tag

# filtered_diagnostic_df.to_csv('filtered_data.csv')

1379
['d0129', 'd2430', 'd3132']


0       d0129
2       d2430
3       d3132
12      d0371
15      d2340
        ...  
6193    d0148
6194    d2526
6195    d1566
6216    d1717
6217    d0407
Name: file_tag, Length: 1379, dtype: object

In [3]:
data_path = 'bids/' ## Change as needed

def recursive_glob_with_filter(directory, filter_list):
    # Use recursive glob pattern to find all files
    all_files = glob(directory + '/**/*T1*nii.gz', recursive=True)
    
    # Filter files based on the presence of strings from filter_list in the file path
    filtered_files = [file for file in all_files if any(filter_str in file for filter_str in filter_list)]
    
    return filtered_files

file_ls = recursive_glob_with_filter(data_path, valid_mr_scans_ls)
print(file_ls[:2])

['bids\\sub-OAS30001\\ses-d0129\\anat\\sub-OAS30001_ses-d0129_run-01_T1w.nii.gz', 'bids\\sub-OAS30001\\ses-d0129\\anat\\sub-OAS30001_ses-d0129_run-02_T1w.nii.gz']


In [4]:
# Function to extract tag from file path
def extract_tag_from_file_path(file_path):
    file_path = file_path.replace(data_path[:-1], '')
    return file_path.split('\\')[2][-5:]

try:
    data_loader_df = pd.read_csv('data_loader_df.csv')
except:
    # Create a new DataFrame to store the results
    data_loader_df = pd.DataFrame(columns=['file_path', 'label'])

    # Iterate over the file list and find the corresponding CDR value
    for file_path in file_ls:
        tag = extract_tag_from_file_path(file_path)
        cdr_value = filtered_diagnostic_df.loc[filtered_diagnostic_df['file_tag'] == tag, 'cdr'].values
        if cdr_value.size > 0:  # Check if cdr_value is not empty
            new_row = pd.DataFrame({'file_path': [file_path], 'label': [cdr_value[0]]})
            data_loader_df = pd.concat([data_loader_df, new_row], ignore_index=True)
            data_loader_df.to_csv('data_loader_df.csv')

print(data_loader_df)

      Unnamed: 0                                          file_path  label
0              0  D:/DL/oasis-scripts/download_scans/bids\sub-OA...    0.0
1              1  D:/DL/oasis-scripts/download_scans/bids\sub-OA...    0.0
2              2  D:/DL/oasis-scripts/download_scans/bids\sub-OA...    0.0
3              3  D:/DL/oasis-scripts/download_scans/bids\sub-OA...    0.0
4              4  D:/DL/oasis-scripts/download_scans/bids\sub-OA...    0.0
...          ...                                                ...    ...
2991        2991  D:/DL/oasis-scripts/download_scans/bids\sub-OA...    0.0
2992        2992  D:/DL/oasis-scripts/download_scans/bids\sub-OA...    1.0
2993        2993  D:/DL/oasis-scripts/download_scans/bids\sub-OA...    0.0
2994        2994  D:/DL/oasis-scripts/download_scans/bids\sub-OA...    0.5
2995        2995  D:/DL/oasis-scripts/download_scans/bids\sub-OA...    0.0

[2996 rows x 3 columns]


This section is to extract 2D slices from 3D volumetric MRI scans. It iterates through the T1-weighted scans and slices between indexes 100-160, choosing a slice for every 3 slices. If the scan has less then 160 slices, then the middle slice will be taken of which there is only 1. Eventually there should be a dataset of 58,376. 

JPG is the saved file format due to its smaller size. NPY was tried but was 840GB and hence rejected.

**Run only once (roughly 7-20 mins depending on CPU).**

Due to an extremely imbalanced dataset (Label 0.0: 50216; Label 1.0: 5748; Label 2.0: 1971; Label 3.0: 441), weights will be used. 

In [5]:
from collections import defaultdict
import os
import pandas as pd
import torch

save_dir = data_path[:-5] + "preprocessed_images"
# Dictionary to store the count of each label
class_counts = defaultdict(int)

# List to store file paths and labels
file_paths_labels = []

# Iterate over all files in the save directory
for filename in os.listdir(save_dir):
    if filename.startswith("label_") and filename.endswith(".jpg"):
        # Extract the label from the filename and convert it to a decimal (float)
        label = float(filename.split("_")[1])
        # Increment the count for this label
        class_counts[label] += 1
        # Append the full file path and label to the list
        full_path = os.path.join(save_dir, filename)
        file_paths_labels.append((full_path, label))

dementia_count = 0
healthy_count = 0
# Print the count of each label
for label, count in class_counts.items():
    print(f"Label {label}: {count}")
    if label >0:
        dementia_count+=count
    else:
        healthy_count+=count
# Create a DataFrame from the list of file paths and labels
jpg_data_loader_df = pd.DataFrame(file_paths_labels, columns=['file_path', 'label'])

## Print the first few rows of the DataFrame
# print(jpg_data_loader_df.head())

# Calculate weights: Inverse of the frequency seems like a simple choice
total_count = sum(class_counts.values())
weights = {k: total_count / v for k, v in class_counts.items()}
print(weights)
# Define the device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)
# Convert weights to a tensor, ensuring the labels are in the correct order
weights_tensor = torch.tensor([weights[0], weights[1], weights[2], weights[3]], dtype=torch.float32)

print(weights_tensor)
weights_tensor = weights_tensor.to(device)
jpg_data_loader_df.to_csv('jpg_data_loader_df.csv')

all_dementia_weight = (healthy_count+dementia_count)/dementia_count
# all_dementia_weight = 3
print(all_dementia_weight)


Label 0.0: 50216
Label 1.0: 5748
Label 2.0: 1971
Label 3.0: 441
{0.0: 1.1624980086028358, 1.0: 10.155880306193458, 2.0: 29.617453069507864, 3.0: 132.3718820861678}
cuda
tensor([  1.1625,  10.1559,  29.6175, 132.3719])
7.153921568627451


## Functions

Functions for training the model. Includes NiftiDataset, HierarchicalCrossEntropyLoss, load_model, train_model and evaluate_model.

In [6]:
from torch.utils.data import DataLoader, random_split

class NiftiDataset(Dataset):
    def __init__(self, dataframe, preprocessed_dir):
        self.dataframe = dataframe
        self.preprocessed_dir = preprocessed_dir

    def __len__(self):
        return len(self.dataframe)

    def __getitem__(self, idx):
        row = self.dataframe.iloc[idx]
        label = row['label']
        img_path = row['file_path']  # Use the exact path from the dataframe
        scan = cv2.imread(img_path)
        # scan = cv2.resize(scan, (224,224), interpolation=cv2.INTER_LINEAR)
        scan = scan / 255.0  # Normalize the image to [0, 1]
        scan_tensor = torch.from_numpy(scan).float()
        
        # If the image is grayscale, we use unsqueeze to add the channel dimension
        if len(scan.shape) == 2:
            scan_tensor = scan_tensor.unsqueeze(0)  # Add channel dimension for grayscale image
        else:
            scan_tensor = scan_tensor.permute(2, 0, 1)  # Rearrange dimensions for color image
        
        label_tensor = torch.tensor(label).long()

        return scan_tensor, label_tensor

dataset = NiftiDataset(jpg_data_loader_df, preprocessed_dir=save_dir)

# Create a list of labels for stratified splitting
labels = jpg_data_loader_df.iloc[:, 1].values
myset = set(labels)
print(myset)

# Define the ratios for splitting
train_ratio = 0.7
val_ratio = 0.2
test_ratio = 0.1

# Calculate the sizes for each split
total_size = len(dataset)
train_size = int(total_size * train_ratio)
val_size = int(total_size * val_ratio)
test_size = total_size - train_size - val_size

# Split the dataset into train, validation, and test sets
train_dataset, val_test_dataset = random_split(dataset, [train_size, val_size + test_size])
val_dataset, test_dataset = random_split(val_test_dataset, [val_size, test_size])

# Create dataloaders for each split
batch_size = 32  # Adjust the batch size as needed
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

# Function to count labels in a dataset
def count_labels(dataset):
    label_counts = {}
    for _, label_tensor in dataset:
        label = label_tensor.item()
        if label in label_counts:
            label_counts[label] += 1
        else:
            label_counts[label] = 1
    return label_counts


{0.0, 1.0, 2.0, 3.0}


### Ensemble learning

In [7]:
import os
import torch
from torch.utils.data import DataLoader
from sklearn.metrics import f1_score, precision_score, recall_score

def evaluate_ensemble(models, test_dataloader, device, method='averaging', weights=None):
    predictions = []
    all_labels = []

    if method in ['weighted_ensemble', 'weighted_max_voting']:
        if weights is None:
            # Initialize weights with equal values
            weights = [1.0 / len(models)] * len(models)
        elif len(weights) != len(models):
            raise ValueError("Number of weights must be equal to the number of models.")

    for inputs, labels in test_dataloader:
        inputs = inputs.to(device)
        labels = labels.to(device)

        if method == 'averaging':
            ensemble_outputs = []
            for i, model in enumerate(models):
                model.cuda()
                model.eval()
                outputs = model(inputs)

                if hasattr(outputs, 'aux_logits'):
                    outputs = outputs.logits

                ensemble_outputs.append(outputs)

            ensemble_outputs = torch.stack(ensemble_outputs, dim=0)
            ensemble_mean = torch.mean(ensemble_outputs, dim=0)
            _, predicted = torch.max(ensemble_mean, dim=1)
        elif method == 'max_voting':
            ensemble_votes = torch.zeros((len(models), len(inputs)), dtype=torch.long).to(device)
            for i, model in enumerate(models):
                model.cuda()
                model.eval()
                outputs = model(inputs)

                if hasattr(outputs, 'aux_logits'):
                    outputs = outputs.logits

                _, predicted = torch.max(outputs, dim=1)
                ensemble_votes[i] = predicted

            predicted, _ = torch.mode(ensemble_votes, dim=0)
        elif method == 'weighted_averaging':
            ensemble_outputs = []
            for i, model in enumerate(models):
                model.cuda()
                model.eval()
                outputs = model(inputs)

                if hasattr(outputs, 'aux_logits'):
                    outputs = outputs.logits

                outputs *= weights[i]
                ensemble_outputs.append(outputs)

            ensemble_mean = torch.mean(torch.stack(ensemble_outputs, dim=0), dim=0)
            _, predicted = torch.max(ensemble_mean, dim=1)
            

        predictions.extend(predicted.tolist())
        all_labels.extend(labels.tolist())

    cm = confusion_matrix(all_labels, predictions)
    print(cm)
    class_accuracies = cm.diagonal() / cm.sum(axis=1)
    print(class_accuracies)
    
    accuracy = torch.sum(torch.tensor(predictions) == torch.tensor(all_labels)).item() / len(predictions)
    f1 = f1_score(all_labels, predictions, average='macro')
    precision = precision_score(all_labels, predictions, average='macro')
    recall = recall_score(all_labels, predictions, average='macro')
    
    print(f'Test Accuracy: {accuracy:.4f}, Test F1: {f1:.4f}, Test Precision: {precision:.4f}, Test Recall: {recall:.4f}')



#### Load models from checkpoints

In [8]:
resnet = models.resnet18()
num_features = resnet.fc.in_features
resnet.fc = nn.Linear(num_features, 4)

checkpoint = torch.load("checkpoints\\best_resnet18_w_crossentropy.pth", map_location=torch.device('cuda'))
resnet.load_state_dict(checkpoint)

<All keys matched successfully>

In [9]:
inception = models.inception_v3()
num_features = inception.fc.in_features
inception.fc = nn.Linear(num_features,4)

num_features_aux = inception.AuxLogits.fc.in_features
inception.AuxLogits.fc = nn.Linear(num_features_aux, 4)

checkpoint = torch.load("checkpoints\\best_inception_v3_w_crossentropy.pth", map_location=torch.device('cuda'))
inception.load_state_dict(checkpoint)



<All keys matched successfully>

In [10]:
import timm
vit = timm.create_model('vit_small_patch16_224', pretrained=False)

# Modify the final layer to match the number of classes (4 in this case)
num_features = vit.head.in_features
vit.head = nn.Linear(num_features, 4)
checkpoint = torch.load("checkpoints\\best_vit_small_patch16_224_w_crossentropy.pth", map_location=torch.device('cuda'))
vit.load_state_dict(checkpoint)

  from .autonotebook import tqdm as notebook_tqdm


<All keys matched successfully>

###  Ensemble Learning

In [13]:
import torch
import torch.nn as nn
import torchvision.models as models

class ResNetBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1):
        super(ResNetBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channels)
        self.downsample = None
        if stride != 1 or in_channels != out_channels:
            self.downsample = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(out_channels)
            )

    def forward(self, x):
        residual = x
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        out = self.conv2(out)
        out = self.bn2(out)
        if self.downsample is not None:
            residual = self.downsample(x)
        out += residual
        out = self.relu(out)
        return out

import torch
import torch.nn as nn
import torchvision.models as models

class ResNetBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1):
        super(ResNetBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channels)
        self.downsample = None
        if stride != 1 or in_channels != out_channels:
            self.downsample = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(out_channels)
            )

    def forward(self, x):
        residual = x
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        out = self.conv2(out)
        out = self.bn2(out)
        if self.downsample is not None:
            residual = self.downsample(x)
        out += residual
        out = self.relu(out)
        return out


class UNet_with_ResNet_Blocks(nn.Module):
    def __init__(self, num_classes):
        super(UNet_with_ResNet_Blocks, self).__init__()
        self.encoder1 = ResNetBlock(3, 64)
        self.pool = nn.MaxPool2d(2, 2)
        self.encoder2 = ResNetBlock(64, 128)
        self.bottleneck = ResNetBlock(128, 256)
        self.upconv2 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
        self.decoder2 = ResNetBlock(256, 128)
        self.upconv1 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
        self.decoder1 = ResNetBlock(128, 64)
        self.avg_pool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(64, num_classes)

    def forward(self, x):
        # Encoder
        enc1 = self.encoder1(x)
        x = self.pool(enc1)
        enc2 = self.encoder2(x)
        x = self.pool(enc2)
        x = self.bottleneck(x)
        x = self.upconv2(x)
        x = torch.cat((x, enc2), dim=1)
        x = self.decoder2(x)
        x = self.upconv1(x)
        x = torch.cat((x, enc1), dim=1)
        x = self.decoder1(x)

        # Classification
        x = self.avg_pool(x)
        x = torch.flatten(x, 1)
        x = self.fc(x)
        return x
    
class AttentionLayer(nn.Module):
    def __init__(self, channel, reduction=16):
        super(AttentionLayer, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Sequential(
            nn.Linear(channel, channel // reduction, bias=False),
            nn.ReLU(inplace=True),
            nn.Linear(channel // reduction, channel, bias=False),
            nn.Sigmoid()
        )

    def forward(self, x):
        b, c, _, _ = x.size()
        y = self.avg_pool(x).view(b, c)
        y = self.fc(y).view(b, c, 1, 1)
        return x * y.expand_as(x)

class VGG_Attention(nn.Module):
    def __init__(self, num_classes=4):
        super(VGG_Attention, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
            AttentionLayer(64),

            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(128, 128, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
            AttentionLayer(128),

            nn.Conv2d(128, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
            AttentionLayer(256),

            nn.Conv2d(256, 512, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(512, 512, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(512, 512, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
            AttentionLayer(512),

            nn.Conv2d(512, 512, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(512, 512, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(512, 512, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
            AttentionLayer(512),
        )
        self.avgpool = nn.AdaptiveAvgPool2d((7, 7))
        self.classifier = nn.Sequential(
            nn.Linear(512 * 7 * 7, 4096),
            nn.ReLU(inplace=True),
            nn.Dropout(),
            nn.Linear(4096, 4096),
            nn.ReLU(inplace=True),
            nn.Dropout(),
            nn.Linear(4096, num_classes),
        )

    def forward(self, x):
        x = self.features(x)
        x = self.avgpool(x)
        x = x.view(x.size(0), -1)
        x = self.classifier(x)
        return x
    
import torch
import torch.nn as nn
import torchvision.models as models

class ChannelAttention(nn.Module):
    def __init__(self, in_channels, reduction_ratio=16):
        super(ChannelAttention, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.max_pool = nn.AdaptiveMaxPool2d(1)
        self.fc = nn.Sequential(
            nn.Conv2d(in_channels, in_channels // reduction_ratio, 1, bias=False),
            nn.ReLU(),
            nn.Conv2d(in_channels // reduction_ratio, in_channels, 1, bias=False)
        )
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        avg_out = self.fc(self.avg_pool(x))
        max_out = self.fc(self.max_pool(x))
        out = avg_out + max_out
        return self.sigmoid(out)

class SpatialAttention(nn.Module):
    def __init__(self, kernel_size=7):
        super(SpatialAttention, self).__init__()
        self.conv = nn.Conv2d(2, 1, kernel_size, padding=kernel_size // 2, bias=False)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        avg_out = torch.mean(x, dim=1, keepdim=True)
        max_out, _ = torch.max(x, dim=1, keepdim=True)
        out = torch.cat([avg_out, max_out], dim=1)
        out = self.conv(out)
        return self.sigmoid(out)

class CBAM(nn.Module):
    def __init__(self, in_channels, reduction_ratio=16, kernel_size=7):
        super(CBAM, self).__init__()
        self.channel_attention = ChannelAttention(in_channels, reduction_ratio)
        self.spatial_attention = SpatialAttention(kernel_size)

    def forward(self, x):
        x = x * self.channel_attention(x)
        x = x * self.spatial_attention(x)
        return x

class CBAMResNet(nn.Module):
    def __init__(self, model=models.resnet50(), num_classes=4, in_channels_ls = [256, 512, 1024, 2048]):
        super(CBAMResNet, self).__init__()
        self.base_model = model
        self.cbam1 = CBAM(in_channels=in_channels_ls[0])
        self.cbam2 = CBAM(in_channels=in_channels_ls[1])
        self.cbam3 = CBAM(in_channels=in_channels_ls[2])
        self.cbam4 = CBAM(in_channels=in_channels_ls[3])

        num_features = self.base_model.fc.in_features
        self.base_model.fc = nn.Linear(num_features, num_classes)

    def forward(self, x):
        x = self.base_model.conv1(x)
        x = self.base_model.bn1(x)
        x = self.base_model.relu(x)
        x = self.base_model.maxpool(x)

        x = self.base_model.layer1(x)
        x = self.cbam1(x)

        x = self.base_model.layer2(x)
        x = self.cbam2(x)

        x = self.base_model.layer3(x)
        x = self.cbam3(x)

        x = self.base_model.layer4(x)
        x = self.cbam4(x)

        x = self.base_model.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.base_model.fc(x)

        return x
    
class SeparableConv2d(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0):
        super(SeparableConv2d, self).__init__()
        self.depthwise = nn.Conv2d(in_channels, in_channels, kernel_size, stride, padding, groups=in_channels)
        self.pointwise = nn.Conv2d(in_channels, out_channels, 1, 1, 0)

    def forward(self, x):
        x = self.depthwise(x)
        x = self.pointwise(x)
        return x

class XceptionInceptionModule(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(XceptionInceptionModule, self).__init__()
        self.branch1 = SeparableConv2d(in_channels, out_channels, kernel_size=3, padding=1)
        self.branch2 = nn.Sequential(
            nn.Conv2d(in_channels, out_channels // 2, kernel_size=1),
            nn.ReLU(),
            SeparableConv2d(out_channels // 2, out_channels, kernel_size=3, padding=1)
        )
        # Removed one branch for simplification
        self.branch4 = nn.Sequential(
            nn.MaxPool2d(kernel_size=3, stride=1, padding=1),
            nn.Conv2d(in_channels, out_channels, kernel_size=1)
        )

    def forward(self, x):
        branch1 = self.branch1(x)
        branch2 = self.branch2(x)
        branch4 = self.branch4(x)
        return torch.cat([branch1, branch2, branch4], 1)

class XceptionInceptionNet(nn.Module):
    def __init__(self, num_classes=4):
        super(XceptionInceptionNet, self).__init__()
        self.conv1 = nn.Conv2d(3, 32, kernel_size=7, stride=2, padding=3)
        self.bn1 = nn.BatchNorm2d(32)
        self.relu = nn.ReLU()
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.xception_inception = XceptionInceptionModule(32, 64)
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(64 * 3, num_classes)  # Adjusted for the reduced number of branches

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)
        x = self.xception_inception(x)
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.fc(x)
        return x

In [14]:
vgg_attention = VGG_Attention(num_classes=4)
checkpoint = torch.load("checkpoints\\best_VGG_Attention.pth", map_location=torch.device('cuda'))
vgg_attention.load_state_dict(checkpoint)

cbam18 = CBAMResNet(num_classes=4, model=models.resnet18(), in_channels_ls = [64, 128, 256, 512])
checkpoint = torch.load("checkpoints\\best_CBAMResNet.pth", map_location=torch.device('cuda'))
cbam18.load_state_dict(checkpoint)

unet_resnet = UNet_with_ResNet_Blocks(num_classes=4)
checkpoint = torch.load("checkpoints\\best_UNet_with_ResNet_Blocks_w_crossentropy.pth", map_location=torch.device('cuda'))
unet_resnet.load_state_dict(checkpoint)

xception = XceptionInceptionNet(num_classes=4)
checkpoint = torch.load("checkpoints\\best_XceptionInceptionNet.pth", map_location=torch.device('cuda'))
xception.load_state_dict(checkpoint)

<All keys matched successfully>

In [18]:
models = [vgg_attention, cbam18]
device = torch.device("cuda")
weights = [0.23, 0,77]
evaluate_ensemble(models, test_dataloader, device, method='averaging')
evaluate_ensemble(models, test_dataloader, device, method='weighted_averaging', weights=weights)

[[4271  675   69   13]
 [ 117  453    1    0]
 [  32    6  150    1]
 [   2    0    0   48]]
[0.84944312 0.79334501 0.79365079 0.96      ]
Test Accuracy: 0.8431, Test F1: 0.7565, Test Precision: 0.7053, Test Recall: 0.8491
[[5028    0    0    0]
 [ 571    0    0    0]
 [ 189    0    0    0]
 [  50    0    0    0]]
[1. 0. 0. 0.]
Test Accuracy: 0.8613, Test F1: 0.2314, Test Precision: 0.2153, Test Recall: 0.2500


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


In [15]:
models = [vgg_attention, unet_resnet]
device = torch.device("cuda")
weights = [0.37, 0.63]
evaluate_ensemble(models, test_dataloader, device, method='averaging')
evaluate_ensemble(models, test_dataloader, device, method='weighted_averaging', weights=weights)

[[3937 1066   18    7]
 [ 322  249    0    0]
 [ 119   63    3    4]
 [  29    7    4   10]]
[0.78301512 0.43607706 0.01587302 0.2       ]
Test Accuracy: 0.7193, Test F1: 0.3497, Test Precision: 0.4173, Test Recall: 0.3587
[[3821 1150   44   13]
 [ 312  254    5    0]
 [ 108   67    8    6]
 [  25    8    4   13]]
[0.75994431 0.44483363 0.04232804 0.26      ]
Test Accuracy: 0.7016, Test F1: 0.3628, Test Precision: 0.4012, Test Recall: 0.3768


In [16]:
models = [xception, cbam18]
weights = [0.23, 0.77]
device = torch.device("cuda")
evaluate_ensemble(models, test_dataloader, device, method='averaging')
evaluate_ensemble(models, test_dataloader, device, method='weighted_averaging', weights=weights)

[[4268  670   76   14]
 [ 117  452    2    0]
 [  31    6  151    1]
 [   1    0    0   49]]
[0.84884646 0.7915937  0.7989418  0.98      ]
Test Accuracy: 0.8428, Test F1: 0.7545, Test Precision: 0.6980, Test Recall: 0.8548
[[4223  699   89   17]
 [ 114  454    3    0]
 [  26    6  156    1]
 [   1    0    0   49]]
[0.83989658 0.79509632 0.82539683 0.98      ]
Test Accuracy: 0.8362, Test F1: 0.7439, Test Precision: 0.6799, Test Recall: 0.8601


In [17]:
models = [unet_resnet, xception]
device = torch.device("cuda")
weights = [0.38, 0.62]
evaluate_ensemble(models, test_dataloader, device, method='averaging')
evaluate_ensemble(models, test_dataloader, device, method='weighted_averaging', weights=weights)

[[4172  806   42    8]
 [ 373  194    4    0]
 [ 131   44    9    5]
 [  31    4    4   11]]
[0.82975338 0.33975482 0.04761905 0.22      ]
Test Accuracy: 0.7513, Test F1: 0.3667, Test Precision: 0.4206, Test Recall: 0.3593
[[4355  650   18    5]
 [ 412  158    1    0]
 [ 150   35    3    1]
 [  34    4    4    8]]
[0.86614956 0.27670753 0.01587302 0.16      ]
Test Accuracy: 0.7749, Test F1: 0.3434, Test Precision: 0.4382, Test Recall: 0.3297


In [18]:
models = [vgg_attention, unet_resnet, cbam18]
device = torch.device("cuda")
weights = [0.159, 0.539, 0.302]
evaluate_ensemble(models, test_dataloader, device, method='averaging')
evaluate_ensemble(models, test_dataloader, device, method='max_voting')
evaluate_ensemble(models, test_dataloader, device, method='weighted_averaging', weights=weights)

[[4278  699   49    2]
 [ 122  449    0    0]
 [  38    7  143    1]
 [   4    0    0   46]]
[0.85083532 0.78633975 0.75661376 0.92      ]
Test Accuracy: 0.8421, Test F1: 0.7759, Test Precision: 0.7588, Test Recall: 0.8284
[[4646  373    9    0]
 [ 335  236    0    0]
 [ 168    4   17    0]
 [  33    0    0   17]]
[0.92402546 0.41330998 0.08994709 0.34      ]
Test Accuracy: 0.8421, Test F1: 0.4936, Test Precision: 0.7339, Test Recall: 0.4418
[[4259  712   57    0]
 [ 129  442    0    0]
 [  43    8  137    1]
 [   3    0    0   47]]
[0.84705648 0.77408056 0.72486772 0.94      ]
Test Accuracy: 0.8368, Test F1: 0.7712, Test Precision: 0.7566, Test Recall: 0.8215


In [19]:
models = [xception, unet_resnet, cbam18 ]
device = torch.device("cuda")
weights = [0.167, 0.560, 0.273]
evaluate_ensemble(models, test_dataloader, device, method='averaging')
evaluate_ensemble(models, test_dataloader, device, method='max_voting')
evaluate_ensemble(models, test_dataloader, device, method='weighted_averaging', weights=weights)

[[4278  690   58    2]
 [ 126  444    1    0]
 [  37    6  145    1]
 [   3    0    0   47]]
[0.85083532 0.77758319 0.76719577 0.94      ]
Test Accuracy: 0.8417, Test F1: 0.7751, Test Precision: 0.7507, Test Recall: 0.8339
[[4484  535    9    0]
 [ 323  248    0    0]
 [ 156   16   17    0]
 [  33    0    0   17]]
[0.89180589 0.43432574 0.08994709 0.34      ]
Test Accuracy: 0.8164, Test F1: 0.4806, Test Precision: 0.7154, Test Recall: 0.4390
[[4268  704   56    0]
 [ 133  438    0    0]
 [  46    8  134    1]
 [   4    0    0   46]]
[0.84884646 0.76707531 0.70899471 0.92      ]
Test Accuracy: 0.8369, Test F1: 0.7663, Test Precision: 0.7559, Test Recall: 0.8112
