In [None]:
#imports
import logging
import os
import sys
import shutil
import tempfile
from monai.data import Dataset

import matplotlib.pyplot as plt
import torch
import torch.nn as nn
#from torch.utils.tensorboard import SummaryWriter
import numpy as np

import monai
from monai.apps import download_and_extract
from monai.config import print_config
from monai.data import DataLoader, ImageDataset
from monai.transforms import (
    EnsureChannelFirst,
    Compose,
    Resize,
    ScaleIntensity,
)

import glob
import nibabel as nib
  



# Pre-training

### Loading the data

In [5]:
pin_memory = torch.cuda.is_available()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
data_dir = "L:/Basic/divi/jstoker/slicer_pdac/Master Students WS 24/Martijn/data/Training/paired_scans" #fill in pretraining datapath
nifti_images = sorted(glob.glob(os.path.join(data_dir, "*.nii.gz")))      

In [None]:
class PairedMedicalDataset(Dataset):
    def __init__(self, image_pairs, metadata, labels, transform=None):
        self.image_pairs = image_pairs
        self.metadata = metadata
        self.labels = labels
        self.transform = transform
    
    def __len__(self):
        return len(self.image_pairs)
    
    def __getitem__(self, idx):
        img1_path, img2_path = self.image_pairs[idx]
        
        # Load images using nibabel (for NIfTI)
        img1 = nib.load(img1_path).get_fdata()
        img2 = nib.load(img2_path).get_fdata()

        # Add channel dimension for CNN input (C, H, W, D)
        img1 = np.expand_dims(img1, axis=0)
        img2 = np.expand_dims(img2, axis=0)
        
        metadata = self.metadata[idx]
        label = self.labels[idx]
        
        if self.transform:
            img1 = self.transform(img1)
            img2 = self.transform(img2)

        # Convert to tensor
        img1 = torch.tensor(img1, dtype=torch.float32)
        img2 = torch.tensor(img2, dtype=torch.float32)
        label = torch.tensor(label, dtype=torch.float32)
        
        return img1, img2, metadata, label

In [19]:
# Create pairs (e.g., first and second file are paired)
image_pairs = [(nifti_images[i], nifti_images[i + 1]) for i in range(0, len(nifti_images) - 1, 2)]
labels = None #Fill in correct path. response, PFS, and OS

# Create dataset        Change resize!!!!!!!!!!!!!!
train_dataset = PairedMedicalDataset(image_pairs, labels, transform=[ScaleIntensity(), EnsureChannelFirst(), Resize((96, 96, 96))])

# DataLoader
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=4, shuffle=True)

### Initialize model

In [None]:
class SiameseNetwork(nn.Module):
    def __init__(self, base_model):
        super(SiameseNetwork, self).__init__()
        self.base_model = base_model
        self.fc = nn.Linear(512, 1)  # Output similarity score
    
    def forward(self, input1, input2):
        # Pass both inputs through the shared model
        output1 = self.base_model(input1)
        output2 = self.base_model(input2)
        
        # Compute absolute difference
        diff = torch.abs(output1 - output2)
        
        # Predict similarity score (sigmoid to get value between 0 and 1)
        similarity = torch.sigmoid(self.fc(diff))
        
        return similarity


### Training