In [1]:
import torch
from torch import nn
import torch.nn.functional as F
torch.backends.cuda.enable_flash_sdp(True)

from einops import rearrange, repeat
from einops.layers.torch import Rearrange
from dataclasses import dataclass
from typing import Tuple

# helpers

def pair(t):
    return t if isinstance(t, tuple) else (t, t)

def triple(t):
    return t if isinstance(t, tuple) else (t, t, t)

# Configuration classes

@dataclass
class ModelConfig:
    """Configuration for 3D ViT model parameters"""
    image_size: int = 64
    image_patch_size: int = 16
    frames: int = 16
    frame_patch_size: int = 4
    num_classes: int = 2
    dim: int = 384
    depth: int = 6
    heads: int = 6
    mlp_dim: int = 1024
    channels: int = 1
    dim_head: int = 64
    dropout: float = 0.1
    emb_dropout: float = 0.1
    flash_attn_type: str = 'pytorch'
    
    @property
    def num_patches_h(self):
        return self.image_size // self.image_patch_size
    
    @property
    def num_patches_w(self):
        return self.image_size // self.image_patch_size
    
    @property
    def num_patches_f(self):
        return self.frames // self.frame_patch_size
    
    @property
    def total_patches(self):
        return self.num_patches_h * self.num_patches_w * self.num_patches_f
    
    @property
    def expected_output_shape(self):
        """Expected output shape for given batch size"""
        return lambda batch_size: (batch_size, self.num_classes, self.num_patches_f, self.num_patches_h, self.num_patches_w)

@dataclass
class TestConfig:
    """Configuration for test data and benchmarking"""
    batch_size: int = 2
    benchmark_iterations: int = 20
    warmup_iterations: int = 5
    benchmark_seq_length: int = 1024
    # Training test parameters
    train_steps: int = 5
    learning_rate: float = 1e-4
    use_mixed_precision: bool = True
    
    def get_test_input_shape(self, config: ModelConfig):
        """Get test input shape based on model config"""
        return (self.batch_size, config.channels, config.frames, config.image_size, config.image_size)

class FeedForward(nn.Module):
    def __init__(self, dim, hidden_dim, dropout = 0.):
        super().__init__()
        self.net = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, dim),
            nn.Dropout(dropout)
        )
    def forward(self, x):
        return self.net(x)

class FlashAttention(nn.Module):
    def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0., flash_attn_type='pytorch'):
        super().__init__()
        inner_dim = dim_head * heads
        project_out = not (heads == 1 and dim_head == dim)

        self.heads = heads
        self.scale = dim_head ** -0.5
        self.flash_attn_type = flash_attn_type
        self.dropout = dropout

        self.norm = nn.LayerNorm(dim)

        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)

        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, dim),
            nn.Dropout(dropout)
        ) if project_out else nn.Identity()

        # Try to import flash_attn if using that backend
        if flash_attn_type == 'flash_attn':
            try:
                from flash_attn import flash_attn_func
                self.flash_attn_func = flash_attn_func
                print("Using flash_attn package for attention")
            except ImportError:
                print("flash_attn package not found, falling back to PyTorch SDPA")
                self.flash_attn_type = 'pytorch'

    def forward(self, x):
        x = self.norm(x)
        qkv = self.to_qkv(x).chunk(3, dim = -1)
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)

        # Determine dropout probability based on training mode
        dropout_p = self.dropout if self.training else 0.0

        if self.flash_attn_type == 'flash_attn' and hasattr(self, 'flash_attn_func'):
            # Use dedicated flash_attn package
            # Rearrange for flash_attn: (batch, seqlen, nheads, headdim)
            q = rearrange(q, 'b h n d -> b n h d')
            k = rearrange(k, 'b h n d -> b n h d')
            v = rearrange(v, 'b h n d -> b n h d')
            
            out = self.flash_attn_func(
                q, k, v,
                dropout_p=dropout_p,
                softmax_scale=self.scale,
                causal=False
            )
            
            # Rearrange back: (batch, seqlen, nheads, headdim) -> (batch, seqlen, nheads * headdim)
            out = rearrange(out, 'b n h d -> b n (h d)')
            
        else:
            # Use PyTorch's scaled_dot_product_attention (includes Flash Attention optimizations)
            out = F.scaled_dot_product_attention(
                q, k, v,
                dropout_p=dropout_p,
                scale=self.scale,
                is_causal=False
            )
            out = rearrange(out, 'b h n d -> b n (h d)')

        return self.to_out(out)

class Transformer(nn.Module):
    def __init__(self, config: ModelConfig):
        super().__init__()
        self.layers = nn.ModuleList([])
        for _ in range(config.depth):
            self.layers.append(nn.ModuleList([
                FlashAttention(
                    dim=config.dim, 
                    heads=config.heads, 
                    dim_head=config.dim_head, 
                    dropout=config.dropout, 
                    flash_attn_type=config.flash_attn_type
                ),
                FeedForward(config.dim, config.mlp_dim, config.dropout)
            ]))
    def forward(self, x):
        for attn, ff in self.layers:
            x = attn(x) + x
            x = ff(x) + x
        return x

class ViT3DSegmentation(nn.Module):
    def __init__(self, config: ModelConfig):
        super().__init__()
        self.config = config
        
        # Validate configuration
        assert config.image_size % config.image_patch_size == 0, 'Image dimensions must be divisible by the patch size.'
        assert config.frames % config.frame_patch_size == 0, 'Frames must be divisible by frame patch size'

        # Calculate patch dimensions
        patch_dim = config.channels * config.image_patch_size * config.image_patch_size * config.frame_patch_size

        self.to_patch_embedding = nn.Sequential(
            Rearrange('b c (f pf) (h p1) (w p2) -> b (f h w) (p1 p2 pf c)', 
                     p1 = config.image_patch_size, p2 = config.image_patch_size, pf = config.frame_patch_size),
            nn.LayerNorm(patch_dim),
            nn.Linear(patch_dim, config.dim),
            nn.LayerNorm(config.dim),
        )

        # No cls token for segmentation - only positional embeddings for patches
        self.pos_embedding = nn.Parameter(torch.randn(1, config.total_patches, config.dim))
        self.dropout = nn.Dropout(config.emb_dropout)

        self.transformer = Transformer(config)

        # Segmentation head - applies to each patch token
        self.segmentation_head = nn.Sequential(
            nn.LayerNorm(config.dim),
            nn.Linear(config.dim, config.num_classes)
        )

    def forward(self, video):
        # video shape: (batch, channels, frames, height, width)
        x = self.to_patch_embedding(video)
        b, n, _ = x.shape

        # Add positional embeddings (no cls token)
        x += self.pos_embedding[:, :n]
        x = self.dropout(x)

        # Process through transformer
        x = self.transformer(x)

        # Apply segmentation head to each patch token
        x = self.segmentation_head(x)  # Shape: (batch, num_patches, num_classes)

        # Reshape back to spatial dimensions
        x = rearrange(x, 'b (f h w) c -> b c f h w', 
                     f=self.config.num_patches_f, h=self.config.num_patches_h, w=self.config.num_patches_w)

        return x

    def get_patch_embeddings(self, video):
        """
        Return the patch embeddings before the segmentation head
        Useful for analysis or feature extraction
        """
        x = self.to_patch_embedding(video)
        b, n, _ = x.shape
        x += self.pos_embedding[:, :n]
        x = self.dropout(x)
        x = self.transformer(x)
        return x

In [2]:


# import torch
from torch import nn
import torch.nn.functional as F
torch.backends.cuda.enable_flash_sdp(True)

from einops import rearrange, repeat
from einops.layers.torch import Rearrange
from dataclasses import dataclass
from typing import Tuple

# helpers

def pair(t):
    return t if isinstance(t, tuple) else (t, t)

def triple(t):
    return t if isinstance(t, tuple) else (t, t, t)

# Configuration classes

@dataclass
class ModelConfig:
    """Configuration for 3D ViT model parameters"""
    image_size: int = 64
    image_patch_size: int = 16
    frames: int = 16
    frame_patch_size: int = 4
    num_classes: int = 2
    dim: int = 384
    depth: int = 6
    heads: int = 6
    mlp_dim: int = 1024
    channels: int = 1
    dim_head: int = 64
    dropout: float = 0.1
    emb_dropout: float = 0.1
    flash_attn_type: str = 'pytorch'
    
    @property
    def num_patches_h(self):
        return self.image_size // self.image_patch_size
    
    @property
    def num_patches_w(self):
        return self.image_size // self.image_patch_size
    
    @property
    def num_patches_f(self):
        return self.frames // self.frame_patch_size
    
    @property
    def total_patches(self):
        return self.num_patches_h * self.num_patches_w * self.num_patches_f
    
    @property
    def expected_output_shape(self):
        """Expected output shape for given batch size"""
        return lambda batch_size: (batch_size, self.num_classes, self.num_patches_f, self.num_patches_h, self.num_patches_w)

@dataclass
class TestConfig:
    """Configuration for test data and benchmarking"""
    batch_size: int = 2
    benchmark_iterations: int = 20
    warmup_iterations: int = 5
    benchmark_seq_length: int = 1024
    # Training test parameters
    train_steps: int = 5
    learning_rate: float = 1e-4
    use_mixed_precision: bool = True
    
    def get_test_input_shape(self, config: ModelConfig):
        """Get test input shape based on model config"""
        return (self.batch_size, config.channels, config.frames, config.image_size, config.image_size)


class FeedForward(nn.Module):
    def __init__(self, dim, hidden_dim, dropout = 0.):
        super().__init__()
        self.net = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, dim),
            nn.Dropout(dropout)
        )
    def forward(self, x):
        return self.net(x)

class FlashAttention(nn.Module):
    def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0., flash_attn_type='pytorch'):
        super().__init__()
        inner_dim = dim_head * heads
        project_out = not (heads == 1 and dim_head == dim)

        self.heads = heads
        self.scale = dim_head ** -0.5
        self.flash_attn_type = flash_attn_type
        self.dropout = dropout

        self.norm = nn.LayerNorm(dim)

        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)

        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, dim),
            nn.Dropout(dropout)
        ) if project_out else nn.Identity()

        # Try to import flash_attn if using that backend
        if flash_attn_type == 'flash_attn':
            try:
                from flash_attn import flash_attn_func
                self.flash_attn_func = flash_attn_func
                print("Using flash_attn package for attention")
            except ImportError:
                print("flash_attn package not found, falling back to PyTorch SDPA")
                self.flash_attn_type = 'pytorch'

    def forward(self, x):
        x = self.norm(x)
        qkv = self.to_qkv(x).chunk(3, dim = -1)
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)

        # Determine dropout probability based on training mode
        dropout_p = self.dropout if self.training else 0.0

        if self.flash_attn_type == 'flash_attn' and hasattr(self, 'flash_attn_func'):
            # Use dedicated flash_attn package
            # Rearrange for flash_attn: (batch, seqlen, nheads, headdim)
            q = rearrange(q, 'b h n d -> b n h d')
            k = rearrange(k, 'b h n d -> b n h d')
            v = rearrange(v, 'b h n d -> b n h d')
            
            out = self.flash_attn_func(
                q, k, v,
                dropout_p=dropout_p,
                softmax_scale=self.scale,
                causal=False
            )
            
            # Rearrange back: (batch, seqlen, nheads, headdim) -> (batch, seqlen, nheads * headdim)
            out = rearrange(out, 'b n h d -> b n (h d)')
            
        else:
            # Use PyTorch's scaled_dot_product_attention (includes Flash Attention optimizations)
            out = F.scaled_dot_product_attention(
                q, k, v,
                dropout_p=dropout_p,
                scale=self.scale,
                is_causal=False
            )
            out = rearrange(out, 'b h n d -> b n (h d)')

        return self.to_out(out)

class Transformer(nn.Module):
    def __init__(self, config: ModelConfig):
        super().__init__()
        self.layers = nn.ModuleList([])
        for _ in range(config.depth):
            self.layers.append(nn.ModuleList([
                FlashAttention(
                    dim=config.dim, 
                    heads=config.heads, 
                    dim_head=config.dim_head, 
                    dropout=config.dropout, 
                    flash_attn_type=config.flash_attn_type
                ),
                FeedForward(config.dim, config.mlp_dim, config.dropout)
            ]))
    def forward(self, x):
        for attn, ff in self.layers:
            x = attn(x) + x
            x = ff(x) + x
        return x

class ViT3DSegmentation(nn.Module):
    def __init__(self, config: ModelConfig):
        super().__init__()
        self.config = config
        
        # Validate configuration
        assert config.image_size % config.image_patch_size == 0, 'Image dimensions must be divisible by the patch size.'
        assert config.frames % config.frame_patch_size == 0, 'Frames must be divisible by frame patch size'

        # Calculate patch dimensions
        patch_dim = config.channels * config.image_patch_size * config.image_patch_size * config.frame_patch_size

        self.to_patch_embedding = nn.Sequential(
            Rearrange('b c (f pf) (h p1) (w p2) -> b (f h w) (p1 p2 pf c)', 
                     p1 = config.image_patch_size, p2 = config.image_patch_size, pf = config.frame_patch_size),
            nn.LayerNorm(patch_dim),
            nn.Linear(patch_dim, config.dim),
            nn.LayerNorm(config.dim),
        )

        # No cls token for segmentation - only positional embeddings for patches
        self.pos_embedding = nn.Parameter(torch.randn(1, config.total_patches, config.dim))
        self.dropout = nn.Dropout(config.emb_dropout)

        self.transformer = Transformer(config)

        # Segmentation head - applies to each patch token
        self.segmentation_head = nn.Sequential(
            nn.LayerNorm(config.dim),
            nn.Linear(config.dim, config.num_classes)
        )

    def forward(self, video):
        # video shape: (batch, channels, frames, height, width)
        x = self.to_patch_embedding(video)
        b, n, _ = x.shape

        # Add positional embeddings (no cls token)
        x += self.pos_embedding[:, :n]
        x = self.dropout(x)

        # Process through transformer
        x = self.transformer(x)

        # Apply segmentation head to each patch token
        x = self.segmentation_head(x)  # Shape: (batch, num_patches, num_classes)

        # Reshape back to spatial dimensions
        x = rearrange(x, 'b (f h w) c -> b c f h w', 
                     f=self.config.num_patches_f, h=self.config.num_patches_h, w=self.config.num_patches_w)

        return x

    def get_patch_embeddings(self, video):
        """
        Return the patch embeddings before the segmentation head
        Useful for analysis or feature extraction
        """
        x = self.to_patch_embedding(video)
        b, n, _ = x.shape
        x += self.pos_embedding[:, :n]
        x = self.dropout(x)
        x = self.transformer(x)
        return x
import os
import glob
import torch
import torch.nn as nn
import torch.nn.functional as F
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import WandbLogger
import numpy as np
import tifffile
import wandb
from torch.utils.data import Dataset, DataLoader
from torch.optim import AdamW
import segmentation_models_pytorch as smp
from tqdm import tqdm
from pathlib import Path
from volumentations import Compose, Rotate, RandomCropFromBorders, ElasticTransform, Resize, Flip, RandomRotate90, GaussianNoise, RandomGamma,RandomBrightnessContrast,GridDistortion
from scipy import ndimage as ndi


class VesuviusFullVolumeDataset(Dataset):
    """
    Dataset for full 3D volumes from Vesuvius TIF files.
    Each file becomes a single sample.
    """
    def __init__(self, dataset_folder, target_size=320, augment=False):
        """
        Args:
            dataset_folder: Path to dataset folder containing imagesTr and labelsTr
            target_size: Target size for volumes (will resize if needed)
            augment: Whether to apply data augmentation
        """
        self.dataset_folder = Path(dataset_folder)
        self.target_size = target_size
        self.augment = augment
        
        # Initialize augmentation pipeline
        if self.augment:
            self.aug_pipeline = self._get_augmentation(target_size)
        
        # Get image and label paths
        self.images_dir = self.dataset_folder / "imagesTr"
        self.labels_dir = self.dataset_folder / "labelsTr"
        
        # Find all image files and their corresponding labels
        self.samples = []
        image_files = list(self.images_dir.glob("*_0000.tif"))
        
        print(f"Found {len(image_files)} image files")
        
        for img_path in image_files:
            # Get corresponding label file
            base_name = img_path.stem.replace("_0000", "")
            label_path = self.labels_dir / f"{base_name}.tif"
            
            if label_path.exists():
                self.samples.append((img_path, label_path))
            else:
                print(f"Warning: Label file not found for {img_path}")
        
        print(f"Total valid samples: {len(self.samples)}")
    
    def _get_augmentation(self, patch_size):
        """Create volumentations augmentation pipeline"""
        return Compose([
            Rotate((-45, 45), (-45, 45), (-45, 45), p=0.1),
            # RandomCropFromBorders(crop_value=0.1, p=0.5),
            # ElasticTransform((0, 0.25), interpolation=2, p=0.1),
            # Resize((patch_size, patch_size, patch_size), interpolation=1, resize_type=0, always_apply=True, p=1.0),
            Flip(0, p=0.25),
            Flip(1, p=0.25),
            Flip(2, p=0.25),
            RandomRotate90(p=0.25),

            
            #     RandomBrightnessContrast(brightness_limit=0.0,
        # contrast_limit=0.0,p=.99),
            GaussianNoise(var_limit=(0, 5), p=0.2),
            GridDistortion(num_steps=3,p=.1),
            # GaussianNoise(var_limit=(0, 5), p=0.2),
            # RandomGamma(gamma_limit=(80, 120), p=0.2),
        ], p=1.0)
    
    def _load_and_preprocess_volume(self, img_path, label_path):
        """Load and preprocess a single volume pair"""
        # Load 3D volumes
        image_vol = tifffile.imread(str(img_path))
        image_vol_dtype=image_vol.dtype
        image_vol=image_vol.astype(np.float32)
        label_vol = tifffile.imread(str(label_path)).astype(np.float32)
        # diamond = np.ones((2,2,2)).astype(bool)
        # # dilate 1x with it
        # label_vol = ndi.binary_dilation(label_vol, diamond, iterations=1)

        # Ensure volumes have same shape
        assert image_vol.shape == label_vol.shape, f"Shape mismatch: {image_vol.shape} vs {label_vol.shape}"
        
        # Resize if necessary using torch for consistency
        if image_vol.shape[0] != self.target_size:
            image_vol=image_vol[:self.target_size,:self.target_size,:self.target_size]
            label_vol=label_vol[:self.target_size,:self.target_size,:self.target_size]

        
        # Normalize image
        if image_vol_dtype==np.uint16:
            image_vol = image_vol / 65535
        else:
            image_vol = image_vol / 255
        # Ensure labels are binary
        label_vol = (label_vol > 0).astype(np.float32)
        
        return image_vol, label_vol
    
    def _augment_volume(self, image_vol, label_vol):
        """Apply 3D augmentations using volumentations"""
        if not self.augment:
            return image_vol, label_vol
        
        # Convert to uint8 for volumentations (expects 0-255 range)
        image_vol_uint8 = (image_vol * 255).astype(np.uint8)
        label_vol_uint8 = label_vol.astype(np.uint8)
        
        # Apply augmentation
        data = {'image': image_vol_uint8, 'mask': label_vol_uint8}
        aug_data = self.aug_pipeline(**data)
        
        # Convert back to float32 and normalize
        image_vol = aug_data['image'].astype(np.float32) / 255.0
        label_vol = aug_data['mask'].astype(np.float32) 
        
        return image_vol, label_vol
    
    def __len__(self):
        return len(self.samples)
    
    def __getitem__(self, idx):
        img_path, label_path = self.samples[idx]
        
        # Load and preprocess
        image_vol, label_vol = self._load_and_preprocess_volume(img_path, label_path)
        # Apply augmentation
        image_vol, label_vol = self._augment_volume(image_vol, label_vol)
        # Convert to tensors
        # plt.imshow(image_vol[0])
        # plt.show()
        image_tensor = torch.FloatTensor(image_vol).unsqueeze(0)  # Add channel dimension
        label_tensor = torch.FloatTensor(label_vol)
        
        return image_tensor, label_tensor


import matplotlib.pyplot as plt
import tempfile
from pathlib import Path
from dataclasses import dataclass
@dataclass
class ModelConfig:
    """Configuration for 3D ViT model parameters"""
    image_size: int = 256
    image_patch_size: int = 16
    frames: int = 256
    frame_patch_size: int = 16
    num_classes: int = 2
    dim: int = 512
    depth: int = 16
    heads: int = 16
    mlp_dim: int = 1024
    channels: int = 1
    dim_head: int = 64
    dropout: float = 0.1
    emb_dropout: float = 0.1
    flash_attn_type: str = 'pytorch'
     
    @property
    def num_patches_h(self):
        return self.image_size // self.image_patch_size
    
    @property
    def num_patches_w(self):
        return self.image_size // self.image_patch_size
    
    @property
    def num_patches_f(self):
        return self.frames // self.frame_patch_size
    
    @property
    def total_patches(self):
        return self.num_patches_h * self.num_patches_w * self.num_patches_f
    
    @property
    def expected_output_shape(self):
        """Expected output shape for given batch size"""
        return lambda batch_size: (batch_size, self.num_classes, self.num_patches_f, self.num_patches_h, self.num_patches_w)
def load_mae_encoder_to_vit3d(mae_checkpoint_path, vit3d_model, device='cuda'):
    checkpoint = torch.load(mae_checkpoint_path, map_location=device)
    mae_state_dict = checkpoint['state_dict']
    
    vit3d_state_dict = vit3d_model.state_dict()
    
    loaded_keys = []
    shape_mismatches = []
    
    for mae_key in mae_state_dict:
        if mae_key.startswith('model.') and not any(x in mae_key for x in ['decoder', 'mask_token']):
            vit3d_key = mae_key.replace('model.', '')
            if vit3d_key in vit3d_state_dict:
                if mae_state_dict[mae_key].shape == vit3d_state_dict[vit3d_key].shape:
                    vit3d_state_dict[vit3d_key] = mae_state_dict[mae_key]
                    loaded_keys.append(vit3d_key)
                else:
                    shape_mismatches.append((vit3d_key, mae_state_dict[mae_key].shape, vit3d_state_dict[vit3d_key].shape))
    
    if shape_mismatches:
        print(f"\nShape mismatches found:")
        for key, mae_shape, vit3d_shape in shape_mismatches[:5]:
            print(f"{key}: MAE {mae_shape} != ViT3D {vit3d_shape}")
        print(f"... and {len(shape_mismatches)-5} more")
    
    vit3d_model.load_state_dict(vit3d_state_dict, strict=False)
    return loaded_keys
class VesuviusViT3DPLModel(pl.LightningModule):
    def __init__(self, input_size=320, patch_size=16, num_classes=512, lr=1e-4, weight_decay=1e-4):
        super(VesuviusViT3DPLModel, self).__init__()
        self.save_hyperparameters()
        
        # Add storage for first validation sample
        self.first_val_sample = None
        
        # Calculate patch dimensions
        self.patch_size = patch_size
        self.input_size = input_size
        self.output_size = input_size // patch_size
        self.voxel_size = 4
        self.intermediate_size = self.output_size * self.voxel_size
        dim=512

        self.model = ViT3DSegmentation(
                    ModelConfig(
            image_size=input_size,
            image_patch_size=patch_size,
            frames=input_size,
            frame_patch_size=patch_size,
            num_classes=dim,
            dim=dim,
            depth=16,
            heads=16,
            mlp_dim=1024,
            flash_attn_type='flash_attn'
        )
        )

        # vit3d_model = ViT3DSegmentation(config)
        loaded = load_mae_encoder_to_vit3d('outputs/vesuvius_mae_vit3d/vesuvius_mae_vit3d_epoch=193_val_mae_loss=0.0000.ckpt', self.model)
        # for param in self.model.parameters():
        #     param.requires_grad = False
        print(f"Loaded {len(loaded)} encoder weights from MAE")
        # print(loaded)
        self.token_decoder = nn.Linear(dim, self.voxel_size**3)
        
        self.conv_decoder = nn.Sequential(
            nn.Conv3d(1, 64, 3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv3d(64, 32, 3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv3d(32, num_classes, 3, padding=1)
        )
        
        self.loss_func1 = smp.losses.DiceLoss(mode='binary')
        self.loss_func2 = smp.losses.SoftBCEWithLogitsLoss(smooth_factor=0.15)
        self.loss_func = lambda x, y: 0.5 * self.loss_func1(x, y) + 0.5 * self.loss_func2(x, y)
        # self.loss_func=self.loss_func1
        # self.loss_func=VesuviusLoss()
        print(f"Model initialized:")
        print(f"  Input size: {input_size}x{input_size}x{input_size}")
        print(f"  Patch size: {patch_size}x{patch_size}x{patch_size}")
        print(f"  Token grid: {self.output_size}x{self.output_size}x{self.output_size}")
        print(f"  Intermediate size: {self.intermediate_size}x{self.intermediate_size}x{self.intermediate_size}")
        print(f"  Voxel size per token: {self.voxel_size}x{self.voxel_size}x{self.voxel_size}")
    
    def forward(self, x):
        batch_size = x.shape[0]
        # with torch.no_grad():
        embeddings = self.model(x)
        
        if embeddings.dim() == 5:
            embeddings = embeddings.view(batch_size, embeddings.shape[1], -1).transpose(1, 2)
        elif embeddings.dim() == 3:
            pass
        else:
            raise ValueError(f"Unexpected ViT output shape: {embeddings.shape}")
        
        num_tokens = embeddings.shape[1]
        expected_tokens = self.output_size ** 3
        
        if num_tokens != expected_tokens:
            print(f"Warning: Expected {expected_tokens} tokens, got {num_tokens}")
        
        decoded_tokens = self.token_decoder(embeddings)
        
        decoded_tokens = decoded_tokens.view(
            batch_size, 
            self.output_size, self.output_size, self.output_size,
            self.voxel_size, self.voxel_size, self.voxel_size
        )
        
        intermediate_vol = decoded_tokens.permute(0, 1, 4, 2, 5, 3, 6).contiguous()
        intermediate_vol = intermediate_vol.view(
            batch_size, 1, 
            self.intermediate_size, 
            self.intermediate_size, 
            self.intermediate_size
        )
        
        decoded_vol = self.conv_decoder(intermediate_vol)
        
        output = F.interpolate(
            decoded_vol, 
            size=(self.input_size, self.input_size, self.input_size),
            mode='trilinear', 
            align_corners=False
        )
        
        return output
    
    def training_step(self, batch, batch_idx):
        x, y = batch
        
        outputs = self(x)
        
        if y.dim() == 4:
            y = y.unsqueeze(1)
        
        loss = self.loss_func(outputs, y)
        
        predictions = torch.sigmoid(outputs) > 0.5
        accuracy = (predictions == (y > 0.5)).float().mean()
        
        intersection = (predictions * (y > 0.5)).sum()
        union = (predictions + (y > 0.5)).clamp(0, 1).sum()
        iou = intersection / (union + 1e-8)
        
        self.log("train/loss", loss, on_step=True, on_epoch=True, prog_bar=True)
        self.log("train/accuracy", accuracy, on_step=True, on_epoch=True, prog_bar=True)
        self.log("train/iou", iou, on_step=True, on_epoch=True, prog_bar=True)
        
        return loss
    
    def validation_step(self, batch, batch_idx):
        x, y = batch
        outputs = self(x)
        
        if y.dim() == 4:
            y = y.unsqueeze(1)
        
        loss = self.loss_func(outputs, y)
        
        predictions = torch.sigmoid(outputs) > 0.5
        accuracy = (predictions == (y > 0.5)).float().mean()
        
        intersection = (predictions * (y > 0.5)).sum()
        union = (predictions + (y > 0.5)).clamp(0, 1).sum()
        iou = intersection / (union + 1e-8)
        
        self.log("val/loss", loss, on_step=False, on_epoch=True, prog_bar=True)
        self.log("val/accuracy", accuracy, on_step=False, on_epoch=True, prog_bar=True)
        self.log("val/iou", iou, on_step=False, on_epoch=True, prog_bar=True)
        
        # Store first validation sample for visualization
        if batch_idx == 0 and self.first_val_sample is None:
            self.first_val_sample = {
                'image': x[0:1].cpu(),
                'label': y[0:1].cpu(), 
                'prediction': outputs[0:1].detach().cpu()
            }
        
        return loss
    
    def visualize_3d_slices(self, image_vol, label_vol, pred_logits, pred_probs, num_slices=3):
        d, h, w = image_vol.shape
        slice_indices = np.linspace(0, d-1, num_slices, dtype=int)
        
        fig, axes = plt.subplots(4, num_slices, figsize=(12, 8))
        if num_slices == 1:
            axes = axes.reshape(-1, 1)
        
        for i, slice_idx in enumerate(slice_indices):
            axes[0, i].imshow(image_vol[slice_idx], cmap='gray')
            axes[0, i].set_title(f'Image Z={slice_idx}')
            axes[0, i].axis('off')
            
            axes[1, i].imshow(label_vol[slice_idx], cmap='Reds', vmin=0, vmax=1)
            axes[1, i].set_title(f'Ground Truth Z={slice_idx}')
            axes[1, i].axis('off')
            
            axes[2, i].imshow(pred_logits[slice_idx], cmap='RdYlBu_r')
            axes[2, i].set_title(f'Pred Logits Z={slice_idx}')
            axes[2, i].axis('off')
            
            axes[3, i].imshow(pred_probs[slice_idx], cmap='Reds', vmin=0, vmax=1)
            axes[3, i].set_title(f'Pred Probs Z={slice_idx}')
            axes[3, i].axis('off')
        
        plt.tight_layout()
        
        with tempfile.NamedTemporaryFile(suffix='.png', delete=False) as tmp:
            plt.savefig(tmp.name, dpi=100, bbox_inches='tight')
            tmp_path = tmp.name
        plt.close()
        
        return tmp_path
    
    def on_validation_epoch_end(self):
        if self.first_val_sample is not None:
            image = self.first_val_sample['image'].squeeze().numpy()
            label = self.first_val_sample['label'].squeeze().numpy()
            pred_logits = self.first_val_sample['prediction'].squeeze().numpy()
            pred_probs = torch.sigmoid(self.first_val_sample['prediction']).squeeze().numpy()
            
            viz_path = self.visualize_3d_slices(image, label, pred_logits, pred_probs)
            
            if self.logger and hasattr(self.logger, 'experiment'):
                self.logger.experiment.log({
                    "val/slice_visualization": wandb.Image(viz_path)
                })
            
            os.unlink(viz_path)
            
        self.first_val_sample = None
    
    def configure_optimizers(self):
        # optimizer = AdamW(list(self.token_decoder.parameters())+list(self.conv_decoder.parameters()), lr=self.hparams.lr, weight_decay=self.hparams.weight_decay)
        optimizer = AdamW(self.parameters(), lr=self.hparams.lr, weight_decay=self.hparams.weight_decay)
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=35, eta_min=1e-6)
        return [optimizer], [scheduler]


class CFG:
    """Configuration class"""
    # Dataset
    dataset_folder = './Dataset090_CEDCrops/'  # Change this to your dataset path
    
    # Model
    input_size = 256      # Size of input volumes
    patch_size = 16       # Patch size for ViT
    num_classes = 1       # Binary segmentation
    
    # Training
    batch_size = 2        # Small batch size due to large volumes
    num_workers = 32      # Reduced workers for large volumes
    lr = 3e-4
    weight_decay = 2e-6
    epochs = 500
    
    # Data split
    train_ratio = 0.99
    
    # Experiment
    exp_name = 'vesuvius_vit3d_full_volume'
    
    # Paths
    model_dir = './outputs/vesuvius_vit3d/'
    
    def __init__(self):
        os.makedirs(self.model_dir, exist_ok=True)


def split_dataset(dataset, train_ratio=0.8):
    """Split dataset into train and validation"""
    num_samples = len(dataset)
    num_train = int(num_samples * train_ratio)
    
    indices = torch.randperm(num_samples)
    train_indices = indices[:num_train]
    val_indices = indices[num_train:]
    
    train_dataset = torch.utils.data.Subset(dataset, train_indices)
    val_dataset = torch.utils.data.Subset(dataset, val_indices)
    
    return train_dataset, val_dataset


def count_dataset_stats(dataset_folder):
    """Count basic statistics about the dataset"""
    dataset_folder = Path(dataset_folder)
    images_dir = dataset_folder / "imagesTr"
    labels_dir = dataset_folder / "labelsTr"
    
    image_files = list(images_dir.glob("*_0000.tif"))
    print(f"Found {len(image_files)} image files")
    
    # Sample a few files to check shapes and statistics
    for i, img_path in enumerate(image_files[:3]):
        base_name = img_path.stem.replace("_0000", "")
        label_path = labels_dir / f"{base_name}.tif"
        
        if label_path.exists():
            try:
                image_vol = tifffile.imread(str(img_path))
                label_vol = tifffile.imread(str(label_path))
                
                positive_ratio = (label_vol > 0).mean()
                print(f"  {base_name}:")
                print(f"    Image shape: {image_vol.shape}")
                print(f"    Label shape: {label_vol.shape}")
                print(f"    Image range: [{image_vol.min()}, {image_vol.max()}]")
                print(f"    Positive ratio: {positive_ratio:.4f}")
                
            except Exception as e:
                print(f"    Error reading {img_path}: {e}")


def main():
    """Main training function"""
    cfg = CFG()
    
    print("=" * 60)
    print("Vesuvius 3D ViT Training Pipeline - Full Volumes")
    print("=" * 60)
    
    # Check dataset
    if not os.path.exists(cfg.dataset_folder):
        print(f"Error: Dataset folder not found at {cfg.dataset_folder}")
        print("Please update CFG.dataset_folder to point to your dataset")
        return
    
    # Count dataset statistics
    print("Dataset statistics:")
    count_dataset_stats(cfg.dataset_folder)
    
    # Create full dataset
    print(f"\nCreating dataset with target size {cfg.input_size}...")
    full_dataset = VesuviusFullVolumeDataset(
        dataset_folder=cfg.dataset_folder,
        target_size=cfg.input_size,
        augment=True
    )
    
    # Split into train and validation
    train_dataset, val_dataset = split_dataset(full_dataset, cfg.train_ratio)
    # val_dataset.augment=False
    print(f"Training samples: {len(train_dataset)}")
    print(f"Validation samples: {len(val_dataset)}")
    
    # Create data loaders
    train_loader = DataLoader(
        train_dataset,
        batch_size=cfg.batch_size,
        shuffle=True,
        num_workers=cfg.num_workers,
        pin_memory=True,
        drop_last=True
    )
    
    val_loader = DataLoader(
        val_dataset,
        batch_size=cfg.batch_size,
        shuffle=False,
        num_workers=cfg.num_workers,
        pin_memory=True,
        drop_last=False
    ) if len(val_dataset) > 0 else None
    
    # Create model
    # vesuvius_vit3d_12l_8h_512_flashattn_epoch=250_val_loss=0.0000.ckpt
    model = VesuviusViT3DPLModel(
        input_size=cfg.input_size,
        patch_size=cfg.patch_size,
        num_classes=cfg.num_classes,
        lr=cfg.lr,
        weight_decay=cfg.weight_decay
    )
    
    # Print model info
    total_params = sum(p.numel() for p in model.parameters())
    print(f"\nModel parameters: {total_params:,}")
    
    # Memory estimate
    batch_memory_gb = (cfg.batch_size * 1 * cfg.input_size**3 * 4) / (1024**3)  # 4 bytes per float32
    print(f"Estimated memory per batch: {batch_memory_gb:.2f} GB")
    
    # Setup logging
    wandb_logger = WandbLogger(
        project="vesuvius_vit3d",
        name=f"{cfg.exp_name}_size_{cfg.input_size}_patch_{cfg.patch_size}"
    )
    
    # Setup callbacks
    checkpoint_callback = ModelCheckpoint(
        dirpath=cfg.model_dir,
        filename=f'vesuvius_vit3d_pretrainedMAE_surface_flashattn_{{epoch:02d}}_{{val_loss:.4f}}',
        monitor='val/loss' if val_loader else 'train/loss',
        mode='min',
        save_top_k=8,
        save_last=True
    )
    
    # Create trainer
    trainer = pl.Trainer(
        max_epochs=cfg.epochs,
        accelerator="gpu" if torch.cuda.is_available() else "cpu",
        devices=1,
        logger=wandb_logger,
        callbacks=[checkpoint_callback],
        precision='16-mixed' if torch.cuda.is_available() else 32,
        gradient_clip_val=5.0,
        log_every_n_steps=10,
        accumulate_grad_batches=12,  # Gradient accumulation for larger effective batch size
        # check_val_every_n_epoch=3   # Validate less frequently
    )
    
    # Train
    print(f"\nStarting training for {cfg.epochs} epochs...")
    print(f"Effective batch size: {cfg.batch_size * 4}")  # Due to gradient accumulation
    
    trainer.fit(model, train_loader, val_loader)
    
    print("Training completed!")
    wandb.finish()


if __name__ == "__main__":
    main()

Loading CUDA kernels...


If this is not desired, please set os.environ['TORCH_CUDA_ARCH_LIST'].


✓ CUDA kernels loaded!
