In [3]:
import os
import torch
from torch import nn
import torch.nn.functional as F
from functools import partial
from dataclasses import dataclass
from typing import Tuple
import numpy as np
from pytorch_optimizer import load_optimizer

# Enable Flash Attention optimizations
torch.backends.cuda.enable_flash_sdp(True)

from einops import rearrange, repeat
from einops.layers.torch import Rearrange

# helpers
def pair(t):
    return t if isinstance(t, tuple) else (t, t)

def triple(t):
    return t if isinstance(t, tuple) else (t, t, t)

# Configuration classes

@dataclass
class MAEConfig:
    """Centralized configuration for ViT3D MAE with ViT3DSegmentation components"""
    # Model architecture
    image_size: int = 256
    image_patch_size: int = 16
    frames: int = 256
    frame_patch_size: int = 16
    channels: int = 1
    
    # Encoder (matching ViT3DSegmentation)
    dim: int = 1024
    depth: int = 12
    heads: int = 8
    dim_head: int = 64
    mlp_dim: int = 1024
    dropout: float = 0.1
    emb_dropout: float = 0.1
    flash_attn_type: str = 'flash_attn'  # 'pytorch' or 'flash_attn'
    
    # Decoder
    decoder_dim: int = 512
    decoder_depth: int = 4
    decoder_heads: int = 8
    
    # Training
    mask_ratio: float = 0.75
    loss_type: str = 'mse'
    norm_pix_loss: bool = False
    
    # ViT3DSegmentation decoder parameters
    voxel_size: int = 4
    
    @property
    def num_patches_h(self):
        return self.image_size // self.image_patch_size
    
    @property
    def num_patches_w(self):
        return self.image_size // self.image_patch_size
    
    @property
    def num_patches_f(self):
        return self.frames // self.frame_patch_size
    
    @property
    def total_patches(self):
        return self.num_patches_h * self.num_patches_w * self.num_patches_f
    
    @property
    def patch_dim(self):
        return self.channels * self.image_patch_size * self.image_patch_size * self.frame_patch_size
    
    @property
    def output_size(self):
        return self.image_size // self.image_patch_size
    
    @property
    def intermediate_size(self):
        return self.output_size * self.voxel_size

# Flash Attention Module (same as ViT3DSegmentation)
class FlashAttention(nn.Module):
    def __init__(self, dim, heads=8, dim_head=64, dropout=0., flash_attn_type='pytorch'):
        super().__init__()
        inner_dim = dim_head * heads
        project_out = not (heads == 1 and dim_head == dim)

        self.heads = heads
        self.scale = dim_head ** -0.5
        self.flash_attn_type = flash_attn_type
        self.dropout = dropout

        self.norm = nn.LayerNorm(dim)

        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False)

        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, dim),
            nn.Dropout(dropout)
        ) if project_out else nn.Identity()

        # Try to import flash_attn if using that backend
        if flash_attn_type == 'flash_attn':
            try:
                from flash_attn import flash_attn_func
                self.flash_attn_func = flash_attn_func
                print("Using flash_attn package for attention")
            except ImportError:
                print("flash_attn package not found, falling back to PyTorch SDPA")
                self.flash_attn_type = 'pytorch'

    def forward(self, x):
        x = self.norm(x)
        qkv = self.to_qkv(x).chunk(3, dim=-1)
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=self.heads), qkv)

        # Determine dropout probability based on training mode
        dropout_p = self.dropout if self.training else 0.0

        if self.flash_attn_type == 'flash_attn' and hasattr(self, 'flash_attn_func'):
            # Use dedicated flash_attn package
            q = rearrange(q, 'b h n d -> b n h d')
            k = rearrange(k, 'b h n d -> b n h d')
            v = rearrange(v, 'b h n d -> b n h d')
            
            out = self.flash_attn_func(
                q, k, v,
                dropout_p=dropout_p,
                softmax_scale=self.scale,
                causal=False
            )
            
            out = rearrange(out, 'b n h d -> b n (h d)')
            
        else:
            # Use PyTorch's scaled_dot_product_attention (includes Flash Attention optimizations)
            out = F.scaled_dot_product_attention(
                q, k, v,
                dropout_p=dropout_p,
                scale=self.scale,
                is_causal=False
            )
            out = rearrange(out, 'b h n d -> b n (h d)')

        return self.to_out(out)

# Feed Forward Module (same as ViT3DSegmentation)
class FeedForward(nn.Module):
    def __init__(self, dim, hidden_dim, dropout=0.):
        super().__init__()
        self.net = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, dim),
            nn.Dropout(dropout)
        )
    
    def forward(self, x):
        return self.net(x)

# Transformer Block (same as ViT3DSegmentation)
class Block(nn.Module):
    def __init__(self, config: MAEConfig):
        super().__init__()
        self.attention = FlashAttention(
            dim=config.dim,
            heads=config.heads,
            dim_head=config.dim_head,
            dropout=config.dropout,
            flash_attn_type=config.flash_attn_type
        )
        self.feed_forward = FeedForward(config.dim, config.mlp_dim, config.dropout)
    
    def forward(self, x):
        x = self.attention(x) + x
        x = self.feed_forward(x) + x
        return x

# Transformer (same as ViT3DSegmentation)
class Transformer(nn.Module):
    def __init__(self, config: MAEConfig):
        super().__init__()
        self.layers = nn.ModuleList([])
        for _ in range(config.depth):
            self.layers.append(nn.ModuleList([
                FlashAttention(
                    dim=config.dim, 
                    heads=config.heads, 
                    dim_head=config.dim_head, 
                    dropout=config.dropout, 
                    flash_attn_type=config.flash_attn_type
                ),
                FeedForward(config.dim, config.mlp_dim, config.dropout)
            ]))
    
    def forward(self, x):
        for attn, ff in self.layers:
            x = attn(x) + x
            x = ff(x) + x
        return x

# Main 3D MAE Model with ViT3DSegmentation Encoder and Decoder
class ViT3DMAE_WithViTDecoder(nn.Module):
    """ 3D Masked Autoencoder using ViT3DSegmentation encoder and decoder """
    
    def __init__(self, config: MAEConfig):
        super().__init__()
        self.config = config

        # --------------------------------------------------------------------------
        # ENCODER: EXACT MATCH WITH ViT3DSegmentation
        # --------------------------------------------------------------------------
        
        # Calculate patch dimensions
        patch_dim = config.channels * config.image_patch_size * config.image_patch_size * config.frame_patch_size
        
        # Same patch embedding as ViT3DSegmentation
        self.to_patch_embedding = nn.Sequential(
            Rearrange('b c (f pf) (h p1) (w p2) -> b (f h w) (p1 p2 pf c)', 
                     p1 = config.image_patch_size, p2 = config.image_patch_size, pf = config.frame_patch_size),
            nn.LayerNorm(patch_dim),
            nn.Linear(patch_dim, config.dim),
            nn.LayerNorm(config.dim),
        )

        # Same positional embedding as ViT3DSegmentation (no cls token for segmentation)
        self.pos_embedding = nn.Parameter(torch.randn(1, config.total_patches, config.dim))
        self.dropout = nn.Dropout(config.emb_dropout)

        # Same transformer as ViT3DSegmentation
        self.transformer = Transformer(config)
        # --------------------------------------------------------------------------

        # --------------------------------------------------------------------------
        # MAE DECODER: Attention layers + ViT3DSegmentation decoder
        # --------------------------------------------------------------------------
        
        # Decoder embedding to project encoder output to decoder dimension
        self.decoder_embed = nn.Linear(config.dim, config.decoder_dim, bias=True)

        # Mask token for missing patches
        self.mask_token = nn.Parameter(torch.zeros(1, 1, config.decoder_dim))

        # Decoder positional embeddings
        self.decoder_pos_embed = nn.Parameter(torch.randn(1, config.total_patches, config.decoder_dim))
        # print(self.pos_embedding.requires_grad,self.mask_token.requires_grad)
        # Create decoder attention blocks
        decoder_config = MAEConfig()
        decoder_config.dim = config.decoder_dim
        decoder_config.depth = config.decoder_depth
        decoder_config.heads = config.decoder_heads
        decoder_config.dim_head = config.dim_head
        decoder_config.mlp_dim = config.decoder_dim * 4
        decoder_config.dropout = config.dropout
        decoder_config.flash_attn_type = config.flash_attn_type

        self.decoder_blocks = nn.ModuleList([])
        for _ in range(config.decoder_depth):
            self.decoder_blocks.append(nn.ModuleList([
                FlashAttention(
                    dim=config.decoder_dim, 
                    heads=config.decoder_heads, 
                    dim_head=config.dim_head, 
                    dropout=config.dropout, 
                    flash_attn_type=config.flash_attn_type
                ),
                FeedForward(config.decoder_dim, config.decoder_dim * 4, config.dropout)
            ]))

        self.decoder_norm = nn.LayerNorm(config.decoder_dim)
        # --------------------------------------------------------------------------
        # ViT3DSegmentation DECODER: Exact match
        # --------------------------------------------------------------------------
        # Token decoder: converts tokens to voxel representations
        self.token_decoder = nn.Linear(config.decoder_dim, config.voxel_size**3)
        # 
        # Conv decoder: processes voxel volume to final output
        self.conv_decoder = nn.Sequential(
            nn.Conv3d(1, 64, 3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv3d(64, 32, 3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv3d(32, 1, 3, padding=1)  # Output 1 channel for reconstruction
        )
        # --------------------------------------------------------------------------

        self.norm_pix_loss = config.norm_pix_loss

        self.initialize_weights()

    def initialize_weights(self):
        # Initialize tokens with small random values
        torch.nn.init.normal_(self.mask_token, std=.02)
        
        # Initialize positional embeddings with small random values
        torch.nn.init.normal_(self.pos_embedding, std=.02)
        torch.nn.init.normal_(self.decoder_pos_embed, std=.02)

        # Initialize nn.Linear and nn.LayerNorm
        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            torch.nn.init.xavier_uniform_(m.weight)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)

    def random_masking(self, x, mask_ratio):
        """
        Perform per-sample random masking by per-sample shuffling.
        """
        N, L, D = x.shape
        len_keep = int(L * (1 - mask_ratio))
        
        noise = torch.rand(N, L, device=x.device)
        
        ids_shuffle = torch.argsort(noise, dim=1)
        ids_restore = torch.argsort(ids_shuffle, dim=1)

        ids_keep = ids_shuffle[:, :len_keep]
        x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D))

        mask = torch.ones([N, L], device=x.device)
        mask[:, :len_keep] = 0
        mask = torch.gather(mask, dim=1, index=ids_restore)

        return x_masked, mask, ids_restore

    def forward_encoder(self, x, mask_ratio):
        """Forward pass through encoder - EXACT MATCH with ViT3DSegmentation"""
        # Patch embedding (same as ViT3DSegmentation)
        x = self.to_patch_embedding(x)
        b, n, _ = x.shape

        # Add positional embeddings (no cls token, same as ViT3DSegmentation)
        x = x + self.pos_embedding[:, :n]
        x = self.dropout(x)

        # Apply masking
        x, mask, ids_restore = self.random_masking(x, mask_ratio)

        # Apply transformer (same as ViT3DSegmentation)
        for attn, ff in self.transformer.layers:
            x = attn(x) + x
            x = ff(x) + x

        return x, mask, ids_restore

    def forward_decoder(self, x, ids_restore):
        """Forward pass through decoder - Attention layers + ViT3DSegmentation decoder"""
        
        # Embed tokens to decoder dimension
        x = self.decoder_embed(x)

        # Append mask tokens to sequence
        mask_tokens = self.mask_token.repeat(x.shape[0], ids_restore.shape[1] - x.shape[1], 1)
        x_ = torch.cat([x, mask_tokens], dim=1)  # no cls token
        x_ = torch.gather(x_, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2]))  # unshuffle

        # Add positional embeddings
        x = x_ + self.pos_embedding

        # Apply decoder attention blocks
        for attn, ff in self.decoder_blocks:
            x = attn(x) + x
            x = ff(x) + x
        x = self.decoder_norm(x)

        # --------------------------------------------------------------------------
        # ViT3DSegmentation Decoder Components
        # --------------------------------------------------------------------------
        
        batch_size = x.shape[0]
        
        # Apply token decoder: convert each patch token to voxel representation
        decoded_tokens = self.token_decoder(x)  # (B, num_patches, voxel_size^3)
        
        # Reshape tokens to spatial 3D arrangement
        decoded_tokens = decoded_tokens.view(
            batch_size, 
            self.config.output_size, self.config.output_size, self.config.output_size,
            self.config.voxel_size, self.config.voxel_size, self.config.voxel_size
        )
        
        # Rearrange to create intermediate volume
        intermediate_vol = decoded_tokens.permute(0, 1, 4, 2, 5, 3, 6).contiguous()
        intermediate_vol = intermediate_vol.view(
            batch_size, 1, 
            self.config.intermediate_size, 
            self.config.intermediate_size, 
            self.config.intermediate_size
        )
        
        # Apply conv decoder to process the intermediate volume
        decoded_vol = self.conv_decoder(intermediate_vol)
        
        # Interpolate back to input size
        output = F.interpolate(
            decoded_vol, 
            size=(self.config.frames, self.config.image_size, self.config.image_size),
            mode='trilinear', 
            align_corners=False
        )
        
        return output

    def patch_mask_to_volume_mask(self, patch_mask):
        """Convert patch-level mask to volume-level mask"""
        B, L = patch_mask.shape
        
        # Reshape patch mask to spatial dimensions
        mask_3d = patch_mask.view(B, self.config.num_patches_f, self.config.num_patches_h, self.config.num_patches_w)
        
        # Expand each patch to its corresponding voxel region
        mask_vol = mask_3d.unsqueeze(1)  # Add channel dimension
        mask_vol = F.interpolate(
            mask_vol.float(), 
            size=(self.config.frames, self.config.image_size, self.config.image_size),
            mode='nearest'
        )
        
        return mask_vol

    def compute_loss(self, video, pred=None, mask=None, mask_ratio=0.75, 
                    loss_type='huber', norm_pix_loss=False, return_components=False):
        """
        Compute reconstruction loss on full volume (not patches)
        """
        # Run forward pass if predictions not provided
        if pred is None or mask is None:
            _, pred, mask = self.forward(video, mask_ratio)
        
        # Target is the original video
        target = video
        
        # Convert patch mask to volume mask for loss computation
        volume_mask = self.patch_mask_to_volume_mask(mask)
        
        # Normalize target if requested
        if norm_pix_loss:
            mean = target.mean(dim=[2, 3, 4], keepdim=True)
            var = target.var(dim=[2, 3, 4], keepdim=True)
            target = (target - mean) / (var + 1e-6)**0.5
        
        # Compute loss based on type
        if loss_type == 'mse':
            loss = (pred - target) ** 2
        elif loss_type == 'l1':
            loss = torch.abs(pred - target)
        elif loss_type == 'smooth_l1':
            loss = torch.nn.functional.smooth_l1_loss(pred, target, reduction='none')
        elif loss_type == 'huber':
            loss = torch.nn.functional.huber_loss(pred, target, reduction='none', delta=1.0)
        else:
            raise ValueError(f"Unknown loss type: {loss_type}")
        
        # Apply mask - only compute loss on masked regions
        masked_loss = (loss * volume_mask).sum() / (volume_mask.sum() + 1e-8)
        
        # Also compute loss on visible regions for monitoring
        visible_loss = (loss * (1 - volume_mask)).sum() / ((1 - volume_mask).sum() + 1e-8)
        
        if return_components:
            components = {
                'total_loss': masked_loss,
                'masked_loss': masked_loss,
                'visible_loss': visible_loss,
                'mask_ratio': mask.mean(),
                'pred_std': pred.std(),
                'target_std': target.std()
            }
            return masked_loss, components
        
        return masked_loss

    def forward(self, video, mask_ratio=0.75, loss_type='mse', norm_pix_loss=False):
        """MAE forward pass with ViT3DSegmentation components"""
        latent, mask, ids_restore = self.forward_encoder(video, mask_ratio)
        pred = self.forward_decoder(latent, ids_restore)
        loss = self.compute_loss(video, pred, mask, mask_ratio, loss_type, norm_pix_loss=False)
        return loss, pred, mask

    def forward_encoder_only(self, video, mask_ratio=0.0):
        """For feature extraction without reconstruction"""
        return self.forward_encoder(video, mask_ratio)

    def get_patch_embeddings(self, video):
        """
        Return the patch embeddings before masking (same as ViT3DSegmentation)
        Useful for analysis or feature extraction
        """
        x = self.to_patch_embedding(video)
        b, n, _ = x.shape
        x = x + self.pos_embedding[:, :n]
        x = self.dropout(x)
        
        # Apply transformer (same as ViT3DSegmentation)
        for attn, ff in self.transformer.layers:
            x = attn(x) + x
            x = ff(x) + x
            
        return x


# Create a convenience function to match the original MAE config creation
def create_config_vesuvius(input_size=256, patch_size=16, mask_ratio=0.8):
    """Create MAE config matching ViT3DSegmentation parameters"""
    return MAEConfig(
        image_size=input_size,
        image_patch_size=patch_size,
        frames=input_size,
        frame_patch_size=patch_size,
        channels=1,
        dim=512,
        depth=16,
        heads=16,
        dim_head=64,
        mlp_dim=1024,
        dropout=0.1,
        emb_dropout=0.1,
        flash_attn_type='flash_attn',
        decoder_dim=512,
        decoder_depth=6,
        decoder_heads=12,
        mask_ratio=mask_ratio,
        voxel_size=8
    )



In [4]:

import glob
import torch
import torch.nn as nn
import torch.nn.functional as F
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import WandbLogger
import numpy as np
import tifffile
import wandb
from torch.utils.data import Dataset, DataLoader
from torch.optim import AdamW
import segmentation_models_pytorch as smp
from tqdm import tqdm
from pathlib import Path
from volumentations import Compose, Rotate, RandomCropFromBorders, ElasticTransform, Resize, Flip, RandomRotate90, GaussianNoise, RandomGamma,RandomBrightnessContrast,GridDistortion
from scipy import ndimage as ndi
import matplotlib.pyplot as plt
import tempfile
import os

# Import the modified MAE model
# from modified_mae_vit3d import ViT3DMAE_WithViTDecoder, create_config_vesuvius


class VesuviusFullVolumeDataset(Dataset):
    """
    Dataset for full 3D volumes from Vesuvius TIF files.
    For MAE pretraining, we only need images (no labels required).
    """
    def __init__(self, dataset_folder, target_size=256, augment=False, mae_pretraining=True):
        """
        Args:
            dataset_folder: Path to dataset folder containing imagesTr and labelsTr
            target_size: Target size for volumes (will resize if needed)
            augment: Whether to apply data augmentation
            mae_pretraining: If True, only load images (no labels needed)
        """
        self.dataset_folder = Path(dataset_folder)
        self.target_size = target_size
        self.augment = augment
        self.mae_pretraining = mae_pretraining
        
        # Initialize augmentation pipeline
        if self.augment:
            self.aug_pipeline = self._get_augmentation(target_size)
        
        # Get image paths
        self.images_dir = self.dataset_folder / "imagesTr"
        
        # Find all image files
        self.samples = []
        image_files = list(self.images_dir.glob("*.tif"))
        
        print(f"Found {len(image_files)} image files")
        
        for img_path in image_files:
            self.samples.append(img_path)
        
        print(f"Total samples for MAE pretraining: {len(self.samples)}")
    
    def _get_augmentation(self, patch_size):
        """Create volumentations augmentation pipeline"""
        return Compose([
            Rotate((-45, 45), (-45, 45), (-45, 45), p=0.1),
            Flip(0, p=0.25),
            Flip(1, p=0.25),
            Flip(2, p=0.25),
            RandomRotate90(p=0.25),
            GaussianNoise(var_limit=(0, 5), p=0.2),
            GridDistortion(num_steps=3, p=.1),
        ], p=1.0)
    
    def _load_and_preprocess_volume(self, img_path):
        """Load and preprocess a single volume"""
        # Load 3D volume
        image_vol = tifffile.imread(str(img_path))
        image_vol_dtype = image_vol.dtype
        image_vol = image_vol.astype(np.float32)

        # Resize if necessary
        if image_vol.shape[0] != self.target_size:
            image_vol = image_vol[:self.target_size, :self.target_size, :self.target_size]
        
        # Normalize image
        if image_vol_dtype == np.uint16:
            image_vol = image_vol / 65535
        else:
            image_vol = image_vol / 255
        
        return image_vol
    
    def _augment_volume(self, image_vol):
        """Apply 3D augmentations using volumentations"""
        if not self.augment:
            return image_vol
        
        # Convert to uint8 for volumentations (expects 0-255 range)
        image_vol_uint8 = (image_vol * 255).astype(np.uint8)
        
        # Apply augmentation (only to image for MAE)
        data = {'image': image_vol_uint8}
        aug_data = self.aug_pipeline(**data)
        
        # Convert back to float32 and normalize
        image_vol = aug_data['image'].astype(np.float32) / 255.0
        
        return image_vol
    
    def __len__(self):
        return len(self.samples)
    
    def __getitem__(self, idx):
        img_path = self.samples[idx]
        
        # Load and preprocess
        image_vol = self._load_and_preprocess_volume(img_path)
        
        # Apply augmentation
        image_vol = self._augment_volume(image_vol)
        
        # Convert to tensor
        image_tensor = torch.FloatTensor(image_vol).unsqueeze(0)  # Add channel dimension
        
        # For MAE, we return the same image as both input and target
        return image_tensor, image_tensor


class VesuviusMAE_ViT3D_PLModel(pl.LightningModule):
    """MAE Pretraining Model using ViT3DSegmentation encoder and decoder"""
    def __init__(self, input_size=256, patch_size=16, lr=1e-4, weight_decay=1e-4, mask_ratio=0.75):
        super(VesuviusMAE_ViT3D_PLModel, self).__init__()
        self.save_hyperparameters()
        
        # Add storage for first validation sample
        self.first_val_sample = None
        
        # Create MAE configuration matching ViT3DSegmentation
        self.mae_config = create_config_vesuvius(
            input_size=input_size,
            patch_size=patch_size,
            mask_ratio=mask_ratio
        )
        
        # Create modified MAE model with ViT3DSegmentation components
        self.model = ViT3DMAE_WithViTDecoder(self.mae_config)
        
        print(f"Modified MAE Model initialized:")
        print(f"  Input size: {input_size}x{input_size}x{input_size}")
        print(f"  Patch size: {patch_size}x{patch_size}x{patch_size}")
        print(f"  Mask ratio: {mask_ratio}")
        print(f"  Encoder dim: {self.mae_config.dim}")
        print(f"  Decoder dim: {self.mae_config.decoder_dim}")
        print(f"  Using ViT3DSegmentation encoder and decoder components")
    
    def forward(self, x, mask_ratio=None):
        """Forward pass through modified MAE"""
        if mask_ratio is None:
            mask_ratio = self.mae_config.mask_ratio
        
        loss, pred, mask = self.model(x, mask_ratio=mask_ratio)
        
        return loss, pred, mask
    
    def training_step(self, batch, batch_idx):
        x, _ = batch  # Ignore labels for MAE pretraining
        
        # MAE forward pass
        
        loss, pred, mask = self(x)
        # print(pred.max().item(),pred.min().item(),x.max().item(),x.min().item())
        # Calculate reconstruction metrics
        with torch.no_grad():
            # Calculate reconstruction quality metrics
            pred_error = ((pred - x) ** 2).mean()
            
            # Mask statistics
            mask_ratio_actual = mask.float().mean()
            
            # Volume-level metrics
            volume_mask = self.model.patch_mask_to_volume_mask(mask)
            masked_region_error = ((pred - x) ** 2 * volume_mask).sum() / (volume_mask.sum() + 1e-8)
            visible_region_error = ((pred - x) ** 2 * (1 - volume_mask)).sum() / ((1 - volume_mask).sum() + 1e-8)
        
        # Log metrics
        self.log("train/mae_loss", loss, on_step=True, on_epoch=True, prog_bar=True)
        self.log("train/recon_error", pred_error, on_step=True, on_epoch=True)
        self.log("train/masked_error", masked_region_error, on_step=True, on_epoch=True)
        self.log("train/visible_error", visible_region_error, on_step=True, on_epoch=True)
        self.log("train/mask_ratio", mask_ratio_actual, on_step=True, on_epoch=True)
        
        return loss
    
    def validation_step(self, batch, batch_idx):
        x, _ = batch  # Ignore labels for MAE pretraining
        
        # MAE forward pass
        loss, pred, mask = self(x)
        
        # Calculate reconstruction metrics
        with torch.no_grad():
            pred_error = ((pred - x) ** 2).mean()
            mask_ratio_actual = mask.float().mean()
            
            # Volume-level metrics
            volume_mask = self.model.patch_mask_to_volume_mask(mask)
            masked_region_error = ((pred - x) ** 2 * volume_mask).sum() / (volume_mask.sum() + 1e-8)
            visible_region_error = ((pred - x) ** 2 * (1 - volume_mask)).sum() / ((1 - volume_mask).sum() + 1e-8)
            
            # Reconstruction quality score (higher is better)
            ssim_score = 1 - masked_region_error.item()  # Simplified SSIM-like metric
        
        # Log metrics
        self.log("val/mae_loss", loss, on_step=False, on_epoch=True, prog_bar=True)
        self.log("val/recon_error", pred_error, on_step=False, on_epoch=True)
        self.log("val/masked_error", masked_region_error, on_step=False, on_epoch=True)
        self.log("val/visible_error", visible_region_error, on_step=False, on_epoch=True)
        self.log("val/mask_ratio", mask_ratio_actual, on_step=False, on_epoch=True)
        self.log("val/ssim_score", ssim_score, on_step=False, on_epoch=True)
        
        # Store first validation sample for visualization
        if batch_idx == 0 and self.first_val_sample is None:
            self.first_val_sample = {
                'original': x[0:1].cpu(),
                'reconstructed': pred[0:1].cpu(),
                'mask': mask[0:1].cpu(),
                'volume_mask': volume_mask[0:1].cpu()
            }
        
        return loss
    
    def visualize_mae_reconstruction_3d(self, original, reconstructed, volume_mask, num_slices=3):
        """Visualize MAE reconstruction results in 3D slices"""
        # Remove batch and channel dimensions
        original = original.squeeze().numpy()
        reconstructed = reconstructed.squeeze().numpy()
        volume_mask = volume_mask.squeeze().numpy()
        
        d, h, w = original.shape
        slice_indices = np.linspace(0, d-1, num_slices, dtype=int)
        
        fig, axes = plt.subplots(4, num_slices, figsize=(15, 10))
        if num_slices == 1:
            axes = axes.reshape(-1, 1)
        
        for i, slice_idx in enumerate(slice_indices):
            # Original
            axes[0, i].imshow(original[slice_idx], cmap='gray', vmin=0, vmax=1)
            axes[0, i].set_title(f'Original Z={slice_idx}')
            axes[0, i].axis('off')
            
            # Reconstructed
            axes[1, i].imshow(reconstructed[slice_idx], cmap='gray', vmin=0, vmax=1)
            axes[1, i].set_title(f'Reconstructed Z={slice_idx}')
            axes[1, i].axis('off')
            
            # Volume mask (masked regions in red)
            axes[2, i].imshow(volume_mask[slice_idx], cmap='gray', vmin=0, vmax=1)
            axes[2, i].set_title(f'Mask Z={slice_idx}')
            axes[2, i].axis('off')
            
            # Difference (error visualization)
            diff = np.abs(original[slice_idx] - reconstructed[slice_idx])
            axes[3, i].imshow(diff, cmap='hot', vmin=0, vmax=0.5)
            axes[3, i].set_title(f'Abs Difference Z={slice_idx}')
            axes[3, i].axis('off')
        
        plt.suptitle('MAE Reconstruction with ViT3D Components', fontsize=14)
        plt.tight_layout()
        
        with tempfile.NamedTemporaryFile(suffix='.png', delete=False) as tmp:
            plt.savefig(tmp.name, dpi=100, bbox_inches='tight')
            tmp_path = tmp.name
        plt.close()
        
        return tmp_path
    
    def on_validation_epoch_end(self):
        if self.first_val_sample is not None:
            original = self.first_val_sample['original']
            reconstructed = self.first_val_sample['reconstructed']
            mask = self.first_val_sample['mask']
            volume_mask = self.first_val_sample['volume_mask']
            
            viz_path = self.visualize_mae_reconstruction_3d(original, reconstructed, volume_mask)
            
            if self.logger and hasattr(self.logger, 'experiment'):
                self.logger.experiment.log({
                    "val/reconstruction_viz": wandb.Image(viz_path),
                    "val/mask_ratio_viz": mask.float().mean().item(),
                    "val/volume_mask_ratio": volume_mask.float().mean().item()
                })
            
            os.unlink(viz_path)
            
        self.first_val_sample = None
    
    def configure_optimizers(self):
        # Use AdamW with cosine annealing for MAE pretraining
        optimizer = AdamW(self.parameters(), lr=self.hparams.lr, weight_decay=self.hparams.weight_decay)
        
        def lr_lambda(epoch):
            warmup_epochs = 1
            if epoch < warmup_epochs:
                return epoch / warmup_epochs
            return 0.5 * (1 + np.cos(np.pi * (epoch - warmup_epochs) / (self.trainer.max_epochs - warmup_epochs)))
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=10, eta_min=1e-6)        
        return [optimizer], [scheduler]


class CFG:
    """Configuration class for modified MAE pretraining"""
    # Dataset
    dataset_folder = './Dataset091'  # Change this to your dataset path
    
    # Model
    input_size = 256      # Size of input volumes
    patch_size = 16       # Patch size for ViT
    mask_ratio = 0.75     # Masking ratio for MAE
    
    # Training
    batch_size = 8        # Small batch size due to large volumes
    num_workers = 32      # Reduced workers for large volumes
    lr = 3e-4             # Learning rate for MAE
    weight_decay = 0.00005   # Weight decay for MAE
    epochs = 500          # More epochs for pretraining
    
    # Data split
    train_ratio = 0.99    # Use more data for pretraining
    
    # Experiment
    exp_name = 'vesuvius_mae_vit3d_hybrid'
    
    # Paths
    model_dir = './outputs/vesuvius_mae_vit3d/'
    
    def __init__(self):
        os.makedirs(self.model_dir, exist_ok=True)


def split_dataset(dataset, train_ratio=0.8):
    """Split dataset into train and validation"""
    num_samples = len(dataset)
    num_train = int(num_samples * train_ratio)
    
    indices = torch.randperm(num_samples)
    train_indices = indices[:num_train]
    val_indices = indices[num_train:]
    
    train_dataset = torch.utils.data.Subset(dataset, train_indices)
    val_dataset = torch.utils.data.Subset(dataset, val_indices)
    
    return train_dataset, val_dataset


def count_dataset_stats(dataset_folder):
    """Count basic statistics about the dataset"""
    dataset_folder = Path(dataset_folder)
    images_dir = dataset_folder / "imagesTr"
    
    image_files = list(images_dir.glob("*.tif"))
    print(f"Found {len(image_files)} image files")
    
    # Sample a few files to check shapes and statistics
    for i, img_path in enumerate(image_files[:3]):
        try:
            image_vol = tifffile.imread(str(img_path))
            print(f"  {img_path.stem}:")
            print(f"    Image shape: {image_vol.shape}")
            print(f"    Image range: [{image_vol.min()}, {image_vol.max()}]")
        except Exception as e:
            print(f"    Error reading {img_path}: {e}")


def test_model_compatibility():
    """Test that the modified MAE model works correctly"""
    print("Testing modified MAE model...")
    
    # Create config and model
    config = create_config_vesuvius(input_size=256, patch_size=16, mask_ratio=0.75)
    model = ViT3DMAE_WithViTDecoder(config)
    
    # Test with dummy data
    batch_size = 1
    dummy_input = torch.randn(batch_size, 1, 256, 256, 256)
    
    print(f"Testing with input shape: {dummy_input.shape}")
    
    # Test forward pass
    with torch.no_grad():
        loss, pred, mask = model(dummy_input)
        print(f"✓ Forward pass successful")
        print(f"  Loss: {loss.item():.4f}")
        print(f"  Output shape: {pred.shape}")
        print(f"  Mask ratio: {mask.float().mean().item():.3f}")
        
        # Test encoder only
        encoded, mask_enc, ids_restore = model.forward_encoder_only(dummy_input, mask_ratio=0.75)
        print(f"✓ Encoder-only pass successful")
        print(f"  Encoded shape: {encoded.shape}")
        
        # Test patch embeddings
        patch_emb = model.get_patch_embeddings(dummy_input)
        print(f"✓ Patch embeddings successful")
        print(f"  Patch embeddings shape: {patch_emb.shape}")
    
    total_params = sum(p.numel() for p in model.parameters())
    print(f"✓ Total parameters: {total_params:,}")
    print("Model compatibility test passed!")
    return True


def main():
    """Main MAE pretraining function with ViT3D components"""
    cfg = CFG()
    
    # Count dataset statistics
    print("\nDataset statistics:")
    count_dataset_stats(cfg.dataset_folder)
    
    # Create full dataset for MAE pretraining
    print(f"\nCreating MAE dataset with target size {cfg.input_size}...")
    full_dataset = VesuviusFullVolumeDataset(
        dataset_folder=cfg.dataset_folder,
        target_size=cfg.input_size,
        augment=True,
        mae_pretraining=True  # Only load images, no labels
    )
    
    # Split into train and validation
    train_dataset, val_dataset = split_dataset(full_dataset, cfg.train_ratio)
    print(f"Training samples: {len(train_dataset)}")
    print(f"Validation samples: {len(val_dataset)}")
    
    # Create data loaders
    train_loader = DataLoader(
        train_dataset,
        batch_size=cfg.batch_size,
        shuffle=True,
        num_workers=cfg.num_workers,
        pin_memory=True,
        drop_last=True
    )
    
    val_loader = DataLoader(
        val_dataset,
        batch_size=cfg.batch_size,
        shuffle=False,
        num_workers=cfg.num_workers,
        pin_memory=True,
        drop_last=False
    ) if len(val_dataset) > 0 else None
    
    # Create modified MAE model
    model = VesuviusMAE_ViT3D_PLModel(
        input_size=cfg.input_size,
        patch_size=cfg.patch_size,
        lr=cfg.lr,
        weight_decay=cfg.weight_decay,
        mask_ratio=cfg.mask_ratio
    )
    
    # Print model info
    total_params = sum(p.numel() for p in model.parameters())
    print(f"\nModified MAE Model parameters: {total_params:,}")
    
    # Memory estimate
    batch_memory_gb = (cfg.batch_size * 1 * cfg.input_size**3 * 4) / (1024**3)
    print(f"Estimated memory per batch: {batch_memory_gb:.2f} GB")
    
    # Setup logging
    wandb_logger = WandbLogger(
        project="vesuvius_mae_vit3d_hybrid",
        name=f"{cfg.exp_name}_size_{cfg.input_size}_patch_{cfg.patch_size}_mask_{cfg.mask_ratio}"
    )
    
    # Setup callbacks
    checkpoint_callback = ModelCheckpoint(
        dirpath=cfg.model_dir,
        filename=f'vesuvius_mae_vit3d_{{epoch:02d}}_{{val_mae_loss:.4f}}',
        monitor='val/mae_loss' if val_loader else 'train/mae_loss',
        mode='min',
        save_top_k=5,
        save_last=True
    )
    
    # Create trainer
    trainer = pl.Trainer(
        max_epochs=cfg.epochs,
        accelerator="gpu" if torch.cuda.is_available() else "cpu",
        devices=1,
        logger=wandb_logger,
        callbacks=[checkpoint_callback],
        precision='16-mixed' if torch.cuda.is_available() else 32,
        gradient_clip_val=1.0,
        log_every_n_steps=10,
        accumulate_grad_batches=32,  # Gradient accumulation
    )
    
    # Train
    print(f"\nStarting modified MAE pretraining for {cfg.epochs} epochs...")
    print(f"Mask ratio: {cfg.mask_ratio}")
    print(f"Using ViT3DSegmentation encoder and decoder components")
    print(f"Effective batch size: {cfg.batch_size * 8}")  # Due to gradient accumulation
    
    trainer.fit(model, train_loader, val_loader)
    
    print("Modified MAE pretraining completed!")
    print(f"Saved model checkpoints in: {cfg.model_dir}")
    wandb.finish()


if __name__ == "__main__":
    main()


Dataset statistics:
Found 47896 image files
  chunk_43_17_12:
    Image shape: (256, 256, 256)
    Image range: [0, 255]
  chunk_2_19_21:
    Image shape: (256, 256, 256)
    Image range: [0, 255]
  1451_chunk_14_10_14:
    Image shape: (256, 256, 256)
    Image range: [0, 255]

Creating MAE dataset with target size 256...
Found 47896 image files
Total samples for MAE pretraining: 47896
Training samples: 47417
Validation samples: 479
Using flash_attn package for attention
Using flash_attn package for attention
Using flash_attn package for attention
Using flash_attn package for attention
Using flash_attn package for attention
Using flash_attn package for attention
Using flash_attn package for attention
Using flash_attn package for attention
Using flash_attn package for attention
Using flash_attn package for attention
Using flash_attn package for attention
Using flash_attn package for attention
Using flash_attn package for attention
Using flash_attn package for attention
Using flash_att

Using 16bit Automatic Mixed Precision (AMP)
  scaler = torch.cuda.amp.GradScaler()
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name  | Type                    | Params
--------------------------------------------------
0 | model | ViT3DMAE_WithViTDecoder | 79.3 M
--------------------------------------------------
79.3 M    Trainable params
0         Non-trainable params
79.3 M    Total params
317.336   Total estimated model params size (MB)



Starting modified MAE pretraining for 500 epochs...
Mask ratio: 0.75
Using ViT3DSegmentation encoder and decoder components
Effective batch size: 64


Sanity Checking: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Modified MAE pretraining completed!
Saved model checkpoints in: ./outputs/vesuvius_mae_vit3d/


Exception in thread Thread-37 (_pin_memory_loop):
Traceback (most recent call last):
  File "/usr/lib/python3.10/threading.py", line 1016, in _bootstrap_inner
    self.run()
  File "/usr/lib/python3.10/threading.py", line 953, in run
    self._target(*self._args, **self._kwargs)
  File "/usr/lib/python3/dist-packages/torch/utils/data/_utils/pin_memory.py", line 59, in _pin_memory_loop
    do_one_step()
  File "/usr/lib/python3/dist-packages/torch/utils/data/_utils/pin_memory.py", line 35, in do_one_step
    r = in_queue.get(timeout=MP_STATUS_CHECK_INTERVAL)
  File "/usr/lib/python3.10/multiprocessing/queues.py", line 122, in get
    return _ForkingPickler.loads(res)
  File "/usr/lib/python3/dist-packages/torch/multiprocessing/reductions.py", line 541, in rebuild_storage_fd
    fd = df.detach()
  File "/usr/lib/python3.10/multiprocessing/resource_sharer.py", line 57, in detach
    with _resource_sharer.get_connection(self._id) as conn:
  File "/usr/lib/python3.10/multiprocessing/resourc

0,1
epoch,▁▁▁▁▁▁▁▁▂▂▂▂▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▅▅▅▆▆▆▆▇▇▇███
train/mae_loss_epoch,█▄▃▃▂▂▁▁
train/mae_loss_step,█▅▇▇▄▄▂▃▅▅▂▃▁▃▃▃▂▂▃▃▂▄▁▂▂▂▂▂▁▂▃▂▂▂▂▃▃▁▂▂
train/mask_ratio_epoch,▁▁▁▁▁▁▁▁
train/mask_ratio_step,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train/masked_error_epoch,█▄▃▃▂▂▁▁
train/masked_error_step,█▆▇▄▄▃▁▄▄▁▃▂▃▂▁▂▃▄▂▃▂▃▃▂▁▁▂▂▃▁▂▃▂▂▃▂▂▂▂▃
train/recon_error_epoch,█▄▃▂▂▂▁▁
train/recon_error_step,▅▄█▄▃▃▃▃▃▂▄▃▂▃▃▂▂▂▃▂▂▂▃▃▂▂▁▁▂▁▂▁▂▁▁▂▂▂▂▃
train/visible_error_epoch,█▄▃▂▂▁▁▁

0,1
epoch,8.0
train/mae_loss_epoch,0.01113
train/mae_loss_step,0.01104
train/mask_ratio_epoch,0.75
train/mask_ratio_step,0.75
train/masked_error_epoch,0.02225
train/masked_error_step,0.02207
train/recon_error_epoch,0.02158
train/recon_error_step,0.02143
train/visible_error_epoch,0.01958
