<a href="https://colab.research.google.com/github/alim98/MPI/blob/main/Vit_MAE/ViT_MAE_Scratch.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!wget -O downloaded_file.zip "https://drive.usercontent.google.com/download?id=1iHPBdBOPEagvPTHZmrN__LD49emXwReY&export=download&authuser=0&confirm=t&uuid=631d60dd-569c-4bb1-a9e8-d681f0ed3d43&at=APvzH3r4me8x_LwP3n8O7lgPo8oK%3A1733988188000"

!pip -q install transformers scikit-learn matplotlib seaborn torch torchvision umap-learn openpyxl imageio

!unzip -q downloaded_file.zip


--2025-01-14 17:40:14--  https://drive.usercontent.google.com/download?id=1iHPBdBOPEagvPTHZmrN__LD49emXwReY&export=download&authuser=0&confirm=t&uuid=631d60dd-569c-4bb1-a9e8-d681f0ed3d43&at=APvzH3r4me8x_LwP3n8O7lgPo8oK%3A1733988188000
Resolving drive.usercontent.google.com (drive.usercontent.google.com)... 142.250.101.132, 2607:f8b0:4023:c06::84
Connecting to drive.usercontent.google.com (drive.usercontent.google.com)|142.250.101.132|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 1264688649 (1.2G) [application/octet-stream]
Saving to: ‘downloaded_file.zip’


2025-01-14 17:40:36 (61.0 MB/s) - ‘downloaded_file.zip’ saved [1264688649/1264688649]

[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m88.8/88.8 kB[0m [31m6.4 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m56.9/56.9 kB[0m [31m4.2 MB/s[0m eta [36m0:00:00[0m
[?25h

In [None]:
# mae3d_reconstruction.py

# -----------------------------
# 1. Import Necessary Libraries
# -----------------------------
import os
import glob
import numpy as np
import imageio.v2 as imageio
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, Subset
from tqdm import tqdm
from sklearn.model_selection import train_test_split

# -----------------------------
# 2. Data Loading Functions
# -----------------------------

def load_volumes(bbox_name, raw_base_dir, seg_base_dir):
    """
    Load raw volume and segmentation volume for a bounding box.

    Args:
        bbox_name (str): Name of the bounding box directory.
        raw_base_dir (str): Base directory for raw data.
        seg_base_dir (str): Base directory for segmentation data.

    Returns:
        tuple: (raw_vol, seg_vol) each as np.ndarray or (None, None) if loading fails.
    """
    raw_dir = os.path.join(raw_base_dir, bbox_name)
    seg_dir = os.path.join(seg_base_dir, bbox_name)

    raw_tif_files = sorted(glob.glob(os.path.join(raw_dir, 'slice_*.tif')))
    seg_tif_files = sorted(glob.glob(os.path.join(seg_dir, 'slice_*.tif')))

    if len(raw_tif_files) == 0:
        print(f"No raw files found for {bbox_name} in {raw_dir}")
        return None, None

    if len(seg_tif_files) == 0:
        print(f"No segmentation files found for {bbox_name} in {seg_dir}")
        return None, None

    if len(raw_tif_files) != len(seg_tif_files):
        print(f"Mismatch in number of raw vs seg slices for {bbox_name}. Skipping.")
        return None, None

    try:
        raw_vol = np.stack([imageio.imread(f) for f in raw_tif_files], axis=0)  # shape: (Z, Y, X)
        seg_vol = np.stack([imageio.imread(f).astype(np.uint32) for f in seg_tif_files], axis=0)
        return raw_vol, seg_vol
    except Exception as e:
        print(f"Error loading volumes for {bbox_name}: {e}")
        return None, None

def create_segmented_cube(
    raw_vol,
    seg_vol,
    central_coord,
    side1_coord,
    side2_coord,
    subvolume_size=80,
    alpha=0.3
):
    """
    Constructs an 80x80x80 segmented 3D cube around the specified synapse coordinates
    and overlays both segmentation masks (side1_coord, side2_coord) on the raw data
    with specified transparency for each slice.

    Returns:
        np.ndarray: Overlaid cube of shape (3, 80, 80, 80).
                    Channels: [Mask1, Raw, Mask2]
    """

    def create_segment_masks(segmentation_volume, s1_coord, s2_coord):
        x1, y1, z1 = s1_coord
        x2, y2, z2 = s2_coord
        # Validate within volume
        if not (0 <= z1 < segmentation_volume.shape[0] and
                0 <= y1 < segmentation_volume.shape[1] and
                0 <= x1 < segmentation_volume.shape[2]):
            raise ValueError("Side1 coordinates are out of bounds.")

        if not (0 <= z2 < segmentation_volume.shape[0] and
                0 <= y2 < segmentation_volume.shape[1] and
                0 <= x2 < segmentation_volume.shape[2]):
            raise ValueError("Side2 coordinates are out of bounds.")

        seg_id_1 = segmentation_volume[z1, y1, x1]
        seg_id_2 = segmentation_volume[z2, y2, x2]

        # If seg_id == 0, it means no segment at that voxel
        if seg_id_1 == 0:
            mask_1 = np.zeros_like(segmentation_volume, dtype=bool)
        else:
            mask_1 = (segmentation_volume == seg_id_1)

        if seg_id_2 == 0:
            mask_2 = np.zeros_like(segmentation_volume, dtype=bool)
        else:
            mask_2 = (segmentation_volume == seg_id_2)

        return mask_1, mask_2

    # Build masks
    mask_1_full, mask_2_full = create_segment_masks(seg_vol, side1_coord, side2_coord)

    # Define subvolume bounds
    half_size = subvolume_size // 2
    cx, cy, cz = central_coord

    x_start, x_end = max(cx - half_size, 0), min(cx + half_size, raw_vol.shape[2])
    y_start, y_end = max(cy - half_size, 0), min(cy + half_size, raw_vol.shape[1])
    z_start, z_end = max(cz - half_size, 0), min(cz + half_size, raw_vol.shape[0])

    # Extract subvolumes
    sub_raw = raw_vol[z_start:z_end, y_start:y_end, x_start:x_end]
    sub_mask_1 = mask_1_full[z_start:z_end, y_start:y_end, x_start:x_end]
    sub_mask_2 = mask_2_full[z_start:z_end, y_start:y_end, x_start:x_end]

    # Pad if smaller than subvolume_size
    pad_z = subvolume_size - sub_raw.shape[0]
    pad_y = subvolume_size - sub_raw.shape[1]
    pad_x = subvolume_size - sub_raw.shape[2]

    if pad_z > 0 or pad_y > 0 or pad_x > 0:
        sub_raw = np.pad(sub_raw, ((0, pad_z), (0, pad_y), (0, pad_x)),
                         mode='constant', constant_values=0)
        sub_mask_1 = np.pad(sub_mask_1, ((0, pad_z), (0, pad_y), (0, pad_x)),
                            mode='constant', constant_values=False)
        sub_mask_2 = np.pad(sub_mask_2, ((0, pad_z), (0, pad_y), (0, pad_x)),
                            mode='constant', constant_values=False)

    # Slice to exact shape
    sub_raw = sub_raw[:subvolume_size, :subvolume_size, :subvolume_size]
    sub_mask_1 = sub_mask_1[:subvolume_size, :subvolume_size, :subvolume_size]
    sub_mask_2 = sub_mask_2[:subvolume_size, :subvolume_size, :subvolume_size]

    # Initialize cube for segmented channels
    # Channels: [Mask1, Raw, Mask2]
    overlaid_cube = np.zeros((3, subvolume_size, subvolume_size, subvolume_size), dtype=np.float32)

    # Normalize raw channel to [0, 1]
    raw_normalized = sub_raw.astype(np.float32)
    raw_min, raw_max = raw_normalized.min(), raw_normalized.max()
    if raw_max != raw_min:
        raw_normalized = (raw_normalized - raw_min) / (raw_max - raw_min)
    else:
        raw_normalized = raw_normalized - raw_min  # All zeros

    # Assign channels
    overlaid_cube[0] = sub_mask_1.astype(np.float32)  # Mask1
    overlaid_cube[1] = raw_normalized                # Raw
    overlaid_cube[2] = sub_mask_2.astype(np.float32)  # Mask2

    return overlaid_cube  # Shape: (3, 80, 80, 80)

# -----------------------------
# 3. Dataset Definition
# -----------------------------

class ElectronMicroscopyDataset(Dataset):
    """
    PyTorch Dataset for loading 3D Electron Microscopy data using custom data loader.
    Handles multiple Excel files, each corresponding to a different bounding box.
    """

    def __init__(self, bbox_names, raw_base_dir, seg_base_dir, excel_dir, subvolume_size=80, transform=None):
        """
        Args:
            bbox_names (list): List of bounding box names to include in the dataset.
            raw_base_dir (str): Base directory containing raw data.
            seg_base_dir (str): Base directory containing segmentation data.
            excel_dir (str): Directory containing Excel files for each bounding box.
                             Each Excel file should be named as '<bbox_name>.xlsx'.
            subvolume_size (int): Size of the subvolume to extract (default: 80).
            transform (callable, optional): Optional transform to be applied on a sample.
        """
        self.bbox_names = bbox_names
        self.raw_base_dir = raw_base_dir
        self.seg_base_dir = seg_base_dir
        self.excel_dir = excel_dir
        self.subvolume_size = subvolume_size
        self.transform = transform

        # Initialize an empty list to hold dataframes
        self.synapse_data_list = []

        # Load synapse data from each Excel file corresponding to the bbox_names
        for bbox_name in self.bbox_names:
            excel_file = os.path.join(self.excel_dir, f"{bbox_name}.xlsx")
            if not os.path.exists(excel_file):
                print(f"Excel file {excel_file} not found for {bbox_name}. Skipping.")
                continue

            try:
                syn_df = pd.read_excel(excel_file)
            except Exception as e:
                print(f"Error reading Excel file '{excel_file}': {e}. Skipping.")
                continue

            # Ensure required columns exist
            required_columns = [
                'central_coord_1', 'central_coord_2', 'central_coord_3',
                'side_1_coord_1', 'side_1_coord_2', 'side_1_coord_3',
                'side_2_coord_1', 'side_2_coord_2', 'side_2_coord_3'
            ]
            for col in required_columns:
                if col not in syn_df.columns:
                    print(f"Missing required column '{col}' in Excel file '{excel_file}'. Skipping.")
                    syn_df = None
                    break
            if syn_df is not None:
                # Add a column for bbox_name to keep track
                syn_df['bbox_name'] = bbox_name
                self.synapse_data_list.append(syn_df)

        if not self.synapse_data_list:
            raise ValueError("No valid synapse data loaded. Please check your Excel files and directories.")

        # Concatenate all synapse data into a single dataframe
        self.synapse_data = pd.concat(self.synapse_data_list, ignore_index=True)
        print(f"Total synapses loaded: {len(self.synapse_data)}")

    def __len__(self):
        return len(self.synapse_data)

    def __getitem__(self, idx):
        syn_info = self.synapse_data.iloc[idx]
        bbox_name = syn_info['bbox_name']

        # Load volumes
        raw_vol, seg_vol = load_volumes(bbox_name, self.raw_base_dir, self.seg_base_dir)
        # if raw_vol is None or seg_vol is None:
        #     # Return dummy data if volumes not found
        #     dummy = torch.zeros((3, self.subvolume_size, self.subvolume_size, self.subvolume_size), dtype=torch.float32)
        #     return dummy, dummy

        # Extract coordinates
        try:
            central_coord = (
                int(syn_info['central_coord_1']),
                int(syn_info['central_coord_2']),
                int(syn_info['central_coord_3'])
            )
            side1_coord = (
                int(syn_info['side_1_coord_1']),
                int(syn_info['side_1_coord_2']),
                int(syn_info['side_1_coord_3'])
            )
            side2_coord = (
                int(syn_info['side_2_coord_1']),
                int(syn_info['side_2_coord_2']),
                int(syn_info['side_2_coord_3'])
            )
        except KeyError as e:
            raise KeyError(f"Missing coordinate column: {e}")
        except ValueError as e:
            raise ValueError(f"Invalid coordinate value: {e}")

        # Create the overlaid segmented cube
        overlaid_cube = create_segmented_cube(
            raw_vol=raw_vol,
            seg_vol=seg_vol,
            central_coord=central_coord,
            side1_coord=side1_coord,
            side2_coord=side2_coord,
            subvolume_size=self.subvolume_size,
            alpha=0.3
        )  # shape: (3, 80, 80, 80)

        # Apply optional transforms
        if self.transform:
            overlaid_cube = self.transform(overlaid_cube)

        # Convert to torch tensor
        overlaid_cube = torch.from_numpy(overlaid_cube).float()  # Shape: (3, 80, 80, 80)

        # For MAE, the target is the same as input
        return overlaid_cube, overlaid_cube

# -----------------------------
# 4. Model Components
# -----------------------------

class PatchEmbed3D(nn.Module):
    """
    3D Patch Embedding Layer

    Splits the input into patches and embeds them.
    """
    def __init__(self, in_channels=3, patch_size=4, embed_dim=768):
        """
        Args:
            in_channels (int): Number of input channels (default: 3).
            patch_size (int): Size of each 3D patch (default: 4).
            embed_dim (int): Dimension of the embedding (default: 768).
        """
        super(PatchEmbed3D, self).__init__()
        self.patch_size = patch_size
        self.proj = nn.Conv3d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size)

    def forward(self, x):
        """
        Args:
            x (torch.Tensor): Input tensor of shape (B, C, D, H, W)
        Returns:
            torch.Tensor: Embedded patches of shape (B, N, E)
        """
        x = self.proj(x)  # (B, E, D/P, H/P, W/P)
        x = x.flatten(2)  # (B, E, N)
        x = x.transpose(1, 2)  # (B, N, E)
        return x

class PositionalEncoding3D(nn.Module):
    """
    3D Positional Encoding using learnable embeddings
    """
    def __init__(self, embed_dim, num_patches):
        super(PositionalEncoding3D, self).__init__()
        self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
        nn.init.trunc_normal_(self.pos_embed, std=0.02)

    def forward(self, x):
        return x + self.pos_embed

class TransformerEncoderBlock(nn.Module):
    """
    Single Transformer Encoder Block
    """
    def __init__(self, embed_dim, num_heads, mlp_ratio=4.0, drop=0.0, attn_drop=0.0):
        super(TransformerEncoderBlock, self).__init__()
        self.norm1 = nn.LayerNorm(embed_dim)
        self.attn = nn.MultiheadAttention(embed_dim, num_heads, dropout=attn_drop)
        self.drop_path = nn.Identity()  # Can implement stochastic depth if needed
        self.norm2 = nn.LayerNorm(embed_dim)
        self.mlp = nn.Sequential(
            nn.Linear(embed_dim, int(embed_dim * mlp_ratio)),
            nn.GELU(),
            nn.Linear(int(embed_dim * mlp_ratio), embed_dim),
            nn.Dropout(drop)
        )

    def forward(self, x):
        # Self-Attention
        x2 = self.norm1(x)
        attn_output, _ = self.attn(x2.transpose(0,1), x2.transpose(0,1), x2.transpose(0,1))
        attn_output = attn_output.transpose(0,1)
        x = x + self.drop_path(attn_output)
        # MLP
        x2 = self.norm2(x)
        x = x + self.drop_path(self.mlp(x2))
        return x

class TransformerEncoder(nn.Module):
    """
    Transformer Encoder consisting of multiple Encoder Blocks
    """
    def __init__(self, depth, embed_dim, num_heads, mlp_ratio=4.0, drop=0.0, attn_drop=0.0):
        super(TransformerEncoder, self).__init__()
        self.layers = nn.ModuleList([
            TransformerEncoderBlock(embed_dim, num_heads, mlp_ratio, drop, attn_drop)
            for _ in range(depth)
        ])
        self.norm = nn.LayerNorm(embed_dim)

    def forward(self, x):
        for layer in self.layers:
            x = layer(x)
        x = self.norm(x)
        return x

class MaskToken(nn.Module):
    """
    Mask Token to be added to the masked positions
    """
    def __init__(self, embed_dim):
        super(MaskToken, self).__init__()
        self.mask_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        nn.init.trunc_normal_(self.mask_token, std=0.02)

    def forward(self, B, num_mask, embed_dim):
        """
        Generates mask tokens for each sample in the batch.

        Args:
            B (int): Batch size.
            num_mask (int): Number of masked patches per sample.
            embed_dim (int): Embedding dimension.

        Returns:
            torch.Tensor: Mask tokens of shape (B, num_mask, embed_dim).
        """
        # Expand the mask token to (B, num_mask, embed_dim)
        return self.mask_token.expand(B, num_mask, embed_dim).clone()

class DecoderBlock(nn.Module):
    """
    Single Decoder Block
    """
    def __init__(self, embed_dim, num_heads, mlp_ratio=4.0, drop=0.0, attn_drop=0.0):
        super(DecoderBlock, self).__init__()
        self.norm1 = nn.LayerNorm(embed_dim)
        self.attn = nn.MultiheadAttention(embed_dim, num_heads, dropout=attn_drop)
        self.drop_path = nn.Identity()
        self.norm2 = nn.LayerNorm(embed_dim)
        self.mlp = nn.Sequential(
            nn.Linear(embed_dim, int(embed_dim * mlp_ratio)),
            nn.GELU(),
            nn.Linear(int(embed_dim * mlp_ratio), embed_dim),
            nn.Dropout(drop)
        )

    def forward(self, x):
        # Self-Attention
        x2 = self.norm1(x)
        attn_output, _ = self.attn(x2.transpose(0,1), x2.transpose(0,1), x2.transpose(0,1))
        attn_output = attn_output.transpose(0,1)
        x = x + self.drop_path(attn_output)
        # MLP
        x2 = self.norm2(x)
        x = x + self.drop_path(self.mlp(x2))
        return x

class TransformerDecoder(nn.Module):
    """
    Transformer Decoder consisting of multiple Decoder Blocks
    """
    def __init__(self, depth, embed_dim, num_heads, mlp_ratio=4.0, drop=0.0, attn_drop=0.0):
        super(TransformerDecoder, self).__init__()
        self.layers = nn.ModuleList([
            DecoderBlock(embed_dim, num_heads, mlp_ratio, drop, attn_drop)
            for _ in range(depth)
        ])
        self.norm = nn.LayerNorm(embed_dim)

    def forward(self, x):
        for layer in self.layers:
            x = layer(x)
        x = self.norm(x)
        return x

class MAE3D(nn.Module):
    """
    3D Masked AutoEncoder with Vision Transformer backbone
    """
    def __init__(
        self,
        in_channels=3,          # [Mask1, Raw, Mask2]
        patch_size=4,
        embed_dim=768,
        encoder_depth=12,
        encoder_num_heads=12,
        decoder_depth=8,
        decoder_num_heads=12,
        mlp_ratio=4.0,
        mask_ratio=0.75
    ):
        super(MAE3D, self).__init__()
        self.patch_embed = PatchEmbed3D(in_channels, patch_size, embed_dim)
        num_patches = (80 // patch_size) ** 3
        self.encoder_pos_embed = PositionalEncoding3D(embed_dim, num_patches)

        self.encoder = TransformerEncoder(
            depth=encoder_depth,
            embed_dim=embed_dim,
            num_heads=encoder_num_heads,
            mlp_ratio=mlp_ratio
        )

        self.mask_ratio = mask_ratio
        self.num_patches = num_patches
        self.patch_size = patch_size

        # Initialize Mask Token
        self.mask_token = MaskToken(embed_dim)

        # Decoder
        self.decoder_embed = nn.Linear(embed_dim, embed_dim, bias=True)
        self.decoder_pos_embed = PositionalEncoding3D(embed_dim, num_patches)

        self.decoder = TransformerDecoder(
            depth=decoder_depth,
            embed_dim=embed_dim,
            num_heads=decoder_num_heads,
            mlp_ratio=mlp_ratio
        )

        # Output projection
        self.decoder_pred = nn.Linear(embed_dim, in_channels * patch_size ** 3, bias=True)

    def forward(self, x):
        """
        Args:
            x (torch.Tensor): Input tensor of shape (B, C, D, H, W)
        Returns:
            tuple: (pred_masked, mask)
                - pred_masked (torch.Tensor): Reconstructed masked patches (B, num_mask, C, P, P, P)
                - mask (torch.Tensor): Boolean mask indicating masked patches (B, N)
        """
        B = x.shape[0]
        # Patch Embedding
        x = self.patch_embed(x)  # (B, N, E)

        # Apply positional encoding
        x = self.encoder_pos_embed(x)

        # Masking: exactly mask 'num_mask' patches per sample
        N = x.shape[1]
        num_mask = int(self.mask_ratio * N)

        # Generate deterministic mask with exactly 'num_mask' patches masked
        rand = torch.rand(B, N, device=x.device)
        _, topk_indices = rand.topk(num_mask, dim=1)
        # Initialize mask tensor
        mask = torch.zeros(B, N, dtype=torch.bool, device=x.device)
        # Scatter True values at masked indices
        mask.scatter_(1, topk_indices, True)  # (B, N)

        # Select visible and masked patches
        visible_idx = ~mask  # (B, N)
        masked_idx = mask      # (B, N)

        # Ensure that the number of visible and masked patches is consistent
        x_visible = x[visible_idx].view(B, N - num_mask, -1)  # (B, N - num_mask, E)
        x_masked = x[masked_idx].view(B, num_mask, -1)        # (B, num_mask, E)

        # Encoder
        encoded = self.encoder(x_visible)  # (B, N - num_mask, E)

        # Prepare tokens for decoder
        # Concatenate encoded visible tokens with mask tokens
        mask_tokens = self.mask_token(B, num_mask, encoded.shape[-1]).to(x.device)  # (B, num_mask, E)
        x_decoder = torch.cat([encoded, mask_tokens], dim=1)  # (B, N, E)

        # Add positional encoding
        x_decoder = self.decoder_pos_embed(x_decoder)

        # Decoder
        decoded = self.decoder(x_decoder)  # (B, N, E)

        # Predictor
        pred = self.decoder_pred(decoded)  # (B, N, C*P^3)

        # Reconstruct the masked patches
        # Only compute loss on masked patches
        pred_masked = pred[masked_idx].view(B, num_mask, 3, self.patch_size, self.patch_size, self.patch_size)  # (B, num_mask, C, P, P, P)

        return pred_masked, mask

# -----------------------------
# 5. Utility Functions
# -----------------------------

def patchify(x, patch_size):
    """
    Divides the input tensor into patches.

    Args:
        x (torch.Tensor): Input tensor of shape (B, C, D, H, W)
        patch_size (int): Size of each patch

    Returns:
        torch.Tensor: Patches of shape (B, N, C, P, P, P)
    """
    B, C, D, H, W = x.shape
    x = x.unfold(2, patch_size, patch_size)
    x = x.unfold(3, patch_size, patch_size)
    x = x.unfold(4, patch_size, patch_size)
    x = x.contiguous().view(B, C, -1, patch_size, patch_size, patch_size)
    x = x.permute(0, 2, 1, 3, 4, 5)  # (B, N, C, P, P, P)
    return x

def initialize_dataloaders(
    raw_base_dir,
    seg_base_dir,
    excel_dir,
    bbox_names,
    batch_size=2,
    subvolume_size=80,
    test_size=0.2,
    random_state=42,
    num_workers=2  # Reduced to 2 to avoid warnings
):
    """
    Initializes training and validation dataloaders.

    Args:
        raw_base_dir (str): Base directory containing raw data.
        seg_base_dir (str): Base directory containing segmentation data.
        excel_dir (str): Directory containing Excel files for each bounding box.
        bbox_names (list): List of bounding box names.
        batch_size (int): Batch size.
        subvolume_size (int): Size of the subvolume.
        test_size (float): Fraction of data to use for validation.
        random_state (int): Random seed for reproducibility.
        num_workers (int): Number of subprocesses for data loading.

    Returns:
        tuple: (train_loader, val_loader)
    """
    # Initialize the full dataset
    full_dataset = ElectronMicroscopyDataset(
        bbox_names=bbox_names,
        raw_base_dir=raw_base_dir,
        seg_base_dir=seg_base_dir,
        excel_dir=excel_dir,
        subvolume_size=subvolume_size
    )

    # Split into training and validation
    train_indices, val_indices = train_test_split(
        list(range(len(full_dataset))),
        test_size=test_size,
        random_state=random_state
    )

    # Subset samplers
    train_dataset = Subset(full_dataset, train_indices)
    val_dataset = Subset(full_dataset, val_indices)

    # Create DataLoaders
    train_loader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=num_workers,
        pin_memory=True
    )
    val_loader = DataLoader(
        val_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers,
        pin_memory=True
    )

    return train_loader, val_loader

# -----------------------------
# 6. Training Loop
# -----------------------------

def train_mae3d(
    model,
    train_loader,
    val_loader,
    optimizer,
    criterion,
    device,
    num_epochs=100,
    scheduler=None
):
    """
    Trains the MAE3D model.

    Args:
        model (nn.Module): The MAE3D model.
        train_loader (DataLoader): Training dataloader.
        val_loader (DataLoader): Validation dataloader.
        optimizer (torch.optim.Optimizer): Optimizer.
        criterion (nn.Module): Loss function.
        device (torch.device): Device to train on.
        num_epochs (int): Number of training epochs.
        scheduler (torch.optim.lr_scheduler, optional): Learning rate scheduler.
    """
    model.train()
    for epoch in range(num_epochs):
        epoch_loss = 0.0
        progress = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}")
        for batch_idx, (inputs, _) in enumerate(progress):
            inputs = inputs.to(device)   # (B, 3, 80, 80, 80)

            optimizer.zero_grad()
            preds, mask = model(inputs)  # preds: (B, num_mask, 3, P, P, P), mask: (B, N)

            B, N = mask.shape
            P = model.patch_size
            C = 3  # Number of channels

            # Patchify the inputs to get target patches
            target_patches = patchify(inputs, P)  # (B, N, C, P, P, P)

            # Select only the masked patches
            mask_flat = mask.view(B, N)
            target_masked = target_patches[mask_flat]  # (B*num_mask, C, P, P, P)

            # Reshape preds to match target_masked
            preds = preds.view(B * preds.shape[1], C, P, P, P)

            # Compute loss
            loss = criterion(preds, target_masked)

            loss.backward()
            optimizer.step()

            epoch_loss += loss.item()
            progress.set_postfix({'Loss': loss.item()})

        avg_loss = epoch_loss / len(train_loader)
        print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {avg_loss:.4f}")

        if scheduler:
            scheduler.step()

        # Optional: Implement validation loop here

    print("Training completed.")

# -----------------------------
# 7. Main Function
# -----------------------------

def main():

    # Configuration
    raw_base_dir = "raw"
    seg_base_dir = "seg"
    excel_dir = ""
    bbox_names = ["bbox1", "bbox2", "bbox3"]
    batch_size = 1
    num_epochs = 100
    learning_rate = 1e-4
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    subvolume_size = 80
    mask_ratio = 0.75
    num_workers = 2
    print(device)


    # Initialize DataLoaders
    train_loader, val_loader = initialize_dataloaders(
        raw_base_dir=raw_base_dir,
        seg_base_dir=seg_base_dir,
        excel_dir=excel_dir,
        bbox_names=bbox_names,
        batch_size=batch_size,
        subvolume_size=subvolume_size,
        test_size=0.2,
        random_state=42,
        num_workers=num_workers
    )

    # Initialize Model
    model = MAE3D(
        in_channels=3,         # 3 channels: Mask1, Raw, Mask2
        patch_size=4,
        embed_dim=768,
        encoder_depth=4,
        encoder_num_heads=4,
        decoder_depth=4,
        decoder_num_heads=4,
        mlp_ratio=4.0,
        mask_ratio=mask_ratio
    ).to(device)

    # Define Loss and Optimizer
    criterion = nn.MSELoss()
    optimizer = optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=0.05)
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.1)

    # Start Training
    train_mae3d(
        model=model,
        train_loader=train_loader,
        val_loader=val_loader,
        optimizer=optimizer,
        criterion=criterion,
        device=device,
        num_epochs=num_epochs,
        scheduler=scheduler
    )

    # Save the trained model
    torch.save(model.state_dict(), "mae3d_model.pth")
    print("Model saved as mae3d_model.pth")

if __name__ == "__main__":
    main()


cuda
Total synapses loaded: 220


Epoch 1/100: 100%|██████████| 176/176 [17:41<00:00,  6.03s/it, Loss=0.175]


Epoch [1/100], Loss: 0.1574


Epoch 2/100: 100%|██████████| 176/176 [17:50<00:00,  6.08s/it, Loss=0.0844]


Epoch [2/100], Loss: 0.1271


Epoch 3/100:  38%|███▊      | 67/176 [06:47<10:57,  6.03s/it, Loss=0.127]