#### Imports

In [1]:
import math
import os
import random
from typing import Any, Dict

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import pytorch_lightning as pl
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
from ipywidgets import IntSlider, interact
from sklearn.metrics import classification_report, confusion_matrix
from torch.nn import Parameter
from torch.utils.data import DataLoader, Dataset, random_split
from torchvision import datasets, transforms
from torchmetrics.image import StructuralSimilarityIndexMeasure, PeakSignalNoiseRatio
import gc

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

import wandb

wandb.login()

from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.callbacks import RichProgressBar, ModelCheckpoint
from pytorch_lightning.tuner import Tuner

import kornia

cuda


[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mnbennewiz[0m to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


#### Dataset

In [2]:
class Dataset3D(Dataset):
    def __init__(self, data_paths, context_window=4, prediction_window=1, transform=None):
        self.data_paths = data_paths
        self.context_window = context_window
        self.prediction_window = prediction_window
        self.transform = transform

        self.data_attributes = []
        test = np.load(self.data_paths[0])
        for data_path in self.data_paths:
            for t in range(test.shape[0]-self.context_window-self.prediction_window+1):
                # file_path, t
                self.data_attributes.append([data_path, t])
        
    def __len__(self):
        return len(self.data_attributes)
    
    def __getitem__(self, idx):
        t = self.data_attributes[idx][1]
        volume_seq = torch.from_numpy(np.load(self.data_attributes[idx][0]))
        return (
            volume_seq[t:t+self.context_window].unsqueeze(1), 
            volume_seq[t+self.context_window:t+self.context_window+self.prediction_window].unsqueeze(1)
        )

In [3]:
def seed_worker(worker_id):
    """Ensures worker processes get the same seed"""
    worker_seed = torch.initial_seed() % 2**32
    np.random.seed(worker_seed)
    random.seed(worker_seed)

class VolumeDataModule(pl.LightningDataModule):
    def __init__(self, root, batch_size=4, sequence_length=4, prediction_length=1, num_workers=0, drop_last=False, pin_memory=False, train_split=0.8, val_split=0.1, test_split=0.1):
        super().__init__()
        self.root = root
        self.batch_size = batch_size
        self.sequence_length = sequence_length
        self.prediction_length = prediction_length
        self.num_workers = num_workers
        self.drop_last = drop_last
        self.pin_memory = pin_memory
        self.train_split = train_split
        self.val_split = val_split
        self.test_split = test_split

    def setup(self, stage=None):
        data_paths = [os.path.join(self.root, path) for path in os.listdir(self.root)]
        total_size = len(data_paths)
    
        # Normalize splits if they don’t sum to 1
        split_sum = self.train_split + self.val_split + self.test_split
        if split_sum != 1.0:
            self.train_split /= split_sum
            self.val_split /= split_sum
            self.test_split /= split_sum
            print(f"Normalized splits to: train={self.train_split:.2f}, val={self.val_split:.2f}, test={self.test_split:.2f}")
    
        # Compute dataset sizes
        train_size = int(total_size * self.train_split)
        val_size = int(total_size * self.val_split)
        test_size = total_size - train_size - val_size  # Ensure all data is used
    
        # Error handling: Ensure valid split sizes
        if train_size <= 0 or val_size <= 0 or test_size <= 0:
            raise ValueError(f"Invalid dataset splits: train={train_size}, val={val_size}, test={test_size}. Check your split values.")
    
        # Perform random split
        self.train_paths, self.val_paths, self.test_paths = random_split(data_paths, [train_size, val_size, test_size], generator=torch.Generator().manual_seed(42))

        
        
        self.train_dataset = Dataset3D(self.train_paths, self.sequence_length, self.prediction_length)
        self.val_dataset = Dataset3D(self.val_paths, self.sequence_length, self.prediction_length)
        self.test_dataset = Dataset3D(self.test_paths, self.sequence_length, self.prediction_length)
    
    def train_dataloader(self):
        return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True, num_workers=self.num_workers, drop_last=self.drop_last, pin_memory=self.pin_memory, worker_init_fn=seed_worker)

    def val_dataloader(self):
        return DataLoader(self.val_dataset, batch_size=self.batch_size, num_workers=self.num_workers, drop_last=self.drop_last, pin_memory=self.pin_memory, worker_init_fn=seed_worker)

    def test_dataloader(self):
        return DataLoader(self.test_dataset, batch_size=self.batch_size, num_workers=self.num_workers, drop_last=self.drop_last, pin_memory=self.pin_memory, worker_init_fn=seed_worker)

    def teardown(self, stage=None):
        if stage == "fit" or stage is None:
            pass
            #print("Cleaning up after training...")

        if stage == "test" or stage is None:
            pass
            #print("Cleaning up after testing...")

        if stage == "validate" or stage is None:
            pass
            #print("Cleaning up after validation...")

        # Free memory by deleting large datasets
        del self.train_dataset
        del self.val_dataset
        del self.test_dataset

#### Model

In [4]:
class DoubleConv3D(nn.Module):
    """
    DoubleConv3D
    -------------
    Applies two successive 3D convolutions with Batch Normalization and ReLU activation.

    Input shape: (B, in_channels, D, H, W)
    Output shape: (B, out_channels, D, H, W)
    """
    def __init__(self, in_channels, out_channels, mid_channels=None):
        super().__init__()
        if mid_channels is None:
            mid_channels = out_channels
        self.double_conv = nn.Sequential(
            nn.Conv3d(in_channels, mid_channels, kernel_size=3, padding=1),  # -> (B, mid_channels, D, H, W)
            nn.BatchNorm3d(mid_channels),
            nn.ReLU(inplace=True),
            nn.Conv3d(mid_channels, out_channels, kernel_size=3, padding=1),  # -> (B, out_channels, D, H, W)
            nn.BatchNorm3d(out_channels),
            nn.ReLU(inplace=True)
        )
        
    def forward(self, x):
        return self.double_conv(x)
    
    
class AdvancedTemporalBlock3D(nn.Module):
    """
    AdvancedTemporalBlock3D
    -------------------------
    This block processes temporal relationships on volumetric features with enhanced sophistication.
    
    It applies a series of 1D convolutions along the temporal dimension with:
      - Increasing dilation rates (to widen the receptive field)
      - Residual connections (to ease optimization)
      - Dropout (for regularization)
      - Batch normalization and ReLU for improved training stability
      
    Input shape: (B, C, T, D, H, W)
    Output shape: (B, C, T, D, H, W)  (the temporal length T is preserved)
    """
    def __init__(self, channels, kernel_size, num_layers=2, dropout=0.2, dilation_base=2):
        super().__init__()
        self.num_layers = num_layers
        self.layers = nn.ModuleList()
        self.dropout = nn.Dropout(dropout)
        
        for i in range(num_layers):
            dilation = dilation_base ** i   # Increase dilation at each layer
            # Set padding so that output temporal size equals input T.
            padding = (kernel_size - 1) * dilation // 2
            self.layers.append(
                nn.Sequential(
                    nn.Conv1d(channels, channels, kernel_size=kernel_size, dilation=dilation, padding=padding),
                    nn.BatchNorm1d(channels),
                    nn.ReLU(inplace=True)
                )
            )
    
    def forward(self, x):
        # Input x shape: (B, C, T, D, H, W)
        B, C, T, D, H, W = x.shape
        
        # Reshape to merge the spatial dimensions: (B*D*H*W, C, T)
        x_reshaped = x.permute(0, 3, 4, 5, 1, 2).contiguous().view(-1, C, T)
        
        # Apply a series of temporal convolutions with residual connections.
        for layer in self.layers:
            out = layer(x_reshaped)  # -> (B*D*H*W, C, T)
            out = self.dropout(out)
            # Residual connection: the output of each layer is added to its input.
            x_reshaped = x_reshaped + out
        
        # Restore the original shape:
        # First reshape back to (B, D, H, W, C, T)
        x_reshaped = x_reshaped.view(B, D, H, W, C, T)
        # Permute back to (B, C, T, D, H, W)
        out = x_reshaped.permute(0, 4, 5, 1, 2, 3).contiguous()
        return out
    
    
class UNet3DPlusTemporal(nn.Module):
    """
    UNet3DPlusTemporal with Advanced Temporal Blocks
    --------------------------------------------------
    Extended UNet designed for volumetric (3D) data with temporal sequences.
    
    Expected input shape: (B, C, T, D, H, W)
      B = batch size
      C = number of input channels (e.g., 1 for grayscale)
      T = number of time frames (must equal input_frames)
      D, H, W = spatial dimensions
      
    Spatial processing is performed using 3D convolutions on each time step independently.
    Temporal processing is then applied on the stacked features using the advanced temporal blocks.
    """
    def __init__(self, input_frames=8, base_filters=32, in_channels=1):
        super().__init__()
        self.input_frames = input_frames
        
        # ---------------------
        # Encoder Path
        # ---------------------
        # Process each 3D volume (at a given time step) independently.
        self.enc1 = DoubleConv3D(in_channels, base_filters)  
        # After enc1: (B, base_filters, D, H, W)
        self.pool1 = nn.MaxPool3d(2)  # -> (B, base_filters, D/2, H/2, W/2)
        
        self.enc2 = DoubleConv3D(base_filters, base_filters * 2)
        # After enc2: (B, base_filters*2, D/2, H/2, W/2)
        self.pool2 = nn.MaxPool3d(2)  # -> (B, base_filters*2, D/4, H/4, W/4)
        
        self.enc3 = DoubleConv3D(base_filters * 2, base_filters * 4)
        # After enc3: (B, base_filters*4, D/4, H/4, W/4)
        # Apply an advanced temporal block to capture temporal dynamics before spatial pooling.
        self.temporal_enc = AdvancedTemporalBlock3D(channels=base_filters * 4, kernel_size=3, num_layers=2, dropout=0.2, dilation_base=2)
        self.pool3 = nn.MaxPool3d(2)  # -> (B, base_filters*4, D/8, H/8, W/8)
        
        # ---------------------
        # Bottleneck
        # ---------------------
        self.bottleneck_spatial = DoubleConv3D(base_filters * 4, base_filters * 8)
        # Bottleneck features: (B, base_filters*8, D/8, H/8, W/8)
        self.temporal_bottleneck = AdvancedTemporalBlock3D(channels=base_filters * 8, kernel_size=3, num_layers=2, dropout=0.2, dilation_base=2)
        
        # ---------------------
        # Decoder Path
        # ---------------------
        self.upconv3 = nn.ConvTranspose3d(base_filters * 8, base_filters * 4, kernel_size=2, stride=2)
        self.dec3 = DoubleConv3D(base_filters * 8, base_filters * 4)
        
        self.upconv2 = nn.ConvTranspose3d(base_filters * 4, base_filters * 2, kernel_size=2, stride=2)
        self.dec2 = DoubleConv3D(base_filters * 4, base_filters * 2)
        
        self.upconv1 = nn.ConvTranspose3d(base_filters * 2, base_filters, kernel_size=2, stride=2)
        self.dec1 = DoubleConv3D(base_filters * 2, base_filters)
        
        self.final_conv = nn.Conv3d(base_filters, 1, kernel_size=1)
    
    def forward(self, x):
        """
        Forward pass of UNet3DPlusTemporal.
        
        x: Input tensor of shape (B, T, C, D, H, W)
        """
        B, T, C, D, H, W = x.shape
        assert T == self.input_frames, f"Expected {self.input_frames} frames, got {T}"
        
        encoder_features = []  # To store skip connections from earlier encoder stages.
        enc3_features = []     # To store features from enc3 for temporal processing.
        
        # Process each time step independently through the encoder.
        for i in range(T):
            # Extract the 3D volume for time step i: (B, C, D, H, W)
            curr_vol = x[:, i]
            
            # Encoder Stage 1
            e1 = self.enc1(curr_vol)         # -> (B, base_filters, D, H, W)
            p1 = self.pool1(e1)              # -> (B, base_filters, D/2, H/2, W/2)
            
            # Encoder Stage 2
            e2 = self.enc2(p1)               # -> (B, base_filters*2, D/2, H/2, W/2)
            p2 = self.pool2(e2)              # -> (B, base_filters*2, D/4, H/4, W/4)
            
            encoder_features.append((e1, e2))  # Save skip connection features for the decoder.
            
            # Encoder Stage 3
            e3 = self.enc3(p2)               # -> (B, base_filters*4, D/4, H/4, W/4)
            enc3_features.append(e3)
        
        # Stack enc3 features along the temporal dimension.
        # New shape: (B, base_filters*4, T, D/4, H/4, W/4)
        enc3_features = torch.stack(enc3_features, dim=2)
        
        # Apply advanced temporal processing.
        enc3_processed = self.temporal_enc(enc3_features)  # -> (B, base_filters*4, T, D/4, H/4, W/4)
        
        # Spatial pooling across the processed features.
        B, C3, T, D_enc, H_enc, W_enc = enc3_processed.shape
        enc3_pooled = enc3_processed.view(B * T, C3, D_enc, H_enc, W_enc)  # (B*T, C3, D/4, H/4, W/4)
        enc3_pooled = self.pool3(enc3_pooled)   # -> (B*T, C3, D/8, H/8, W/8)
        _, _, D_pool, H_pool, W_pool = enc3_pooled.shape
        enc3_pooled = enc3_pooled.view(B, C3, T, D_pool, H_pool, W_pool)
        
        # Bottleneck processing: apply the spatial bottleneck then advanced temporal processing for each time step.
        bottle_features = []
        for i in range(T):
            feat = enc3_pooled[:, :, i]  # -> (B, C3, D/8, H/8, W/8)
            bottle_feat = self.bottleneck_spatial(feat)  # -> (B, base_filters*8, D/8, H/8, W/8)
            bottle_features.append(bottle_feat)
        # Stack to get shape: (B, base_filters*8, T, D/8, H/8, W/8)
        bottle_features = torch.stack(bottle_features, dim=2)
        bottle_processed = self.temporal_bottleneck(bottle_features)  # -> (B, base_filters*8, T, D/8, H/8, W/8)
        
        # For decoding, we select the final temporal state.
        #bottle_final = bottle_processed[:, :, -1]  # -> (B, base_filters*8, D/8, H/8, W/8)
        
        # ---------------------
        # Decoder Path
        # ---------------------
        # Retrieve skip connection features from the last time step.
        
        #e1_last, e2_last = encoder_features[-1]
        #e3_last = enc3_processed[:, :, -1]  # -> (B, base_filters*4, D/4, H/4, W/4)
        decoder_features = []
        for t in range(T):
            e1_last, e2_last = encoder_features[t]
            bottle_final = bottle_processed[:, :, t]
            e3_last = enc3_processed[:, :, t]
            
            d3 = self.upconv3(bottle_final)   # -> (B, base_filters*4, D/4, H/4, W/4)
            d3 = torch.cat([d3, e3_last], dim=1)  # Concatenated channels: (B, base_filters*8, D/4, H/4, W/4)
            d3 = self.dec3(d3)  # -> (B, base_filters*4, D/4, H/4, W/4)
            
            d2 = self.upconv2(d3)             # -> (B, base_filters*2, D/2, H/2, W/2)
            d2 = torch.cat([d2, e2_last], dim=1)   # -> (B, base_filters*4, D/2, H/2, W/2)
            d2 = self.dec2(d2)  # -> (B, base_filters*2, D/2, H/2, W/2)
            
            d1 = self.upconv1(d2)             # -> (B, base_filters, D, H, W)
            d1 = torch.cat([d1, e1_last], dim=1)   # -> (B, base_filters*2, D, H, W)
            d1 = self.dec1(d1)  # -> (B, base_filters, D, H, W)
            
            d1 = self.final_conv(d1)         # -> (B, 1, D, H, W)
            decoder_features.append(d1)
        decoder_features = torch.stack(decoder_features, dim=1)
        return decoder_features

#### Pl Model

In [5]:
class HuberSSIMLoss2D(nn.Module):
    def __init__(self, alpha=0.7, delta=0.05, window_size=5, temporal_weight=0.1):
        super().__init__()
        # ------------------------------------------------------------------------------
        # ALPHA (Controls balance between Huber Loss and SSIM)
        # More alpha  → Model focuses on voxel accuracy (Huber loss)
        # Less alpha  → Model prioritizes structural similarity (SSIM)
        # ------------------------------------------------------------------------------
        # Recommended tuning:
        # - 0.6 - 0.9 → Noisy data (CT/MRI with artifacts) → More Huber
        # - 0.4 - 0.7 → Sharp structures (CT/MRI edges)  → Balance of both
        # - 0.3 - 0.5 → Blurry predictions → More SSIM for finer details
        # Default: alpha = 0.7 (Strong Huber, some SSIM)
        # ------------------------------------------------------------------------------
        self.alpha = alpha  # More alpha = More reliance on Huber
        # ------------------------------------------------------------------------------
        # DELTA (Threshold where Huber Loss switches from MSE-like to MAE-like behavior)
        # - Higher delta → More sensitive to small errors (acts like MSE)
        # - Lower delta  → More resistant to outliers (acts like MAE)
        # ------------------------------------------------------------------------------
        # Recommended tuning:
        # - > 0.05  → High-noise data (CT/MRI with artifacts) → More robust to outliers
        # - < 0.05  → Low-noise data (Well-normalized, synthetic) → More sensitive to details
        # - 0.02 - 0.05  → If predictions are too blurry
        # Default: delta = 0.05 (or dynamically adjusted per epoch)
        # ------------------------------------------------------------------------------
        self.delta = delta  # This will be dynamically updated every epoch
        # ------------------------------------------------------------------------------
        # TEMPORAL WEIGHT (Penalizes abrupt voxel intensity changes between frames)
        # - Higher value → Forces smoother transitions
        # - Lower value  → Allows more flexibility in voxel changes
        # ------------------------------------------------------------------------------
        # Recommended tuning:
        # - > 0.2  → Strong penalty for sudden intensity jumps (Flickering reduction)
        # - 0.05 - 0.1  → Best for smoothly changing sequences (MRI, Weather, Fluids)
        # - < 0.05  → Allows more dynamic changes (if model is too rigid)
        # Default: temporal_weight = 0.1 (balanced smoothness)
        # ------------------------------------------------------------------------------
        self.temporal_weight = temporal_weight  # More weight = Stronger smoothness enforcement

        #window_size means padding
        self.ssim_module = kornia.losses.SSIMLoss(window_size=window_size, reduction="mean")

    def temporal_smoothness_loss(self, y_pred):
        #Penalizes sudden changes over time by computing L1 loss between consecutive frames.
        #So the object should remain stationary over time
        return torch.mean(torch.abs(y_pred[:, 1:] - y_pred[:, :-1]))  # Difference between t and t+1

    def forward(self, y_pred, y_true):
        B, T, C, H, W = y_pred.shape

        # Swap time and channel dimensions: [B, C*T, H, W]
        y_pred_restructured = y_pred.permute(0, 2, 1, 3, 4).reshape(B, C * T, H, W)
        y_true_restructured = y_true.permute(0, 2, 1, 3, 4).reshape(B, C * T, H, W)

        # Computes the SSIM treating time as another spatial dimension
        ssim_loss = self.ssim_module(y_pred_restructured, y_true_restructured)

        # Computes the Huber loss with the adjusted delta
        huber_loss = F.huber_loss(y_pred, y_true, delta=self.delta, reduction="mean")

        # Compute the Temporal Smoothness Loss
        temporal_loss = self.temporal_smoothness_loss(y_pred)

        # Weighted combination
        total_loss = self.alpha * huber_loss + (1 - self.alpha) * ssim_loss + self.temporal_weight * temporal_loss
        return total_loss, huber_loss, ssim_loss, temporal_loss

class HuberSSIMLoss3D(nn.Module):
    def __init__(self, alpha=0.7, delta=0.05, window_size=5, temporal_weight=0.1):
        super().__init__()
        # ------------------------------------------------------------------------------
        # ALPHA (Controls balance between Huber Loss and SSIM)
        # More alpha  → Model focuses on voxel accuracy (Huber loss)
        # Less alpha  → Model prioritizes structural similarity (SSIM)
        # ------------------------------------------------------------------------------
        # Recommended tuning:
        # - 0.6 - 0.9 → Noisy data (CT/MRI with artifacts) → More Huber
        # - 0.4 - 0.7 → Sharp structures (CT/MRI edges)  → Balance of both
        # - 0.3 - 0.5 → Blurry predictions → More SSIM for finer details
        # Default: alpha = 0.7 (Strong Huber, some SSIM)
        # ------------------------------------------------------------------------------
        self.alpha = alpha  # More alpha = More reliance on Huber
        # ------------------------------------------------------------------------------
        # DELTA (Threshold where Huber Loss switches from MSE-like to MAE-like behavior)
        # - Higher delta → More sensitive to small errors (acts like MSE)
        # - Lower delta  → More resistant to outliers (acts like MAE)
        # ------------------------------------------------------------------------------
        # Recommended tuning:
        # - > 0.05  → High-noise data (CT/MRI with artifacts) → More robust to outliers
        # - < 0.05  → Low-noise data (Well-normalized, synthetic) → More sensitive to details
        # - 0.02 - 0.05  → If predictions are too blurry
        # Default: delta = 0.05 (or dynamically adjusted per epoch)
        # ------------------------------------------------------------------------------
        self.delta = delta  # This will be dynamically updated every epoch
        # ------------------------------------------------------------------------------
        # TEMPORAL WEIGHT (Penalizes abrupt voxel intensity changes between frames)
        # - Higher value → Forces smoother transitions
        # - Lower value  → Allows more flexibility in voxel changes
        # ------------------------------------------------------------------------------
        # Recommended tuning:
        # - > 0.2  → Strong penalty for sudden intensity jumps (Flickering reduction)
        # - 0.05 - 0.1  → Best for smoothly changing sequences (MRI, Weather, Fluids)
        # - < 0.05  → Allows more dynamic changes (if model is too rigid)
        # Default: temporal_weight = 0.1 (balanced smoothness)
        # ------------------------------------------------------------------------------
        self.temporal_weight = temporal_weight  # More weight = Stronger smoothness enforcement

        #window_size means padding
        self.ssim_module = kornia.losses.SSIM3DLoss(window_size=window_size, reduction="mean")

    def temporal_smoothness_loss(self, y_pred):
        #Penalizes sudden changes over time by computing L1 loss between consecutive frames.
        #So the object should remain stationary over time
        return torch.mean(torch.abs(y_pred[:, 1:] - y_pred[:, :-1]))  # Difference between t and t+1

    def forward(self, y_pred, y_true):
        B, T, C, D, H, W = y_pred.shape

        # Swap time and channel dimensions: [B, C*T, D, H, W]
        y_pred_restructured = y_pred.permute(0, 2, 1, 3, 4, 5).reshape(B, C * T, D, H, W)
        y_true_restructured = y_true.permute(0, 2, 1, 3, 4, 5).reshape(B, C * T, D, H, W)

        # Computes the SSIM treating time as another spatial dimension
        ssim_loss = self.ssim_module(y_pred_restructured, y_true_restructured)

        # Computes the Huber loss with the adjusted delta
        huber_loss = F.huber_loss(y_pred, y_true, delta=self.delta, reduction="mean")

        # Compute the Temporal Smoothness Loss
        temporal_loss = self.temporal_smoothness_loss(y_pred)

        # Weighted combination
        total_loss = self.alpha * huber_loss + (1 - self.alpha) * ssim_loss + self.temporal_weight * temporal_loss
        return total_loss, huber_loss, ssim_loss, temporal_loss

In [6]:
class Pl_Model(pl.LightningModule):
    def __init__(
        self,
        passed_model: nn.Module,
        config: Dict[str, Any],
    ):
        super(Pl_Model, self).__init__()
        self.passed_model = passed_model
        self.config = config
        # ------------------------------------------------------------------------------
        # DELTA FACTOR (Scales how much delta is updated per epoch)
        # - Higher delta_factor → More aggressive updates to delta
        # - Lower delta_factor  → Smoother, slower changes to delta
        # ------------------------------------------------------------------------------
        # Recommended tuning:
        # - > 1.5  → If delta is too unstable (jumps too much)
        # - 1.2 - 1.5  → Best for gradual adaptation (Default)
        # - < 1.2  → If delta changes too slowly (use for very stable datasets)
        # ------------------------------------------------------------------------------
        # Default: delta_factor = 1.2 (Balanced adaptation)
        # ------------------------------------------------------------------------------
        self.delta_factor = 1.2  # More factor = Faster delta adjustments
        self.delta = 0.05
        self.previous_delta = self.delta
        
        #speicher alle parameter ab
        self.save_hyperparameters(ignore=["passed_model"])

        # Setup training components
        self.mse_criterion = nn.MSELoss()
        self.psnr_criterion = PeakSignalNoiseRatio()
        self.huberssim3d_criterion = HuberSSIMLoss3D()
        self.huber_criterion = nn.HuberLoss()
        

    def forward(self, x):
        x = self.passed_model(x)
        #Tanh has a larger gradient range, reducing saturation issues compared to sigmoid.
        #Allows more stable gradient flow for deep networks.
        x = 0.5*(F.tanh(x)+1)
        return x

    def configure_optimizers(self):
        """Sets the Optimizer for the Model"""
        optimizer = optim.Adam(
            self.parameters(), 
            lr=config['learning_rate'],
        )
        return [optimizer]

    def _calculate_loss(self, batch, mode="train"):
        """Calculates the loss for a batch in different modes (training, validation, testing)"""
        inputs, targets = batch

        #forward pass
        mse_loss = 0.0
        huber_loss = 0.0
        rmse_loss = 0.0
        ssim_loss = 0.0
        huberssim_loss = 0.0
        temporal_loss = 0.0
        psnr_loss = 0.0
        total_loss = 0.0
        for t in range(0, self.config["pred_frames"], self.config["pred_n_frames_per_step"]):
            if self.config["pred_frames"]-t<self.config["pred_n_frames_per_step"]:
                frames_this_step = self.config["pred_frames"]-t
            else:
                frames_this_step = self.config["pred_n_frames_per_step"]
            outputs = self.forward(inputs)
            #print(f"{t}:{t+frames_this_step}")
            #get only the first predicted frame
            outputs = outputs[:, :frames_this_step, :, :, :, :]

            #calcualte losses
            mse_loss_ = self.mse_criterion(outputs, targets[:, t:t+frames_this_step, :, :, :, :])
            rmse_loss_ = torch.sqrt(self.mse_criterion(outputs, targets[:, t:t+frames_this_step, :, :, :, :]))
            psnr_loss_ = self.psnr_criterion(outputs, targets[:, t:t+frames_this_step, :, :, :, :])
            huberssim_loss_, huber_loss_, ssim_loss_, temporal_loss_ = self.huberssim3d_criterion(outputs, targets[:, t:t+frames_this_step, :, :, :, :])
            total_loss_ = huberssim_loss_ #self.huber_criterion(outputs, targets[:, t:t+frames_this_step, :, :])  
            
            mse_loss += mse_loss_
            huber_loss += huber_loss_
            rmse_loss += rmse_loss_
            ssim_loss += ssim_loss_
            huberssim_loss += huberssim_loss_
            temporal_loss += temporal_loss_
            psnr_loss += psnr_loss_
            total_loss += total_loss_
            
            inputs = torch.cat([inputs[:, self.config["pred_n_frames_per_step"]:, :, :, :, :], outputs], dim=1)

        #logging
        self.log(f"{mode}_mse_loss", mse_loss)
        self.log(f"{mode}_huber_loss", huber_loss)
        self.log(f"{mode}_rmse_loss", rmse_loss)
        self.log(f"{mode}_ssim_loss", ssim_loss)
        self.log(f"{mode}_huberssim_loss", huberssim_loss)
        self.log(f"{mode}_temporal_loss", temporal_loss)
        self.log(f"{mode}_psnr_loss", psnr_loss)
        self.log(f"{mode}_total_loss", total_loss, prog_bar=True)

        return total_loss

    def training_step(self, batch, batch_idx):
        loss = self._calculate_loss(batch, mode="train")
        return loss

    @torch.no_grad()
    def validation_step(self, batch, batch_idx):
        _ = self._calculate_loss(batch, mode="val")

        inputs, targets = batch
        outputs = []
        for t in range(0, self.config["pred_frames"], self.config["pred_n_frames_per_step"]):
            if self.config["pred_frames"]-t<self.config["pred_n_frames_per_step"]:
                frames_this_step = self.config["pred_frames"]-t
            else:
                frames_this_step = self.config["pred_n_frames_per_step"]
            outputs_t = self.forward(inputs)
            #print(f"{t}:{t+frames_this_step}")
            #get only the first predicted frame
            outputs_t = outputs_t[:, :frames_this_step, :, :, :, :]
            
            outputs.append(outputs_t)

            inputs = torch.cat([inputs[:, self.config["pred_n_frames_per_step"]:, :, :, :, :], outputs_t], dim=1)
            
            #concat time and add to overall lst
        outputs = torch.concat(outputs, dim=1)
        
        #calculate losses
        mse_loss = self.mse_criterion(outputs, targets)
        rmse_loss = torch.sqrt(self.mse_criterion(outputs, targets))
        huberssim_loss, huber_loss, ssim_loss, temporal_loss = self.huberssim3d_criterion(outputs, targets)
        psnr_loss = self.psnr_criterion(outputs, targets)
        total_loss = huberssim_loss

        #logging
        self.log(f"overall_val_mse_loss", mse_loss)
        self.log(f"overall_val_huber_loss", huber_loss)
        self.log(f"overall_val_rmse_loss", rmse_loss)
        self.log(f"overall_val_ssim_loss", ssim_loss)
        self.log(f"overall_val_huberssim_loss", huberssim_loss)
        self.log(f"overall_val_temporal_loss", temporal_loss)
        self.log(f"overall_val_psnr_loss", psnr_loss)
        self.log(f"overall_val_total_loss", total_loss)

    @torch.no_grad()
    def test_step(self, batch, batch_idx):
        _ = self._calculate_loss(batch, mode="test")

    def on_train_epoch_end(self):
        #adjust delta
        val_loader = self.trainer.datamodule.val_dataloader()
        all_errors = []
    
        with torch.no_grad():  
            for batch in val_loader:
                x, y = batch
                y_pred = self(x.to(self.device))
                error = torch.abs(y.to(self.device) - y_pred)
                all_errors.append(error.view(-1))
    
        all_errors = torch.cat(all_errors)
        new_delta = self.delta_factor * torch.std(all_errors).item()

        #Blend previous and new delta for smoother updates
        #is capped between 0.02 and 0.35 so that is the data is too noisy huber does not just become mse
        new_delta = min(0.5, max(0.02, 0.8 * self.previous_delta + 0.2 * new_delta))
        self.previous_delta = new_delta

        #update
        self.huberssim3d_criterion.delta = new_delta
    
        #logging
        self.log("delta", new_delta)
        
    @torch.no_grad()
    def check_losses(self, loader, mode, use_wandb=False):
        mse_loss = 0.0
        huber_loss = 0.0
        rmse_loss = 0.0
        ssim_loss = 0.0
        huberssim_loss = 0.0
        temporal_loss = 0.0
        psnr_loss = 0.0
        total_loss = 0.0
        for inputs, targets in loader:
            for t in range(self.config["pred_frames"]):
                mse_loss_ = self.mse_criterion(inputs[:, -1, :, :, :].unsqueeze(1), targets[:, t, :, :].unsqueeze(1))
                huberssim_loss_, huber_loss_, ssim_loss_, temporal_loss_ = self.huberssim3d_criterion(inputs[:, -1, :, :, :].unsqueeze(1), targets[:, t, :, :].unsqueeze(1))
                rmse_loss_ = torch.sqrt(self.mse_criterion(inputs[:, -1, :, :, :].unsqueeze(1), targets[:, t, :, :].unsqueeze(1)))
                psnr_loss_ = self.psnr_criterion(inputs[:, -1, :, :, :].unsqueeze(1), targets[:, t, :, :].unsqueeze(1))
                total_loss_ = huberssim_loss_   
                
                mse_loss += mse_loss_.item()
                huber_loss += huber_loss_.item()
                rmse_loss += rmse_loss_.item()
                ssim_loss += ssim_loss_.item()
                huberssim_loss += huberssim_loss_.item()
                temporal_loss += temporal_loss_.item()
                psnr_loss += psnr_loss_.item()
                total_loss += total_loss_.item()
                
        mse_loss = mse_loss / len(loader)
        huber_loss = huber_loss / len(loader)
        rmse_loss = rmse_loss / len(loader)
        ssim_loss = ssim_loss / len(loader)
        huberssim_loss = huberssim_loss / len(loader)
        temporal_loss = temporal_loss / len(loader)
        psnr_loss = psnr_loss / len(loader)
        total_loss = total_loss / len(loader)

        if use_wandb:
            wandb.log({f"Checked_{mode}_mse_loss": mse_loss})
            wandb.log({f"Checked_{mode}_huber_loss": huber_loss})
            wandb.log({f"Checked_{mode}_rmse_loss": rmse_loss})
            wandb.log({f"Checked_{mode}_ssim_loss": ssim_loss})
            wandb.log({f"Checked_{mode}_huberssim_loss": huberssim_loss})
            wandb.log({f"Checked_{mode}_temporal_loss": temporal_loss})
            wandb.log({f"Checked_{mode}_psnr_loss": psnr_loss})
            wandb.log({f"Checked_{mode}_total_loss": total_loss})
        
        return mse_loss, huber_loss, ssim_loss, huberssim_loss, temporal_loss, rmse_loss, psnr_loss, total_loss
        
    def log_predictions(self):
        """Log example predictions to wandb"""
        #needs to be added to other method
        if epoch % self.config['viz_interval'] == 0:
                self.log_predictions()
        #but this whole method needs to be rewritten
        self.model.eval()
        with torch.no_grad():
            # Get a batch of validation data
            data, target = next(iter(self.val_loader))
            data = data.to(self.device)
            target = target.to(self.device)
            
            # Generate predictions
            output = self.model(data)
            
            # Log images
            wandb.log({
                "predictions": wandb.Image(output[0, 0].cpu()),
                "targets": wandb.Image(target[0, 0].cpu()),
                "input_sequence": [wandb.Image(data[0, i].cpu()) for i in range(data.shape[1])]
            })

#### Config

In [7]:
config = {
    #for the dataloaders
    "root": "../NormalizedQualityFiltered",
    'batch_size': 1,
    'learning_rate': 1e-4,
    "num_workers": 10,#0, wenn die gpu nicht benutzt wird
    "pin_memory": True if torch.cuda.is_available() else False,#False, wenn die gpu nicht benutzt wird
    "drop_last": False,
    'epochs': 40,
    #'log_interval': 20,
    #'viz_interval': 1,
    'run_name': '3D-UNet+_temp',
    'input_frames': 9,
    "pred_frames": 9,
    "pred_n_frames_per_step": 9,
    'base_filters': 16,
    "train_split": 0.7,
    "val_split": 0.15,
    "test_split": 0.15,
}
config["run_name"] += f"_{config['pred_frames']}"
if config["pred_frames"] == config["pred_n_frames_per_step"]:
    config["run_name"] += "_NAR"
elif config["pred_n_frames_per_step"] == 1:
    config["run_name"] += "_FAR"
else:
    config["run_name"] += f"_PAR_{config['pred_n_frames_per_step']}"


# Get data loaders
"""train_loader, val_loader, test_loader = get_data_loaders(
    batch_size=config['batch_size'],
    num_workers=config["num_workers"],
    pin_memory=config["pin_memory"],
    drop_last=config["drop_last"],
    sequence_length=config["input_frames"], 
    prediction_length=config["pred_frames"],
)"""
dm = VolumeDataModule(
    root=config["root"],
    batch_size=config['batch_size'],
    num_workers=config["num_workers"],
    pin_memory=config["pin_memory"],
    drop_last=config["drop_last"],
    sequence_length=config["input_frames"],
    prediction_length=config["pred_frames"],
    train_split=config["train_split"],
    val_split=config["val_split"],
    test_split=config["test_split"],
)
wandb_logger = WandbLogger(entity="ChadCTP", project="perfusion-ct-prediction", name=config["run_name"])

# Initialize model for tuning
model = UNet3DPlusTemporal(
    input_frames=config['input_frames'],
    base_filters=config['base_filters'],
)

# Initialize pl_model for tuning
pl_model = Pl_Model(
    passed_model=model,
    config=config,
)

checkpoint_callback = ModelCheckpoint(
    monitor="val_total_loss",  
    mode="min",  
    save_top_k=1,
    filename="best-checkpoint",
    verbose=True,
)

# Initialize trainer for tuning
trainer = pl.Trainer(
    logger=wandb_logger,
    accelerator="gpu",
    devices= [1] if torch.cuda.is_available() else None,
    max_epochs=config["epochs"],
    callbacks=[RichProgressBar(), checkpoint_callback],
    check_val_every_n_epoch=5,
)

#wandb_logger.watch(pl_model)

#tuning
#tuner = Tuner(trainer)
#tuner.scale_batch_size(pl_model, datamodule=dm, mode="binsearch")

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


In [None]:
trainer.fit(
    model=pl_model,
    datamodule=dm,
)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2]


Output()

Epoch 4, global step 415: 'val_total_loss' reached 0.24342 (best 0.24342), saving model to './perfusion-ct-prediction/uouhw2tu/checkpoints/best-checkpoint.ckpt' as top 1


In [None]:
#check and log the losses "to beat"
dm.setup()
pl_model.check_losses(dm.train_dataloader(), mode="train", use_wandb=True)
pl_model.check_losses(dm.val_dataloader(), mode="val", use_wandb=True)
pl_model.check_losses(dm.test_dataloader(), mode="test", use_wandb=True)

In [None]:
val_results = trainer.validate(pl_model, datamodule=dm)

In [None]:
test_results = trainer.test(pl_model, datamodule=dm)

In [None]:
save_load_path = f"../ModelWeights/{config['run_name']}.ckpt"
trainer.save_checkpoint(save_load_path)

In [8]:


#testing
pl_model = Pl_Model.load_from_checkpoint(
    "./perfusion-ct-prediction/uouhw2tu/checkpoints/best-checkpoint.ckpt",
    passed_model=model,
)

dm.setup()
test_dataloader = dm.test_dataloader()

print(len(test_dataloader))
pl_model.to(device)
"""for inputs, targets in test_dataloader:
    outputs = pl_model.forward(inputs.to(device))
    outputs = outputs.detach().cpu()
    outputs = torch.concat([inputs, outputs], dim=1).squeeze(0, 2).numpy()
    targets = torch.concat([inputs, targets], dim=1).squeeze(0, 2).numpy()
    print(outputs.shape)
    print(targets.shape)
    break"""
# method 1
#inputs, targets = next(iter(test_dataloader))

# method 2
# only for single shot
vol = torch.tensor(np.load("../NormalizedQualityFiltered/MOL-061.npy")).unsqueeze(1)
inputs = vol[0:9].unsqueeze(0)
targets = vol[9:].unsqueeze(0)

outputs = pl_model.forward(inputs.to(device))
outputs = outputs.detach().cpu()
outputs = torch.concat([inputs, outputs], dim=1).squeeze(0, 2).numpy()
targets = torch.concat([inputs, targets], dim=1).squeeze(0, 2).numpy()
print(outputs.shape)
print(targets.shape)

np.save(f"outputs_{config["run_name"]}.npy", outputs)
np.save(f"targets_{config["run_name"]}.npy", targets)

19
(18, 16, 256, 256)
(18, 16, 256, 256)


In [None]:
@torch.no_grad()
def overall_loss(model, loader, device):
    mse_loss = 0.0
    huber_loss = 0.0
    rmse_loss = 0.0
    #ssim_loss = 0.0
    psnr_loss = 0.0
    total_loss = 0.0
    model = model.to(device)
    for inputs, targets in loader:
        print(targets.shape)
        inputs = inputs.to(device)
        targets = targets.to(device)
        outputs = []
        for t in range(0, model.config["pred_frames"], model.config["pred_n_frames_per_step"]):
            if model.config["pred_frames"]-t<model.config["pred_n_frames_per_step"]:
                frames_this_step = model.config["pred_frames"]-t
            else:
                frames_this_step = model.config["pred_n_frames_per_step"]
            outputs_t = model.forward(inputs)
            #print(f"{t}:{t+frames_this_step}")
            #get only the first predicted frame
            outputs_t = outputs_t[:, :frames_this_step, :, :, :, :]
            
            outputs.append(outputs_t)

            inputs = torch.cat([inputs[:, model.config["pred_n_frames_per_step"]:, :, :, :, :], outputs_t], dim=1)
            
            #concat time and add to overall lst
        outputs = torch.concat(outputs, dim=1)
        
        #calculate losses
        mse_loss += model.mse_criterion(outputs, targets).item()
        huber_loss += model.huber_criterion(outputs, targets).item()
        rmse_loss += torch.sqrt(model.mse_criterion(outputs, targets)).item()
        #ssim_loss = model.ssim_criterion(outputs, targets).item()
        psnr_loss += model.psnr_criterion(outputs, targets).item()
        total_loss += mse_loss + 0.5 * huber_loss

    mse_loss = mse_loss / len(loader)
    huber_loss = huber_loss / len(loader)
    rmse_loss = rmse_loss / len(loader)
    #ssim_loss = ssim_loss / len(loader)
    psnr_loss = psnr_loss / len(loader)
    total_loss = total_loss / len(loader)

    return outputs, mse_loss, huber_loss, rmse_loss, psnr_loss, total_loss

dm.setup()
_, mse_loss, huber_loss, rmse_loss, psnr_loss, total_loss = overall_loss(model=pl_model, loader=dm.test_dataloader(), device=device)
mse_loss, huber_loss, rmse_loss, psnr_loss, total_loss

In [None]:
def multi_vol_seq_interactive(volume_seqs, titles=None):
    """
    Interactive plot of multiple volume sequences using ipywidgets
    
    Parameters:
    - volume_seqs: List of 4D volume sequences to display
    - titles: Optional list of titles for each sequence
    """
    print(len(volume_seqs))
    if titles is None:
        titles = [f"Volume {i+1}" for i in range(len(volume_seqs))]
        
    num_volumes = len(volume_seqs)
    nrows = int(num_volumes ** 0.5)
    ncols = (num_volumes + nrows - 1) // nrows
    
    def plot_volumes(time_idx, slice_idx):
        fig, axes = plt.subplots(nrows, ncols, 
                                figsize=(5*ncols, 5*nrows),
                                squeeze=True)
        if nrows == 1:
            if ncols == 1:
                axes = [[axes]]
            else:
                axes = [axes]
                
        for i, (volume_seq, title) in enumerate(zip(volume_seqs, titles)):
            row, col = i // ncols, i % ncols
            ax = axes[row][col]
            
            t = min(time_idx, len(volume_seq) - 1)
            s = min(slice_idx, len(volume_seq[t]) - 1)
            
            im = ax.imshow(volume_seq[t][s], cmap='magma')
            ax.set_title(title)
            plt.colorbar(im, ax=ax)
            
        plt.tight_layout()
        plt.show(block=True)
        
    max_time = max(len(vol) for vol in volume_seqs) - 1
    max_slice = max(len(vol[0]) for vol in volume_seqs) - 1
    
    interact(
        plot_volumes,
        time_idx=IntSlider(min=0, max=max_time, step=1, value=0, description='Time:'),
        slice_idx=IntSlider(min=0, max=max_slice, step=1, value=0, description='Slice:')
    )

multi_vol_seq_interactive([outputs, targets])