In [7]:
#imports
import logging
import os
import sys
import shutil
import tempfile
from monai.data import Dataset

import matplotlib.pyplot as plt
import torch
import torch.nn as nn
#from torch.utils.tensorboard import SummaryWriter
import numpy as np

import monai
from monai.apps import download_and_extract
from monai.config import print_config
from monai.data import DataLoader, ImageDataset
from monai.transforms import (
    EnsureChannelFirst,
    Compose,
    Resize,
    ScaleIntensity,
)

import glob
import nibabel as nib
from sklearn.model_selection import train_test_split
  


ModuleNotFoundError: No module named 'sklearn'

# Training

### Loading the data

In [4]:
pin_memory = torch.cuda.is_available()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cpu


In [None]:
data_dir = "L:/Basic/divi/jstoker/slicer_pdac/Master Students WS 24/Martijn/data/Training/paired_scans" #fill in training datapath
nifti_images = sorted(glob.glob(os.path.join(data_dir, "*.nii.gz")))      

In [None]:
class PairedMedicalDataset(Dataset):
    def __init__(self, image_pairs, metadata, labels, transform=None):
        self.image_pairs = image_pairs
        self.metadata = metadata
        self.labels = labels
        self.transform = transform
    
    def __len__(self):
        return len(self.image_pairs)
    
    def __getitem__(self, idx):
        img1_path, img2_path = self.image_pairs[idx]
        
        # Load images using nibabel (for NIfTI)
        img1 = nib.load(img1_path).get_fdata()
        img2 = nib.load(img2_path).get_fdata()

        # Add channel dimension for CNN input (C, H, W, D)
        img1 = np.expand_dims(img1, axis=0)
        img2 = np.expand_dims(img2, axis=0)
        
        metadata = self.metadata[idx]
        label = self.labels[idx]
        
        if self.transform:
            img1 = self.transform(img1)
            img2 = self.transform(img2)

        # Convert to tensor
        img1 = torch.tensor(img1, dtype=torch.float32)
        img2 = torch.tensor(img2, dtype=torch.float32)
        label = torch.tensor(label, dtype=torch.float32)
        
        return img1, img2, metadata, label

In [None]:
# Create pairs (e.g., first and second file are paired)
image_pairs = [(nifti_images[i], nifti_images[i + 1]) for i in range(0, len(nifti_images) - 1, 2)]
labels = None #Fill in correct path. response, PFS, and OS



# Split the data into training and validation sets
train_image_pairs, val_image_pairs, train_labels, val_labels = train_test_split(
    image_pairs, labels, test_size=0.2, random_state=42  # 20% for validation
)

# Create training dataset
train_dataset = PairedMedicalDataset(
    train_image_pairs, train_labels, transform=[ScaleIntensity(), EnsureChannelFirst(), Resize((64, 256, 256))]
)

# Create validation dataset
val_dataset = PairedMedicalDataset(
    val_image_pairs, val_labels, transform=[ScaleIntensity(), EnsureChannelFirst(), Resize((64, 256, 256))]
)

# DataLoaders
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=2, shuffle=False)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=2, shuffle=False)

### Initialize model

In [3]:
class SiameseNetwork(nn.Module):
    def __init__(self, base_model):
        super(SiameseNetwork, self).__init__()
        self.base_model = base_model
        self.adaptive_pool = nn.AdaptiveAvgPool3d((1, 1, 1))
        self.classifier = nn.Sequential(
            nn.Linear(2053,128),
            nn.ReLU(),
            nn.Linear(128,6),      #8 classes (Path. resp. 2, PFS 3, OS 3)
            nn.Sigmoid()
        )

    
    def forward(self, image1, image2, metadata):
        # Pass both inputs through the shared model
        output1 = self.base_model(image1)
        output2 = self.base_model(image2)

        # Apply adaptive average pooling to both outputs
        output1 = self.adaptive_pool(output1).view(output1.size(0), -1)  # Flatten after pooling
        output2 = self.adaptive_pool(output2).view(output2.size(0), -1)  # Flatten after pooling

        combined_embeddings = torch.cat((output1, output2, metadata), dim=1)
        output3 = self.classifier(combined_embeddings)
        return output3


In [4]:
resnet_model = torch.hub.load('Warvito/MedicalNet-models', 'medicalnet_resnet50')

# Remove the final classification layer (fc) to keep only the encoder part
encoder = nn.Sequential(*list(resnet_model.children())[:-1])

Using cache found in C:\Users\marti/.cache\torch\hub\Warvito_MedicalNet-models_main


In [5]:
model = SiameseNetwork(encoder)




In [6]:
x = torch.randn(1, 1, 64, 256, 256)
model(x,x, torch.Tensor([[0,1,0,1,0]]))


tensor([[0.4802, 0.4945, 0.4337, 0.5213, 0.5400, 0.5161]],
       grad_fn=<SigmoidBackward0>)

In [45]:
# Create example input tensor
x = torch.randn(1, 1, 64, 256, 256)  # (batch_size, channels, depth, height, width)

# Pass through the encoder
encoder.eval()
with torch.no_grad():
    features = encoder(x)

print(features.shape)

adaptive_pool = nn.AdaptiveAvgPool3d((1, 1, 1))
print(len(adaptive_pool(features)))



torch.Size([1, 1024, 8, 32, 32])
1


### Training

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import StepLR
from tqdm import tqdm

def train_model(model, train_loader, val_loader, criterion, optimizer, scheduler, num_epochs=10, device='cuda'):
    best_val_loss = float('inf')

    for epoch in range(num_epochs):
        print(f"\nEpoch {epoch + 1}/{num_epochs}")
        print('-' * 20)

        # ---------------------------
        # TRAINING PHASE
        # ---------------------------
        model.train()
        running_loss = 0.0
        correct = 0
        total = 0
        
        for inputs, labels in tqdm(train_loader):
            inputs, labels = inputs.to(device), labels.to(device)
            
            # Forward pass
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            
            # Backward pass and optimization
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            # Track metrics
            running_loss += loss.item()
            _, preds = torch.max(outputs, 1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)
        
        epoch_loss = running_loss / len(train_loader)
        epoch_acc = correct / total
        
        print(f'Train Loss: {epoch_loss:.4f} | Train Acc: {epoch_acc:.4f}')

        # ---------------------------
        # VALIDATION PHASE
        # ---------------------------
        model.eval()
        val_loss = 0.0
        correct = 0
        total = 0
        
        with torch.no_grad():
            for inputs, labels in val_loader:
                inputs, labels = inputs.to(device), labels.to(device)
                
                outputs = model(inputs)
                loss = criterion(outputs, labels)
                
                val_loss += loss.item()
                _, preds = torch.max(outputs, 1)
                correct += (preds == labels).sum().item()
                total += labels.size(0)
        
        val_loss /= len(val_loader)
        val_acc = correct / total
        print(f'Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.4f}')
        
        # ---------------------------
        # LEARNING RATE SCHEDULER STEP
        # ---------------------------
        scheduler.step(val_loss)

        # ---------------------------
        # SAVE BEST MODEL
        # ---------------------------
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            torch.save(model.state_dict(), 'best_model.pth')
            print("🔥 Best model saved!")
    
    print("\nTraining complete. Best Val Loss: {:.4f}".format(best_val_loss))



In [None]:
# Model
model = model.to('cuda')

# Loss function (for classification)
criterion = nn.CrossEntropyLoss()

# Optimizer
optimizer = optim.Adam(model.parameters(), lr=1e-4)

# Learning rate scheduler
scheduler = StepLR(optimizer, step_size=5, gamma=0.1)

# Start training!
train_model(model, train_loader, val_loader, criterion, optimizer, scheduler, num_epochs=20)
