<a href="https://colab.research.google.com/github/alim98/Thesis/blob/main/EnhancedVMAEPro.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
# Install necessary libraries
!pip -q install torch torchvision torchaudio
!pip -q install einops
!pip -q install monai
!pip -q install torchmetrics
!pip -q install timm
!pip -q install pytorch-lightning



[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.5/1.5 MB[0m [31m26.3 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m927.3/927.3 kB[0m [31m21.2 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m819.3/819.3 kB[0m [31m16.0 MB/s[0m eta [36m0:00:00[0m
[?25h

Note: This is a simplified version of the decoder. Incorporating generative and diffusion-based enhancements like Latent Diffusion Models (LDM) requires additional implementation, which can be complex. For demonstration purposes, we'll keep it basic, but you should consider integrating libraries like Diffusers or custom implementations for LDM.

In [5]:
# import pytorch_lightning as pl
# from pytorch_lightning import Trainer
# from pytorch_lightning.callbacks import ModelSummary

In [2]:

# Import libraries
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
import torchvision
from einops import rearrange
import numpy as np
import matplotlib.pyplot as plt
from monai.transforms import (
     Compose, LoadImaged, ScaleIntensityd, EnsureTyped, Resized
)
from monai.data import CacheDataset, load_decathlon_datalist
from torchmetrics import Dice
import timm
import pytorch_lightning as pl
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelSummary

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import pytorch_lightning as pl
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelSummary
from functools import partial

class TransformerBlock(nn.Module):
    def __init__(self, embed_dim, num_heads, mlp_ratio=4.0, dropout=0.1):
        super(TransformerBlock, self).__init__()
        self.attn = nn.MultiheadAttention(embed_dim, num_heads, dropout=dropout)
        self.norm1 = nn.LayerNorm(embed_dim)
        self.mlp = nn.Sequential(
            nn.Linear(embed_dim, int(embed_dim * mlp_ratio)),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(int(embed_dim * mlp_ratio), embed_dim),
            nn.Dropout(dropout),
        )
        self.norm2 = nn.LayerNorm(embed_dim)

    def forward(self, x):
        """
        Args:
            x: Tensor of shape [L, N, E]
        Returns:
            Tensor of shape [L, N, E]
        """
        # Multihead Attention
        attn_output, _ = self.attn(x, x, x)  # [L, N, E]
        x = x + attn_output
        x = self.norm1(x)

        # MLP
        mlp_output = self.mlp(x)  # [L, N, E]
        x = x + mlp_output
        x = self.norm2(x)

        return x  # [L, N, E]

class HierarchicalViTEncoder(nn.Module):
    def __init__(
        self,
        embed_dim=768,
        num_heads=12,
        num_layers=6,
        mlp_ratio=4.0,
        dropout=0.1,
        num_stages=3,
        patch_size=(4, 16, 16)  # Adjusted patch size
    ):
        super(HierarchicalViTEncoder, self).__init__()
        self.num_stages = num_stages
        self.patch_size = patch_size
        self.embed_dim = embed_dim

        # Define transformer layers for each stage
        self.transformer_layers = nn.ModuleList()
        for stage in range(num_stages):
            for _ in range(num_layers):
                self.transformer_layers.append(
                    TransformerBlock(embed_dim, num_heads, mlp_ratio, dropout)
                )

        # Downsampling layers between stages
        self.downsamples = nn.ModuleList()
        for stage in range(num_stages - 1):
            self.downsamples.append(
                nn.Conv3d(embed_dim, embed_dim, kernel_size=2, stride=2)  # Keep embed_dim constant
            )

        self.norm = nn.LayerNorm(embed_dim)

    def forward(self, x):
        """
        Args:
            x: Tensor of shape [B * p, E] where p = D * H * W
        Returns:
            features: list of feature maps at different scales
            x: Tensor of shape [B * p_new, E_new]
        """
        features = []

        # Define initial spatial dimensions; adjusted to 4x4x4
        D, H, W = 4, 4, 4  # Adjusted to prevent zero dimensions

        # Calculate patch count per batch
        B_p, E = x.shape
        p = D * H * W
        assert B_p % p == 0, f"Batch size {B_p} is not divisible by patch count {p}"
        B = B_p // p

        # Reshape to [B, E, D, H, W]
        x = x.view(B, E, D, H, W)
        # print(f"Encoder Input Shape: {x.shape}")  # Debug statement

        for stage in range(self.num_stages):
            # Apply transformer layers
            for _ in range(len(self.transformer_layers) // self.num_stages):
                layer = self.transformer_layers.pop(0)

                # Reshape x to [B, E, D, H, W] -> [B, p, E]
                x_reshaped = x.view(B, E, -1).permute(0, 2, 1)  # [B, p, E]

                # Transpose to [p, B, E] for transformer
                x_transposed = x_reshaped.permute(1, 0, 2)  # [p, B, E]

                # Pass through transformer layer
                x_transformed = layer(x_transposed)  # [p, B, E]

                # Transpose back to [B, p, E] and reshape to [B, E, D, H, W]
                x = x_transformed.permute(1, 0, 2).reshape(B, E, D, H, W)  # [B, E, D, H, W]
                # print(f"Transformer Output Shape: {x.shape}")  # Debug

            # Collect features
            features.append(x.clone())
            # print(f"Collected Feature Shape at Stage {stage}: {x.shape}")  # Debug

            if stage < self.num_stages - 1:
                # Downsample using the corresponding downsampling layer
                x = self.downsamples[stage](x)    # [B, E, D/2, H/2, W/2]
                # print(f"After Downsampling Shape: {x.shape}")  # Debug

                # Update spatial dimensions for next stage
                D, H, W = D // 2, H // 2, W // 2
                assert D > 0 and H > 0 and W > 0, "Spatial dimensions reduced to zero or negative"

                # No need to flatten here as the next stage will handle reshaping
                # Just ensure x has shape [B, E, D, H, W]

        # x = self.norm(x.view(B, E, -1)).reshape(B, E, D, H, W)
        # Apply norm on the last dimension
        x = x.view(B, E, D * H * W).permute(0, 2, 1)  # [B, D*H*W, E]
        x = self.norm(x).permute(0, 2, 1).reshape(B, E, D, H, W)  # Normalize on the embedding dimension
# Apply norm on the last dimension
        # print(f"Encoder Output Shape: {x.shape}")  # Debug
        return features, x.view(B * D * H * W, E)  # Return all intermediate features and final output

class PatchEmbedding(nn.Module):
    def __init__(self, in_channels=1, embed_dim=768, patch_size=(4, 16, 16)):
        super(PatchEmbedding, self).__init__()
        self.patch_size = patch_size
        self.proj = nn.Conv3d(
            in_channels, embed_dim,
            kernel_size=patch_size,
            stride=patch_size
        )
        self.norm = nn.LayerNorm(embed_dim)

    def forward(self, x):
        """
        Args:
            x: Tensor of shape [B, C, D, H, W]
        Returns:
            patches: Tensor of shape [B * p, E]
        """
        x = self.proj(x)  # [B, E, D', H', W']
        B, E, D, H, W = x.shape
        x = x.view(B, E, -1).permute(0, 2, 1)  # [B, p, E]
        x = self.norm(x)
        patches = x.view(B * x.shape[1], E)  # [B * p, E]
        return patches

class ViT_MAE_Decoder(nn.Module):
    def __init__(self, embed_dim=768, num_heads=12, num_layers=6, mlp_ratio=4.0, dropout=0.1):
        super(ViT_MAE_Decoder, self).__init__()
        self.transformer_layers = nn.ModuleList([
            TransformerBlock(embed_dim, num_heads, mlp_ratio, dropout)
            for _ in range(num_layers)
        ])
        self.norm = nn.LayerNorm(embed_dim)
        self.output_layer = nn.Linear(embed_dim, embed_dim)

    def forward(self, x, multi_scale_features, source, target):
        """
        Args:
            x: Tensor of shape [B * p_new, E_new]
            multi_scale_features: list of tensors from encoder
            source, target: Additional inputs for decoder
        Returns:
            reconstructed_patches, warped_src, phi
        """
        for layer in self.transformer_layers:
            x = layer(x)  # [B * p_new, E_new]

        x = self.norm(x)
        reconstructed_patches = self.output_layer(x)  # [B * p_new, E_new]

        # Placeholder for warped_src and phi
        warped_src = torch.zeros_like(reconstructed_patches)
        phi = torch.zeros_like(reconstructed_patches)

        return reconstructed_patches, warped_src, phi

class RegistrationHead(nn.Module):
    def __init__(self, in_channels=768, out_channels=3):
        super(RegistrationHead, self).__init__()
        self.conv1 = nn.Conv3d(in_channels, in_channels, kernel_size=3, padding=1)
        self.relu1 = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv3d(in_channels, in_channels, kernel_size=3, padding=1)
        self.relu2 = nn.ReLU(inplace=True)
        self.conv3 = nn.Conv3d(in_channels, out_channels, kernel_size=1)

    def forward(self, x):
        x = self.conv1(x)
        x = self.relu1(x)
        x = self.conv2(x)
        x = self.relu2(x)
        phi = self.conv3(x)
        return phi
class SpatialTransformer3D(nn.Module):
    def __init__(self, in_channels=768, out_channels=768, kernel_size=3, padding=1):
        super(SpatialTransformer3D, self).__init__()
        self.conv = nn.Conv3d(in_channels, out_channels, kernel_size, padding=padding)
        self.norm = nn.LayerNorm(out_channels)

    def forward(self, x):
        """
        Args:
            x: Tensor of shape [B, E, D, H, W]
        Returns:
            Transformed tensor
        """
        x = self.conv(x)
        # Reshape for LayerNorm: [B, D, H, W, E]
        x = x.permute(0, 2, 3, 4, 1)
        x = self.norm(x)
        # Permute back to [B, E, D, H, W]
        x = x.permute(0, 4, 1, 2, 3)
        return x
class VMAEProLoss(nn.Module):
    def __init__(self):
        super(VMAEProLoss, self).__init__()
        self.mse = nn.MSELoss()

    def forward(self, reconstructed_patches, target):
        """
        Args:
            reconstructed_patches: Tensor of shape [B * p, E]
            target: Tensor of shape [B * p, E]
        Returns:
            loss: Scalar tensor
        """
        loss = self.mse(reconstructed_patches, target)
        return loss

class VMAEProModel(pl.LightningModule):
    def __init__(
        self,
        in_channels=1,
        embed_dim=768,
        num_heads=12,
        num_encoder_layers=6,
        num_decoder_layers=6,
        mlp_ratio=4.0,
        dropout=0.1,
        num_stages=3,
        patch_size=(4, 16, 16),
        lr=1e-4
    ):
        super(VMAEProModel, self).__init__()
        self.save_hyperparameters()

        # Patch Embedding
        self.patch_embedding = PatchEmbedding(
            in_channels=in_channels,
            embed_dim=embed_dim,
            patch_size=patch_size
        )

        # Encoder
        self.encoder = HierarchicalViTEncoder(
            embed_dim=embed_dim,
            num_heads=num_heads,
            num_layers=num_encoder_layers,
            mlp_ratio=mlp_ratio,
            dropout=dropout,
            num_stages=num_stages,
            patch_size=patch_size
        )

        # Decoder
        self.decoder = ViT_MAE_Decoder(
            embed_dim=embed_dim,
            num_heads=num_heads,
            num_layers=num_decoder_layers,
            mlp_ratio=mlp_ratio,
            dropout=dropout
        )

        # Registration Head
        self.registration_head = RegistrationHead(in_channels=embed_dim)

        # Spatial Transformer
        self.spatial_transformer = SpatialTransformer3D(in_channels=embed_dim)

        # Loss
        self.criterion = VMAEProLoss()

        # Teacher Encoder (for knowledge distillation or similar purposes)
        self.teacher_encoder = HierarchicalViTEncoder(
            embed_dim=embed_dim,
            num_heads=num_heads,
            num_layers=num_encoder_layers,
            mlp_ratio=mlp_ratio,
            dropout=dropout,
            num_stages=num_stages,
            patch_size=patch_size
        )

        # Optimizer Learning Rate
        self.lr = lr

    def forward(self, visible_patches, masked_indices, source, target):
        """
        Args:
            visible_patches: Tensor of shape [B, C, D, H, W]
            masked_indices: Tensor indicating masked patches
            source, target: Additional inputs
        Returns:
            reconstructed_patches, warped_src, phi
        """
        # Patch Embedding
        embeddings = self.patch_embedding(visible_patches)  # [B * p, E]
        # print(f"Embeddings Shape: {embeddings.shape}")  # Debug

        # Encoder
        multi_scale_features, encoder_output = self.encoder(embeddings)  # [features list], [B * p_new, E_new]

        # Decoder
        reconstructed_patches, warped_src, phi = self.decoder(
            encoder_output, multi_scale_features, source, target
        )

        return reconstructed_patches, warped_src, phi

    def training_step(self, batch, batch_idx):
        """
        Args:
            batch: Tuple containing (visible_patches, masked_indices, source, target)
        Returns:
            loss
        """
        visible_patches, masked_indices, source, target = batch

        # Forward pass
        reconstructed_patches, warped_src, phi = self.forward(
            visible_patches, masked_indices, source, target
        )

        # Compute loss
        loss = self.criterion(reconstructed_patches, target)
        self.log('train_loss', loss)
        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(self.parameters(), lr=self.lr)
        return optimizer
class Dummy3DDataset(Dataset):
    def __init__(
        self,
        num_samples=1000,
        in_channels=1,
        depth=16,
        height=64,
        width=64,
        patch_size=(4, 16, 16)
    ):
        super(Dummy3DDataset, self).__init__()
        self.num_samples = num_samples
        self.in_channels = in_channels
        self.depth = depth
        self.height = height
        self.width = width
        self.patch_size = patch_size

    def __len__(self):
        return self.num_samples

    def __getitem__(self, idx):
        try:
            # Generate random 3D data
            visible_patches = torch.randn(self.in_channels, self.depth, self.height, self.width)

            # Masked indices can be random for dummy data
            masked_indices = torch.randint(0, 2, (1,))  # Placeholder

            # Source can be a random tensor
            source = torch.randn(1)

            # After PatchEmbedding with patch_size=(4,16,16):
            # D'=16//4=4, H'=64//16=4, W'=64//16=4
            # p =4*4*4=64
            # For encoder with num_stages=3, final p_new=1
            # Therefore, target should be [E] = [768]
            target = torch.randn(768)  # Single sample target without batch dimension

            return visible_patches, masked_indices, source, target
        except Exception as e:
            print(f"Error in __getitem__ at index {idx}: {e}")
            raise e
# Instantiate the dataset
dataset = Dummy3DDataset(
    num_samples=100,  # Reduced number for quick testing
    in_channels=1,
    depth=16,
    height=64,
    width=64,
    patch_size=(4, 16, 16)
)

# Create DataLoader with num_workers=0 for debugging
dataloader = DataLoader(
    dataset,
    batch_size=2,  # Must align with target's first dimension
    shuffle=True,
    num_workers=0,  # Set to 0 to debug
    pin_memory=True  # Improve performance on CUDA
)


# Instantiate the model
model = VMAEProModel(
    in_channels=1,
    embed_dim=768,
    num_heads=12,
    num_encoder_layers=6,
    num_decoder_layers=6,
    mlp_ratio=4.0,
    dropout=0.1,
    num_stages=3,
    patch_size=(4, 16, 16),
    lr=1e-4
)
# Check if CUDA is available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
from pytorch_lightning import Trainer
# Check if CUDA is available
use_cuda = torch.cuda.is_available()
print(f"Using device: {'cuda' if use_cuda else 'cpu'}")

# Initialize Trainer with updated arguments
trainer = Trainer(
    max_epochs=10,
    accelerator='gpu' if use_cuda else 'cpu',
    devices=1 if use_cuda else None,
    callbacks=[ModelSummary(max_depth=3)],
    log_every_n_steps=20  # Replaces 'progress_bar_refresh_rate'
)

# Start training
trainer.fit(model, dataloader)


INFO:pytorch_lightning.utilities.rank_zero:Trainer already configured with model summary callbacks: [<class 'pytorch_lightning.callbacks.model_summary.ModelSummary'>]. Skipping setting a default `ModelSummary` callback.
INFO:pytorch_lightning.utilities.rank_zero:GPU available: True (cuda), used: True
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs
INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Using device: cuda
Using device: cuda


INFO:pytorch_lightning.callbacks.model_summary:
   | Name                                  | Type                   | Params | Mode 
------------------------------------------------------------------------------------------
0  | patch_embedding                       | PatchEmbedding         | 788 K  | train
1  | patch_embedding.proj                  | Conv3d                 | 787 K  | train
2  | patch_embedding.norm                  | LayerNorm              | 1.5 K  | train
3  | encoder                               | HierarchicalViTEncoder | 137 M  | train
4  | encoder.transformer_layers            | ModuleList             | 127 M  | train
5  | encoder.transformer_layers.0          | TransformerBlock       | 7.1 M  | train
6  | encoder.transformer_layers.1          | TransformerBlock       | 7.1 M  | train
7  | encoder.transformer_layers.2          | TransformerBlock       | 7.1 M  | train
8  | encoder.transformer_layers.3          | TransformerBlock       | 7.1 M  | train
9  | encode

Training: |          | 0/? [00:00<?, ?it/s]