In [1]:
import numpy as np

def load_data(path: str):
    return np.load(path)

v1 = load_data('../selected_volumes/MOL-001.npy')

In [2]:
v1.shape

(18, 16, 256, 256)

In [3]:
import matplotlib.pyplot as plt
from ipywidgets import IntSlider, interact
def multi_vol_seq_interactive(volume_seqs, titles=None):
    """
    Interactive plot of multiple volume sequences using ipywidgets
    
    Parameters:
    - volume_seqs: List of 4D volume sequences to display
    - titles: Optional list of titles for each sequence
    """
    print(len(volume_seqs))
    if titles is None:
        titles = [f"Volume {i+1}" for i in range(len(volume_seqs))]
        
    num_volumes = len(volume_seqs)
    nrows = int(num_volumes ** 0.5)
    ncols = (num_volumes + nrows - 1) // nrows
    
    def plot_volumes(time_idx, slice_idx):
        fig, axes = plt.subplots(nrows, ncols, 
                                figsize=(5*ncols, 5*nrows),
                                squeeze=True)
        if nrows == 1:
            if ncols == 1:
                axes = [[axes]]
            else:
                axes = [axes]
                
        for i, (volume_seq, title) in enumerate(zip(volume_seqs, titles)):
            row, col = i // ncols, i % ncols
            ax = axes[row][col]
            
            t = min(time_idx, len(volume_seq) - 1)
            s = min(slice_idx, len(volume_seq[t]) - 1)
            
            im = ax.imshow(volume_seq[t][s], cmap='magma')
            ax.set_title(title)
            plt.colorbar(im, ax=ax)
            
        plt.tight_layout()
        plt.show(block=True)
        
    max_time = max(len(vol) for vol in volume_seqs) - 1
    max_slice = max(len(vol[0]) for vol in volume_seqs) - 1
    
    interact(
        plot_volumes,
        time_idx=IntSlider(min=0, max=max_time, step=1, value=0, description='Time:'),
        slice_idx=IntSlider(min=0, max=max_slice, step=1, value=0, description='Slice:')
    )



In [12]:
multi_vol_seq_interactive([v1])

1


interactive(children=(IntSlider(value=0, description='Time:', max=17), IntSlider(value=0, description='Slice:'…

In [13]:
v1.shape

(18, 16, 256, 256)

In [4]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
import os
import math
import torchvision
from torchvision import datasets, transforms
from sklearn.metrics import confusion_matrix
from sklearn.metrics import classification_report
import random
import pytorch_lightning as pl
from torch.nn import Parameter

In [5]:
device = torch.device('mps') or ('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='mps')

## Dataset 1:
- load the data as 

In [6]:
class Dataset2D(Dataset):
    def __init__(self, data_paths, context_window=4, transform=None):
        self.data_paths = data_paths
        self.context_window = context_window
        self.transform = transform
        self.samples = []
        # For every path to a volume sequence in .npy
        for data_path in self.data_paths:
            volume_seq = np.load(data_path)
            # Convert to tensor
            volume_seq = torch.from_numpy(volume_seq)
            for h in range(volume_seq.shape[1]):
                # Generate samples
                for t in range(len(volume_seq) - self.context_window):
                    # Input volume sequence (context_window x 1 x 256 x 256), target volume (1 x 1 x 256 x 256)
                    self.samples.append((volume_seq[t:t+self.context_window, h], 
                                         volume_seq[t+self.context_window:t+self.context_window+1, h]))
        
    def __len__(self):
        return len(self.samples)
    
    def __getitem__(self, idx):

        return self.samples[idx]

In [7]:
root = '../selected_volumes'
data_paths = [os.path.join(root, path) for path in os.listdir(root)]
d = Dataset2D(data_paths)
len(d)

2240

In [50]:
root = '../selected_volumes'
device = torch.device('mps') or ('cuda' if torch.cuda.is_available() else 'cpu')
# Training parameters
batch_size = 4
sequence_length = 8
learning_rate = 1e-4
num_epochs = 100
# Data parameters
train_split = 0.8
val_split = 0.1


data_paths = [os.path.join(root, path) for path in os.listdir(root)]
dataset = Dataset2D(data_paths, context_window=4)
print(len(data_paths))
def get_data_loaders(batch_size=4, sequence_length=4):
    # Load all folder paths
    # Split into train/val/test
    n_train = int(len(data_paths) * train_split)
    n_val = int(len(data_paths) * val_split)
    train_paths = data_paths[:n_train]
    val_paths = data_paths[n_train:n_train+n_val]
    test_paths = data_paths[n_train+n_val:]
    
    # Create datasets
    train_dataset = Dataset2D(train_paths, sequence_length)
    val_dataset = Dataset2D(val_paths, sequence_length)
    test_dataset = Dataset2D(test_paths, sequence_length)
    
    # Create dataloaders
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size)
    test_loader = DataLoader(test_dataset, batch_size=batch_size)
    
    return train_loader, val_loader, test_loader

10


In [51]:

train_loader, val_loader, test_loader = get_data_loaders(batch_size=batch_size, sequence_length=sequence_length)

In [52]:
for i, (input, target) in enumerate(train_loader):
    print(i, input.shape, target.shape)
    break


0 torch.Size([4, 8, 256, 256]) torch.Size([4, 1, 256, 256])


2Plus1D Unet

In [53]:
class DoubleConv2D(nn.Module):
    def __init__(self, in_channels, out_channels, mid_channels=None):
        super().__init__()
        
        if not mid_channels:
            mid_channels = out_channels
            
        self.double_conv = nn.Sequential(
            # First convolution
            nn.Conv2d(
                in_channels, 
                mid_channels,
                kernel_size=3,
                padding=1,
                bias=False  # No bias when using batch norm
            ),
            nn.BatchNorm2d(mid_channels),
            nn.ReLU(inplace=True),
            
            # Second convolution
            nn.Conv2d(
                mid_channels,
                out_channels,
                kernel_size=3,
                padding=1,
                bias=False  # No bias when using batch norm
            ),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )
        
        # Optional: Add residual connection if input and output channels match
        self.use_residual = in_channels == out_channels
        
    def forward(self, x):
        if self.use_residual:
            return self.double_conv(x) + x
        return self.double_conv(x)
    
class TemporalBlock(nn.Module):
    def __init__(self, channels, temporal_kernel_size=3):
        super().__init__()
        
        padding = temporal_kernel_size // 2
        
        self.temporal_conv = nn.Sequential(
            # Depthwise temporal conv
            nn.Conv1d(channels, channels,
                     kernel_size=temporal_kernel_size,
                     padding=padding,
                     groups=channels),
            nn.BatchNorm1d(channels),
            nn.ReLU(inplace=True),
            
            # Point-wise conv
            nn.Conv1d(channels, channels, kernel_size=1),
            nn.BatchNorm1d(channels),
            nn.ReLU(inplace=True)
        )
    
    def forward(self, x):
        # x shape: [B, C, T, H, W]
        b, c, t, h, w = x.shape
        
        # Reshape for temporal convolution
        x_temp = x.contiguous()  # Make memory contiguous
        x_temp = x_temp.permute(0, 3, 4, 1, 2)  # [B, H, W, C, T]
        x_temp = x_temp.reshape(b*h*w, c, t)
        
        # Apply temporal convolution
        x_temp = self.temporal_conv(x_temp)
        
        # Reshape back
        x_temp = x_temp.reshape(b, h, w, c, t)
        x_temp = x_temp.permute(0, 3, 4, 1, 2)  # [B, C, T, H, W]
        
        return x_temp


class UNet2DPlusTemporal(nn.Module):
    def __init__(self, input_frames=None, output_frames=8, in_channels=1, base_filters=32):
        super(UNet2DPlusTemporal, self).__init__()
        
        self.input_frames = input_frames
        
        # Update first encoder to accept in_channels instead of hardcoded 1
        self.enc1 = DoubleConv2D(in_channels, base_filters)
        self.pool1 = nn.MaxPool2d(2)
        
        self.enc2 = DoubleConv2D(base_filters, base_filters*2)
        self.pool2 = nn.MaxPool2d(2)
        
        # Last encoder with temporal processing
        self.enc3 = DoubleConv2D(base_filters*2, base_filters*4)
        self.temporal_enc = TemporalBlock(
            channels=base_filters*4,
            temporal_kernel_size=3
        )
        self.pool3 = nn.MaxPool2d(2)
        
        # Bottleneck with temporal processing
        self.bottleneck_spatial = DoubleConv2D(base_filters*4, base_filters*8)
        self.temporal_bottleneck = TemporalBlock(
            channels=base_filters*8,
            temporal_kernel_size=3
        )
        
        # Decoder Path
        self.upconv3 = nn.ConvTranspose2d(base_filters*8, base_filters*4, 
                                         kernel_size=2, stride=2)
        self.dec3 = DoubleConv2D(base_filters*8, base_filters*4)
        
        self.upconv2 = nn.ConvTranspose2d(base_filters*4, base_filters*2, 
                                         kernel_size=2, stride=2)
        self.dec2 = DoubleConv2D(base_filters*4, base_filters*2)
        
        self.upconv1 = nn.ConvTranspose2d(base_filters*2, base_filters, 
                                         kernel_size=2, stride=2)
        self.dec1 = DoubleConv2D(base_filters*2, base_filters)
        
        self.final_conv = nn.Conv2d(base_filters, output_frames, kernel_size=1)
    
    def forward(self, x):
        # x shape: [batch, time, height, width]
        b, t, c, h, w = x.shape
        assert t == self.input_frames, f"Expected {self.input_frames} frames, got {t}"
        
        # Process each time step through initial spatial encoders
        encoder_features = []
        enc3_features = []
        
        for i in range(t):
            curr_frame = x[:, i]  # [B, C, H, W] - removed unsqueeze since we already have channels
            
            # Initial encoder path
            e1 = self.enc1(curr_frame)
            p1 = self.pool1(e1)
            
            e2 = self.enc2(p1)
            p2 = self.pool2(e2)
            
            # Store for skip connections
            encoder_features.append((e1, e2))
            
            # Last encoder
            e3 = self.enc3(p2)
            enc3_features.append(e3)
        
        # Process enc3 features temporally
        enc3_features = torch.stack(enc3_features, dim=2)  # [B, C, T, H, W]
        enc3_processed = self.temporal_enc(enc3_features)
        
        # Pool spatially after temporal processing
        b, c, t, h, w = enc3_processed.shape
        enc3_pooled = enc3_processed.contiguous()
        enc3_pooled = enc3_pooled.view(b*t, c, h, w)
        enc3_pooled = self.pool3(enc3_pooled)
        _, _, h_pooled, w_pooled = enc3_pooled.shape
        enc3_pooled = enc3_pooled.view(b, c, t, h_pooled, w_pooled)
        
        # Bottleneck processing
        bottle_features = []
        for i in range(t):
            curr_feat = enc3_pooled[:, :, i]  # [B, C, H, W]
            bottle_feat = self.bottleneck_spatial(curr_feat)
            bottle_features.append(bottle_feat)
        
        # Stack and apply temporal processing in bottleneck
        bottle_features = torch.stack(bottle_features, dim=2)  # [B, C, T, H, W]
        bottle_processed = self.temporal_bottleneck(bottle_features)
        
        # Take last temporal state for decoder
        bottle_final = bottle_processed[:, :, -1]  # [B, C, H, W]
        
        # Decoder path (using last frame's encoder features)
        e1, e2 = encoder_features[-1]
        e3 = enc3_processed[:, :, -1]
        
        d3 = self.upconv3(bottle_final)
        d3 = torch.cat([d3, e3], dim=1)
        d3 = self.dec3(d3)
        
        d2 = self.upconv2(d3)
        d2 = torch.cat([d2, e2], dim=1)
        d2 = self.dec2(d2)
        
        d1 = self.upconv1(d2)
        d1 = torch.cat([d1, e1], dim=1)
        d1 = self.dec1(d1)
        
        return self.final_conv(d1) # [B, output_frames, H, W]

Using device: mps

Testing with T=8
Input shape: torch.Size([2, 8, 256, 256])
Output shape: torch.Size([2, 1, 256, 256])
Test passed!

Testing with T=16
Input shape: torch.Size([2, 16, 256, 256])
Output shape: torch.Size([2, 1, 256, 256])
Test passed!


In [32]:
import wandb
wandb.login()

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize
[34m[1mwandb[0m: Paste an API key from your profile and hit enter, or press ctrl+c to quit:[34m[1mwandb[0m: Paste an API key from your profile and hit enter, or press ctrl+c to quit:[34m[1mwandb[0m: Paste an API key from your profile and hit enter, or press ctrl+c to quit:[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /Users/simonma/.netrc


True

In [55]:
from typing import Dict, Any


class WandbTrainer:
    def __init__(
        self,
        model: nn.Module,
        train_loader: DataLoader,
        val_loader: DataLoader,
        config: Dict[str, Any],
        project_name: str = "perfusion-ct-prediction"
    ):
        self.model = model
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.config = config
        
        # Initialize wandb
        wandb.init(
            project=project_name,
            config=config,
            name=config.get('run_name', None)
        )
        
        # Setup training components
        self.criterion = nn.MSELoss()
        self.huber_loss = nn.HuberLoss(delta=1.0)
        self.optimizer = torch.optim.Adam(
            model.parameters(), 
            lr=config['learning_rate']
        )
        
        # Move model to device
        self.device = torch.device('mps' if torch.backends.mps.is_available() else 'cuda' if torch.cuda.is_available() else 'cpu')
        self.model = self.model.to(self.device)
        
    def train_epoch(self, epoch: int):
        self.model.train()
        total_loss = 0
        
        for batch_idx, (data, target) in enumerate(self.train_loader):
            # Move to device
            data = data.to(self.device)
            target = target.to(self.device)
            
            # Forward pass
            self.optimizer.zero_grad()
            output = self.model(data)
            
            # Calculate losses
            mse_loss = self.criterion(output, target)
            huber_loss = self.huber_loss(output, target)
            loss = mse_loss + 0.5 * huber_loss
            
            # Backward pass
            loss.backward()
            self.optimizer.step()
            
            total_loss += loss.item()
            
            # Log batch metrics
            if batch_idx % self.config['log_interval'] == 0:
                wandb.log({
                    'batch': batch_idx,
                    'batch_loss': loss.item(),
                    'batch_mse': mse_loss.item(),
                    'batch_huber': huber_loss.item()
                })
        
        # Log epoch metrics
        avg_loss = total_loss / len(self.train_loader)
        wandb.log({
            'epoch': epoch,
            'train_loss': avg_loss
        })
        
        return avg_loss
    
    @torch.no_grad()
    def validate(self, epoch: int):
        self.model.eval()
        val_loss = 0
        
        for data, target in self.val_loader:
            data = data.to(self.device)
            target = target.to(self.device)
            
            output = self.model(data)
            val_loss += self.criterion(output, target).item()
            
        val_loss /= len(self.val_loader)
        
        # Log validation metrics
        wandb.log({
            'epoch': epoch,
            'val_loss': val_loss
        })
        
        return val_loss
    
    def train(self):
        best_val_loss = float('inf')
        
        for epoch in range(self.config['epochs']):
            # Training
            train_loss = self.train_epoch(epoch)
            
            # Validation
            val_loss = self.validate(epoch)
            
            # Save best model
            if val_loss < best_val_loss:
                best_val_loss = val_loss
                torch.save(self.model.state_dict(), f"{wandb.run.dir}/best_model.pt")
                wandb.save(f"{wandb.run.dir}/best_model.pt")
            
            # Log example predictions
            if epoch % self.config['viz_interval'] == 0:
                self.log_predictions()
    
    def log_predictions(self):
        """Log example predictions to wandb"""
        self.model.eval()
        with torch.no_grad():
            # Get a batch of validation data
            data, target = next(iter(self.val_loader))
            data = data.to(self.device)
            target = target.to(self.device)
            
            # Generate predictions
            output = self.model(data)
            
            # Log images
            wandb.log({
                "predictions": wandb.Image(output[0, 0].cpu()),
                "targets": wandb.Image(target[0, 0].cpu()),
                "input_sequence": [wandb.Image(data[0, i].cpu()) for i in range(data.shape[1])]
            })


# Usage example:
def train_with_wandb():
    # Configuration
    config = {
        'batch_size': 4,
        'learning_rate': 1e-4,
        'epochs': 10,
        'log_interval': 20,
        'viz_interval': 1,
        'run_name': 'unet2Dplus_temporal',
        'model_type': 'UNet2DPlusTemporal',
        'input_frames': 8,
        'output_frames': 1,
        'base_filters': 32
    }
    
    # Initialize model
    model = UNet2DPlusTemporal(
        input_frames=config['input_frames'],
        base_filters=config['base_filters']
    )
    
    # Get data loaders
    train_loader, val_loader, test_loader = get_data_loaders(
        batch_size=config['batch_size'],
        sequence_length=config['input_frames']
    )
    
    # Initialize trainer
    trainer = WandbTrainer(
        model=model,
        train_loader=train_loader,
        val_loader=val_loader,
        config=config,
        project_name="perfusion-ct-prediction"
    )
    
    # Train
    trainer.train()
    
    # Close wandb run
    wandb.finish()


if __name__ == "__main__":
    train_with_wandb()

VBox(children=(Label(value='0.011 MB of 0.011 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
batch,▁▂▄▅▇█
batch_huber,█▂▁▁▁▁
batch_loss,█▂▁▁▁▁
batch_mse,█▂▁▁▁▁

0,1
batch,100.0
batch_huber,0.04715
batch_loss,0.12015
batch_mse,0.09657




KeyboardInterrupt: 

This is working. 04.01.25

In [20]:
class DoubleConv3D(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        
        self.double_conv = nn.Sequential(
            # First 3D convolution
            nn.Conv3d(in_channels, out_channels, 
                     kernel_size=3, padding=1, bias=False),
            nn.BatchNorm3d(out_channels),
            nn.ReLU(inplace=True),
            
            # Second 3D convolution
            nn.Conv3d(out_channels, out_channels, 
                     kernel_size=3, padding=1, bias=False),
            nn.BatchNorm3d(out_channels),
            nn.ReLU(inplace=True)
        )
        
    def forward(self, x):
        return self.double_conv(x)

class UNet3DTemporal(nn.Module):
    def __init__(self, in_channels=1, base_filters=32, input_frames=None):
        super().__init__()
        
        self.input_frames = input_frames
        
        # Encoder Path
        self.enc1 = DoubleConv3D(1, base_filters)
        # [B, 1, T, 256, 256] -> [B, 32, T, 256, 256]
        
        self.pool1 = nn.MaxPool3d(kernel_size=(2, 2, 2), stride=(2, 2, 2))
        # [B, 32, T, 256, 256] -> [B, 32, T/2, 128, 128]
        
        self.enc2 = DoubleConv3D(base_filters, base_filters*2)
        # [B, 32, T/2, 128, 128] -> [B, 64, T/2, 128, 128]
        
        self.pool2 = nn.MaxPool3d(kernel_size=(2, 2, 2), stride=(2, 2, 2))
        # [B, 64, T/2, 128, 128] -> [B, 64, T/4, 64, 64]
        
        self.enc3 = DoubleConv3D(base_filters*2, base_filters*4)
        # [B, 64, T/4, 64, 64] -> [B, 128, T/4, 64, 64]
        
        self.pool3 = nn.MaxPool3d(kernel_size=(2, 2, 2), stride=(2, 2, 2))
        # [B, 128, T/4, 64, 64] -> [B, 128, T/8, 32, 32]
        
        # Bottleneck
        self.bottleneck = DoubleConv3D(base_filters*4, base_filters*8)
        # [B, 128, T/8, 32, 32] -> [B, 256, T/8, 32, 32]
        
        # Decoder Path
        self.upconv3 = nn.ConvTranspose3d(
            base_filters*8, base_filters*4,
            kernel_size=2, stride=2
        )
        # [B, 256, T/8, 32, 32] -> [B, 128, T/4, 64, 64]
        
        self.dec3 = DoubleConv3D(base_filters*8, base_filters*4)
        
        self.upconv2 = nn.ConvTranspose3d(
            base_filters*4, base_filters*2,
            kernel_size=2, stride=2
        )
        # [B, 128, T/4, 64, 64] -> [B, 64, T/2, 128, 128]
        
        self.dec2 = DoubleConv3D(base_filters*4, base_filters*2)
        
        self.upconv1 = nn.ConvTranspose3d(
            base_filters*2, base_filters,
            kernel_size=2, stride=2
        )
        # [B, 64, T/2, 128, 128] -> [B, 32, T, 256, 256]
        
        self.dec1 = DoubleConv3D(base_filters*2, base_filters)
        
        # Final temporal reduction
        self.final_temporal_conv = nn.Conv3d(
            base_filters, base_filters,
            kernel_size=(input_frames, 1, 1),
            stride=(1, 1, 1)
        )
        
        self.final_conv = nn.Conv3d(base_filters, 1, kernel_size=1)
        
        self.instance_norm = nn.InstanceNorm3d(1)
        
    def forward(self, x):
        # Input: [B, T, H, W]
        b, t, h, w = x.shape
        assert t == self.input_frames, f"Expected {self.input_frames} frames, got {t}"
        
        # Add channel dimension and normalize
        x = x.unsqueeze(1)  # [B, 1, T, 256, 256]
        x = self.instance_norm(x)
        
        # Encoder Path with skip connections
        e1 = self.enc1(x)         # [B, 32, T, 256, 256]
        p1 = self.pool1(e1)       # [B, 32, T/2, 128, 128]
        
        e2 = self.enc2(p1)        # [B, 64, T/2, 128, 128]
        p2 = self.pool2(e2)       # [B, 64, T/4, 64, 64]
        
        e3 = self.enc3(p2)        # [B, 128, T/4, 64, 64]
        p3 = self.pool3(e3)       # [B, 128, T/8, 32, 32]
        
        # Bottleneck
        bottle = self.bottleneck(p3)  # [B, 256, T/8, 32, 32]
        
        # Decoder Path
        d3 = self.upconv3(bottle)     # [B, 128, T/4, 64, 64]
        d3 = torch.cat([d3, e3], dim=1)
        d3 = self.dec3(d3)
        
        d2 = self.upconv2(d3)         # [B, 64, T/2, 128, 128]
        d2 = torch.cat([d2, e2], dim=1)
        d2 = self.dec2(d2)
        
        d1 = self.upconv1(d2)         # [B, 32, T, 256, 256]
        d1 = torch.cat([d1, e1], dim=1)
        d1 = self.dec1(d1)
        
        # Final convolutions
        out = self.final_temporal_conv(d1)  # [B, 32, 1, 256, 256]
        out = self.final_conv(out)          # [B, 1, 1, 256, 256]
        
        return out.squeeze(2)  # [B, 1, 256, 256]


def test_model():
    # Test with different temporal dimensions
    temporal_sizes = [8, 16, 32]  # Must be multiples of 8 due to 3 pooling layers
    
    for T in temporal_sizes:
        print(f"\nTesting with T={T}")
        model = UNet3DTemporal(input_frames=T)
        model.to(device)
        x = torch.randn(2, T, 256, 256)  # batch_size=2
        
        try:
            out = model(x)
            print(f"Input shape: {x.shape}")
            print(f"Output shape: {out.shape}")
            print("Test passed!")
        except Exception as e:
            print(f"Error with T={T}: {str(e)}")

test_model()


Testing with T=8

Testing with T=16

Testing with T=32


In [37]:
class DoubleConv2D(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU()
        )
    
    def forward(self, x):
        return self.double_conv(x)
    
class TemporalAttentionBlock(nn.Module):
    def __init__(self, dim, num_heads=8, sequence_length=9):
        super().__init__()
        
        self.norm = nn.LayerNorm([dim, None, None])  # Normalize over channels
        self.attention = nn.MultiheadAttention(
            embed_dim=dim,
            num_heads=num_heads,
            batch_first=True
        )
        
    def forward(self, x):
        # x shape: [batch, time, channels, height, width]
        b, t, c, h, w = x.shape
        
        # Reshape for attention
        x = x.permute(0, 3, 4, 1, 2)  # [B, H, W, T, C]
        x = x.reshape(b*h*w, t, c)     # [B*H*W, T, C]
        
        # Apply attention
        x = self.norm(x)
        attn_out, _ = self.attention(x, x, x)
        
        # Reshape back
        x = attn_out.reshape(b, h, w, t, c)
        x = x.permute(0, 3, 4, 1, 2)   # [B, T, C, H, W]
        
        return x

    
class PerfusionCTPredictor(nn.Module):
    def __init__(self, input_frames=4, base_filters=64):
        super().__init__()
        
        self.input_frames = input_frames
        
        # Encoder (single channel for CT scans)
        self.enc1 = DoubleConv2D(1, base_filters)
        self.pool1 = nn.MaxPool2d(2)
        self.enc2 = DoubleConv2D(base_filters, base_filters*2)
        self.pool2 = nn.MaxPool2d(2)
        
        # Bottleneck with temporal attention
        self.bottleneck_conv = DoubleConv2D(base_filters*2, base_filters*4)
        self.temporal_attention = TemporalAttentionBlock(
            dim=base_filters*4,
            num_heads=4,  # Reduced for sequence length of 4
            sequence_length=input_frames
        )
        
        # Decoder
        self.upconv2 = nn.ConvTranspose2d(base_filters*4, base_filters*2, kernel_size=2, stride=2)
        self.dec2 = DoubleConv2D(base_filters*4, base_filters*2)
        self.upconv1 = nn.ConvTranspose2d(base_filters*2, base_filters, kernel_size=2, stride=2)
        self.dec1 = DoubleConv2D(base_filters*2, base_filters)
        
        # Final prediction head
        self.final_conv = nn.Conv2d(base_filters, 1, kernel_size=1)
        
        # Optional: Intensity normalization
        self.instance_norm = nn.InstanceNorm2d(1, affine=True)

    def forward(self, x):
        # x shape: [batch, time=4, height, width]
        b, t, h, w = x.shape
        assert t == self.input_frames, f"Expected {self.input_frames} input frames, got {t}"
        
        # Optional: Normalize intensities
        x = x.view(b*t, 1, h, w)
        x = self.instance_norm(x)
        x = x.view(b, t, h, w)
        
        # Process each time step through encoder
        encoder_features = []
        bottleneck_features = []
        
        for i in range(t):
            # Encoder path
            x1 = self.enc1(x[:, i].unsqueeze(1))  # Add channel dimension
            x2 = self.enc2(self.pool1(x1))
            
            encoder_features.append((x1, x2))
            
            # Bottleneck
            bottle = self.bottleneck_conv(self.pool2(x2))
            bottleneck_features.append(bottle)
            
        # Stack and apply temporal attention at bottleneck
        bottleneck_features = torch.stack(bottleneck_features, dim=1)
        bottleneck_features = self.temporal_attention(bottleneck_features)
        
        # Use the last temporal feature for prediction
        bottle = bottleneck_features[:, -1]  # Take last temporal state
        
        # Decoder path (single time step)
        x1, x2 = encoder_features[-1]  # Use features from last input frame
        
        d2 = self.upconv2(bottle)
        d2 = torch.cat([d2, x2], dim=1)
        d2 = self.dec2(d2)
        
        d1 = self.upconv1(d2)
        d1 = torch.cat([d1, x1], dim=1)
        d1 = self.dec1(d1)
        
        # Predict next frame
        pred = self.final_conv(d1)
        
        return pred


# Training setup
def train_perfusion_predictor():
    model = PerfusionCTPredictor(input_frames=4)
    criterion = nn.MSELoss()  # or nn.L1Loss()
    
    # Optional: Add Huber loss for robustness to outliers
    huber_loss = nn.HuberLoss(delta=1.0)
    
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
    
    def train_step(x, y):
        # x: [batch, 4, height, width]
        # y: [batch, 1, height, width] (next frame)
        pred = model(x)
        
        # Combine losses
        loss_mse = criterion(pred, y)
        loss_huber = huber_loss(pred, y)
        loss = loss_mse + 0.5 * loss_huber
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        return loss.item()

2D UNet with temporal Attention Block in Bootleneck