# Installing and Importing Libraries

In [1]:
!pip install monai seaborn google-cloud-storage



In [2]:
import os
import h5py
import zipfile
import monai
import torch
import time
import io
import cv2

import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import pandas as pd

from torch.optim.lr_scheduler import StepLR, CosineAnnealingWarmRestarts
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from typing import cast
from tqdm import tqdm
from pathlib import Path

from sklearn.metrics import f1_score, accuracy_score, precision_score, recall_score

from monai.networks.utils import one_hot
from monai.networks.nets import resnet10, resnet18, resnet34, resnet50, resnet101, resnet152, resnet200
from monai.visualize import GradCAMpp, GradCAM
from monai.transforms import (
    LoadImaged, EnsureChannelFirstd, Orientationd, Spacingd, ScaleIntensityRanged, Resized,
    RandCropByPosNegLabeld, RandFlipd, RandRotate90d, RandShiftIntensityd,
    RandAffined, RandZoomd, RandGaussianNoised, Compose, EnsureTyped, ToTensord, ScaleIntensity
)
from google.cloud import storage

In [None]:
!gsutil -m cp -r gs://oai_dataset/* .

[1;30;43mПоказано результат, скорочений до останніх рядків (5000).[0m
Copying gs://oai_dataset/FEMAI-Knee/OAI_00m_SAG_IW_TSE_LEFT/SAG_IW_TSE_LEFT/00m/9522950_00m_SAG_IW_TSE_LEFT.hdf5...
Copying gs://oai_dataset/FEMAI-Knee/OAI_00m_SAG_IW_TSE_LEFT/SAG_IW_TSE_LEFT/00m/9523022_00m_SAG_IW_TSE_LEFT.hdf5...
Copying gs://oai_dataset/FEMAI-Knee/OAI_00m_SAG_IW_TSE_LEFT/SAG_IW_TSE_LEFT/00m/9523138_00m_SAG_IW_TSE_LEFT.hdf5...
Copying gs://oai_dataset/FEMAI-Knee/OAI_00m_SAG_IW_TSE_LEFT/SAG_IW_TSE_LEFT/00m/9523523_00m_SAG_IW_TSE_LEFT.hdf5...
Copying gs://oai_dataset/FEMAI-Knee/OAI_00m_SAG_IW_TSE_LEFT/SAG_IW_TSE_LEFT/00m/9523641_00m_SAG_IW_TSE_LEFT.hdf5...
Copying gs://oai_dataset/FEMAI-Knee/OAI_00m_SAG_IW_TSE_LEFT/SAG_IW_TSE_LEFT/00m/9523742_00m_SAG_IW_TSE_LEFT.hdf5...
Copying gs://oai_dataset/FEMAI-Knee/OAI_00m_SAG_IW_TSE_LEFT/SAG_IW_TSE_LEFT/00m/9524398_00m_SAG_IW_TSE_LEFT.hdf5...
Copying gs://oai_dataset/FEMAI-Knee/OAI_00m_SAG_IW_TSE_LEFT/SAG_IW_TSE_LEFT/00m/9524666_00m_SAG_IW_TSE_LEFT.hdf5...


In [None]:
from google.colab import drive
drive.mount('/content/drive/')
!cp -r "/content/drive/MyDrive/pretrained_RAI" "/content/pretrained"

Drive already mounted at /content/drive/; to attempt to forcibly remount, call drive.mount("/content/drive/", force_remount=True).


# Data Preprocessing

In [3]:
# Hyperparams
DEVICE = "cuda"
# DEVICE = "cpu"
EXP_NAME = ''
ARCHITECTURE = 'resnet_34'
MULTIMODAL_P = True
NUM_CLASSES = 2
MRI_HEIGHT = 128
MRI_WIDTH = 128
MRI_DEPTH = 35
MLP_HIDDEN = 8
MLP_OUTPUT = 4
RESNET_OUTPUT = 124
GRADCAM_IDX = 15
ORIG_HEIGHT = 444
ORIG_WIDTH = 448
NUM_EPOCHS = 40
RESNET_LR = 1e-4
MLP_LR = 1e-3
FC_LR = 1e-3
# LR = 0.001
BS = 5
GRADCAM_LAYERS = {
    'resnet_10': 'mri_resnet.resnet.layer4.0.conv2',
    'resnet_18': 'mri_resnet.resnet.layer4.1.conv2',
    'resnet_34': 'mri_resnet.resnet.layer4.2.conv2',
    'resnet_50': 'mri_resnet.resnet.layer4.2.conv3',
    'resnet_101': 'mri_resnet.resnet.layer4.2.conv3',
    'resnet_152': 'mri_resnet.resnet.layer4.2.conv3',
    'resnet_200': 'mri_resnet.resnet.layer4.2.conv3'
}


MRI_LEFT_PATH = f"/content/FEMAI-Knee/OAI_00m_SAG_IW_TSE_LEFT/SAG_IW_TSE_LEFT/00m/"
MRI_RIGHT_PATH = f"/content/FEMAI-Knee/OAI_00m_SAG_IW_TSE_RIGHT/SAG_IW_TSE_RIGHT/00m/"
CLINICAL_CSV_PATH = f"/content/FEMAI-Knee/enrollee_info.csv"
LABELS_CSV_PATH = f"/content/FEMAI-Knee/kl_baseline.csv"


metrics_csv_path =  "/content/training_metrics_resnet.csv"

folder_gradcam = Path("outputs/gradcam")
folder_metrics = Path("outputs/metrics")
folder_model = Path("outputs/model_weights")

folder_metrics.mkdir(parents=True, exist_ok=True)
folder_gradcam.mkdir(parents=True, exist_ok=True)
folder_model.mkdir(parents=True, exist_ok=True)

In [4]:
folder_gradcam_drive = Path(f"/content/drive/MyDrive/models_{ARCHITECTURE}_{EXP_NAME}/outputs/gradcam")
folder_metrics_drive = Path(f"/content/drive/MyDrive/models_{ARCHITECTURE}_{EXP_NAME}/outputs/metrics")
folder_model_drive = Path(f"/content/drive/MyDrive/models_{ARCHITECTURE}_{EXP_NAME}/outputs/model_weights")
metrics_csv_path_drive = f"/content/drive/MyDrive/models_{ARCHITECTURE}_{EXP_NAME}/outputs/metrics/training_metrics_resnet.csv"
folder_metrics_drive.mkdir(parents=True, exist_ok=True)
folder_gradcam_drive.mkdir(parents=True, exist_ok=True)
folder_model_drive.mkdir(parents=True, exist_ok=True)

In [5]:
def unzip_archive(zip_path):
    with zipfile.ZipFile(zip_path, 'r') as zip_ref:
        zip_ref.extractall(os.getcwd())


def preprocess_labels(data):
    data = data[['ID', 'SIDE', 'V00XRKL']]
    data = data.dropna(subset=['V00XRKL'])
    return data.map(lambda x: int(x))


def preprocess_clinical_data(data):
    def extract_code(value):
        try:
            return int(value.split(':')[0].strip())
        except (ValueError, AttributeError):
            return np.nan

    data = data[['ID', 'P02RACE', 'P02SEX']]
    data.loc[:, 'P02RACE'] = data['P02RACE'].apply(extract_code)
    data.loc[:, 'P02SEX'] = data['P02SEX'].apply(extract_code)
    data = data.dropna()
    return data.astype({'P02RACE': int, 'P02SEX': int})

In [6]:
pre_rescale_transforms = Compose([
    EnsureChannelFirstd(keys="image", channel_dim=0),

    Orientationd(keys="image", axcodes="RAS"),

    Spacingd(keys=["image"], pixdim=(1.0, 1.0, 1.0), mode="bilinear"),

    Resized(keys=["image"], spatial_size=(MRI_DEPTH, MRI_HEIGHT, MRI_WIDTH), mode="trilinear"),

    ScaleIntensityRanged(
        keys=["image"], a_min=0, a_max=2126, b_min=0.0, b_max=1.0, clip=True
    ),
])

train_transforms = Compose([
    pre_rescale_transforms,

    RandGaussianNoised(keys="image", prob=0.2, mean=0.0, std=0.1),
    RandShiftIntensityd(keys="image", offsets=0.1, prob=0.5),

    EnsureTyped(keys="image"),

    ToTensord(keys="image"),
])

val_test_transforms = Compose([
    pre_rescale_transforms,

    EnsureTyped(keys="image"),

    ToTensord(keys="image"),
])

In [7]:
class MultiModalDataset(Dataset):
    def __init__(self, mri_left_path, mri_right_path, clinical_csv, labels_csv, transforms=None, split="train", train_frac=0.7, val_frac=0.15):
        
        left_files = {
            int(os.path.basename(f).split('_')[0]): f
            for f in os.listdir(mri_left_path) if f.endswith(".hdf5")
        }
        right_files = {
            int(os.path.basename(f).split('_')[0]): f
            for f in os.listdir(mri_right_path) if f.endswith(".hdf5")
        }

        clinical_data = pd.read_csv(clinical_csv)
        labels = pd.read_csv(labels_csv)

        clinical_data = preprocess_clinical_data(clinical_data)
        labels = preprocess_labels(labels)

        labels["KNEE_OA"] = labels['V00XRKL'].apply(lambda x: 1 if x >= 2 else 0)

        left_labels = labels[(labels['SIDE'] == 2) & (labels['ID'].isin(left_files.keys()))]
        right_labels = labels[(labels['SIDE'] == 1) & (labels['ID'].isin(right_files.keys()))]

        left_data = pd.merge(clinical_data, left_labels, on='ID')
        right_data = pd.merge(clinical_data, right_labels, on='ID')

        left_data['file_name'] = left_data['ID'].map(left_files)
        right_data['file_name'] = right_data['ID'].map(right_files)

        self.all_data = pd.concat([left_data, right_data]).reset_index(drop=True)
        self.all_data = self.all_data.sort_values(by='ID').reset_index(drop=True)
        data_knee_oa_0 = self.all_data[self.all_data["KNEE_OA"] == 0]
        data_knee_oa_1 = self.all_data[self.all_data["KNEE_OA"] == 1]

        # Randomly sample half of the data with KNEE_OA == 0
        data_knee_oa_0_sampled = data_knee_oa_0.sample(frac=0.5, random_state=42)

        # Combine the sampled KNEE_OA == 0 data with the KNEE_OA == 1 data
        filtered_data = pd.concat([data_knee_oa_0_sampled, data_knee_oa_1]).reset_index(drop=True)

        # Shuffle the data to mix it
        filtered_data = filtered_data.sample(frac=1, random_state=42).reset_index(drop=True)

        total_data = len(filtered_data)
        train_idx = int(total_data * train_frac)
        val_idx = train_idx + int(total_data * val_frac)


        if split == "train":
            self.all_data = filtered_data.iloc[:train_idx]
        elif split == "val":
            self.all_data = filtered_data.iloc[train_idx:val_idx]
        elif split == "test":
            self.all_data = filtered_data.iloc[val_idx:]
        else:
            raise ValueError("Invalid split. Choose from ['train', 'val', 'test']")

        self.transforms = transforms
        self.mri_left_path = mri_left_path
        self.mri_right_path = mri_right_path

    def __len__(self):
        return len(self.all_data)

    def __getitem__(self, idx):
        file_name = self.all_data.iloc[idx]['file_name']
        mri_path = os.path.join(
            self.mri_left_path if "LEFT" in file_name else self.mri_right_path,
            file_name
        )

        with h5py.File(mri_path, "r") as h5_file:
            mri_image = np.array(h5_file["data"][:37])
        mri_image = np.expand_dims(mri_image, axis=0)

        if self.transforms:
            mri_image = self.transforms({"image": mri_image})["image"]

        clinical_data = torch.tensor(self.all_data.iloc[idx][['P02RACE', 'P02SEX']].values.astype(np.float32), dtype=torch.float32)
        label = torch.tensor(self.all_data.iloc[idx]['KNEE_OA']).squeeze()

        return mri_image.to(DEVICE), clinical_data.to(DEVICE), label.to(DEVICE)

train_dataset = MultiModalDataset(
    MRI_LEFT_PATH,
    MRI_RIGHT_PATH,
    CLINICAL_CSV_PATH,
    LABELS_CSV_PATH,
    transforms=train_transforms,
    split="train"
)

val_dataset = MultiModalDataset(
    MRI_LEFT_PATH,
    MRI_RIGHT_PATH,
    CLINICAL_CSV_PATH,
    LABELS_CSV_PATH,
    transforms=val_test_transforms,
    split="val"
)

test_dataset = MultiModalDataset(
    MRI_LEFT_PATH,
    MRI_RIGHT_PATH,
    CLINICAL_CSV_PATH,
    LABELS_CSV_PATH,
    transforms=val_test_transforms,
    split="test"
)

# Modelling

In [8]:
class MonaiResNet3DWrapper(nn.Module):
    def __init__(self, model='resnet_18', pretrained=True, datasets=True):
        super(MonaiResNet3DWrapper, self).__init__()
        mapping = {
            'resnet_10': resnet10,
            'resnet_18': resnet18,
            'resnet_34': resnet34,
            'resnet_50': resnet50,
            'resnet_101': resnet101,
            'resnet_152': resnet152,
            'resnet_200': resnet200
        }
        if model not in mapping:
            raise NotImplementedError

        self.resnet = mapping[model](spatial_dims=3, n_input_channels=1, num_classes=RESNET_OUTPUT)

        if pretrained:
            if datasets:
                pretrained_weights = torch.load(f"/content/drive/MyDrive/pretrained/{model}_23dataset.pth", map_location=DEVICE)
            else:
                pretrained_weights = torch.load(f"/content/drive/MyDrive/pretrained/{model}.pth", map_location=DEVICE)

            self.resnet.load_state_dict({k: v for k, v in pretrained_weights.items() if "fc" not in k}, strict=False)

    def forward(self, x):
        return self.resnet(x)


class ClinicalMLP(nn.Module):
    def __init__(self, input_size, hidden_size=MLP_HIDDEN, out_features=MLP_OUTPUT):
        super(ClinicalMLP, self).__init__()
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.fc2 = nn.Linear(hidden_size, out_features)
        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.relu(self.fc1(x))
        return self.fc2(x)

class MultiModalModel(nn.Module):
    def __init__(self, multimodal, pretrained):
        super(MultiModalModel, self).__init__()
        self.mri_resnet = MonaiResNet3DWrapper(model=ARCHITECTURE, pretrained=pretrained)
        self.multimodal = multimodal

        if self.multimodal:
            self.clinical_mlp = ClinicalMLP(input_size=2)
            out_shape = RESNET_OUTPUT + MLP_OUTPUT
        else:
            out_shape = RESNET_OUTPUT

        self.fc = nn.Linear(out_shape, NUM_CLASSES)


    def forward(self, data):
        mri_img, clinical_data = data[0], data[1]

        mri_features = self.mri_resnet(mri_img)

        if self.multimodal:
            clinical_features = self.clinical_mlp(clinical_data)
            combined_features = torch.cat((mri_features, clinical_features), dim=1)
        else:
            combined_features = mri_features

        score = self.fc(combined_features)

        return score

In [9]:
def custom_normalizer(x):
    """
    A linear intensity scaling by mapping the (min, max) to (1, 0).
    If the input data is PyTorch Tensor, the output data will be Tensor on the same device,
    otherwise, output data will be numpy array.

    Note: This will flip magnitudes (i.e., smallest will become biggest and vice versa).
    """

    def _compute(data: np.ndarray) -> np.ndarray:
        scaler = ScaleIntensity(minv=0.0, maxv=1.0)
        return np.stack([scaler(i) for i in data], axis=0)

    if isinstance(x, torch.Tensor):
        return torch.as_tensor(_compute(x.detach().cpu().numpy()), device=x.device)  # type: ignore

    return _compute(x)  # type: ignore


class CustomGradCAMpp(GradCAM):
    def __init__(self, nn_module, target_layers):
        super().__init__(nn_module=nn_module, target_layers=target_layers, postprocessing=custom_normalizer)

    def _upsample_and_post_process(self, acti_map, x):
        # upsampling and postprocessing
        img_spatial = x[0].shape[2:]
        acti_map = self.upsampler(img_spatial)(acti_map)
        return self.postprocessing(acti_map)

    def compute_map(self, x, class_idx=None, retain_graph=False, layer_idx=-1, **kwargs):
        _, acti, grad = self.nn_module(x, class_idx=class_idx, retain_graph=retain_graph, **kwargs)

        score = self.nn_module.score

        acti, grad = acti[layer_idx], grad[layer_idx]

        b, c, *spatial = grad.shape
        alpha_nr = grad.pow(2)
        alpha_dr = alpha_nr.mul(2) + acti.mul(grad.pow(3)).view(b, c, -1).sum(-1).view(b, c, *[1] * len(spatial))
        alpha_dr = torch.where(alpha_dr != 0.0, alpha_dr, torch.ones_like(alpha_dr))
        alpha = alpha_nr.div(alpha_dr + 1e-7)
        relu_grad = F.relu(cast(torch.Tensor, score).exp() * grad)
        weights = (alpha * relu_grad).view(b, c, -1).sum(-1).view(b, c, *[1] * len(spatial))
        acti_map = (weights * acti).sum(1, keepdim=True)
        return F.relu(acti_map)

    def __call__(self, x, class_idx=None, layer_idx=-1, retain_graph=False, **kwargs):  # type: ignore[override]
        """
        Compute the activation map with upsampling and postprocessing.

        Args:
            x: input tensor, shape must be compatible with `nn_module`.
            class_idx: index of the class to be visualized. Default to argmax(logits)
            layer_idx: index of the target layer if there are multiple target layers. Defaults to -1.
            retain_graph: whether to retain_graph for torch module backward call.
            kwargs: any extra arguments to be passed on to the module as part of its `__call__`.

        Returns:
            activation maps
        """
        acti_map = self.compute_map(x, class_idx=class_idx, retain_graph=retain_graph, layer_idx=layer_idx, **kwargs)
        return self._upsample_and_post_process(acti_map, x)


In [10]:
model = MultiModalModel(multimodal=MULTIMODAL_P, pretrained=True).to(DEVICE)
print(model)
criterion = nn.CrossEntropyLoss()

if MULTIMODAL_P:
    optimization_params = [{'params': model.mri_resnet.parameters(), 'lr': RESNET_LR}, {'params': model.clinical_mlp.parameters(), 'lr': MLP_LR}, {'params': model.fc.parameters(), 'lr': FC_LR}]
else:
    optimization_params = [{'params': model.mri_resnet.parameters(), 'lr': RESNET_LR}, {'params': model.fc.parameters(), 'lr': FC_LR}]

optimizer = optim.AdamW(optimization_params)

scheduler = CosineAnnealingWarmRestarts(optimizer, T_0=3, T_mult=1)

train_loader = DataLoader(train_dataset, batch_size=BS, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=BS, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=BS, shuffle=False)

grad_cam = CustomGradCAMpp(nn_module=model, target_layers=GRADCAM_LAYERS[ARCHITECTURE])

train_losses, val_losses = [], []
train_accuracies, val_accuracies = [], []
train_f1_scores, val_f1_scores = [], []

  pretrained_weights = torch.load(f"/content/drive/MyDrive/pretrained/{model}_23dataset.pth", map_location=DEVICE)


MultiModalModel(
  (mri_resnet): MonaiResNet3DWrapper(
    (resnet): ResNet(
      (conv1): Conv3d(1, 64, kernel_size=(7, 7, 7), stride=(1, 1, 1), padding=(3, 3, 3), bias=False)
      (bn1): BatchNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (act): ReLU(inplace=True)
      (maxpool): MaxPool3d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
      (layer1): Sequential(
        (0): ResNetBlock(
          (conv1): Conv3d(64, 64, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
          (bn1): BatchNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (act): ReLU(inplace=True)
          (conv2): Conv3d(64, 64, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
          (bn2): BatchNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
        (1): ResNetBlock(
          (conv1): Conv3d(64, 64, kernel_size=(3, 3, 3), stride=(1, 1, 1)

# Training

In [None]:
# Training and validation loop
for epoch in range(NUM_EPOCHS):
    model.train()
    running_loss, correct, total = 0.0, 0, 0
    all_preds, all_labels = [], []

    for mri, clinical, labels in tqdm(train_loader):

        optimizer.zero_grad()
        outputs = model([mri, clinical])

        loss = criterion(outputs, labels)

        loss.backward()
        optimizer.step()

        running_loss += loss.item() * mri.size(0)
        predicted = torch.argmax(outputs, dim=1)

        correct += (predicted == labels).sum().item()
        total += labels.size(0)
        all_preds.extend(predicted.detach().cpu().numpy())
        all_labels.extend(labels.detach().cpu().numpy())


    train_loss = running_loss / total
    train_accuracy = correct / total
    train_f1 = f1_score(all_labels, all_preds, average='macro')

    train_losses.append(train_loss)
    train_accuracies.append(train_accuracy)
    train_f1_scores.append(train_f1)

    # Validation phase
    model.eval()
    val_loss, val_correct, val_total = 0.0, 0, 0
    val_preds, val_labels = [], []


    with torch.no_grad():
        for mri, clinical, labels in tqdm(val_loader):
            outputs = model([mri, clinical])

            loss = criterion(outputs, labels)
            val_loss += loss.item() * mri.size(0)

            predicted = torch.argmax(outputs, dim=1)

            val_correct += (predicted == labels).sum().item()
            val_total += labels.size(0)
            val_preds.extend(predicted.cpu().numpy())
            val_labels.extend(labels.cpu().numpy())
            last_data = [mri, clinical, labels, predicted]


    validation_loss = val_loss / val_total
    val_accuracy = val_correct / val_total
    val_f1 = f1_score(val_labels, val_preds, average='macro')

    val_losses.append(validation_loss)
    val_accuracies.append(val_accuracy)
    val_f1_scores.append(val_f1)

    data = {
    "train_loss": train_losses,
    "val_loss": val_losses,
    "train_accuracy": train_accuracies,
    "val_accuracy": val_accuracies,
    "train_f1_score": train_f1_scores,
    "val_f1_score": val_f1_scores,
    }

    df = pd.DataFrame(data)
    df.to_csv(metrics_csv_path, index=False)
    df.to_csv(metrics_csv_path_drive, index=False)

    scheduler.step()

    checkpoint = {
    "epoch": epoch,
    "model_state": model.state_dict(),
    "optimizer_state": optimizer.state_dict(),
    "scheduler_state": scheduler.state_dict(),
    }

    torch.save(checkpoint, folder_model_drive/f"model_checkpoints_{epoch + 1}.pt")

    print(f"Epoch {epoch + 1}/{NUM_EPOCHS}, "
          f"Train Loss: {train_loss:.4f}, Train Acc: {train_accuracy:.4f}, Train F1: {train_f1:.4f}, "
          f"Val Loss: {validation_loss:.4f}, Val Acc: {val_accuracy:.4f}, Val F1: {val_f1:.4f}")

    # Inference with GradCAM++ after each epoch
    model.eval()

    grad_cam_output = grad_cam([last_data[0][-1].unsqueeze(0), last_data[1][-1].unsqueeze(0)], class_idx=last_data[2][-1].item())

    grad_cam_heatmap = grad_cam_output.cpu().numpy()

    plt.figure(figsize=(8, 8))
    plt.imshow(last_data[0][-1].squeeze(0)[GRADCAM_IDX].cpu(), cmap='gray')
    plt.imshow(grad_cam_heatmap.squeeze().squeeze()[GRADCAM_IDX], cmap='jet', alpha=0.5)
    cbar = plt.colorbar()
    plt.title(f"Grad-CAM++ Epoch {epoch + 1}")
    plt.axis('off')

    plt.savefig(folder_gradcam/f"gradcam_epoch_{epoch + 1}.png")
    plt.savefig(folder_gradcam_drive/f"gradcam_epoch_{epoch + 1}.png")
    plt.close()

    plt.figure(figsize=(8, 8))
    plt.title(f"Grad-CAM++ outputs Epoch {epoch + 1}")
    plt.axis('off')
    plt.imshow(grad_cam_heatmap.squeeze().squeeze()[GRADCAM_IDX], cmap='jet')
    plt.savefig(folder_gradcam/f"gradcam_pure_epoch_{epoch + 1}.png")
    plt.savefig(folder_gradcam_drive/f"gradcam_pure_epoch_{epoch + 1}.png")
    plt.close()

    plt.figure(figsize=(8, 8))
    plt.title(f"MRI input Epoch {epoch + 1}")
    plt.axis('off')
    plt.imshow(last_data[0][-1].squeeze(0)[GRADCAM_IDX].cpu(), cmap='gray')
    plt.savefig(folder_gradcam/f"mri_epoch_{epoch + 1}.png")
    plt.savefig(folder_gradcam_drive/f"mri_epoch_{epoch + 1}.png")
    plt.close()

    gradcam_out = cv2.resize(grad_cam_heatmap.squeeze().squeeze()[GRADCAM_IDX], [ORIG_WIDTH, ORIG_HEIGHT], interpolation=cv2.INTER_CUBIC)
    mri_orig = cv2.resize(last_data[0][-1].squeeze(0)[GRADCAM_IDX].cpu().numpy(), [ORIG_WIDTH, ORIG_HEIGHT], interpolation=cv2.INTER_CUBIC)

    plt.figure(figsize=(8, 8))
    plt.imshow(mri_orig, cmap='gray')
    plt.imshow(gradcam_out, cmap='jet', alpha=0.5)
    cbar = plt.colorbar()
    plt.title(f"Grad-CAM++ Epoch {epoch + 1}")
    plt.axis('off')
    plt.savefig(folder_gradcam/f"gradcam_epoch_{epoch + 1}_resized.png")
    plt.savefig(folder_gradcam_drive/f"gradcam_epoch_{epoch + 1}_resized.png")
    plt.close()

    plt.figure(figsize=(8, 8))
    plt.title(f"Grad-CAM++ outputs Epoch {epoch + 1}")
    plt.axis('off')
    plt.imshow(gradcam_out, cmap='jet')
    plt.savefig(folder_gradcam/f"gradcam_pure_epoch_{epoch + 1}_resized.png")
    plt.savefig(folder_gradcam_drive/f"gradcam_pure_epoch_{epoch + 1}_resized.png")
    plt.close()

    plt.figure(figsize=(8, 8))
    plt.title(f"MRI input Epoch {epoch + 1}")
    plt.axis('off')
    plt.imshow(mri_orig, cmap='gray')
    plt.savefig(folder_gradcam/f"mri_epoch_{epoch + 1}_resized.png")
    plt.savefig(folder_gradcam_drive/f"mri_epoch_{epoch + 1}_resized.png")
    plt.close()


100%|██████████| 833/833 [30:44<00:00,  2.21s/it]
100%|██████████| 179/179 [03:50<00:00,  1.29s/it]


Epoch 1/40, Train Loss: 0.6903, Train Acc: 0.5669, Train F1: 0.5655, Val Loss: 0.6074, Val Acc: 0.6513, Val F1: 0.6512


100%|██████████| 833/833 [30:09<00:00,  2.17s/it]
100%|██████████| 179/179 [03:43<00:00,  1.25s/it]


Epoch 2/40, Train Loss: 0.6326, Train Acc: 0.6402, Train F1: 0.6391, Val Loss: 0.6138, Val Acc: 0.6715, Val F1: 0.6359


100%|██████████| 833/833 [30:15<00:00,  2.18s/it]
100%|██████████| 179/179 [03:52<00:00,  1.30s/it]


Epoch 3/40, Train Loss: 0.5861, Train Acc: 0.6952, Train F1: 0.6942, Val Loss: 0.5615, Val Acc: 0.6973, Val F1: 0.6971


 22%|██▏       | 182/833 [06:42<24:51,  2.29s/it]

In [None]:
# torch.save(model.state_dict(), folder_model/"model_weights.pt")

# Model Evaluation

In [None]:
# model_weights_path = "/content/outputs/model_weights/model_weights_11.pt"
# model.load_state_dict(torch.load(model_weights_path))

In [None]:
model.eval()
test_loss, test_correct, test_total = 0.0, 0, 0
test_preds, test_labels = [], []
with torch.no_grad():
    for mri, clinical, labels in tqdm(test_loader):
        outputs = model([mri, clinical])

        outputs = outputs.squeeze(1)
        labels = labels.float()

        loss = criterion(outputs, labels)
        test_loss += loss.item() * mri.size(0)

        predicted = torch.round(torch.sigmoid(outputs))

        test_correct += (predicted == labels).sum().item()
        test_total += labels.size(0)
        test_preds.extend(predicted.cpu().numpy())
        test_labels.extend(labels.cpu().numpy())
        last_data = [mri, clinical, labels]


    test_loss = test_loss / test_total
    test_accuracy = test_correct / test_total
    test_f1 = f1_score(test_labels, test_preds, average='macro')

    precision = precision_score(test_labels, test_preds, average='macro')
    recall = recall_score(test_labels, test_preds, average='macro')
    precision_over_recall = precision / recall if recall > 0 else float('inf')

    scheduler.step()

    print(f"Test Loss: {test_loss:.4f}, Test Acc: {test_accuracy:.4f}, Test F1: {test_f1:.4f}")
    print(f"Precision: {precision:.4f}, Recall: {recall:.4f}, Precision/Recall: {precision_over_recall:.4f}")

In [None]:
total_oa, correct_oa, not_correct_oa, correct_no_oa, not_correct_no_oa = 0, 0, 0, 0, 0
for i in range(len(test_labels)):
    total_oa += int(test_labels[i])
    if  test_labels[i] == 1 and test_preds[i] == 1:
        correct_oa += 1
    elif test_labels[i] == 0 and test_preds[i] == 1:
        not_correct_oa += 1
    elif test_labels[i] == 0 and test_preds[i] == 0:
        correct_no_oa += 1
    elif test_labels[i] == 1 and test_preds[i] == 0:
        not_correct_no_oa += 1
    else:
        print(test_labels[i],  test_preds[i])


print(f"Total OA: {total_oa}, Total NoOA {len(test_labels) - total_oa}")
print(f"Correct Predict OA: {correct_oa}, Incorrect Predict OA: {not_correct_no_oa}, CLASS1_ACC: = {correct_oa / total_oa}")
print(f"Correct Predict NoOA: {correct_no_oa}, Incorrect Predict NoOA: {not_correct_oa}, CLASS0_ACC: {correct_no_oa / (len(test_labels) - total_oa)}")

In [None]:
metrics_csv_path = "/content/training_metrics_resnet_oa.csv"
metrics_df = pd.read_csv(metrics_csv_path)

# Concatenate the existing lists with data from the CSV
train_losses = metrics_df["train_loss"].tolist() + train_losses
val_losses = metrics_df["val_loss"].tolist() + val_losses
train_accuracies = metrics_df["train_accuracy"].tolist() + train_accuracies
val_accuracies = metrics_df["val_accuracy"].tolist() + val_accuracies
train_f1_scores = metrics_df["train_f1_score"].tolist() + train_f1_scores
val_f1_scores = metrics_df["val_f1_score"].tolist() + val_f1_scores

In [None]:
# Plot training and validation loss, accuracy, and F1 score
epochs = range(1, 7)
plt.figure(figsize=(12, 8))

# Loss plot
plt.subplot(3, 1, 1)
plt.plot(epochs, train_losses, label='Train Loss')
plt.plot(epochs, val_losses, label='Val Loss')
plt.xlabel("Epochs")
plt.ylabel("Loss")
plt.legend()
plt.title("Training and Validation Loss")

# Accuracy plot
plt.subplot(3, 1, 2)
plt.plot(epochs, train_accuracies, label='Train Accuracy')
plt.plot(epochs, val_accuracies, label='Val Accuracy')
plt.xlabel("Epochs")
plt.ylabel("Accuracy")
plt.legend()
plt.title("Training and Validation Accuracy")

# F1 Score plot
plt.subplot(3, 1, 3)
plt.plot(epochs, train_f1_scores, label='Train F1 Score')
plt.plot(epochs, val_f1_scores, label='Val F1 Score')
plt.xlabel("Epochs")
plt.ylabel("F1 Score")
plt.legend()
plt.title("Training and Validation F1 Score")

plt.tight_layout()
plt.savefig(folder_metrics/"training_metrics.png")
plt.show()

In [None]:
!zip -r outputs_mlp_oa_eq.zip outputs

In [None]:
data = {
    "train_loss": train_losses,
    "val_loss": val_losses,
    "train_accuracy": train_accuracies,
    "val_accuracy": val_accuracies,
    "train_f1_score": train_f1_scores,
    "val_f1_score": val_f1_scores,
}

# Convert the dictionary to a DataFrame
df = pd.DataFrame(data)

# Save the DataFrame to a CSV file
csv_file = "training_metrics_resnet34.csv"
df.to_csv(csv_file, index=False)

# Data Split Distributions

In [None]:
train_frac = 0.7
val_frac = 0.15


left_files = {
    int(os.path.basename(f).split('_')[0]): f
    for f in os.listdir(MRI_LEFT_PATH) if f.endswith(".hdf5")
}
right_files = {
    int(os.path.basename(f).split('_')[0]): f
    for f in os.listdir(MRI_RIGHT_PATH) if f.endswith(".hdf5")
}
# Load clinical and label data
clinical_data = pd.read_csv(CLINICAL_CSV_PATH)
labels = pd.read_csv(LABELS_CSV_PATH)

# Preprocess clinical and label data
clinical_data = preprocess_clinical_data(clinical_data)
labels = preprocess_labels(labels)

labels["KNEE_OA"] = labels['V00XRKL'].apply(lambda x: 1 if x >= 2 else 0)

left_labels = labels[(labels['SIDE'] == 2) & (labels['ID'].isin(left_files.keys()))]
right_labels = labels[(labels['SIDE'] == 1) & (labels['ID'].isin(right_files.keys()))]

# Merge with clinical data
left_data = pd.merge(clinical_data, left_labels, on='ID')
right_data = pd.merge(clinical_data, right_labels, on='ID')

# Add file names to the datasets
left_data['file_name'] = left_data['ID'].map(left_files)
right_data['file_name'] = right_data['ID'].map(right_files)

# Combine left and right data
all_data = pd.concat([left_data, right_data]).reset_index(drop=True)
all_data = all_data.sort_values(by='ID').reset_index(drop=True)

data_knee_oa_0 = all_data[all_data["KNEE_OA"] == 0]
data_knee_oa_1 = all_data[all_data["KNEE_OA"] == 1]

# Randomly sample half of the data with KNEE_OA == 0
data_knee_oa_0_sampled = data_knee_oa_0.sample(frac=0.5, random_state=42)

# Combine the sampled KNEE_OA == 0 data with the KNEE_OA == 1 data
filtered_data = pd.concat([data_knee_oa_0_sampled, data_knee_oa_1]).reset_index(drop=True)

# Shuffle the data to mix it
filtered_data = filtered_data.sample(frac=1, random_state=42).reset_index(drop=True)

# Split data into train, val, test
total_data = len(filtered_data)
train_idx = int(total_data * train_frac)
val_idx = train_idx + int(total_data * val_frac)

train = filtered_data.iloc[:train_idx]
val = filtered_data.iloc[train_idx:val_idx]
test = filtered_data.iloc[val_idx:]

def plot_dist(dataset, name):
    columns = ['P02RACE', 'P02SEX', 'SIDE', 'KNEE_OA']
    plt.figure(figsize=(16, 3))
    for i, col in enumerate(columns):
        plt.subplot(1, len(columns), i + 1)
        sns.histplot(dataset[col], kde=True, bins=10)
        plt.title(f"Distribution of {col} in {name}")
        plt.savefig(f'{name}_{col}_dist.png')
    plt.tight_layout()
    plt.show()

for dataset in [(train, 'train'), (val, 'val'), (test, 'test')]:
    plot_dist(dataset[0], dataset[1])
