In [1]:
import numpy as np

def load_data(path: str):
    return np.load(path)

v1 = load_data('../selected_volumes/MOL-001.npy')

In [2]:
import matplotlib.pyplot as plt
from ipywidgets import IntSlider, interact
def multi_vol_seq_interactive(volume_seqs, titles=None):
    """
    Interactive plot of multiple volume sequences using ipywidgets
    
    Parameters:
    - volume_seqs: List of 4D volume sequences to display
    - titles: Optional list of titles for each sequence
    """
    if titles is None:
        titles = [f"Volume {i+1}" for i in range(len(volume_seqs))]
        
    num_volumes = len(volume_seqs)
    nrows = int(num_volumes ** 0.5)
    ncols = (num_volumes + nrows - 1) // nrows
    
    def plot_volumes(time_idx, slice_idx):
        fig, axes = plt.subplots(nrows, ncols, 
                                figsize=(5*ncols, 5*nrows),
                                squeeze=True)
        if nrows == 1:
            if ncols == 1:
                axes = [[axes]]
            else:
                axes = [axes]
                
        for i, (volume_seq, title) in enumerate(zip(volume_seqs, titles)):
            row, col = i // ncols, i % ncols
            ax = axes[row][col]
            
            t = min(time_idx, len(volume_seq) - 1)
            s = min(slice_idx, len(volume_seq[t]) - 1)
            
            im = ax.imshow(volume_seq[t][s], cmap='magma')
            ax.set_title(title)
            plt.colorbar(im, ax=ax)
            
        plt.tight_layout()
        plt.show(block=True)
        
    max_time = max(len(vol) for vol in volume_seqs) - 1
    max_slice = max(len(vol[0]) for vol in volume_seqs) - 1
    
    interact(
        plot_volumes,
        time_idx=IntSlider(min=0, max=max_time, step=1, value=0, description='Time:'),
        slice_idx=IntSlider(min=0, max=max_slice, step=1, value=0, description='Slice:')
    )



In [23]:
multi_vol_seq_interactive([v1])

interactive(children=(IntSlider(value=0, description='Time:', max=17), IntSlider(value=0, description='Slice:'…

In [4]:
v1.shape

(18, 16, 256, 256)

In [6]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
import os
import math
import torchvision
from torchvision import datasets, transforms
from sklearn.metrics import confusion_matrix
from sklearn.metrics import classification_report
import random
import pytorch_lightning as pl
from torch.nn import Parameter

## Dataset 1:
- load the data as 

In [12]:
class Dataset3D(Dataset):
    def __init__(self, data_paths, context_window=4, transform=None):
        self.data_paths = data_paths
        self.context_window = context_window
        self.transform = transform
        self.samples = []
        # For every path to a volume sequence in .npy
        for data_path in self.data_paths:
            volume_seq = np.load(data_path)
            # Convert to tensor
            volume_seq = torch.from_numpy(volume_seq)
            # Generate samples
            for i in range(len(volume_seq) - self.context_window):
                # Input volume sequence (context_window x 16 x 256 x 256), target volume (16 x 256 x 256)
                self.samples.append((volume_seq[i:i+self.context_window], volume_seq[i+self.context_window]))
        
    def __len__(self):
        return len(self.samples)
    
    def __getitem__(self, idx):

        return self.samples[idx]

In [16]:
root = '../selected_volumes'
device = torch.device('mps') or ('cuda' if torch.cuda.is_available() else 'cpu')
# Training parameters
batch_size = 4
sequence_length = 4
learning_rate = 1e-4
num_epochs = 100
# Data parameters
train_split = 0.8
val_split = 0.1


data_paths = [os.path.join(root, path) for path in os.listdir(root)]
dataset = Dataset3D(data_paths, context_window=4)

def get_data_loaders(batch_size=4, sequence_length=4):
    # Load all folder paths
    # Split into train/val/test
    n_train = int(len(data_paths) * train_split)
    n_val = int(len(data_paths) * val_split)
    
    train_paths = data_paths[:n_train]
    val_paths = data_paths[n_train:n_train+n_val]
    test_paths = data_paths[n_train+n_val:]
    
    # Create datasets
    train_dataset = Dataset3D(train_paths, sequence_length)
    val_dataset = Dataset3D(val_paths, sequence_length)
    test_dataset = Dataset3D(test_paths, sequence_length)
    
    # Create dataloaders
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size)
    test_loader = DataLoader(test_dataset, batch_size=batch_size)
    
    return train_loader, val_loader, test_loader

In [17]:

train_loader, val_loader, test_loader = get_data_loaders(batch_size=batch_size, sequence_length=sequence_length)

In [20]:
for i, (input, target) in enumerate(train_loader):
    print(i, input.shape, target.shape)


0 torch.Size([4, 4, 16, 256, 256]) torch.Size([4, 16, 256, 256])
1 torch.Size([4, 4, 16, 256, 256]) torch.Size([4, 16, 256, 256])
2 torch.Size([4, 4, 16, 256, 256]) torch.Size([4, 16, 256, 256])
3 torch.Size([4, 4, 16, 256, 256]) torch.Size([4, 16, 256, 256])
4 torch.Size([4, 4, 16, 256, 256]) torch.Size([4, 16, 256, 256])
5 torch.Size([4, 4, 16, 256, 256]) torch.Size([4, 16, 256, 256])
6 torch.Size([4, 4, 16, 256, 256]) torch.Size([4, 16, 256, 256])
7 torch.Size([4, 4, 16, 256, 256]) torch.Size([4, 16, 256, 256])
8 torch.Size([4, 4, 16, 256, 256]) torch.Size([4, 16, 256, 256])
9 torch.Size([4, 4, 16, 256, 256]) torch.Size([4, 16, 256, 256])
10 torch.Size([4, 4, 16, 256, 256]) torch.Size([4, 16, 256, 256])
11 torch.Size([4, 4, 16, 256, 256]) torch.Size([4, 16, 256, 256])
12 torch.Size([4, 4, 16, 256, 256]) torch.Size([4, 16, 256, 256])
13 torch.Size([4, 4, 16, 256, 256]) torch.Size([4, 16, 256, 256])
14 torch.Size([4, 4, 16, 256, 256]) torch.Size([4, 16, 256, 256])
15 torch.Size([4, 4,

(2+1)D CNN

In [None]:
class DoubleConv2D(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU()
        )
    
    def forward(self, x):
        return self.double_conv(x)
    
class TemporalAttentionBlock(nn.Module):
    def __init__(self, dim, num_heads=8, sequence_length=9):
        super().__init__()
        
        self.norm = nn.LayerNorm([dim, None, None])  # Normalize over channels
        self.attention = nn.MultiheadAttention(
            embed_dim=dim,
            num_heads=num_heads,
            batch_first=True
        )
        
    def forward(self, x):
        # x shape: [batch, time, channels, height, width]
        b, t, c, h, w = x.shape
        
        # Reshape for attention
        x = x.permute(0, 3, 4, 1, 2)  # [B, H, W, T, C]
        x = x.reshape(b*h*w, t, c)     # [B*H*W, T, C]
        
        # Apply attention
        x = self.norm(x)
        attn_out, _ = self.attention(x, x, x)
        
        # Reshape back
        x = attn_out.reshape(b, h, w, t, c)
        x = x.permute(0, 3, 4, 1, 2)   # [B, T, C, H, W]
        
        return x

    
class PerfusionCTPredictor(nn.Module):
    def __init__(self, input_frames=4, base_filters=64):
        super().__init__()
        
        self.input_frames = input_frames
        
        # Encoder (single channel for CT scans)
        self.enc1 = DoubleConv2D(1, base_filters)
        self.pool1 = nn.MaxPool2d(2)
        self.enc2 = DoubleConv2D(base_filters, base_filters*2)
        self.pool2 = nn.MaxPool2d(2)
        
        # Bottleneck with temporal attention
        self.bottleneck_conv = DoubleConv2D(base_filters*2, base_filters*4)
        self.temporal_attention = TemporalAttentionBlock(
            dim=base_filters*4,
            num_heads=4,  # Reduced for sequence length of 4
            sequence_length=input_frames
        )
        
        # Decoder
        self.upconv2 = nn.ConvTranspose2d(base_filters*4, base_filters*2, kernel_size=2, stride=2)
        self.dec2 = DoubleConv2D(base_filters*4, base_filters*2)
        self.upconv1 = nn.ConvTranspose2d(base_filters*2, base_filters, kernel_size=2, stride=2)
        self.dec1 = DoubleConv2D(base_filters*2, base_filters)
        
        # Final prediction head
        self.final_conv = nn.Conv2d(base_filters, 1, kernel_size=1)
        
        # Optional: Intensity normalization
        self.instance_norm = nn.InstanceNorm2d(1, affine=True)

    def forward(self, x):
        # x shape: [batch, time=4, height, width]
        b, t, h, w = x.shape
        assert t == self.input_frames, f"Expected {self.input_frames} input frames, got {t}"
        
        # Optional: Normalize intensities
        x = x.view(b*t, 1, h, w)
        x = self.instance_norm(x)
        x = x.view(b, t, h, w)
        
        # Process each time step through encoder
        encoder_features = []
        bottleneck_features = []
        
        for i in range(t):
            # Encoder path
            x1 = self.enc1(x[:, i].unsqueeze(1))  # Add channel dimension
            x2 = self.enc2(self.pool1(x1))
            
            encoder_features.append((x1, x2))
            
            # Bottleneck
            bottle = self.bottleneck_conv(self.pool2(x2))
            bottleneck_features.append(bottle)
            
        # Stack and apply temporal attention at bottleneck
        bottleneck_features = torch.stack(bottleneck_features, dim=1)
        bottleneck_features = self.temporal_attention(bottleneck_features)
        
        # Use the last temporal feature for prediction
        bottle = bottleneck_features[:, -1]  # Take last temporal state
        
        # Decoder path (single time step)
        x1, x2 = encoder_features[-1]  # Use features from last input frame
        
        d2 = self.upconv2(bottle)
        d2 = torch.cat([d2, x2], dim=1)
        d2 = self.dec2(d2)
        
        d1 = self.upconv1(d2)
        d1 = torch.cat([d1, x1], dim=1)
        d1 = self.dec1(d1)
        
        # Predict next frame
        pred = self.final_conv(d1)
        
        return pred


# Training setup
def train_perfusion_predictor():
    model = PerfusionCTPredictor(input_frames=4)
    criterion = nn.MSELoss()  # or nn.L1Loss()
    
    # Optional: Add Huber loss for robustness to outliers
    huber_loss = nn.HuberLoss(delta=1.0)
    
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
    
    def train_step(x, y):
        # x: [batch, 4, height, width]
        # y: [batch, 1, height, width] (next frame)
        pred = model(x)
        
        # Combine losses
        loss_mse = criterion(pred, y)
        loss_huber = huber_loss(pred, y)
        loss = loss_mse + 0.5 * loss_huber
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        return loss.item()

2D UNet with temporal Attention Block in Bootleneck