In [5]:
# Imports
import torch
import torch.nn as nn
from models import SiameseNetwork, PairedMedicalDataset
import shap
import glob
import os
import pandas as pd
import numpy as np
from monai.transforms import (
    Resize,
    ScaleIntensity,
    Compose
)
import nibabel as nib

# Model evaluation

In [9]:
# Instantiate the base model (e.g., ResNet or other feature extractor)

resnet_model = torch.hub.load('Warvito/MedicalNet-models', 'medicalnet_resnet10')

# Remove the final classification layer (fc) to keep only the encoder part
encoder = nn.Sequential(*list(resnet_model.children())[:-1])

model = SiameseNetwork(encoder)
model.load_state_dict(torch.load("best_model.pth", map_location='cpu'))
model.eval()

Using cache found in C:\Users\P095550/.cache\torch\hub\Warvito_MedicalNet-models_main


SiameseNetwork(
  (base_model): Sequential(
    (0): Conv3d(1, 64, kernel_size=(7, 7, 7), stride=(2, 2, 2), padding=(3, 3, 3), bias=False)
    (1): BatchNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): MaxPool3d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (4): Sequential(
      (0): BasicBlock(
        (conv1): Conv3d(64, 64, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
        (bn1): BatchNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (conv2): Conv3d(64, 64, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
        (bn2): BatchNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (5): Sequential(
      (0): BasicBlock(
        (conv1): Conv3d(64, 128, kernel_size=(3, 3, 3), stride=(2, 2, 2), padding=(1, 1, 1), bias=False)
        (bn1): Batch

In [23]:
class SiameseWrapper(nn.Module):
    def __init__(self, model):
        super(SiameseWrapper, self).__init__()
        self.model = model

    def forward(self, input):
        image1, image2, metadata = input
        return self.model(image1, image2, metadata)

In [19]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

data_dir = "L:/Basic/divi/jstoker/slicer_pdac/Master Students WS 24/Martijn/data/Training/paired_scans"
clinical_data_dir = "C:/Users/P095550/OneDrive - Amsterdam UMC/Documenten/GitHub/CRLM-morph-features"
nifti_images = sorted(glob.glob(os.path.join(data_dir, "*.nii.gz")))   

# Create pairs (e.g., first and second file are paired)
image_pairs = [(nifti_images[i], nifti_images[i + 1]) for i in range(0, len(nifti_images) - 1, 2)]

pd_metadata = pd.read_csv(os.path.join(clinical_data_dir, "training_input.csv"))
all_metadata = torch.tensor(pd_metadata.values.tolist())

pd_labels = pd.read_csv(os.path.join(clinical_data_dir, "training_labels.csv"))
all_labels = torch.tensor(pd_labels.values.tolist()) #Fill in correct path. response, PFS, and OS

Using device: cpu


In [20]:


# Create training and validation datasets
train_dataset = PairedMedicalDataset(
    image_pairs, all_metadata, all_labels, transform=[ScaleIntensity(), Resize((64, 256, 256), mode="trilinear")]
)

# Create DataLoaders
data_loader = torch.utils.data.DataLoader(train_dataset, batch_size=11, shuffle=True)


In [None]:
#data

# Create an iterator for the DataLoader
data_iter = iter(data_loader)

# Extract one batch
batch = next(data_iter)

# Unpack the batch (if applicable)
img1, img2, metadata, labels = batch


# Dummy image pair input
image1_test = img1[0].to("cuda")
image2_test = img2[0].to("cuda")
metadata_test = metadata[0].to("cuda")


# Background for SHAP (used by DeepExplainer)
image1_bg = img1[1:].to("cuda")
image2_bg = img2[1:].to("cuda")
metadata_bg = metadata[1:].to("cuda")


In [41]:
print(type(image2_test))

<class 'monai.data.meta_tensor.MetaTensor'>


In [None]:
wrapped_model = SiameseWrapper(model).to("cuda")

# Use DeepExplainer for PyTorch models
explainer = shap.GradientExplainer(
    wrapped_model,
    data=(image1_bg, image2_bg, metadata_bg)
)

shap_values = explainer.shap_values((image1_test, image2_test, metadata_test))



KeyboardInterrupt

