<a href="https://colab.research.google.com/github/alim98/Thesis/blob/main/VMAEPro.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [3]:
# Install necessary libraries
!pip -q install torch torchvision torchaudio
!pip -q install einops
!pip -q install monai
!pip -q install torchmetrics
!pip -q install timm
!pip -q install pytorch-lightning



[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.5/1.5 MB[0m [31m25.8 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m927.3/927.3 kB[0m [31m21.4 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m819.3/819.3 kB[0m [31m17.0 MB/s[0m eta [36m0:00:00[0m
[?25h

Note: This is a simplified version of the decoder. Incorporating generative and diffusion-based enhancements like Latent Diffusion Models (LDM) requires additional implementation, which can be complex. For demonstration purposes, we'll keep it basic, but you should consider integrating libraries like Diffusers or custom implementations for LDM.

In [5]:
# import pytorch_lightning as pl
# from pytorch_lightning import Trainer
# from pytorch_lightning.callbacks import ModelSummary

In [4]:

# Import libraries
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
import torchvision
from einops import rearrange
import numpy as np
import matplotlib.pyplot as plt
from monai.transforms import (
     Compose, LoadImaged, ScaleIntensityd, EnsureTyped, Resized
)
from monai.data import CacheDataset, load_decathlon_datalist
from torchmetrics import Dice
import timm
import pytorch_lightning as pl
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelSummary

In [6]:
# ============================
# 3. Dummy Dataset
# ============================

class DummyMedicalImageRegistrationDataset(Dataset):
    def __init__(self, num_samples=100, volume_size=(64, 64, 64), patch_size=(16,16,16), mask_ratio=0.75):
        """
        Args:
            num_samples (int): Number of samples in the dataset.
            volume_size (tuple): Size of the 3D volume (D, H, W).
            patch_size (tuple): Size of each 3D patch (P_D, P_H, P_W).
            mask_ratio (float): Ratio of patches to mask.
        """
        self.num_samples = num_samples
        self.volume_size = volume_size
        self.patch_size = patch_size
        self.mask_ratio = mask_ratio

        # Calculate number of patches per volume
        self.num_patches = (volume_size[0] // patch_size[0]) * \
                           (volume_size[1] // patch_size[1]) * \
                           (volume_size[2] // patch_size[2])

    def __len__(self):
        return self.num_samples

    def __getitem__(self, idx):
        # Generate random source and target volumes
        source = np.random.rand(*self.volume_size).astype(np.float32)
        target = np.random.rand(*self.volume_size).astype(np.float32)

        # Concatenate along channel dimension
        concatenated = np.stack([source, target], axis=0)  # Shape: (2, D, H, W)

        # Divide into patches
        patches = self.divide_into_patches(concatenated, self.patch_size)  # Shape: (N_patches, 2, P_D, P_H, P_W)

        # Masking
        N_visible = int((1 - self.mask_ratio) * self.num_patches)
        indices = np.random.permutation(self.num_patches)
        visible_indices = indices[:N_visible]
        masked_indices = indices[N_visible:]

        visible_patches = patches[visible_indices]  # Shape: (N_visible, 2, P_D, P_H, P_W)
        masked_patches = patches[masked_indices]    # Shape: (N_masked, 2, P_D, P_H, P_W)

        sample = {
            'visible_patches': torch.tensor(visible_patches, dtype=torch.float32),  # (N_visible, 2, P_D, P_H, P_W)
            'masked_patches': torch.tensor(masked_patches, dtype=torch.float32),    # (N_masked, 2, P_D, P_H, P_W)
            'masked_indices': torch.tensor(masked_indices, dtype=torch.long),      # (N_masked,)
            'source': torch.tensor(source, dtype=torch.float32),                    # (D, H, W)
            'target': torch.tensor(target, dtype=torch.float32)                     # (D, H, W)
        }

        return sample

    def divide_into_patches(self, volume, patch_size):
        """
        Divide a 3D volume into non-overlapping patches.
        Args:
            volume (np.array): 3D volume with shape (C, D, H, W)
            patch_size (tuple): Size of each patch (P_D, P_H, P_W)
        Returns:
            patches (np.array): Array of patches with shape (N_patches, C, P_D, P_H, P_W)
        """
        C, D, H, W = volume.shape
        P_D, P_H, P_W = patch_size
        assert D % P_D == 0 and H % P_H == 0 and W % P_W == 0, "Volume dimensions must be divisible by patch size."
        patches = volume.reshape(
            C,
            D//P_D, P_D,
            H//P_H, P_H,
            W//P_W, P_W
        )
        patches = patches.transpose(1,2,3,4,5,6,0)
        patches = patches.reshape(-1, C, P_D, P_H, P_W)
        return patches
# ============================
# 4. Initialize Dataset and DataLoader
# ============================

# Initialize dataset and dataloader
num_samples = 100
volume_size = (48, 48, 48)       # Depth, Height, Width
patch_size = (16, 16, 16)        # Patch size
mask_ratio = 19/27               # 75% patches masked

dataset = DummyMedicalImageRegistrationDataset(
    num_samples=num_samples,
    volume_size=volume_size,
    patch_size=patch_size,
    mask_ratio=mask_ratio
)

batch_size = 1  # Set to 1 to ensure N_patches_new is a perfect cube (64)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=2)
# ============================
# 5. Patch Embedding
# ============================

class PatchEmbedding(nn.Module):
    def __init__(self, in_channels=2, embed_dim=768, patch_size=(16,16,16)):
        super(PatchEmbedding, self).__init__()
        self.patch_size = patch_size
        self.embed_dim = embed_dim
        self.proj = nn.Conv3d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size)

    def forward(self, x):
        """
        Args:
            x: (B * N_visible, C, P_D, P_H, P_W)
        Returns:
            embeddings: (B * N_visible, embed_dim)
        """
        B_N, C, P_D, P_H, P_W = x.shape
        x = self.proj(x)  # (B * N_visible, embed_dim, 1, 1, 1)
        x = x.view(B_N, self.embed_dim)  # (B * N_visible, embed_dim)
        return x
# ============================
# 6. Hierarchical Vision Transformer (H-ViT) Encoder
# ============================

class TransformerBlock(nn.Module):
    def __init__(self, embed_dim, num_heads, mlp_ratio=4., dropout=0.):
        super(TransformerBlock, self).__init__()
        self.norm1 = nn.LayerNorm(embed_dim)
        self.attn = nn.MultiheadAttention(embed_dim, num_heads, dropout=dropout)
        self.dropout1 = nn.Dropout(dropout)
        self.norm2 = nn.LayerNorm(embed_dim)
        self.mlp = nn.Sequential(
            nn.Linear(embed_dim, int(embed_dim * mlp_ratio)),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(int(embed_dim * mlp_ratio), embed_dim),
            nn.Dropout(dropout),
        )

    def forward(self, x):
        # x: (S, B, embed_dim)
        x2 = self.norm1(x)
        attn_output, _ = self.attn(x2, x2, x2)  # Self-attention
        x = x + self.dropout1(attn_output)
        x2 = self.norm2(x)
        x = x + self.mlp(x2)
        return x

class HierarchicalViTEncoder(nn.Module):
    def __init__(self, embed_dim=768, num_heads=12, num_layers=6, mlp_ratio=4., dropout=0.,
                 num_stages=1, patch_size=(16,16,16)):
        super(HierarchicalViTEncoder, self).__init__()
        self.num_stages = num_stages
        self.patch_size = patch_size
        self.embed_dim = embed_dim

        # Define transformer layers for each stage
        self.transformer_layers = nn.ModuleList()
        for stage in range(num_stages):
            for _ in range(num_layers):
                self.transformer_layers.append(
                    TransformerBlock(embed_dim, num_heads, mlp_ratio, dropout)
                )

        # Downsampling layers between stages (if any)
        self.downsamples = nn.ModuleList()
        for stage in range(num_stages - 1):
            self.downsamples.append(
                nn.Conv3d(embed_dim, embed_dim, kernel_size=2, stride=2)
            )

        self.norm = nn.LayerNorm(embed_dim)

    def forward(self, x):
        """
        Args:
            x: (B, N_patches, embed_dim)
        Returns:
            features: list of feature maps at different scales
            x: (B * N_patches_new, embed_dim)
        """
        features = []
        B, N_patches, E = x.shape
        x = x.view(B * N_patches, E)  # Flatten patches across batch

        for stage in range(self.num_stages):
            # Transformer layers
            for _ in range(len(self.transformer_layers) // self.num_stages):
                layer = self.transformer_layers.pop(0)
                # Prepare input for MultiheadAttention: (S, B, E)
                # Here, treat each patch as a sequence element
                x = x.unsqueeze(0)  # (1, B * N_patches, E)
                x = layer(x)         # (1, B * N_patches, E)
                x = x.squeeze(0)     # (B * N_patches, E)

            # Collect features
            features.append(x.clone())

            if stage < self.num_stages -1:
                # Compute D, H, W
                N_patches_new = B * N_patches
                D = H = W = int(np.ceil(N_patches_new ** (1/3)))
                padding = D * H * W - N_patches_new
                if padding >0:
                    pad_tensor = torch.zeros(B * padding, E).to(x.device)
                    x = torch.cat([x, pad_tensor], dim=0)  # (B * N_patches_new + padding, E)

                # Reshape to (B, embed_dim, D, H, W) using einops
                x = rearrange(x, '(b p) e -> b e d h w', b=B, p=D*H*W//(D*H*W))  # Correct pattern

                # Downsample using the corresponding downsampling layer
                x_downsampled = self.downsamples[stage](x)    # (B, embed_dim, D/2, H/2, W/2)

                # Flatten back to patches using einops
                x = rearrange(x_downsampled, 'b c d h w -> b (d h w) c')  # (B, N_patches_new, embed_dim)
                x = x.view(B * x.shape[1], self.embed_dim)  # (B * N_patches_new, embed_dim)

        x = self.norm(x)
        return features, x  # Return all intermediate features and final output
# ============================
# 7. Vision Transformer Masked Autoencoder (ViT-MAE) Decoder
# ============================

class ViT_MAE_Decoder(nn.Module):
    def __init__(self, embed_dim=768, num_heads=12, num_layers=6, mlp_ratio=4., dropout=0.,
                 patch_size=(16,16,16)):
        super(ViT_MAE_Decoder, self).__init__()
        self.embed_dim = embed_dim
        self.patch_size = patch_size

        self.transformer_layers = nn.ModuleList([
            TransformerBlock(embed_dim, num_heads, mlp_ratio, dropout) for _ in range(num_layers)
        ])
        self.norm = nn.LayerNorm(embed_dim)
        self.output_layer = nn.Linear(embed_dim, 2 * np.prod(patch_size))  # Reconstruct 2 channels

    def forward(self, x, masked_indices):
        """
        Args:
            x: (B, N_patches_new, embed_dim)
            masked_indices: (B, N_masked)
        Returns:
            reconstructed: (B, N_masked, 2, P_D, P_H, P_W)
        """
        B, N_patches_new, E = x.shape
        N_masked = masked_indices.shape[1]

        # Reshape x to (B, N_patches_new, embed_dim)
        x = x  # (B, N_patches_new, embed_dim)

        # Pass through decoder transformer layers
        for layer in self.transformer_layers:
            # Prepare input for MultiheadAttention: (S, B, E)
            x = x.transpose(0,1)  # (N_patches_new, B, E)
            x = layer(x)           # (N_patches_new, B, E)
            x = x.transpose(0,1)  # (B, N_patches_new, E)

        x = self.norm(x)  # (B, N_patches_new, E)

        # Generate mask tokens (learnable or random)
        mask_tokens = torch.randn(B, N_masked, E).to(x.device)  # (B, N_masked, E)

        # Concatenate visible and mask tokens
        x = torch.cat([x, mask_tokens], dim=1)  # (B, N_patches_new + N_masked, E)

        # Pass through decoder transformer layers again
        for layer in self.transformer_layers:
            x = x.transpose(0,1)  # (N_total, B, E)
            x = layer(x)           # (N_total, B, E)
            x = x.transpose(0,1)  # (B, N_total, E)

        x = self.norm(x)  # (B, N_total, E)

        # Split visible and masked tokens
        visible = x[:, :N_patches_new, :]  # (B, N_patches_new, E)
        masked = x[:, N_patches_new:, :]   # (B, N_masked, E)

        # Reconstruct masked patches
        reconstructed = self.output_layer(masked)  # (B, N_masked, 2 * P_D * P_H * P_W)
        reconstructed = reconstructed.view(B, N_masked, 2, *self.patch_size)  # (B, N_masked, 2, P_D, P_H, P_W)

        return reconstructed
# ============================
# 8. Registration Head
# ============================

class RegistrationHead(nn.Module):
    def __init__(self, in_channels=768, base_channels=512):
        super(RegistrationHead, self).__init__()
        self.conv1 = nn.Conv3d(in_channels, base_channels, kernel_size=3, padding=1)
        self.relu1 = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv3d(base_channels, base_channels//2, kernel_size=3, padding=1)
        self.relu2 = nn.ReLU(inplace=True)
        self.conv3 = nn.Conv3d(base_channels//2, 3, kernel_size=3, padding=1)  # 3D displacement vectors
        self.tanh = nn.Tanh()  # To constrain displacement values, adjust as needed

    def forward(self, x):
        """
        Args:
            x: (B, C, D, H, W)
        Returns:
            phi: (B, 3, D, H, W)
        """
        x = self.conv1(x)
        x = self.relu1(x)
        x = self.conv2(x)
        x = self.relu2(x)
        phi = self.conv3(x)
        phi = self.tanh(phi)
        return phi
# ============================
# 9. Spatial Transformer
# ============================
class SpatialTransformer3D(nn.Module):
    def __init__(self):
        super(SpatialTransformer3D, self).__init__()

    def forward(self, src, phi):
        """
        Args:
            src: (B, D_src, H_src, W_src)
            phi: (B, 3, D_phi, H_phi, W_phi) displacement vectors
        Returns:
            warped_src: (B, D_src, H_src, W_src)
        """
        B, D_src, H_src, W_src = src.shape
        device = src.device

        # Create grid with source's spatial dimensions
        grid_d, grid_h, grid_w = torch.meshgrid(
            torch.linspace(-1, 1, D_src, device=device),
            torch.linspace(-1, 1, H_src, device=device),
            torch.linspace(-1, 1, W_src, device=device),
            indexing='ij'  # Specify indexing to avoid future warnings
        )

        grid = torch.stack((grid_w, grid_h, grid_d), dim=-1)  # (D_src, H_src, W_src, 3)
        grid = grid.unsqueeze(0).repeat(B, 1, 1, 1, 1)      # (B, D_src, H_src, W_src, 3)

        # Normalize displacement to [-1, 1]
        # Rearrange phi to (B, C, D_phi, H_phi, W_phi)
        phi_norm = rearrange(phi, 'b c d h w -> b c d h w') * 0.1  # (B, 3, D_phi, H_phi, W_phi)

        # Upsample phi_norm to match source's spatial dimensions
        phi_norm_upsampled = F.interpolate(
            phi_norm,
            size=(D_src, H_src, W_src),
            mode='trilinear',
            align_corners=True
        )  # (B, 3, D_src, H_src, W_src)

        # Rearrange to (B, D_src, H_src, W_src, 3)
        phi_norm_upsampled = rearrange(phi_norm_upsampled, 'b c d h w -> b d h w c')  # (B, D_src, H_src, W_src, 3)

        # Add displacement to grid
        grid = grid + phi_norm_upsampled  # (B, D_src, H_src, W_src, 3)

        # Sample using grid_sample
        src = src.unsqueeze(1)  # (B, 1, D_src, H_src, W_src)
        warped_src = F.grid_sample(
            src,
            grid,
            mode='bilinear',
            padding_mode='border',
            align_corners=True
        )  # (B, 1, D_src, H_src, W_src)

        warped_src = warped_src.squeeze(1)  # (B, D_src, H_src, W_src)

        return warped_src

# ============================
# 10. Loss Function
# ============================

class VMAEProLoss(nn.Module):
    def __init__(self, lambda_recon=1.0, lambda_reg=1.0, lambda_smooth=0.1):
        super(VMAEProLoss, self).__init__()
        self.lambda_recon = lambda_recon
        self.lambda_reg = lambda_reg
        self.lambda_smooth = lambda_smooth
        self.mse = nn.MSELoss()

    def forward(self, reconstructed, masked_patches, warped_src, target, phi):
        """
        Args:
            reconstructed: (B, N_masked, 2, P_D, P_H, P_W)
            masked_patches: (B, N_masked, 2, P_D, P_H, P_W)
            warped_src: (B, D, H, W)
            target: (B, D, H, W)
            phi: (B, 3, D, H, W)
        Returns:
            total_loss: scalar
        """
        # Reconstruction Loss
        recon_loss = self.mse(reconstructed, masked_patches)

        # Registration Loss
        reg_loss = self.mse(warped_src, target)

        # Smoothness Loss (encouraging small displacements)
        smooth_loss = self.mse(phi, torch.zeros_like(phi))

        # Total Loss
        total_loss = self.lambda_recon * recon_loss + \
                     self.lambda_reg * reg_loss + \
                     self.lambda_smooth * smooth_loss
        return total_loss
# ============================
# 11. Complete Model
# ============================


class VMAEProModel(nn.Module):
    def __init__(self, embed_dim=768, num_heads=12, num_layers=6, mlp_ratio=4., dropout=0.,
                 patch_size=(16,16,16), mask_ratio=0.75, base_channels=512):
        super(VMAEProModel, self).__init__()
        self.patch_embedding = PatchEmbedding(in_channels=2, embed_dim=embed_dim, patch_size=patch_size)
        self.encoder = HierarchicalViTEncoder(
            embed_dim=embed_dim,
            num_heads=num_heads,
            num_layers=num_layers,
            mlp_ratio=mlp_ratio,
            dropout=dropout,
            num_stages=1,  # Set to 1 to avoid reshaping issues
            patch_size=patch_size
        )
        self.decoder = ViT_MAE_Decoder(
            embed_dim=embed_dim,
            num_heads=num_heads,
            num_layers=num_layers,
            mlp_ratio=mlp_ratio,
            dropout=dropout,
            patch_size=patch_size
        )
        self.registration_head = RegistrationHead(in_channels=embed_dim, base_channels=base_channels)
        self.spatial_transformer = SpatialTransformer3D()

    def forward(self, visible_patches, masked_indices, source, target):
        B, N_visible, C, P_D, P_H, P_W = visible_patches.shape

        # Reshape visible_patches to (B * N_visible, C, P_D, P_H, P_W)
        visible_patches = visible_patches.view(B * N_visible, C, P_D, P_H, P_W)

        # Patch Embedding
        embeddings = self.patch_embedding(visible_patches)  # (B * N_visible, embed_dim)

        # Reshape embeddings back to (B, N_visible, embed_dim) using einops
        embeddings = rearrange(embeddings, '(b n) e -> b n e', b=B, n=N_visible)  # (B, N_visible, embed_dim)

        # Encoder
        features, encoder_output = self.encoder(embeddings)  # features: list, encoder_output: (B * N_patches_new, embed_dim)

        # Reshape encoder_output to (B, N_patches_new, embed_dim) using einops
        N_patches_new = encoder_output.shape[0] // B
        encoder_output = rearrange(encoder_output, '(b p) e -> b p e', b=B, p=N_patches_new)  # (B, N_patches_new, embed_dim)

        # Decoder (Reconstruction)
        reconstructed = self.decoder(encoder_output, masked_indices)  # (B, N_masked, 2, P_D, P_H, P_W)

        # Registration Head
        # Reshape encoder_output to (B, embed_dim, D, H, W) using einops
        D = H = W = int(round(N_patches_new ** (1/3)))  # For N_patches_new=8, D=2
        assert D * H * W == N_patches_new, f"D*H*W={D*H*W} does not equal N_patches_new={N_patches_new}"

        # Corrected rearrange pattern
        registration_input = rearrange(encoder_output, 'b (d h w) e -> b e d h w', d=D, h=H, w=W)  # (B, embed_dim, D, H, W)
        phi = self.registration_head(registration_input)  # (B, 3, D, H, W)

        # Spatial Transformer
        warped_src = self.spatial_transformer(source, phi)  # (B, D, H, W)

        return reconstructed, warped_src, phi

class VMAEProLightningModule(pl.LightningModule):
    def __init__(self, embed_dim=768, num_heads=12, num_layers=6, mlp_ratio=4., dropout=0.,
                 patch_size=(16,16,16), mask_ratio=0.75, base_channels=512, learning_rate=1e-4):
        super(VMAEProLightningModule, self).__init__()
        self.save_hyperparameters()

        # Initialize your model components
        self.patch_embedding = PatchEmbedding(in_channels=2, embed_dim=embed_dim, patch_size=patch_size)
        self.encoder = HierarchicalViTEncoder(
            embed_dim=embed_dim,
            num_heads=num_heads,
            num_layers=num_layers,
            mlp_ratio=mlp_ratio,
            dropout=dropout,
            num_stages=1,  # Set to 1 to avoid reshaping issues
            patch_size=patch_size
        )
        self.decoder = ViT_MAE_Decoder(
            embed_dim=embed_dim,
            num_heads=num_heads,
            num_layers=num_layers,
            mlp_ratio=mlp_ratio,
            dropout=dropout,
            patch_size=patch_size
        )
        self.registration_head = RegistrationHead(in_channels=embed_dim, base_channels=base_channels)
        self.spatial_transformer = SpatialTransformer3D()
        self.criterion = VMAEProLoss(
            lambda_recon=1.0,
            lambda_reg=1.0,
            lambda_smooth=0.1  # Adjust as needed
        )

    def forward(self, visible_patches, masked_indices, source, target):
        B, N_visible, C, P_D, P_H, P_W = visible_patches.shape

        # Reshape visible_patches to (B * N_visible, C, P_D, P_H, P_W)
        visible_patches = visible_patches.view(B * N_visible, C, P_D, P_H, P_W)

        # Patch Embedding
        embeddings = self.patch_embedding(visible_patches)  # (B * N_visible, embed_dim)

        # Reshape embeddings back to (B, N_visible, embed_dim) using einops
        embeddings = rearrange(embeddings, '(b n) e -> b n e', b=B, n=N_visible)  # (B, N_visible, embed_dim)

        # Encoder
        features, encoder_output = self.encoder(embeddings)  # features: list, encoder_output: (B * N_patches_new, embed_dim)

        # Reshape encoder_output to (B, N_patches_new, embed_dim) using einops
        N_patches_new = encoder_output.shape[0] // B
        encoder_output = rearrange(encoder_output, '(b p) e -> b p e', b=B, p=N_patches_new)  # (B, N_patches_new, embed_dim)

        # Decoder (Reconstruction)
        reconstructed = self.decoder(encoder_output, masked_indices)  # (B, N_masked, 2, P_D, P_H, P_W)

        # Registration Head
        # Reshape encoder_output to (B, embed_dim, D, H, W) using einops
        D = H = W = int(round(N_patches_new ** (1/3)))  # For N_patches_new=8, D=2
        assert D * H * W == N_patches_new, f"D*H*W={D*H*W} does not equal N_patches_new={N_patches_new}"

        # Corrected rearrange pattern
        registration_input = rearrange(encoder_output, 'b (d h w) e -> b e d h w', d=D, h=H, w=W)  # (B, embed_dim, D, H, W)
        phi = self.registration_head(registration_input)  # (B, 3, D, H, W)

        # Spatial Transformer
        warped_src = self.spatial_transformer(source, phi)  # (B, D_src, H_src, W_src)

        return reconstructed, warped_src, phi

    def training_step(self, batch, batch_idx):
        visible_patches = batch['visible_patches']  # (B, N_visible, 2, P_D, P_H, P_W)
        masked_patches = batch['masked_patches']    # (B, N_masked, 2, P_D, P_H, P_W)
        masked_indices = batch['masked_indices']    # (B, N_masked)
        source = batch['source']                    # (B, D_src, H_src, W_src)
        target = batch['target']                    # (B, D_src, H_src, W_src)

        # Forward pass
        reconstructed, warped_src, phi = self.forward(visible_patches, masked_indices, source, target)

        # Compute loss
        loss = self.criterion(reconstructed, masked_patches, warped_src, target, phi)

        # Log loss
        self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)

        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(self.parameters(), lr=self.hparams.learning_rate)
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.1)
        return [optimizer], [scheduler]
# 12. Training Loop
# ============================

def train_one_step(model, dataloader, optimizer, criterion, device, num_epochs=10, scheduler=None):
    model.to(device)
    model.train()

    for epoch in range(1, num_epochs + 1):
        epoch_loss = 0.0
        for batch_idx, batch in enumerate(dataloader):
            visible_patches = batch['visible_patches'].to(device)  # (B, N_visible, 2, P_D, P_H, P_W)
            masked_patches = batch['masked_patches'].to(device)    # (B, N_masked, 2, P_D, P_H, P_W)
            masked_indices = batch['masked_indices'].to(device)    # (B, N_masked)
            source = batch['source'].to(device)                    # (B, D, H, W)
            target = batch['target'].to(device)                    # (B, D, H, W)

            optimizer.zero_grad()

            # Forward pass
            reconstructed, warped_src, phi = model(visible_patches, masked_indices, source, target)

            # Compute loss
            loss = criterion(reconstructed, masked_patches, warped_src, target, phi)
            loss.backward()
            optimizer.step()

            epoch_loss += loss.item()

        avg_loss = epoch_loss / len(dataloader)
        print(f"Epoch [{epoch}/{num_epochs}], Loss: {avg_loss:.4f}")

        if scheduler:
            scheduler.step()
# ============================
# 13. Execution
# ============================

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Initialize the model
# embed_dim = 768
# num_heads = 12
# num_layers = 6
# mlp_ratio = 4.
# dropout = 0.
# patch_size = (16,16,16)
# mask_ratio = 0.75
# base_channels = 512

# model = VMAEProModel(
#     embed_dim=embed_dim,
#     num_heads=num_heads,
#     num_layers=num_layers,
#     mlp_ratio=mlp_ratio,
#     dropout=dropout,
#     patch_size=patch_size,
#     mask_ratio=mask_ratio,
#     base_channels=base_channels
# )

# Define optimizer and scheduler
# learning_rate = 1e-4
# optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
# scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.1)  # Decays LR by 0.1 every 5 epochs

# # Initialize loss function
# criterion = VMAEProLoss(
#     lambda_recon=1.0,
#     lambda_reg=1.0,
#     lambda_smooth=0.1  # Lower weight for smoothness to prevent overpowering other losses
# )

# Start training
# num_epochs = 10
# train_one_step(
#     model=model,
#     dataloader=dataloader,
#     optimizer=optimizer,
#     criterion=criterion,
#     device=device,
#     num_epochs=num_epochs,
#     scheduler=scheduler
# )

model = VMAEProLightningModule(
    embed_dim=768,
    num_heads=12,
    num_layers=6,
    mlp_ratio=4.,
    dropout=0.,
    patch_size=(16,16,16),
    mask_ratio=19/27,  # Approximately 0.7037 to get N_visible=8 patches
    base_channels=512,
    learning_rate=1e-4
)

# Define callbacks
model_summary_callback = ModelSummary(max_depth=-1)  # max_depth=-1 for full summary

# Initialize the Trainer
trainer = Trainer(
    max_epochs=10,
    callbacks=[model_summary_callback],
    accelerator='cpu',  # Change to 'gpu' if using CUDA
    devices=1,          # Number of GPUs or CPUs
    log_every_n_steps=1
)
# Start training
trainer.fit(model, dataloader)

Using device: cuda


INFO:pytorch_lightning.utilities.rank_zero:Trainer already configured with model summary callbacks: [<class 'pytorch_lightning.callbacks.model_summary.ModelSummary'>]. Skipping setting a default `ModelSummary` callback.
INFO:pytorch_lightning.utilities.rank_zero:GPU available: True (cuda), used: False
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs
/usr/local/lib/python3.10/dist-packages/pytorch_lightning/trainer/setup.py:177: GPU available but not used. You can set it by doing `Trainer(accelerator='gpu')`.
INFO:pytorch_lightning.callbacks.model_summary:
    | Name                                       | Type                            | Params | Mode 
---------------------------------------------------------------------------------------------------------
0   | patch_embedding                            | PatchEmbedding                  | 6.3 M  | train
1   | patch_embedd

Training: |          | 0/? [00:00<?, ?it/s]

INFO:pytorch_lightning.utilities.rank_zero:
Detected KeyboardInterrupt, attempting graceful shutdown ...


NameError: name 'exit' is not defined

In [None]:


# class VMAEProLightningModule(pl.LightningModule):
#     def __init__(self, embed_dim=768, num_heads=12, num_layers=6, mlp_ratio=4., dropout=0.,
#                  patch_size=(16,16,16), mask_ratio=0.75, base_channels=512, learning_rate=1e-4):
#         super(VMAEProLightningModule, self).__init__()
#         self.save_hyperparameters()

#         # Initialize your model components
#         self.patch_embedding = PatchEmbedding(in_channels=2, embed_dim=embed_dim, patch_size=patch_size)
#         self.encoder = HierarchicalViTEncoder(
#             embed_dim=embed_dim,
#             num_heads=num_heads,
#             num_layers=num_layers,
#             mlp_ratio=mlp_ratio,
#             dropout=dropout,
#             num_stages=1,  # Set to 1 to avoid reshaping issues
#             patch_size=patch_size
#         )
#         self.decoder = ViT_MAE_Decoder(
#             embed_dim=embed_dim,
#             num_heads=num_heads,
#             num_layers=num_layers,
#             mlp_ratio=mlp_ratio,
#             dropout=dropout,
#             patch_size=patch_size
#         )
#         self.registration_head = RegistrationHead(in_channels=embed_dim, base_channels=base_channels)
#         self.spatial_transformer = SpatialTransformer3D()
#         self.criterion = VMAEProLoss(
#             lambda_recon=1.0,
#             lambda_reg=1.0,
#             lambda_smooth=0.1  # Adjust as needed
#         )

#     def forward(self, visible_patches, masked_indices, source, target):
#         B, N_visible, C, P_D, P_H, P_W = visible_patches.shape

#         # Reshape visible_patches to (B * N_visible, C, P_D, P_H, P_W)
#         visible_patches = visible_patches.view(B * N_visible, C, P_D, P_H, P_W)

#         # Patch Embedding
#         embeddings = self.patch_embedding(visible_patches)  # (B * N_visible, embed_dim)

#         # Reshape embeddings back to (B, N_visible, embed_dim) using einops
#         embeddings = rearrange(embeddings, '(b n) e -> b n e', b=B, n=N_visible)  # (B, N_visible, embed_dim)

#         # Encoder
#         features, encoder_output = self.encoder(embeddings)  # features: list, encoder_output: (B * N_patches_new, embed_dim)

#         # Reshape encoder_output to (B, N_patches_new, embed_dim) using einops
#         N_patches_new = encoder_output.shape[0] // B
#         encoder_output = rearrange(encoder_output, '(b p) e -> b p e', b=B, p=N_patches_new)  # (B, N_patches_new, embed_dim)

#         # Decoder (Reconstruction)
#         reconstructed = self.decoder(encoder_output, masked_indices)  # (B, N_masked, 2, P_D, P_H, P_W)

#         # Registration Head
#         # Reshape encoder_output to (B, embed_dim, D, H, W) using einops
#         D = H = W = int(round(N_patches_new ** (1/3)))  # For N_patches_new=8, D=2
#         assert D * H * W == N_patches_new, f"D*H*W={D*H*W} does not equal N_patches_new={N_patches_new}"

#         # Corrected rearrange pattern
#         registration_input = rearrange(encoder_output, 'b (d h w) e -> b e d h w', d=D, h=H, w=W)  # (B, embed_dim, D, H, W)
#         phi = self.registration_head(registration_input)  # (B, 3, D, H, W)

#         # Spatial Transformer
#         warped_src = self.spatial_transformer(source, phi)  # (B, D_src, H_src, W_src)

#         return reconstructed, warped_src, phi

#     def training_step(self, batch, batch_idx):
#         visible_patches = batch['visible_patches']  # (B, N_visible, 2, P_D, P_H, P_W)
#         masked_patches = batch['masked_patches']    # (B, N_masked, 2, P_D, P_H, P_W)
#         masked_indices = batch['masked_indices']    # (B, N_masked)
#         source = batch['source']                    # (B, D_src, H_src, W_src)
#         target = batch['target']                    # (B, D_src, H_src, W_src)

#         # Forward pass
#         reconstructed, warped_src, phi = self.forward(visible_patches, masked_indices, source, target)

#         # Compute loss
#         loss = self.criterion(reconstructed, masked_patches, warped_src, target, phi)

#         # Log loss
#         self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)

#         return loss

#     def configure_optimizers(self):
#         optimizer = torch.optim.AdamW(self.parameters(), lr=self.hparams.learning_rate)
#         scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.1)
#         return [optimizer], [scheduler]
# # Initialize the model
# model = VMAEProLightningModule(
#     embed_dim=768,
#     num_heads=12,
#     num_layers=6,
#     mlp_ratio=4.,
#     dropout=0.,
#     patch_size=(16,16,16),
#     mask_ratio=19/27,  # Approximately 0.7037 to get N_visible=8 patches
#     base_channels=512,
#     learning_rate=1e-4
# )

# # Define callbacks
# model_summary_callback = ModelSummary(max_depth=-1)  # max_depth=-1 for full summary

# # Initialize the Trainer
# trainer = Trainer(
#     max_epochs=10,
#     callbacks=[model_summary_callback],
#     accelerator='cpu',  # Change to 'gpu' if using CUDA
#     devices=1,          # Number of GPUs or CPUs
#     log_every_n_steps=1
# )
# # Start training
# trainer.fit(model, dataloader)
