In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchmetrics.functional import structural_similarity_index_measure, peak_signal_noise_ratio
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
import os
import math
import torchvision
from torchvision import datasets, transforms
from sklearn.metrics import confusion_matrix
from sklearn.metrics import classification_report
import random
import pytorch_lightning as pl
from torch.nn import Parameter
from typing import Dict, Any
from ipywidgets import IntSlider, interact

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

import wandb
wandb.login()

from pytorch_lightning.loggers import WandbLogger

cpu


In [11]:
class Dataset3D(Dataset):
    def __init__(self, data_paths, context_window=4, prediction_window=1, transform=None):
        self.data_paths = data_paths
        self.context_window = context_window
        self.prediction_window = prediction_window
        self.transform = transform
        self.samples = []
        # For every path to a volume sequence in .npy
        for data_path in self.data_paths:
            volume_seq = np.load(data_path)
            # Convert to tensor
            volume_seq = torch.from_numpy(volume_seq)

            for t in range(volume_seq.shape[0]-context_window-prediction_window+1):
                self.samples.append([volume_seq[t:t+context_window].unsqueeze(1), volume_seq[t+context_window:t+context_window+prediction_window].unsqueeze(1)])
                # Input-shape: [T, C, D, H, W]
                # Target-shape: [T, C, D, H, W]
    def __len__(self):
        return len(self.samples)
    
    def __getitem__(self, idx):
        return self.samples[idx]

In [14]:
def get_data_loaders_3D_pred(root='../Data',
                            batch_size=4, 
                            sequence_length=4, 
                            prediction_length=1, 
                            num_workers=0, 
                            pin_memory=False, 
                            drop_last=False,
                            train_split=0.8,
                            val_split=0.1):
    """
    All Datasets are 3D
    """
    data_paths = [os.path.join(root, path) for path in os.listdir(root)]
    # Load all folder paths
    # Split into train/val/test
    n_train = int(len(data_paths) * train_split)
    n_val = int(len(data_paths) * val_split)
    train_paths = data_paths[:n_train]
    val_paths = data_paths[n_train:n_train+n_val]
    test_paths = data_paths[n_train+n_val:]
    
    # Create datasets
    train_dataset = Dataset3D(train_paths, sequence_length, prediction_length)
    val_dataset = Dataset3D(val_paths, sequence_length, prediction_length)
    test_dataset = Dataset3D(test_paths, sequence_length, prediction_length)
    
    # Create dataloaders
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, drop_last=drop_last)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, drop_last=drop_last)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, drop_last=drop_last)
    
    return train_loader, val_loader, test_loader

In [15]:

train_loader, val_loader, test_loader = get_data_loaders_3D_pred(
    batch_size=40, 
    sequence_length=8,
    prediction_length=1
    )

In [16]:
for i, (input, target) in enumerate(train_loader):
    print(i, input.shape, target.shape)
    break


0 torch.Size([40, 8, 1, 16, 256, 256]) torch.Size([40, 1, 1, 16, 256, 256])


2Plus1D Unet

In [19]:
class DoubleConv3D(nn.Module):
    """
    DoubleConv3D
    -------------
    Applies two successive 3D convolutions with Batch Normalization and ReLU activation.

    Input shape: (B, in_channels, D, H, W)
    Output shape: (B, out_channels, D, H, W)
    """
    def __init__(self, in_channels, out_channels, mid_channels=None):
        super().__init__()
        if mid_channels is None:
            mid_channels = out_channels
        self.double_conv = nn.Sequential(
            nn.Conv3d(in_channels, mid_channels, kernel_size=3, padding=1),  # -> (B, mid_channels, D, H, W)
            nn.BatchNorm3d(mid_channels),
            nn.ReLU(inplace=True),
            nn.Conv3d(mid_channels, out_channels, kernel_size=3, padding=1),  # -> (B, out_channels, D, H, W)
            nn.BatchNorm3d(out_channels),
            nn.ReLU(inplace=True)
        )
        
    def forward(self, x):
        return self.double_conv(x)
    
    
class AdvancedTemporalBlock3D(nn.Module):
    """
    AdvancedTemporalBlock3D
    -------------------------
    This block processes temporal relationships on volumetric features with enhanced sophistication.
    
    It applies a series of 1D convolutions along the temporal dimension with:
      - Increasing dilation rates (to widen the receptive field)
      - Residual connections (to ease optimization)
      - Dropout (for regularization)
      - Batch normalization and ReLU for improved training stability
      
    Input shape: (B, C, T, D, H, W)
    Output shape: (B, C, T, D, H, W)  (the temporal length T is preserved)
    """
    def __init__(self, channels, kernel_size, num_layers=2, dropout=0.2, dilation_base=2):
        super().__init__()
        self.num_layers = num_layers
        self.layers = nn.ModuleList()
        self.dropout = nn.Dropout(dropout)
        
        for i in range(num_layers):
            dilation = dilation_base ** i   # Increase dilation at each layer
            # Set padding so that output temporal size equals input T.
            padding = (kernel_size - 1) * dilation // 2
            self.layers.append(
                nn.Sequential(
                    nn.Conv1d(channels, channels, kernel_size=kernel_size, dilation=dilation, padding=padding),
                    nn.BatchNorm1d(channels),
                    nn.ReLU(inplace=True)
                )
            )
    
    def forward(self, x):
        # Input x shape: (B, C, T, D, H, W)
        B, C, T, D, H, W = x.shape
        
        # Reshape to merge the spatial dimensions: (B*D*H*W, C, T)
        x_reshaped = x.permute(0, 3, 4, 5, 1, 2).contiguous().view(-1, C, T)
        
        # Apply a series of temporal convolutions with residual connections.
        for layer in self.layers:
            out = layer(x_reshaped)  # -> (B*D*H*W, C, T)
            out = self.dropout(out)
            # Residual connection: the output of each layer is added to its input.
            x_reshaped = x_reshaped + out
        
        # Restore the original shape:
        # First reshape back to (B, D, H, W, C, T)
        x_reshaped = x_reshaped.view(B, D, H, W, C, T)
        # Permute back to (B, C, T, D, H, W)
        out = x_reshaped.permute(0, 4, 5, 1, 2, 3).contiguous()
        return out
    
    
class UNet3DPlusTemporal(nn.Module):
    """
    UNet3DPlusTemporal with Advanced Temporal Blocks
    --------------------------------------------------
    Extended UNet designed for volumetric (3D) data with temporal sequences.
    
    Expected input shape: (B, C, T, D, H, W)
      B = batch size
      C = number of input channels (e.g., 1 for grayscale)
      T = number of time frames (must equal input_frames)
      D, H, W = spatial dimensions
      
    Spatial processing is performed using 3D convolutions on each time step independently.
    Temporal processing is then applied on the stacked features using the advanced temporal blocks.
    """
    def __init__(self, input_frames=8, base_filters=32, in_channels=1):
        super().__init__()
        self.input_frames = input_frames
        
        # ---------------------
        # Encoder Path
        # ---------------------
        # Process each 3D volume (at a given time step) independently.
        self.enc1 = DoubleConv3D(in_channels, base_filters)  
        # After enc1: (B, base_filters, D, H, W)
        self.pool1 = nn.MaxPool3d(2)  # -> (B, base_filters, D/2, H/2, W/2)
        
        self.enc2 = DoubleConv3D(base_filters, base_filters * 2)
        # After enc2: (B, base_filters*2, D/2, H/2, W/2)
        self.pool2 = nn.MaxPool3d(2)  # -> (B, base_filters*2, D/4, H/4, W/4)
        
        self.enc3 = DoubleConv3D(base_filters * 2, base_filters * 4)
        # After enc3: (B, base_filters*4, D/4, H/4, W/4)
        # Apply an advanced temporal block to capture temporal dynamics before spatial pooling.
        self.temporal_enc = AdvancedTemporalBlock3D(channels=base_filters * 4, kernel_size=3, num_layers=2, dropout=0.2, dilation_base=2)
        self.pool3 = nn.MaxPool3d(2)  # -> (B, base_filters*4, D/8, H/8, W/8)
        
        # ---------------------
        # Bottleneck
        # ---------------------
        self.bottleneck_spatial = DoubleConv3D(base_filters * 4, base_filters * 8)
        # Bottleneck features: (B, base_filters*8, D/8, H/8, W/8)
        self.temporal_bottleneck = AdvancedTemporalBlock3D(channels=base_filters * 8, kernel_size=3, num_layers=2, dropout=0.2, dilation_base=2)
        
        # ---------------------
        # Decoder Path
        # ---------------------
        self.upconv3 = nn.ConvTranspose3d(base_filters * 8, base_filters * 4, kernel_size=2, stride=2)
        self.dec3 = DoubleConv3D(base_filters * 8, base_filters * 4)
        
        self.upconv2 = nn.ConvTranspose3d(base_filters * 4, base_filters * 2, kernel_size=2, stride=2)
        self.dec2 = DoubleConv3D(base_filters * 4, base_filters * 2)
        
        self.upconv1 = nn.ConvTranspose3d(base_filters * 2, base_filters, kernel_size=2, stride=2)
        self.dec1 = DoubleConv3D(base_filters * 2, base_filters)
        
        self.final_conv = nn.Conv3d(base_filters, 1, kernel_size=1)
    
    def forward(self, x):
        """
        Forward pass of UNet3DPlusTemporal.
        
        x: Input tensor of shape (B, C, T, D, H, W)
        """
        B, C, T, D, H, W = x.shape
        assert T == self.input_frames, f"Expected {self.input_frames} frames, got {T}"
        
        encoder_features = []  # To store skip connections from earlier encoder stages.
        enc3_features = []     # To store features from enc3 for temporal processing.
        
        # Process each time step independently through the encoder.
        for i in range(T):
            # Extract the 3D volume for time step i: (B, C, D, H, W)
            curr_vol = x[:, :, i]
            
            # Encoder Stage 1
            e1 = self.enc1(curr_vol)         # -> (B, base_filters, D, H, W)
            p1 = self.pool1(e1)              # -> (B, base_filters, D/2, H/2, W/2)
            
            # Encoder Stage 2
            e2 = self.enc2(p1)               # -> (B, base_filters*2, D/2, H/2, W/2)
            p2 = self.pool2(e2)              # -> (B, base_filters*2, D/4, H/4, W/4)
            
            encoder_features.append((e1, e2))  # Save skip connection features for the decoder.
            
            # Encoder Stage 3
            e3 = self.enc3(p2)               # -> (B, base_filters*4, D/4, H/4, W/4)
            enc3_features.append(e3)
        
        # Stack enc3 features along the temporal dimension.
        # New shape: (B, base_filters*4, T, D/4, H/4, W/4)
        enc3_features = torch.stack(enc3_features, dim=2)
        
        # Apply advanced temporal processing.
        enc3_processed = self.temporal_enc(enc3_features)  # -> (B, base_filters*4, T, D/4, H/4, W/4)
        
        # Spatial pooling across the processed features.
        B, C3, T, D_enc, H_enc, W_enc = enc3_processed.shape
        enc3_pooled = enc3_processed.view(B * T, C3, D_enc, H_enc, W_enc)  # (B*T, C3, D/4, H/4, W/4)
        enc3_pooled = self.pool3(enc3_pooled)   # -> (B*T, C3, D/8, H/8, W/8)
        _, _, D_pool, H_pool, W_pool = enc3_pooled.shape
        enc3_pooled = enc3_pooled.view(B, C3, T, D_pool, H_pool, W_pool)
        
        # Bottleneck processing: apply the spatial bottleneck then advanced temporal processing for each time step.
        bottle_features = []
        for i in range(T):
            feat = enc3_pooled[:, :, i]  # -> (B, C3, D/8, H/8, W/8)
            bottle_feat = self.bottleneck_spatial(feat)  # -> (B, base_filters*8, D/8, H/8, W/8)
            bottle_features.append(bottle_feat)
        # Stack to get shape: (B, base_filters*8, T, D/8, H/8, W/8)
        bottle_features = torch.stack(bottle_features, dim=2)
        bottle_processed = self.temporal_bottleneck(bottle_features)  # -> (B, base_filters*8, T, D/8, H/8, W/8)
        
        # For decoding, we select the final temporal state.
        bottle_final = bottle_processed[:, :, -1]  # -> (B, base_filters*8, D/8, H/8, W/8)
        
        # ---------------------
        # Decoder Path
        # ---------------------
        # Retrieve skip connection features from the last time step.
        e1_last, e2_last = encoder_features[-1]
        e3_last = enc3_processed[:, :, -1]  # -> (B, base_filters*4, D/4, H/4, W/4)
        
        d3 = self.upconv3(bottle_final)   # -> (B, base_filters*4, D/4, H/4, W/4)
        d3 = torch.cat([d3, e3_last], dim=1)  # Concatenated channels: (B, base_filters*8, D/4, H/4, W/4)
        d3 = self.dec3(d3)  # -> (B, base_filters*4, D/4, H/4, W/4)
        
        d2 = self.upconv2(d3)             # -> (B, base_filters*2, D/2, H/2, W/2)
        d2 = torch.cat([d2, e2_last], dim=1)   # -> (B, base_filters*4, D/2, H/2, W/2)
        d2 = self.dec2(d2)  # -> (B, base_filters*2, D/2, H/2, W/2)
        
        d1 = self.upconv1(d2)             # -> (B, base_filters, D, H, W)
        d1 = torch.cat([d1, e1_last], dim=1)   # -> (B, base_filters*2, D, H, W)
        d1 = self.dec1(d1)  # -> (B, base_filters, D, H, W)
        
        out = self.final_conv(d1)         # -> (B, 1, D, H, W)
        return out

In [None]:
# Test the model
m = UNet3DPlusTemporal()
m(torch.randn(1, 1, 8, 256, 256, 256))

: 

In [None]:
def compute_rmse(pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
    """
    Compute Root Mean Squared Error (RMSE) between prediction and target.
    
    Args:
        pred (torch.Tensor): The predicted tensor.
        target (torch.Tensor): The ground truth tensor.
    
    Returns:
        torch.Tensor: The computed RMSE.
    """
    mse = F.mse_loss(pred, target)
    rmse = torch.sqrt(mse)
    return rmse

In [11]:

class Pl_Model(pl.LightningModule):
    def __init__(
        self,
        passed_model: nn.Module,
        config: Dict[str, Any],
        test
    ):
        super(Pl_Model, self).__init__()
        self.passed_model = passed_model
        self.config = config

        #speicher alle parameter ab
        self.save_hyperparameters()
        
        # Setup training components
        self.mse_criterion = nn.MSELoss()
        self.huber_criterion = nn.HuberLoss(delta=1.0)
        

    def forward(self, x):
        x = self.passed_model(x)
        return x

    def configure_optimizers(self):
        """Sets the Optimizer for the Model"""
        optimizer = optim.Adam(
            self.parameters(), 
            lr=self.hparams.config['learning_rate'],
        )
        return [optimizer]

    def _calculate_loss(self, batch, mode="train"):
        """Calculates the loss for a batch in different modes (training, validation, testing)"""
        inputs, targets = batch
        #to device
        #inputs = inputs.to(device)
        #targets = targets.to(device)

        #forward pass
        outputs = self.forward(inputs)

        #calcualte losses
        mse_loss = self.mse_criterion(outputs, targets)
        huber_loss = self.huber_criterion(outputs, targets)
        total_loss = mse_loss + 0.5 * huber_loss

        #logging
        self.log(f"{mode}_mse_loss", mse_loss)
        self.log(f"{mode}_huber_loss", huber_loss)
        self.log(f"{mode}_total_loss", total_loss)

        return total_loss, mse_loss, huber_loss

    def training_step(self, batch, batch_idx):
        loss, _, _ = self._calculate_loss(batch, mode="train")
        return loss

    def validation_step(self, batch, batch_idx):
        _ = self._calculate_loss(batch, mode="val")

    def test_step(self, batch, batch_idx):
        _ = self._calculate_loss(batch, mode="test")

    def check_losses(self, loader, mode, use_wandb=False):
        mse_loss = 0.0
        huber_loss = 0.0
        total_loss = 0.0
        for inputs, targets in loader:
            mse_loss_ = self.mse_criterion(inputs[:, -1, :, :, :].unsqueeze(1), targets)
            huber_loss_ = self.huber_criterion(inputs[:, -1, :, :, :].unsqueeze(1), targets)
            total_loss_ = mse_loss_ + 0.5 * huber_loss_   
        
            mse_loss += mse_loss_.item()
            huber_loss += huber_loss_.item()
            total_loss += total_loss_.item()
        mse_loss = mse_loss/len(loader)
        huber_loss = huber_loss/len(loader)
        total_loss = total_loss/len(loader)
    
        if use_wandb:
            self.log(f"Checked_{mode}_mse_loss", mse_loss)
            self.log(f"Checked_{mode}_mse_loss", huber_loss)
            self.log(f"Checked_{mode}_mse_loss", total_loss)
        
        return mse_loss, huber_loss, total_loss
        
    def log_predictions(self):
        """Log example predictions to wandb"""
        #needs to be added to other method
        if epoch % self.config['viz_interval'] == 0:
                self.log_predictions()
        #but this whole method needs to be rewritten
        self.model.eval()
        with torch.no_grad():
            # Get a batch of validation data
            data, target = next(iter(self.val_loader))
            data = data.to(self.device)
            target = target.to(self.device)
            
            # Generate predictions
            output = self.model(data)
            
            # Log images
            wandb.log({
                "predictions": wandb.Image(output[0, 0].cpu()),
                "targets": wandb.Image(target[0, 0].cpu()),
                "input_sequence": [wandb.Image(data[0, i].cpu()) for i in range(data.shape[1])]
            })

In [15]:
config = {
    #for the dataloaders
    'batch_size': 40,
    'input_frames': 8,
    'prediction_length': 1,
    'root': '../Data',
    'learning_rate': 0.0005,
    "num_workers": 10,# 0, wenn die gpu nicht benutzt wird
    "pin_memory": True,# False, wenn die gpu nicht benutzt wird
    "drop_last": False,
    'epochs': 40,
    
    #'log_interval': 20,
    #'viz_interval': 1,
    'run_name': '3D-UNet+_temp_v1',
    'base_filters': 32
}

# Initialize model
model = UNet3DPlusTemporal(
    input_frames=config['input_frames'],
    base_filters=config['base_filters'],
)

# Get data loaders
train_loader, val_loader, test_loader = get_data_loaders_3D_pred(
    batch_size=config['batch_size'],
    sequence_length=config['input_frames'],
    num_workers=config["num_workers"],
    pin_memory=config["pin_memory"],
    drop_last=config["drop_last"],
)

wandb_logger = WandbLogger(project="perfusion-ct-prediction", name=config["run_name"])

# Initialize pl_model
pl_model = Pl_Model(
    passed_model=model,
    config=config,
)

# Initialize trainer
trainer = pl.Trainer(
    logger=wandb_logger,
    accelerator="gpu",
    devices= [1] if torch.cuda.is_available() else None,
    max_epochs=config["epochs"],
)

wandb_logger.watch(pl_model)


GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
[34m[1mwandb[0m: logging graph, to disable use `wandb.watch(log_graph=False)`


In [16]:

trainer.fit(
    pl_model,
    train_dataloaders=train_loader,
    val_dataloaders=val_loader,
)


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]

  | Name            | Type               | Params | Mode 
---------------------------------------------------------------
0 | passed_model    | UNet2DPlusTemporal | 2.0 M  | train
1 | mse_criterion   | MSELoss            | 0      | train
2 | huber_criterion | HuberLoss          | 0      | train
---------------------------------------------------------------
2.0 M     Trainable params
0         Non-trainable params
2.0 M     Total params
8.047     Total estimated model params size (MB)
82        Modules in train mode
0         Modules in eval mode


Sanity Checking: |                                                                                            …

Training: |                                                                                                   …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

`Trainer.fit` stopped: `max_epochs=40` reached.


IndexError: too many indices for tensor of dimension 4

In [18]:
inputs, targets = next(iter(train_loader))

In [19]:
inputs.shape

torch.Size([40, 8, 256, 256])

In [22]:
val_results = trainer.validate(dataloaders=val_loader)
test_results = trainer.test(dataloaders=test_loader)

Restoring states from the checkpoint path at ./perfusion-ct-prediction/90959bem/checkpoints/epoch=39-step=1280.ckpt
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]
Loaded model weights from the checkpoint at ./perfusion-ct-prediction/90959bem/checkpoints/epoch=39-step=1280.ckpt


Validation: |                                                                                                 …

Restoring states from the checkpoint path at ./perfusion-ct-prediction/90959bem/checkpoints/epoch=39-step=1280.ckpt
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]
Loaded model weights from the checkpoint at ./perfusion-ct-prediction/90959bem/checkpoints/epoch=39-step=1280.ckpt


────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
     Validate metric           DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
     val_huber_loss        0.0055703893303871155
      val_mse_loss         0.011146889999508858
     val_total_loss        0.013932084664702415
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────


/usr/local/lib/python3.8/dist-packages/pytorch_lightning/trainer/connectors/data_connector.py:424: The 'test_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=255` in the `DataLoader` to improve performance.


Testing: |                                                                                                    …

────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
     test_huber_loss       0.006914335303008556
      test_mse_loss        0.013844689354300499
     test_total_loss       0.017301857471466064
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────


In [21]:
save_load_path = f"../ModelWeights/{config['run_name']}.ckpt"
trainer.save_checkpoint(save_load_path)