In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchmetrics.functional import structural_similarity_index_measure, peak_signal_noise_ratio
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
import os
import math
import torchvision
from torchvision import datasets, transforms
from sklearn.metrics import confusion_matrix
from sklearn.metrics import classification_report
import random
import pytorch_lightning as pl
from torch.nn import Parameter
from typing import Dict, Any
from ipywidgets import IntSlider, interact

device = torch.device('mps') or torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

import wandb
wandb.login()

from pytorch_lightning.loggers import WandbLogger

mps


[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33msimon-ma[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [2]:
class Dataset2D(Dataset):
    def __init__(self, data_paths, context_window=4, prediction_window=1, transform=None):
        self.data_paths = data_paths
        self.context_window = context_window
        self.prediction_window = prediction_window
        self.transform = transform
        self.samples = []
        # For every path to a volume sequence in .npy
        for data_path in self.data_paths:
            volume_seq = np.load(data_path)
            # Convert to tensor
            volume_seq = torch.from_numpy(volume_seq)
            for h in range(volume_seq.shape[1]):
                # Generate samples
                for t in range(len(volume_seq) - self.context_window - self.prediction_window + 1):
                    # Input shape: [T, C, H, W] --> [context_window, 1, 256, 256]
                    # Target shape: [T, C, H, W] --> [prediction_window, 1, 256, 256]
                    self.samples.append((volume_seq[t:t+context_window, h].unsqueeze(1), 
                                         volume_seq[t+context_window:t+context_window+prediction_window, h].unsqueeze(1)))
        
    def __len__(self):
        return len(self.samples)
    
    def __getitem__(self, idx):

        return self.samples[idx]
    
class Dataset3D(Dataset):
    def __init__(self, data_paths, context_window=4, prediction_window=1, transform=None):
        self.data_paths = data_paths
        self.context_window = context_window
        self.prediction_window = prediction_window
        self.transform = transform
        self.samples = []
        # For every path to a volume sequence in .npy
        for data_path in self.data_paths:
            volume_seq = np.load(data_path)
            # Convert to tensor
            volume_seq = torch.from_numpy(volume_seq)

            for t in range(volume_seq.shape[0]-context_window-prediction_window+1):
                self.samples.append([volume_seq[t:t+context_window].unsqueeze(1), 
                                     volume_seq[t+context_window:t+context_window+prediction_window].unsqueeze(1)])
                # Input-shape: [T, C, D, H, W]
                # Target-shape: [T, C, D, H, W]
    def __len__(self):
        return len(self.samples)
    
    def __getitem__(self, idx):
        return self.samples[idx]
    
def get_data_loaders_2D_pred(root='../Data',
                            batch_size=4, 
                            sequence_length=4, 
                            prediction_length=1, 
                            num_workers=0, 
                            pin_memory=False, 
                            drop_last=False,
                            train_split=0.8,
                            val_split=0.1):
    """
    Train dataset is 2D, Validation and Test datasets are 3D for unified validation and testing with 3D models
    """
    # Load all folder paths
    data_paths = [os.path.join(root, path) for path in os.listdir(root)]
    # Split into train/val/test
    n_train = int(len(data_paths) * train_split)
    n_val = int(len(data_paths) * val_split)
    train_paths = data_paths[:n_train]
    val_paths = data_paths[n_train:n_train+n_val]
    test_paths = data_paths[n_train+n_val:]
    
    # Create datasets
    train_dataset = Dataset2D(train_paths, sequence_length, prediction_length)
    val_dataset = Dataset3D(val_paths, sequence_length, prediction_length)
    test_dataset = Dataset3D(test_paths, sequence_length, prediction_length)
    
    # Create dataloaders
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, drop_last=drop_last)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, drop_last=drop_last)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, drop_last=drop_last)
    
    return train_loader, val_loader, test_loader

2Plus1D Unet

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class DoubleConv3D(nn.Module):
    """
    This block applies two consecutive 3D convolutions (with 3×3×3 kernels and
    padding=1 so that spatial dimensions are preserved) followed by ReLU activations.
    """
    def __init__(self, in_channels, mid_channels, out_channels):
        super(DoubleConv3D, self).__init__()
        self.double_conv = nn.Sequential(
            nn.Conv3d(in_channels, mid_channels, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv3d(mid_channels, out_channels, kernel_size=3, padding=1),
            nn.ReLU(inplace=True)
        )
    
    def forward(self, x):
        return self.double_conv(x)


class UNet3D(nn.Module):
    """
    3D U-Net model with a four-layer encoder–decoder architecture.
    
    Adapted from Ichikawa, Shota, Makoto Ozaki, Hideki Itadani, Hiroyuki Sugimori, and Yohan Kondo, 
    ‘Deep Learning-Based Correction for Time Truncation in Cerebral Computed Tomography Perfusion’, 
    Radiological Physics and Technology, 17.3 (2024), pp. 666–78, doi:10.1007/s12194-024-00818-6

    Changes: omitted sigmoid activation function after final conv since data is standardized, not normalized
    (likely a mistake in the original paper)

    Expected input shape: (B, T, C, H, W) where C=1.

    Internally, we permute the input to (B, C, T, H, W) so that 3D operations work as intended.
    The encoder path consists of four blocks:
      - Block 1: Two convs from (in_channels → 32 → 64) followed by a max pool (1,2,2).
      - Block 2: Two convs (64 → 64 → 128) with max pool.
      - Block 3: Two convs (128 → 128 → 256) with max pool.
      - Block 4: Two convs (256 → 256 → 512) at the bottleneck.

    The decoder upsamples features using transposed convolutions (with kernel (1,2,2)),
    concatenates skip connections, and applies additional convs. Finally a 1×1×1 convolution
    maps to the desired number of output channels.

    The final output is permuted back to keep the temporal dimension as the second dimension.
    """
    def __init__(self, in_channels=1, n_classes=1):
        super(UNet3D, self).__init__()
        # Encoder
        self.enc1 = DoubleConv3D(in_channels, 32, 64)
        self.pool1 = nn.MaxPool3d(kernel_size=(1,2,2))
        
        self.enc2 = DoubleConv3D(64, 64, 128)
        self.pool2 = nn.MaxPool3d(kernel_size=(1,2,2))
        
        self.enc3 = DoubleConv3D(128, 128, 256)
        self.pool3 = nn.MaxPool3d(kernel_size=(1,2,2))
        
        self.enc4 = DoubleConv3D(256, 256, 512)
        
        # Decoder
        self.up3 = nn.ConvTranspose3d(512, 256, kernel_size=(1,2,2), stride=(1,2,2))
        # After upsampling: concatenation of upsampled features and corresponding encoder feature (256+256)
        self.dec3 = DoubleConv3D(256 + 256, 256, 256)
        
        self.up2 = nn.ConvTranspose3d(256, 128, kernel_size=(1,2,2), stride=(1,2,2))
        self.dec2 = DoubleConv3D(128 + 128, 128, 128)
        
        self.up1 = nn.ConvTranspose3d(128, 64, kernel_size=(1,2,2), stride=(1,2,2))
        self.dec1 = DoubleConv3D(64 + 64, 64, 64)
        
        self.out_conv = nn.Conv3d(64, n_classes, kernel_size=1)

    def forward(self, x):
        """
        Forward pass of the network.

        Parameters:
            x: tensor of shape (B, input_frames, C, H, W)
        
        Returns:
            Output tensor of shape (B, T, output_frames, H, W)
        """
        # Permute input from (B, T, C, H, W) to (B, C, T, H, W) for Conv3d
        x = x.permute(0, 2, 1, 3, 4)
        
        # Encoder path
        enc1 = self.enc1(x)        # => (B, 64, T, H, W)
        x = self.pool1(enc1)       # => (B, 64, T, H/2, W/2)
        
        enc2 = self.enc2(x)        # => (B, 128, T, H/2, W/2)
        x = self.pool2(enc2)       # => (B, 128, T, H/4, W/4)
        
        enc3 = self.enc3(x)        # => (B, 256, T, H/4, W/4)
        x = self.pool3(enc3)       # => (B, 256, T, H/8, W/8)
        
        x = self.enc4(x)           # => (B, 512, T, H/8, W/8)
        
        # Decoder path
        x = self.up3(x)            # => (B, 256, T, H/4, W/4)
        x = torch.cat([x, enc3], dim=1)  # Concatenate along channels: (B, 512, T, H/4, W/4)
        x = self.dec3(x)           # => (B, 256, T, H/4, W/4)
        
        x = self.up2(x)            # => (B, 128, T, H/2, W/2)
        x = torch.cat([x, enc2], dim=1)  # => (B, 256, T, H/2, W/2)
        x = self.dec2(x)           # => (B, 128, T, H/2, W/2)
        
        x = self.up1(x)            # => (B, 64, T, H, W)
        x = torch.cat([x, enc1], dim=1)  # => (B, 128, T, H, W)
        x = self.dec1(x)           # => (B, 64, T, H, W)
        
        out = self.out_conv(x)     # => (B, n_classes, T, H, W)
        
        # Permute output back to (B, T, n_classes, H, W) to match input convention
        out = out.permute(0, 2, 1, 3, 4)
        return out

# Simple test run
# model = UNet3D(in_channels=1, n_classes=1)
# x = torch.randn(1, 8, 1, 256, 256)
# output = model(x)
# print("Output shape:", output.shape)  # Expected: (1, 16, 1, 256, 256)

In [None]:
def compute_rmse(pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
    """
    Compute Root Mean Squared Error (RMSE) between prediction and target.
    """
    mse = F.mse_loss(pred, target)
    rmse = torch.sqrt(mse)
    return rmse

In [6]:

class Pl_Model(pl.LightningModule):
    def __init__(
        self,
        passed_model: nn.Module,
        config: Dict[str, Any],
    ):
        super(Pl_Model, self).__init__()
        self.passed_model = passed_model
        self.config = config

        #speicher alle parameter ab
        self.save_hyperparameters(ignore=["passed_model"])

        
        # Setup training components
        self.mse_criterion = nn.MSELoss()
        self.huber_criterion = nn.HuberLoss(delta=1.0)
        

    def forward(self, x):
        x = self.passed_model(x)
        return x

    def configure_optimizers(self):
        """Sets the Optimizer for the Model"""
        optimizer = optim.Adam(
            self.parameters(), 
            lr=self.config['learning_rate'],
        )
        return [optimizer]

    def _calculate_loss(self, batch, mode="train"):
        """Calculates the loss for a batch in different modes (training, validation, testing)"""
        inputs, targets = batch
        #to device
        #inputs = inputs.to(device)
        #targets = targets.to(device)
        if mode=="val" or mode=="test":
            B, T, C, D, H, W = targets.shape
            outputs = torch.zeros_like(targets)
            # account for added depth adimension in val and test sets
            for h in range(D):
                outputs[:, :, :, h] = self.forward(input[:, :, :, h]).unsqueeze(2)
            if mode=="test":
                rmse = compute_rmse(outputs, targets)
                ssim = structural_similarity_index_measure(outputs, targets)
                psnr = peak_signal_noise_ratio(outputs, targets)
                return rmse, ssim, psnr
        else:
            outputs = self.forward(inputs)
        mse_loss = self.mse_criterion(outputs, targets)
        huber_loss = self.huber_criterion(outputs, targets)
        total_loss = mse_loss + 0.5 * huber_loss
        #forward pass
        # mse_loss = 0.0
        # huber_loss = 0.0
        # total_loss = 0.0

        # for t in range(self.config["output_frames"]):
        #     outputs = self.forward(inputs)
            
        #     #calcualte losses
        #     mse_loss_ = self.mse_criterion(outputs, targets[:, t, :, :].unsqueeze(1))
        #     huber_loss_ = self.huber_criterion(outputs, targets[:, t, :, :].unsqueeze(1))
        #     total_loss += mse_loss_ + 0.5 * huber_loss_
        #     mse_loss += mse_loss_.item()
        #     huber_loss += huber_loss_.item()
            
        #     inputs = torch.cat([inputs[:, 1:, :, :], outputs], dim=1)       

        #logging
        self.log(f"{mode}_mse_loss", mse_loss)
        self.log(f"{mode}_huber_loss", huber_loss)
        self.log(f"{mode}_total_loss", total_loss)

        return total_loss, mse_loss, huber_loss

    def training_step(self, batch, batch_idx):
        loss, _, _ = self._calculate_loss(batch, mode="train")
        return loss

    def validation_step(self, batch, batch_idx):
        _ = self._calculate_loss(batch, mode="val")

    def test_step(self, batch, batch_idx):
        _ = self._calculate_loss(batch, mode="test")

    # def check_losses(self, loader, mode, use_wandb=False):
    #     mse_loss = 0.0
    #     huber_loss = 0.0
    #     total_loss = 0.0
    #     for inputs, targets in loader:
    #         for t in range(self.config["output_frames"]):
    #             mse_loss_ = self.mse_criterion(inputs[:, -1, :, :, :].unsqueeze(1), targets[:, t, :, :, :].unsqueeze(1))
    #             huber_loss_ = self.huber_criterion(inputs[:, -1, :, :, :].unsqueeze(1), targets[:, t, :, :, :].unsqueeze(1))
    #             total_loss_ = mse_loss_ + 0.5 * huber_loss_   
        
    #             mse_loss += mse_loss_.item()
    #             huber_loss += huber_loss_.item()
    #             total_loss += total_loss_.item()
    #     mse_loss = mse_loss/len(loader)
    #     huber_loss = huber_loss/len(loader)
    #     total_loss = total_loss/len(loader)
    
    #     if use_wandb:
    #         self.log(f"Checked_{mode}_mse_loss", mse_loss)
    #         self.log(f"Checked_{mode}_mse_loss", huber_loss)
    #         self.log(f"Checked_{mode}_mse_loss", total_loss)
        
    #     return mse_loss, huber_loss, total_loss
        
    def log_predictions(self):
        """Log example predictions to wandb"""
        #needs to be added to other method
        if epoch % self.config['viz_interval'] == 0:
                self.log_predictions()
        #but this whole method needs to be rewritten
        self.model.eval()
        with torch.no_grad():
            # Get a batch of validation data
            data, target = next(iter(self.val_loader))
            data = data.to(self.device)
            target = target.to(self.device)
            
            # Generate predictions
            output = self.model(data)
            
            # Log images
            wandb.log({
                "predictions": wandb.Image(output[0, 0].cpu()),
                "targets": wandb.Image(target[0, 0].cpu()),
                "input_sequence": [wandb.Image(data[0, i].cpu()) for i in range(data.shape[1])]
            })

In [8]:
config = {
    #for the dataloaders
    'batch_size': 1,
    'learning_rate': 1e-4,
    "num_workers": 4,#0, wenn die gpu nicht benutzt wird
    "pin_memory": True,#False, wenn die gpu nicht benutzt wird
    "drop_last": False,
    'epochs': 30,
    #'log_interval': 20,
    #'viz_interval': 1,
    'run_name': '2D-UNet+_temp_v3',
    'input_frames': 8,
    "output_frames": 8,
    'base_filters': 32
}

# Initialize model
model = UNet3D(
    input_frames=config['input_frames'],
    base_filters=config['base_filters'],
)

# Get data loaders
train_loader, val_loader, test_loader = get_data_loaders_2D_pred(
    batch_size=config['batch_size'],
    num_workers=config["num_workers"],
    pin_memory=config["pin_memory"],
    drop_last=config["drop_last"],
    sequence_length=config["input_frames"], 
    prediction_length=config["output_frames"],
)

wandb_logger = WandbLogger(project="perfusion-ct-prediction", name=config["run_name"])

# Initialize pl_model
pl_model = Pl_Model(
    passed_model=model,
    config=config,
)

# Initialize trainer
trainer = pl.Trainer(
    logger=wandb_logger,
    accelerator="gpu",
    devices= [0] if torch.cuda.is_available() else None,
    max_epochs=config["epochs"],
)

wandb_logger.watch(pl_model)


GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
/Users/simonma/miniconda3/envs/ct/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py:76: Starting from v1.9.0, `tensorboardX` has been removed as a dependency of the `pytorch_lightning` package, due to potential conflicts with other packages in the ML ecosystem. For this reason, `logger=True` will use `CSVLogger` as the default logger, unless the `tensorboard` or `tensorboardX` packages are found. Please `pip install lightning[extra]` or one of them to enable TensorBoard support by default
/Users/simonma/miniconda3/envs/ct/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:425: The 'test_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=7` in the `DataLoader` to improve performance.


Testing: |          | 0/? [00:00<?, ?it/s]

RuntimeError: slow_conv2d_forward_mps: input(device='cpu') and weight(device=mps:0')  must be on the same device

In [62]:

trainer.fit(
    pl_model,
    train_dataloaders=train_loader,
    val_dataloaders=val_loader,
)

#check the losses "to beat"
pl_model.check_losses(train_loader, mode="train", use_wandb=True)
pl_model.check_losses(val_loader, mode="val", use_wandb=True)
pl_model.check_losses(test_loader, mode="test", use_wandb=True)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name         | Type               | Params | Mode 
------------------------------------------------------------
0 | passed_model | UNet2DPlusTemporal | 2.0 M  | train
1 | criterion    | MSELoss            | 0      | train
2 | huber_loss   | HuberLoss          | 0      | train
------------------------------------------------------------
2.0 M     Trainable params
0         Non-trainable params
2.0 M     Total params
8.047     Total estimated model params size (MB)
82        Modules in train mode
0         Modules in eval mode


Sanity Checking: |                                                                                            …

/usr/local/lib/python3.8/dist-packages/pytorch_lightning/trainer/connectors/data_connector.py:424: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=255` in the `DataLoader` to improve performance.


Training: |                                                                                                   …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

`Trainer.fit` stopped: `max_epochs=30` reached.


In [63]:
val_results = trainer.validate(dataloaders=val_loader)
test_results = trainer.test(dataloaders=test_loader)

Restoring states from the checkpoint path at ./perfusion-ct-prediction/mpw4rbc7/checkpoints/epoch=29-step=2880.ckpt
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Loaded model weights from the checkpoint at ./perfusion-ct-prediction/mpw4rbc7/checkpoints/epoch=29-step=2880.ckpt


Validation: |                                                                                                 …

Restoring states from the checkpoint path at ./perfusion-ct-prediction/mpw4rbc7/checkpoints/epoch=29-step=2880.ckpt
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Loaded model weights from the checkpoint at ./perfusion-ct-prediction/mpw4rbc7/checkpoints/epoch=29-step=2880.ckpt


────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
     Validate metric           DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
     val_huber_loss         0.02376500517129898
      val_mse_loss          0.04780226945877075
     val_total_loss         0.05968477204442024
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────


/usr/local/lib/python3.8/dist-packages/pytorch_lightning/trainer/connectors/data_connector.py:424: The 'test_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=255` in the `DataLoader` to improve performance.


Testing: |                                                                                                    …

────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
     test_huber_loss       0.011898771859705448
      test_mse_loss        0.023859838023781776
     test_total_loss       0.029809223487973213
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────


In [None]:
save_load_path = f"../ModelWeights/{config["run_name"]}.ckpt"
trainer.save_checkpoint(save_load_path)