<a href="https://colab.research.google.com/github/alim98/MPI/blob/main/MPI_video_MAE_f.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Doc

https://github.com/MCG-NJU/VideoMAE

colab for finetuning Docs:
https://colab.research.google.com/github/huggingface/notebooks/blob/main/examples/video_classification.ipynb


# Essential downloads

## Args

In [1]:
import argparse

def parse_args():
    """
    Parse command-line arguments for configurable paths and training parameters.
    """
    parser = argparse.ArgumentParser(description="VideoMAE Pre-training Script with Segmented Videos")
    parser.add_argument('--raw_base_dir', type=str, default='./raw', help='Path to raw data directory')
    parser.add_argument('--seg_base_dir', type=str, default='./seg', help='Path to segmentation data directory')
    parser.add_argument('--csv_output_dir', type=str, default='csv_outputs', help='Directory to save CSV outputs')
    parser.add_argument('--checkpoint_dir', type=str, default='checkpoints', help='Directory to save model checkpoints')
    parser.add_argument('--log_dir', type=str, default='logs', help='Directory for TensorBoard logs')
    parser.add_argument('--batch_size', type=int, default=2, help='Batch size for training')
    parser.add_argument('--num_epochs', type=int, default=100, help='Number of training epochs')
    parser.add_argument('--learning_rate', type=float, default=1e-4, help='Learning rate for optimizer')
    parser.add_argument('--weight_decay', type=float, default=1e-2, help='Weight decay for optimizer')
    parser.add_argument('--subvol_size', type=int, default=80, help='Size of the sub-volume to extract')
    parser.add_argument('--num_frames', type=int, default=16, help='Number of frames per video clip')
    parser.add_argument('--mask_ratio', type=float, default=0.50, help='Mask ratio for VideoMAE')
    # Removed the patience argument as it's related to early stopping
    # parser.add_argument('--patience', type=int, default=3, help='Patience for early stopping')
    parser.add_argument('--resume_checkpoint', type=str, default=None, help='Path to resume checkpoint')

    # Optionally add WandB specific arguments
    parser.add_argument('--wandb_project', type=str, default='VideoMAE_PreTraining', help='WandB project name')
    parser.add_argument('--wandb_entity', type=str, default=None, help='WandB entity/team name')
    parser.add_argument('--wandb_run_name', type=str, default=None, help='WandB run name')

    args, _ = parser.parse_known_args()
    return args


In [2]:
!wget -O downloaded_file.zip "https://drive.usercontent.google.com/download?id=1iHPBdBOPEagvPTHZmrN__LD49emXwReY&export=download&authuser=0&confirm=t&uuid=631d60dd-569c-4bb1-a9e8-d681f0ed3d43&at=APvzH3r4me8x_LwP3n8O7lgPo8oK%3A1733988188000"

!pip -q install transformers scikit-learn matplotlib seaborn torch torchvision umap-learn openpyxl imageio

!unzip -q downloaded_file.zip

--2025-01-14 12:57:43--  https://drive.usercontent.google.com/download?id=1iHPBdBOPEagvPTHZmrN__LD49emXwReY&export=download&authuser=0&confirm=t&uuid=631d60dd-569c-4bb1-a9e8-d681f0ed3d43&at=APvzH3r4me8x_LwP3n8O7lgPo8oK%3A1733988188000
Resolving drive.usercontent.google.com (drive.usercontent.google.com)... 142.251.107.132, 2607:f8b0:400c:c32::84
Connecting to drive.usercontent.google.com (drive.usercontent.google.com)|142.251.107.132|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 1264688649 (1.2G) [application/octet-stream]
Saving to: ‘downloaded_file.zip’


2025-01-14 12:57:57 (89.8 MB/s) - ‘downloaded_file.zip’ saved [1264688649/1264688649]

[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m88.8/88.8 kB[0m [31m3.8 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m56.9/56.9 kB[0m [31m3.1 MB/s[0m eta [36m0:00:00[0m
[?25h

#2

In [None]:
import os
import glob
import argparse
import math
import numpy as np
import pandas as pd
import torch
from torch.utils.data import Dataset, DataLoader
import imageio.v2 as iio
from transformers import (
    VideoMAEForPreTraining,
    VideoMAEImageProcessor,
)
from torch.optim import AdamW
from torch.optim.lr_scheduler import CosineAnnealingLR, LinearLR, SequentialLR
import wandb
from tqdm.auto import tqdm
import multiprocessing
import warnings
from sklearn.model_selection import train_test_split

# Suppress warnings for cleaner output
warnings.filterwarnings("ignore")


def parse_args():
    """
    Parse command-line arguments for configurable paths and training parameters.
    """
    parser = argparse.ArgumentParser(description="VideoMAE Pre-training Script with Segmented Videos")
    parser.add_argument('--raw_base_dir', type=str, default='./raw', help='Path to raw data directory')
    parser.add_argument('--seg_base_dir', type=str, default='./seg', help='Path to segmentation data directory')
    parser.add_argument('--csv_output_dir', type=str, default='csv_outputs', help='Directory to save CSV outputs')
    parser.add_argument('--checkpoint_dir', type=str, default='checkpoints', help='Directory to save model checkpoints')
    parser.add_argument('--log_dir', type=str, default='logs', help='Directory for TensorBoard logs')
    parser.add_argument('--batch_size', type=int, default=2, help='Batch size for training')
    parser.add_argument('--num_epochs', type=int, default=100, help='Number of training epochs')
    parser.add_argument('--learning_rate', type=float, default=1e-4, help='Learning rate for optimizer')
    parser.add_argument('--weight_decay', type=float, default=1e-2, help='Weight decay for optimizer')
    parser.add_argument('--subvol_size', type=int, default=224, help='Spatial size of the sub-volume to extract (aligned with model\'s image_size)')
    parser.add_argument('--num_frames', type=int, default=16, help='Number of frames per video clip')
    parser.add_argument('--mask_ratio', type=float, default=0.75, help='Mask ratio for VideoMAE')
    parser.add_argument('--warmup_epochs', type=int, default=10, help='Number of warmup epochs for learning rate scheduling')
    parser.add_argument('--resume_checkpoint', type=str, default=None, help='Path to resume checkpoint')

    # Optional WandB arguments
    parser.add_argument('--wandb_project', type=str, default='VideoMAE_finetuning', help='WandB project name')
    parser.add_argument('--wandb_entity', type=str, default=None, help='WandB entity/team name')
    parser.add_argument('--wandb_run_name', type=str, default=None, help='WandB run name')

    # Verbosity flag
    parser.add_argument('--verbose', action='store_true', help='Enable verbose logging for debugging')

    args, _ = parser.parse_known_args()
    return args


def load_volumes(bbox_name, raw_base_dir, seg_base_dir):
    """
    Load raw volume and segmentation volume for a bounding box.

    Args:
        bbox_name (str): Name of the bounding box directory.
        raw_base_dir (str): Base directory for raw data.
        seg_base_dir (str): Base directory for segmentation data.

    Returns:
        tuple: (raw_vol, seg_vol) each as np.ndarray
    """
    raw_dir = os.path.join(raw_base_dir, bbox_name)
    seg_dir = os.path.join(seg_base_dir, bbox_name)

    raw_tif_files = sorted(glob.glob(os.path.join(raw_dir, 'slice_*.tif')))
    seg_tif_files = sorted(glob.glob(os.path.join(seg_dir, 'slice_*.tif')))

    if len(raw_tif_files) == 0:
        print(f"No raw files found for {bbox_name} in {raw_dir}")
        return None, None

    if len(seg_tif_files) == 0:
        print(f"No segmentation files found for {bbox_name} in {seg_dir}")
        return None, None

    if len(raw_tif_files) != len(seg_tif_files):
        print(f"Mismatch in number of raw vs seg slices for {bbox_name}. Skipping.")
        return None, None

    try:
        raw_vol = np.stack([iio.imread(f) for f in raw_tif_files], axis=0)  # shape: (Z, Y, X)
        seg_vol = np.stack([iio.imread(f).astype(np.uint32) for f in seg_tif_files], axis=0)
        return raw_vol, seg_vol
    except Exception as e:
        print(f"Error loading volumes for {bbox_name}: {e}")
        return None, None


def create_segmented_cube(
    raw_vol,
    seg_vol,
    central_coord,
    side1_coord,
    side2_coord,
    subvolume_size=224,
    num_frames=16,
    alpha=0.3,
    verbose=False  # Added verbosity control
):
    """
    Constructs a segmented 3D video clip around the specified synapse coordinates
    and overlays both segmentation masks (side1_coord, side2_coord) on the raw data
    with specified transparency for each frame.

    Args:
        raw_vol (np.ndarray): Raw volume data of shape (Z, Y, X).
        seg_vol (np.ndarray): Segmentation volume data of shape (Z, Y, X).
        central_coord (tuple): (x, y, z) coordinates for the central synapse.
        side1_coord (tuple): (x, y, z) coordinates for the first side synapse.
        side2_coord (tuple): (x, y, z) coordinates for the second side synapse.
        subvolume_size (int): Spatial size (height and width) of the sub-volume to extract.
        num_frames (int): Number of frames to extract for the video clip.
        alpha (float): Blending alpha for segmentation masks.
        verbose (bool): Flag to control verbosity.

    Returns:
        np.ndarray: Overlaid video clip of shape (num_frames, subvolume_size, subvolume_size, 3).
    """

    def create_segment_masks(segmentation_volume, s1_coord, s2_coord):
        x1, y1, z1 = s1_coord
        x2, y2, z2 = s2_coord
        # Validate within volume
        if not (0 <= z1 < segmentation_volume.shape[0] and
                0 <= y1 < segmentation_volume.shape[1] and
                0 <= x1 < segmentation_volume.shape[2]):
            raise ValueError("Side1 coordinates are out of bounds.")

        if not (0 <= z2 < segmentation_volume.shape[0] and
                0 <= y2 < segmentation_volume.shape[1] and
                0 <= x2 < segmentation_volume.shape[2]):
            raise ValueError("Side2 coordinates are out of bounds.")

        seg_id_1 = segmentation_volume[z1, y1, x1]
        seg_id_2 = segmentation_volume[z2, y2, x2]

        # If seg_id == 0, it means no segment at that voxel
        if seg_id_1 == 0:
            mask_1 = np.zeros_like(segmentation_volume, dtype=bool)
        else:
            mask_1 = (segmentation_volume == seg_id_1)

        if seg_id_2 == 0:
            mask_2 = np.zeros_like(segmentation_volume, dtype=bool)
        else:
            mask_2 = (segmentation_volume == seg_id_2)

        return mask_1, mask_2

    # Build masks
    mask_1_full, mask_2_full = create_segment_masks(seg_vol, side1_coord, side2_coord)

    # Define subvolume bounds
    half_size = subvolume_size // 2
    cx, cy, cz = central_coord

    x_start, x_end = max(cx - half_size, 0), min(cx + half_size, raw_vol.shape[2])
    y_start, y_end = max(cy - half_size, 0), min(cy + half_size, raw_vol.shape[1])
    z_start, z_end = max(cz - half_size, 0), min(cz + half_size, raw_vol.shape[0])

    # Extract subvolumes
    sub_raw = raw_vol[z_start:z_end, y_start:y_end, x_start:x_end]
    sub_mask_1 = mask_1_full[z_start:z_end, y_start:y_end, x_start:x_end]
    sub_mask_2 = mask_2_full[z_start:z_end, y_start:y_end, x_start:x_end]

    # Debug: Print shapes after extraction
    if verbose:
        print(f"Extracted sub_raw shape: {sub_raw.shape}")
        print(f"Extracted sub_mask_1 shape: {sub_mask_1.shape}")
        print(f"Extracted sub_mask_2 shape: {sub_mask_2.shape}")

    # Calculate required spatial padding
    pad_y = subvolume_size - sub_raw.shape[1]
    pad_x = subvolume_size - sub_raw.shape[2]

    if pad_y > 0 or pad_x > 0:
        pad_y_before = pad_y // 2
        pad_y_after = pad_y - pad_y_before
        pad_x_before = pad_x // 2
        pad_x_after = pad_x - pad_x_before

        if verbose:
            print(f"Applying padding: pad_y_before={pad_y_before}, pad_y_after={pad_y_after}, "
                  f"pad_x_before={pad_x_before}, pad_x_after={pad_x_after}")

        sub_raw = np.pad(
            sub_raw,
            ((0, 0), (pad_y_before, pad_y_after), (pad_x_before, pad_x_after)),
            mode='constant',
            constant_values=0
        )
        sub_mask_1 = np.pad(
            sub_mask_1,
            ((0, 0), (pad_y_before, pad_y_after), (pad_x_before, pad_x_after)),
            mode='constant',
            constant_values=False
        )
        sub_mask_2 = np.pad(
            sub_mask_2,
            ((0, 0), (pad_y_before, pad_y_after), (pad_x_before, pad_x_after)),
            mode='constant',
            constant_values=False
        )

        # Debug: Print shapes after padding
        if verbose:
            print(f"Padded sub_raw shape: {sub_raw.shape}")
            print(f"Padded sub_mask_1 shape: {sub_mask_1.shape}")
            print(f"Padded sub_mask_2 shape: {sub_mask_2.shape}")

    # Verify spatial dimensions after padding
    assert sub_raw.shape[1] == subvolume_size and sub_raw.shape[2] == subvolume_size, \
        f"After padding, spatial dimensions are {sub_raw.shape[1:3]}, expected {(subvolume_size, subvolume_size)}."

    # Handle temporal (frame) padding
    total_available_frames = sub_raw.shape[0]
    if total_available_frames < num_frames:
        # Pad frames by repeating the last frame
        pad_frames = num_frames - total_available_frames
        if verbose:
            print(f"Padding frames: {pad_frames} frames")
        sub_raw = np.pad(
            sub_raw,
            ((0, pad_frames), (0, 0), (0, 0)),
            mode='edge'
        )
        sub_mask_1 = np.pad(
            sub_mask_1,
            ((0, pad_frames), (0, 0), (0, 0)),
            mode='constant',
            constant_values=False
        )
        sub_mask_2 = np.pad(
            sub_mask_2,
            ((0, pad_frames), (0, 0), (0, 0)),
            mode='constant',
            constant_values=False
        )
    else:
        # Trim excess frames
        sub_raw = sub_raw[:num_frames, :, :]
        sub_mask_1 = sub_mask_1[:num_frames, :, :]
        sub_mask_2 = sub_mask_2[:num_frames, :, :]

    # Initialize an array for RGB images
    # Dimensions: (num_frames, subvolume_size, subvolume_size, 3)
    overlaid_video = np.zeros((num_frames, subvolume_size, subvolume_size, 3), dtype=np.uint8)

    for z in range(num_frames):
        # Normalize the raw slice to [0, 1]
        raw_slice = sub_raw[z].astype(np.float32)
        mn, mx = raw_slice.min(), raw_slice.max()
        if mx > mn:
            raw_normalized = (raw_slice - mn) / (mx - mn)
        else:
            raw_normalized = raw_slice - mn  # If all values are equal

        # Scale to 0-255
        raw_scaled = (raw_normalized * 255).astype(np.uint8)

        # Middle channel (G) for the raw image
        overlaid_video[z, :, :, 1] = raw_scaled

        # Extract masks
        mask1 = sub_mask_1[z].astype(np.uint8)
        mask2 = sub_mask_2[z].astype(np.uint8)

        # Overlay segmentation masks with alpha blending
        # Red Channel (R) for mask1
        overlaid_video[z, :, :, 0] = (
            (raw_scaled * mask1 * alpha) + (overlaid_video[z, :, :, 0] * (1 - alpha))
        ).astype(np.uint8)

        # Blue Channel (B) for mask2
        overlaid_video[z, :, :, 2] = (
            (raw_scaled * mask2 * alpha) + (overlaid_video[z, :, :, 2] * (1 - alpha))
        ).astype(np.uint8)

    return overlaid_video  # Shape: (num_frames, H, W, 3)


class VideoMAEDataset(Dataset):
    """
    Dataset class that uses segmented volumes (side1 & side2) for VideoMAE pre-training.
    """
    def __init__(self, vol_data_list, synapse_df, processor, subvol_size=224, num_frames=16, alpha=0.3, verbose=False):
        """
        Args:
            vol_data_list (List[Tuple[np.ndarray, np.ndarray]]): List of (raw_vol, seg_vol).
            synapse_df (pd.DataFrame): DataFrame with synapse coordinates (central, side1, side2).
            processor (VideoMAEImageProcessor): Processor for VideoMAE.
            subvol_size (int): Spatial size of the sub-volume to extract.
            num_frames (int): Number of frames for the model.
            alpha (float): Blending alpha for segmentation.
            verbose (bool): Flag to control verbosity.
        """
        self.vol_data_list = vol_data_list
        self.synapse_df = synapse_df.reset_index(drop=True)
        self.processor = processor
        self.subvol_size = subvol_size
        self.num_frames = num_frames
        self.alpha = alpha
        self.verbose = verbose  # Store verbosity flag

    def __len__(self):
        return len(self.synapse_df)

    def __getitem__(self, idx):
        syn_info = self.synapse_df.iloc[idx]
        bbox_index = syn_info['bbox_index']

        raw_vol, seg_vol = self.vol_data_list[bbox_index]
        if raw_vol is None or seg_vol is None:
            # Return dummy data if volumes not found
            pixel_values = torch.zeros((self.num_frames, 3, self.subvol_size, self.subvol_size), dtype=torch.float32)
            return pixel_values, pixel_values

        # Coordinates
        central_coord = (
            int(syn_info['central_coord_1']),
            int(syn_info['central_coord_2']),
            int(syn_info['central_coord_3'])
        )
        side1_coord = (
            int(syn_info['side_1_coord_1']),
            int(syn_info['side_1_coord_2']),
            int(syn_info['side_1_coord_3'])
        )
        side2_coord = (
            int(syn_info['side_2_coord_1']),
            int(syn_info['side_2_coord_2']),
            int(syn_info['side_2_coord_3'])
        )

        # Create the overlaid segmented video
        overlaid_video = create_segmented_cube(
            raw_vol=raw_vol,
            seg_vol=seg_vol,
            central_coord=central_coord,
            side1_coord=side1_coord,
            side2_coord=side2_coord,
            subvolume_size=self.subvol_size,
            num_frames=self.num_frames,
            alpha=self.alpha,
            verbose=self.verbose  # Pass verbosity flag
        )  # shape: (num_frames, H, W, 3)

        # Process using the VideoMAEImageProcessor
        # The processor expects a list of frames in [H, W, C] format
        frames = [overlaid_video[f] for f in range(self.num_frames)]

        inputs = self.processor(frames, return_tensors="pt")
        pixel_values = inputs["pixel_values"].squeeze(0)  # [num_frames, 3, H, W]
        pixel_values = pixel_values.float()

        # For MAE, target is the same
        return pixel_values, pixel_values


def generate_masked_positions(batch_size, sequence_length, mask_ratio=0.75, device='cuda'):
    """
    Generate a boolean mask indicating which positions are masked and return their indices.

    Returns:
        masks (torch.Tensor): Boolean tensor of shape [batch_size, sequence_length]
        masked_indices (List[torch.Tensor]): List of masked indices for each sample in the batch
    """
    masks = torch.zeros(batch_size, sequence_length, dtype=torch.bool, device=device)
    masked_indices = []
    num_mask = int(mask_ratio * sequence_length)
    for i in range(batch_size):
        mask_idx = torch.randperm(sequence_length, device=device)[:num_mask]
        masks[i, mask_idx] = True
        masked_indices.append(mask_idx)
    return masks, masked_indices


class VideoMAEFineTuner:
    def __init__(
        self,
        model_name="MCG-NJU/videomae-base",
        learning_rate=1e-4,
        weight_decay=1e-2,
        num_epochs=100,
        warmup_epochs=10,
        mask_ratio=0.75,
        device='cuda' if torch.cuda.is_available() else 'cpu'
    ):
        self.device = device
        self.num_epochs = num_epochs
        self.warmup_epochs = warmup_epochs
        self.mask_ratio = mask_ratio

        # Initialize model
        self.model = VideoMAEForPreTraining.from_pretrained(model_name)
        self.model.to(self.device)

        # Print model configuration for debugging
        print(f"Model configuration:\n"
              f"Image Size: {self.model.config.image_size}\n"
              f"Patch Size: {self.model.config.patch_size}\n"
              f"Tubelet Size: {self.model.config.tubelet_size}\n"
              f"Num Frames: {self.model.config.num_frames}\n")

        # Initialize optimizer
        self.optimizer = AdamW(
            self.model.parameters(),
            lr=learning_rate,
            weight_decay=weight_decay,
            betas=(0.9, 0.95)
        )

        # Initialize schedulers with warmup
        self.total_steps = None  # To be set later when training starts
        self.warmup_steps = None  # To be set later

        self.scheduler = None  # To be set in train()

    def generate_masks(self, batch_size, sequence_length):
        """Generate boolean masks for MAE training"""
        masks, _ = generate_masked_positions(
            batch_size=batch_size,
            sequence_length=sequence_length,
            mask_ratio=self.mask_ratio,
            device=self.device
        )
        return masks

    def train_epoch(self, dataloader, epoch):
        self.model.train()
        total_loss = 0

        progress_bar = tqdm(dataloader, desc=f"Training Epoch {epoch+1}/{self.num_epochs}")
        for batch_idx, (pixel_values, targets) in enumerate(progress_bar):
            pixel_values = pixel_values.to(self.device)
            targets = targets.to(self.device)

            # Generate masks for this batch
            batch_size = pixel_values.size(0)
            sequence_length = (self.model.config.num_frames // self.model.config.tubelet_size) * \
                              (self.model.config.image_size // self.model.config.patch_size) ** 2
            bool_masked_pos, _ = generate_masked_positions(batch_size, sequence_length, self.mask_ratio, self.device)

            # Forward pass
            outputs = self.model(
                pixel_values=pixel_values,
                bool_masked_pos=bool_masked_pos
            )

            loss = outputs.loss

            # Backward pass
            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()

            if self.scheduler is not None:
                self.scheduler.step()

            # Update progress
            total_loss += loss.item()
            progress_bar.set_postfix({'loss': loss.item()})

            # Log to wandb
            wandb.log({
                'train_batch_loss': loss.item(),
                'learning_rate': self.optimizer.param_groups[0]['lr']
            })

        epoch_loss = total_loss / len(dataloader)
        return epoch_loss

    def validate(self, dataloader, epoch):
        self.model.eval()
        total_loss = 0

        with torch.no_grad():
            progress_bar = tqdm(dataloader, desc="Validation")
            for batch_idx, (pixel_values, targets) in enumerate(progress_bar):
                pixel_values = pixel_values.to(self.device)
                targets = targets.to(self.device)

                # Generate masks
                batch_size = pixel_values.size(0)
                sequence_length = (self.model.config.num_frames // self.model.config.tubelet_size) * \
                                  (self.model.config.image_size // self.model.config.patch_size) ** 2
                bool_masked_pos, _ = generate_masked_positions(batch_size, sequence_length, self.mask_ratio, self.device)

                outputs = self.model(
                    pixel_values=pixel_values,
                    bool_masked_pos=bool_masked_pos
                )

                loss = outputs.loss
                total_loss += loss.item()
                progress_bar.set_postfix({'val_loss': loss.item()})

        val_loss = total_loss / len(dataloader)
        return val_loss

    def train(self, train_dataloader, val_dataloader):
        # Initialize the scheduler now that we know the number of steps
        self.total_steps = len(train_dataloader) * self.num_epochs
        self.warmup_steps = len(train_dataloader) * self.warmup_epochs

        # Define warmup scheduler
        warmup_scheduler = LinearLR(
            self.optimizer,
            start_factor=1e-6,
            end_factor=1.0,
            total_iters=self.warmup_steps
        )

        # Define cosine annealing scheduler
        cosine_scheduler = CosineAnnealingLR(
            self.optimizer,
            T_max=self.total_steps - self.warmup_steps,
            eta_min=0
        )

        # Combine schedulers
        self.scheduler = SequentialLR(
            self.optimizer,
            schedulers=[warmup_scheduler, cosine_scheduler],
            milestones=[self.warmup_steps]
        )

        best_val_loss = float('inf')

        for epoch in range(self.num_epochs):
            print(f"\nEpoch {epoch+1}/{self.num_epochs}")

            # Training phase
            train_loss = self.train_epoch(train_dataloader, epoch)

            # Validation phase
            val_loss = self.validate(val_dataloader, epoch)

            # Log metrics
            wandb.log({
                'epoch': epoch + 1,
                'train_epoch_loss': train_loss,
                'val_loss': val_loss
            })

            # Save best model
            if val_loss < best_val_loss:
                best_val_loss = val_loss
                self.model.save_pretrained(os.path.join('best_videomae_model'))
                print(f"Saved best model with validation loss: {val_loss:.4f}")
                wandb.run.summary["best_val_loss"] = best_val_loss
                wandb.save(os.path.join('best_videomae_model', '*'))

            print(f"Epoch {epoch+1} - Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}")


def main():
    args = parse_args()

    wandb.init(
        project=args.wandb_project,
        entity=args.wandb_entity,
        name=args.wandb_run_name,
        config={
            "raw_base_dir": args.raw_base_dir,
            "seg_base_dir": args.seg_base_dir,
            "csv_output_dir": args.csv_output_dir,
            "checkpoint_dir": args.checkpoint_dir,
            "log_dir": args.log_dir,
            "batch_size": args.batch_size,
            "num_epochs": args.num_epochs,
            "learning_rate": args.learning_rate,
            "weight_decay": args.weight_decay,
            "subvol_size": args.subvol_size,
            "num_frames": args.num_frames,
            "mask_ratio": args.mask_ratio,
            "warmup_epochs": args.warmup_epochs,
            "resume_checkpoint": args.resume_checkpoint,
        },
        save_code=True,
    )

    # Initialize processor
    processor_videomae = VideoMAEImageProcessor.from_pretrained("MCG-NJU/videomae-base")

    # Define bounding box names (for demo, using one bbox)
    bbox_names = [f'bbox{i}' for i in range(1, 2)]  # Example: ['bbox1']

    all_vol_data = []
    all_syn_df = []

    for bbox_index, bbox_name in enumerate(bbox_names):
        print(f"Loading data for {bbox_name}...")
        raw_vol, seg_vol = load_volumes(bbox_name, args.raw_base_dir, args.seg_base_dir)
        if raw_vol is None or seg_vol is None:
            print(f"Skipping {bbox_name} due to loading errors.")
            continue

        # Suppose we have an Excel file: bbox1.xlsx, bbox2.xlsx, etc.
        excel_file = f"{bbox_name}.xlsx"
        if not os.path.exists(excel_file):
            print(f"Excel file {excel_file} not found. Skipping {bbox_name}.")
            continue

        syn_df = pd.read_excel(excel_file)

        # Ensure syn_df has the required columns
        required_columns = [
            'central_coord_1', 'central_coord_2', 'central_coord_3',
            'side_1_coord_1', 'side_1_coord_2', 'side_1_coord_3',
            'side_2_coord_1', 'side_2_coord_2', 'side_2_coord_3'
        ]
        if not all(col in syn_df.columns for col in required_columns):
            print(f"Excel file {excel_file} is missing required columns. Skipping {bbox_name}.")
            continue

        # Add the bbox_index to the DataFrame
        syn_df['bbox_index'] = bbox_index

        all_vol_data.append((raw_vol, seg_vol))
        all_syn_df.append(syn_df)

    if not all_syn_df:
        print("No synapse data loaded. Exiting.")
        wandb.finish()
        return

    combined_syn_df = pd.concat(all_syn_df, ignore_index=True)
    print(f"Total synapses loaded: {len(combined_syn_df)}")

    # Split into train/val
    train_syn_df, val_syn_df = train_test_split(combined_syn_df, test_size=0.2, random_state=42)
    print(f"Training synapses: {len(train_syn_df)}, Validation synapses: {len(val_syn_df)}")

    # Build Datasets
    dataset_videomae_train = VideoMAEDataset(
        vol_data_list=all_vol_data,
        synapse_df=train_syn_df,
        processor=processor_videomae,
        subvol_size=args.subvol_size,  # Should align with model's image_size
        num_frames=args.num_frames,
        alpha=0.3,
        verbose=args.verbose  # Set based on the verbosity flag
    )
    dataset_videomae_val = VideoMAEDataset(
        vol_data_list=all_vol_data,
        synapse_df=val_syn_df,
        processor=processor_videomae,
        subvol_size=args.subvol_size,
        num_frames=args.num_frames,
        alpha=0.3,
        verbose=args.verbose  # Set based on the verbosity flag
    )

    # Dynamically set the number of workers based on CPU cores
    num_cpus = multiprocessing.cpu_count()
    num_workers = min(4, num_cpus)  # Adjust as needed
    print(f"Using {num_workers} workers for DataLoader.")

    dataloader_videomae_train = DataLoader(
        dataset_videomae_train,
        batch_size=args.batch_size,
        shuffle=True,
        num_workers=num_workers,  # Increased number of workers
        pin_memory=True,
        persistent_workers=True  # Keeps workers alive between epochs for efficiency
    )
    dataloader_videomae_val = DataLoader(
        dataset_videomae_val,
        batch_size=args.batch_size,
        shuffle=False,
        num_workers=num_workers,  # Increased number of workers
        pin_memory=True,
        persistent_workers=True  # Keeps workers alive between epochs for efficiency
    )

    # Initialize trainer
    trainer = VideoMAEFineTuner(
        model_name="MCG-NJU/videomae-base",
        learning_rate=args.learning_rate,
        weight_decay=args.weight_decay,
        num_epochs=args.num_epochs,
        warmup_epochs=args.warmup_epochs,
        mask_ratio=args.mask_ratio
    )

    # Ensure that dataloaders are defined
    if dataloader_videomae_train is None or dataloader_videomae_val is None:
        raise ValueError("Define your training and validation dataloaders before training.")

    # Train the model
    trainer.train(dataloader_videomae_train, dataloader_videomae_val)

    # Close wandb
    wandb.finish()


if __name__ == "__main__":
    main()


Loading data for bbox1...
Total synapses loaded: 58
Training synapses: 46, Validation synapses: 12
Using 2 workers for DataLoader.
Model configuration:
Image Size: 224
Patch Size: 16
Tubelet Size: 2
Num Frames: 16


Epoch 1/100


Training Epoch 1/100:   0%|          | 0/23 [00:00<?, ?it/s]

Validation:   0%|          | 0/6 [00:00<?, ?it/s]

Saved best model with validation loss: 0.2562
Epoch 1 - Train Loss: 0.2833, Val Loss: 0.2562

Epoch 2/100


Training Epoch 2/100:   0%|          | 0/23 [00:00<?, ?it/s]

Validation:   0%|          | 0/6 [00:00<?, ?it/s]

Saved best model with validation loss: 0.2480
Epoch 2 - Train Loss: 0.2583, Val Loss: 0.2480

Epoch 3/100


Training Epoch 3/100:   0%|          | 0/23 [00:00<?, ?it/s]

# New

# Data

In [None]:
# def load_volumes(bbox_name, raw_base_dir, seg_base_dir):
#     """
#     Load raw volume and segmentation volume for a bounding box.

#     Args:
#         bbox_name (str): Name of the bounding box directory.
#         raw_base_dir (str): Base directory for raw data.
#         seg_base_dir (str): Base directory for segmentation data.

#     Returns:
#         tuple: (raw_vol, seg_vol) each as np.ndarray
#     """
#     raw_dir = os.path.join(raw_base_dir, bbox_name)
#     seg_dir = os.path.join(seg_base_dir, bbox_name)

#     raw_tif_files = sorted(glob.glob(os.path.join(raw_dir, 'slice_*.tif')))
#     seg_tif_files = sorted(glob.glob(os.path.join(seg_dir, 'slice_*.tif')))

#     if len(raw_tif_files) == 0:
#         print(f"No raw files found for {bbox_name} in {raw_dir}")
#         return None, None

#     if len(seg_tif_files) == 0:
#         print(f"No segmentation files found for {bbox_name} in {seg_dir}")
#         return None, None

#     if len(raw_tif_files) != len(seg_tif_files):
#         print(f"Mismatch in number of raw vs seg slices for {bbox_name}. Skipping.")
#         return None, None

#     try:
#         raw_vol = np.stack([iio.imread(f) for f in raw_tif_files], axis=0)  # shape: (Z, Y, X)
#         seg_vol = np.stack([iio.imread(f).astype(np.uint32) for f in seg_tif_files], axis=0)
#         return raw_vol, seg_vol
#     except Exception as e:
#         print(f"Error loading volumes for {bbox_name}: {e}")
#         return None, None


# def create_segmented_cube(
#     raw_vol,
#     seg_vol,
#     central_coord,
#     side1_coord,
#     side2_coord,
#     subvolume_size=80,
#     alpha=0.3
# ):
#     """
#     Constructs an 80x80x80 segmented 3D cube around the specified synapse coordinates
#     and overlays both segmentation masks (side1_coord, side2_coord) on the raw data
#     with specified transparency for each slice.

#     Returns:
#         np.ndarray: Overlaid cube of shape (height, width, 3, depth),
#                     i.e., (80, 80, 3, 80) if subvolume_size=80.
#     """

#     def create_segment_masks(segmentation_volume, s1_coord, s2_coord):
#         x1, y1, z1 = s1_coord
#         x2, y2, z2 = s2_coord
#         # Validate within volume
#         if not (0 <= z1 < segmentation_volume.shape[0] and
#                 0 <= y1 < segmentation_volume.shape[1] and
#                 0 <= x1 < segmentation_volume.shape[2]):
#             raise ValueError("Side1 coordinates are out of bounds.")

#         if not (0 <= z2 < segmentation_volume.shape[0] and
#                 0 <= y2 < segmentation_volume.shape[1] and
#                 0 <= x2 < segmentation_volume.shape[2]):
#             raise ValueError("Side2 coordinates are out of bounds.")

#         seg_id_1 = segmentation_volume[z1, y1, x1]
#         seg_id_2 = segmentation_volume[z2, y2, x2]

#         # If seg_id == 0, it means no segment at that voxel
#         if seg_id_1 == 0:
#             mask_1 = np.zeros_like(segmentation_volume, dtype=bool)
#         else:
#             mask_1 = (segmentation_volume == seg_id_1)

#         if seg_id_2 == 0:
#             mask_2 = np.zeros_like(segmentation_volume, dtype=bool)
#         else:
#             mask_2 = (segmentation_volume == seg_id_2)

#         return mask_1, mask_2

#     # Build masks
#     mask_1_full, mask_2_full = create_segment_masks(seg_vol, side1_coord, side2_coord)

#     # Define subvolume bounds
#     half_size = subvolume_size // 2
#     cx, cy, cz = central_coord

#     x_start, x_end = max(cx - half_size, 0), min(cx + half_size, raw_vol.shape[2])
#     y_start, y_end = max(cy - half_size, 0), min(cy + half_size, raw_vol.shape[1])
#     z_start, z_end = max(cz - half_size, 0), min(cz + half_size, raw_vol.shape[0])

#     # Extract subvolumes
#     sub_raw = raw_vol[z_start:z_end, y_start:y_end, x_start:x_end]
#     sub_mask_1 = mask_1_full[z_start:z_end, y_start:y_end, x_start:x_end]
#     sub_mask_2 = mask_2_full[z_start:z_end, y_start:y_end, x_start:x_end]

#     # Pad if smaller than subvolume_size
#     pad_z = subvolume_size - sub_raw.shape[0]
#     pad_y = subvolume_size - sub_raw.shape[1]
#     pad_x = subvolume_size - sub_raw.shape[2]

#     if pad_z > 0 or pad_y > 0 or pad_x > 0:
#         sub_raw = np.pad(sub_raw, ((0, pad_z), (0, pad_y), (0, pad_x)),
#                          mode='constant', constant_values=0)
#         sub_mask_1 = np.pad(sub_mask_1, ((0, pad_z), (0, pad_y), (0, pad_x)),
#                             mode='constant', constant_values=False)
#         sub_mask_2 = np.pad(sub_mask_2, ((0, pad_z), (0, pad_y), (0, pad_x)),
#                             mode='constant', constant_values=False)

#     # Slice to exact shape
#     sub_raw = sub_raw[:subvolume_size, :subvolume_size, :subvolume_size]
#     sub_mask_1 = sub_mask_1[:subvolume_size, :subvolume_size, :subvolume_size]
#     sub_mask_2 = sub_mask_2[:subvolume_size, :subvolume_size, :subvolume_size]

#     # ایجاد آرایه برای تصاویر سه‌کاناله
#     # ابعاد: (ارتفاع، عرض، کانال‌ها، عمق)
#     overlaid_cube = np.zeros((subvolume_size, subvolume_size, 3, subvolume_size), dtype=np.uint8)

#     for z in range(subvolume_size):
#         # نرمال‌سازی برش اصلی به محدوده [0, 1]
#         raw_slice = sub_raw[z].astype(np.float32)
#         mn, mx = raw_slice.min(), raw_slice.max()
#         if mx > mn:
#             raw_normalized = (raw_slice - mn) / (mx - mn)
#         else:
#             raw_normalized = raw_slice - mn  # در صورتی که همه مقادیر برابر باشند

#         # تبدیل به مقیاس 0-255
#         raw_scaled = (raw_normalized * 255).astype(np.uint8)

#         # کانال میانی (کانال 1) عکس اصلی
#         overlaid_cube[:, :, 1, z] = raw_scaled

#         # استخراج ماسک‌ها
#         mask1 = sub_mask_1[z].astype(np.uint8)
#         mask2 = sub_mask_2[z].astype(np.uint8)

#         # کانال 0: فقط قسمت سگمنت شده با mask_1
#         overlaid_cube[:, :, 0, z] = raw_scaled * mask1

#         # کانال 2: فقط قسمت سگمنت شده با mask_2
#         overlaid_cube[:, :, 2, z] = raw_scaled * mask2
#     return overlaid_cube


# class VideoMAEDataset(Dataset):
#     """
#     Dataset class that uses segmented volumes (side1 & side2) for VideoMAE pre-training.
#     """
#     def __init__(self, vol_data_list, synapse_df, processor, subvol_size=80, num_frames=16, alpha=0.3):
#         """
#         Args:
#             vol_data_list (List[Tuple[np.ndarray, np.ndarray]]): List of (raw_vol, seg_vol).
#             synapse_df (pd.DataFrame): DataFrame with synapse coordinates (central, side1, side2).
#             processor (VideoMAEImageProcessor): Processor for VideoMAE.
#             subvol_size (int): Size of the sub-volume to extract.
#             num_frames (int): Number of frames for the model.
#             alpha (float): Blending alpha for segmentation.
#         """
#         self.vol_data_list = vol_data_list
#         self.synapse_df = synapse_df.reset_index(drop=True)
#         self.processor = processor
#         self.subvol_size = subvol_size
#         self.num_frames = num_frames
#         self.alpha = alpha

#     def __len__(self):
#         return len(self.synapse_df)

#     def __getitem__(self, idx):
#         syn_info = self.synapse_df.iloc[idx]
#         bbox_index = syn_info['bbox_index']

#         raw_vol, seg_vol = self.vol_data_list[bbox_index]
#         if raw_vol is None or seg_vol is None:
#             # Return dummy data if volumes not found
#             pixel_values = torch.zeros((self.num_frames, 3, self.subvol_size, self.subvol_size), dtype=torch.float32)
#             return pixel_values, pixel_values

#         # Coordinates
#         central_coord = (
#             int(syn_info['central_coord_1']),
#             int(syn_info['central_coord_2']),
#             int(syn_info['central_coord_3'])
#         )
#         side1_coord = (
#             int(syn_info['side_1_coord_1']),
#             int(syn_info['side_1_coord_2']),
#             int(syn_info['side_1_coord_3'])
#         )
#         side2_coord = (
#             int(syn_info['side_2_coord_1']),
#             int(syn_info['side_2_coord_2']),
#             int(syn_info['side_2_coord_3'])
#         )

#         # Create the overlaid segmented cube
#         overlaid_cube = create_segmented_cube(
#             raw_vol=raw_vol,
#             seg_vol=seg_vol,
#             central_coord=central_coord,
#             side1_coord=side1_coord,
#             side2_coord=side2_coord,
#             subvolume_size=self.subvol_size,
#             alpha=self.alpha
#         )  # shape: (80, 80, 3, 80)

#         # We interpret the last dimension (depth) as frames
#         frames = []
#         for z in range(overlaid_cube.shape[3]):  # 80 slices
#             frame_rgb = overlaid_cube[..., z]  # (80, 80, 3)
#             frames.append(frame_rgb)

#         # Now reduce or expand to self.num_frames
#         total_slices = len(frames)  # 80
#         if total_slices < self.num_frames:
#             while len(frames) < self.num_frames:
#                 frames.append(frames[-1])
#         elif total_slices > self.num_frames:
#             indices = np.linspace(0, total_slices - 1, self.num_frames, dtype=int)
#             frames = [frames[i] for i in indices]

#         # Process using the VideoMAEImageProcessor
#         inputs = self.processor(frames, return_tensors="pt")
#         pixel_values = inputs["pixel_values"].squeeze(0)  # (num_frames, 3, H, W)
#         pixel_values = pixel_values.float()

#         # For MAE, target is the same
#         return pixel_values, pixel_values


# def generate_masked_positions(batch_size, sequence_length, mask_ratio=0.75, device='cuda'):
#     """
#     Generate a boolean mask indicating which positions are masked.
#     """
#     masks = torch.zeros(batch_size, sequence_length, dtype=torch.bool, device=device)
#     num_mask = int(mask_ratio * sequence_length)
#     for i in range(batch_size):
#         mask_indices = torch.randperm(sequence_length, device=device)[:num_mask]
#         masks[i, mask_indices] = True
#     return masks

# # For demonstration, let's assume we have bounding box names and their corresponding Excel files
# bbox_names = [f'bbox{i}' for i in range(1, 2)]  # fewer bboxes for a short demo
# all_vol_data = []
# all_syn_df = []

# for bbox_index, bbox_name in enumerate(bbox_names):
#     print(f"Loading data for {bbox_name}...")
#     raw_vol, seg_vol = load_volumes(bbox_name, args.raw_base_dir, args.seg_base_dir)
#     if raw_vol is None or seg_vol is None:
#         print(f"Skipping {bbox_name} due to loading errors.")
#         continue

#     # Suppose we have an Excel file: bbox1.xlsx, bbox2.xlsx, etc.
#     excel_file = f"{bbox_name}.xlsx"
#     if not os.path.exists(excel_file):
#         print(f"Excel file {excel_file} not found. Skipping {bbox_name}.")
#         continue

#     syn_df = pd.read_excel(excel_file)

#     # We assume syn_df has columns: central_coord_1/2/3, side_1_coord_1/2/3, side_2_coord_1/2/3, etc.
#     # We'll just add the bbox_index:
#     syn_df['bbox_index'] = bbox_index

#     all_vol_data.append((raw_vol, seg_vol))
#     all_syn_df.append(syn_df)

# # if not all_syn_df:
# #     print("No synapse data loaded. Exiting.")
# #     wandb.finish()
# #     return

# combined_syn_df = pd.concat(all_syn_df, ignore_index=True)
# print(f"Total synapses loaded: {len(combined_syn_df)}")

# # Split into train/val
# train_syn_df, val_syn_df = train_test_split(combined_syn_df, test_size=0.2, random_state=42)
# print(f"Training synapses: {len(train_syn_df)}, Validation synapses: {len(val_syn_df)}")

# # Build Datasets
# dataset_videomae_train = VideoMAEDataset(
#     vol_data_list=all_vol_data,
#     synapse_df=train_syn_df,
#     processor=processor_videomae,
#     subvol_size=args.subvol_size,
#     num_frames=args.num_frames,
#     alpha=0.3
# )
# dataset_videomae_val = VideoMAEDataset(
#     vol_data_list=all_vol_data,
#     synapse_df=val_syn_df,
#     processor=processor_videomae,
#     subvol_size=args.subvol_size,
#     num_frames=args.num_frames,
#     alpha=0.3
# )

# num_workers = min(4, multiprocessing.cpu_count())
# print(f"Using {num_workers} workers for DataLoader.")

# dataloader_videomae_train = DataLoader(
#     dataset_videomae_train,
#     batch_size=args.batch_size,
#     shuffle=True,
#     num_workers=num_workers,
#     pin_memory=True
# )
# dataloader_videomae_val = DataLoader(
#     dataset_videomae_val,
#     batch_size=args.batch_size,
#     shuffle=False,
#     num_workers=num_workers,
#     pin_memory=True
# )

Loading data for bbox1...
Total synapses loaded: 58
Training synapses: 46, Validation synapses: 12
Using 2 workers for DataLoader.


In [None]:
# from transformers import VideoMAEForVideoClassification, AutoImageProcessor

# # Load the pretrained Video MAE model
# model = VideoMAEForVideoClassification.from_pretrained("MCG-NJU/videomae-base-finetuned-kinetics", attn_implementation="sdpa", torch_dtype=torch.float16)
# processor = AutoImageProcessor.from_pretrained("MCG-NJU/videomae-base")

# # Define the processor (preprocessing pipeline) for the model
# # processor = VideoMAEProcessor.from_pretrained("facebook/videomae-base")


config.json:   0%|          | 0.00/22.9k [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/346M [00:00<?, ?B/s]

# Works

In [None]:
# import torch
# import torch.nn as nn
# from transformers import VideoMAEForPreTraining, VideoMAEConfig
# from torch.optim import AdamW
# from torch.optim.lr_scheduler import CosineAnnealingLR
# import wandb
# import numpy as np
# from tqdm.auto import tqdm

# class VideoMAEFineTuner:
#     def __init__(
#         self,
#         model_name="MCG-NJU/videomae-base",
#         learning_rate=1e-4,
#         weight_decay=0.05,
#         num_epochs=100,
#         warmup_epochs=10,
#         mask_ratio=0.75,
#         device='cuda' if torch.cuda.is_available() else 'cpu'
#     ):
#         self.device = device
#         self.num_epochs = num_epochs
#         self.warmup_epochs = warmup_epochs
#         self.mask_ratio = mask_ratio

#         # Initialize model
#         self.model = VideoMAEForPreTraining.from_pretrained(model_name)
#         self.model.to(self.device)

#         # Initialize optimizer and scheduler
#         self.optimizer = AdamW(
#             self.model.parameters(),
#             lr=learning_rate,
#             weight_decay=weight_decay,
#             betas=(0.9, 0.95)
#         )

#         # We'll set the scheduler later when we know the number of steps
#         self.scheduler = None

#     def generate_masks(self, batch_size, sequence_length):
#         """Generate boolean masks for MAE training"""
#         return generate_masked_positions(
#             batch_size=batch_size,
#             sequence_length=sequence_length,
#             mask_ratio=self.mask_ratio,
#             device=self.device
#         )

#     def train_epoch(self, dataloader, epoch):
#         self.model.train()
#         total_loss = 0

#         progress_bar = tqdm(dataloader, desc=f"Training Epoch {epoch}")
#         for batch_idx, (pixel_values, targets) in enumerate(progress_bar):
#             pixel_values = pixel_values.to(self.device)
#             targets = targets.to(self.device)

#             # Generate masks for this batch
#             batch_size = pixel_values.size(0)
#             sequence_length = (self.model.config.num_frames //
#                              self.model.config.tubelet_size) * \
#                             (self.model.config.image_size //
#                              self.model.config.patch_size) ** 2
#             bool_masked_pos = self.generate_masks(batch_size, sequence_length)

#             # Forward pass
#             outputs = self.model(
#                 pixel_values=pixel_values,
#                 bool_masked_pos=bool_masked_pos
#             )

#             loss = outputs.loss

#             # Backward pass
#             self.optimizer.zero_grad()
#             loss.backward()
#             self.optimizer.step()

#             if self.scheduler is not None:
#                 self.scheduler.step()

#             # Update progress
#             total_loss += loss.item()
#             progress_bar.set_postfix({'loss': loss.item()})

#             # Log to wandb
#             wandb.log({
#                 'train_batch_loss': loss.item(),
#                 'learning_rate': self.optimizer.param_groups[0]['lr']
#             })

#         epoch_loss = total_loss / len(dataloader)
#         return epoch_loss

#     def validate(self, dataloader):
#         self.model.eval()
#         total_loss = 0

#         with torch.no_grad():
#             progress_bar = tqdm(dataloader, desc="Validation")
#             for pixel_values, targets in progress_bar:
#                 pixel_values = pixel_values.to(self.device)
#                 targets = targets.to(self.device)

#                 # Generate masks
#                 batch_size = pixel_values.size(0)
#                 sequence_length = (self.model.config.num_frames //
#                                  self.model.config.tubelet_size) * \
#                                 (self.model.config.image_size //
#                                  self.model.config.patch_size) ** 2
#                 bool_masked_pos = self.generate_masks(batch_size, sequence_length)

#                 outputs = self.model(
#                     pixel_values=pixel_values,
#                     bool_masked_pos=bool_masked_pos
#                 )

#                 loss = outputs.loss
#                 total_loss += loss.item()
#                 progress_bar.set_postfix({'val_loss': loss.item()})

#         val_loss = total_loss / len(dataloader)
#         return val_loss

#     def train(self, train_dataloader, val_dataloader):
#         # Initialize the scheduler now that we know the number of steps
#         total_steps = len(train_dataloader) * self.num_epochs
#         warmup_steps = len(train_dataloader) * self.warmup_epochs

#         self.scheduler = CosineAnnealingLR(
#             self.optimizer,
#             T_max=total_steps - warmup_steps,
#             eta_min=0
#         )

#         best_val_loss = float('inf')

#         for epoch in range(self.num_epochs):
#             print(f"\nEpoch {epoch+1}/{self.num_epochs}")

#             # Training phase
#             train_loss = self.train_epoch(train_dataloader, epoch)

#             # Validation phase
#             val_loss = self.validate(val_dataloader)

#             # Log metrics
#             wandb.log({
#                 'epoch': epoch,
#                 'train_epoch_loss': train_loss,
#                 'val_loss': val_loss
#             })

#             # Save best model
#             if val_loss < best_val_loss:
#                 best_val_loss = val_loss
#                 self.model.save_pretrained('best_videomae_model')
#                 print(f"Saved best model with validation loss: {val_loss:.4f}")

#             print(f"Epoch {epoch+1} - Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}")

# def main():
#     # Initialize wandb
#     wandb.init(
#         project="videomae-finetuning",
#         config={
#             "learning_rate": 1e-4,
#             "weight_decay": 0.05,
#             "num_epochs": 100,
#             "warmup_epochs": 10,
#             "mask_ratio": 0.75,
#             "batch_size": args.batch_size,
#             "model": "MCG-NJU/videomae-base"
#         }
#     )

#     # Initialize trainer
#     trainer = VideoMAEFineTuner(
#         model_name="MCG-NJU/videomae-base",
#         learning_rate=1e-4,
#         weight_decay=0.05,
#         num_epochs=100,
#         warmup_epochs=10,
#         mask_ratio=0.75
#     )

#     # Train the model
#     trainer.train(dataloader_videomae_train, dataloader_videomae_val)

#     # Close wandb
#     wandb.finish()

# if __name__ == "__main__":
#     main()


Epoch 1/100


Training Epoch 0:   0%|          | 0/23 [00:00<?, ?it/s]

Validation:   0%|          | 0/6 [00:00<?, ?it/s]

Saved best model with validation loss: 0.3429
Epoch 1 - Train Loss: 0.3683, Val Loss: 0.3429

Epoch 2/100


Training Epoch 1:   0%|          | 0/23 [00:00<?, ?it/s]

Validation:   0%|          | 0/6 [00:00<?, ?it/s]

Saved best model with validation loss: 0.3339
Epoch 2 - Train Loss: 0.3393, Val Loss: 0.3339

Epoch 3/100


Training Epoch 2:   0%|          | 0/23 [00:00<?, ?it/s]

Validation:   0%|          | 0/6 [00:00<?, ?it/s]

Saved best model with validation loss: 0.3298
Epoch 3 - Train Loss: 0.3348, Val Loss: 0.3298

Epoch 4/100


Training Epoch 3:   0%|          | 0/23 [00:00<?, ?it/s]

Validation:   0%|          | 0/6 [00:00<?, ?it/s]

Epoch 4 - Train Loss: 0.3311, Val Loss: 0.3342

Epoch 5/100


Training Epoch 4:   0%|          | 0/23 [00:00<?, ?it/s]

Validation:   0%|          | 0/6 [00:00<?, ?it/s]

Epoch 5 - Train Loss: 0.3303, Val Loss: 0.3307

Epoch 6/100


Training Epoch 5:   0%|          | 0/23 [00:00<?, ?it/s]

Validation:   0%|          | 0/6 [00:00<?, ?it/s]

Epoch 6 - Train Loss: 0.3258, Val Loss: 0.3302

Epoch 7/100


Training Epoch 6:   0%|          | 0/23 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7c89af8e68c0>
Traceback (most recent call last):
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py", line 1604, in __del__
    self._shutdown_workers()Exception ignored in: 
<function _MultiProcessingDataLoaderIter.__del__ at 0x7c89af8e68c0>
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py", line 1587, in _shutdown_workers
Traceback (most recent call last):
    if w.is_alive():  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py", line 1604, in __del__

    self._shutdown_workers()  File "/usr/lib/python3.10/multiprocessing/process.py", line 160, in is_alive
    
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py", line 1587, in _shutdown_workers
    assert self._parent_pid == os.getpid(), 'can only test a child process'if w.is_alive():

  File "/usr/lib/python3.10/multiprocessing/process.py", line 1

Validation:   0%|          | 0/6 [00:00<?, ?it/s]

Saved best model with validation loss: 0.3266
Epoch 7 - Train Loss: 0.3224, Val Loss: 0.3266

Epoch 8/100


Training Epoch 7:   0%|          | 0/23 [00:00<?, ?it/s]

Validation:   0%|          | 0/6 [00:00<?, ?it/s]

Epoch 8 - Train Loss: 0.3209, Val Loss: 0.3274

Epoch 9/100


Training Epoch 8:   0%|          | 0/23 [00:00<?, ?it/s]

Validation:   0%|          | 0/6 [00:00<?, ?it/s]

Saved best model with validation loss: 0.3253
Epoch 9 - Train Loss: 0.3187, Val Loss: 0.3253

Epoch 10/100


Training Epoch 9:   0%|          | 0/23 [00:00<?, ?it/s]

Validation:   0%|          | 0/6 [00:00<?, ?it/s]

Saved best model with validation loss: 0.3237
Epoch 10 - Train Loss: 0.3179, Val Loss: 0.3237

Epoch 11/100


Training Epoch 10:   0%|          | 0/23 [00:00<?, ?it/s]

Validation:   0%|          | 0/6 [00:00<?, ?it/s]

Epoch 11 - Train Loss: 0.3174, Val Loss: 0.3263

Epoch 12/100


Training Epoch 11:   0%|          | 0/23 [00:00<?, ?it/s]

Validation:   0%|          | 0/6 [00:00<?, ?it/s]

Epoch 12 - Train Loss: 0.3160, Val Loss: 0.3281

Epoch 13/100


Training Epoch 12:   0%|          | 0/23 [00:00<?, ?it/s]

Validation:   0%|          | 0/6 [00:00<?, ?it/s]

Epoch 13 - Train Loss: 0.3129, Val Loss: 0.3263

Epoch 14/100


Training Epoch 13:   0%|          | 0/23 [00:00<?, ?it/s]

# 3 chan without early stop

In [None]:
# - Channel 0: Segmented part with mask_1 in grayscale
# - Channel 1: Original image in grayscale
# - Channel 2: Segmented part with mask_2 in grayscale
import os
import glob
import argparse
import numpy as np
import pandas as pd
import torch
from torch.utils.data import Dataset, DataLoader
import imageio.v2 as iio
from transformers import (
    VideoMAEForPreTraining,
    VideoMAEImageProcessor,
    get_cosine_schedule_with_warmup,
)
from sklearn.model_selection import train_test_split
import umap.umap_ as umap
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots
from tqdm import tqdm
import warnings
from torch.utils.tensorboard import SummaryWriter
from torch.cuda.amp import GradScaler, autocast
from collections import deque
import time
import multiprocessing
import imageio
import wandb
import io  # For in-memory file handling
import matplotlib.pyplot as plt

warnings.filterwarnings("ignore", category=UserWarning, module="torch.utils.data.dataloader")




# Data

In [None]:

def load_volumes(bbox_name, raw_base_dir, seg_base_dir):
    """
    Load raw volume and segmentation volume for a bounding box.

    Args:
        bbox_name (str): Name of the bounding box directory.
        raw_base_dir (str): Base directory for raw data.
        seg_base_dir (str): Base directory for segmentation data.

    Returns:
        tuple: (raw_vol, seg_vol) each as np.ndarray
    """
    raw_dir = os.path.join(raw_base_dir, bbox_name)
    seg_dir = os.path.join(seg_base_dir, bbox_name)

    raw_tif_files = sorted(glob.glob(os.path.join(raw_dir, 'slice_*.tif')))
    seg_tif_files = sorted(glob.glob(os.path.join(seg_dir, 'slice_*.tif')))

    if len(raw_tif_files) == 0:
        print(f"No raw files found for {bbox_name} in {raw_dir}")
        return None, None

    if len(seg_tif_files) == 0:
        print(f"No segmentation files found for {bbox_name} in {seg_dir}")
        return None, None

    if len(raw_tif_files) != len(seg_tif_files):
        print(f"Mismatch in number of raw vs seg slices for {bbox_name}. Skipping.")
        return None, None

    try:
        raw_vol = np.stack([iio.imread(f) for f in raw_tif_files], axis=0)  # shape: (Z, Y, X)
        seg_vol = np.stack([iio.imread(f).astype(np.uint32) for f in seg_tif_files], axis=0)
        return raw_vol, seg_vol
    except Exception as e:
        print(f"Error loading volumes for {bbox_name}: {e}")
        return None, None


def create_segmented_cube(
    raw_vol,
    seg_vol,
    central_coord,
    side1_coord,
    side2_coord,
    subvolume_size=80,
    alpha=0.3
):
    """
    Constructs an 80x80x80 segmented 3D cube around the specified synapse coordinates
    and overlays both segmentation masks (side1_coord, side2_coord) on the raw data
    with specified transparency for each slice.

    Returns:
        np.ndarray: Overlaid cube of shape (height, width, 3, depth),
                    i.e., (80, 80, 3, 80) if subvolume_size=80.
    """

    def create_segment_masks(segmentation_volume, s1_coord, s2_coord):
        x1, y1, z1 = s1_coord
        x2, y2, z2 = s2_coord
        # Validate within volume
        if not (0 <= z1 < segmentation_volume.shape[0] and
                0 <= y1 < segmentation_volume.shape[1] and
                0 <= x1 < segmentation_volume.shape[2]):
            raise ValueError("Side1 coordinates are out of bounds.")

        if not (0 <= z2 < segmentation_volume.shape[0] and
                0 <= y2 < segmentation_volume.shape[1] and
                0 <= x2 < segmentation_volume.shape[2]):
            raise ValueError("Side2 coordinates are out of bounds.")

        seg_id_1 = segmentation_volume[z1, y1, x1]
        seg_id_2 = segmentation_volume[z2, y2, x2]

        # If seg_id == 0, it means no segment at that voxel
        if seg_id_1 == 0:
            mask_1 = np.zeros_like(segmentation_volume, dtype=bool)
        else:
            mask_1 = (segmentation_volume == seg_id_1)

        if seg_id_2 == 0:
            mask_2 = np.zeros_like(segmentation_volume, dtype=bool)
        else:
            mask_2 = (segmentation_volume == seg_id_2)

        return mask_1, mask_2

    # Build masks
    mask_1_full, mask_2_full = create_segment_masks(seg_vol, side1_coord, side2_coord)

    # Define subvolume bounds
    half_size = subvolume_size // 2
    cx, cy, cz = central_coord

    x_start, x_end = max(cx - half_size, 0), min(cx + half_size, raw_vol.shape[2])
    y_start, y_end = max(cy - half_size, 0), min(cy + half_size, raw_vol.shape[1])
    z_start, z_end = max(cz - half_size, 0), min(cz + half_size, raw_vol.shape[0])

    # Extract subvolumes
    sub_raw = raw_vol[z_start:z_end, y_start:y_end, x_start:x_end]
    sub_mask_1 = mask_1_full[z_start:z_end, y_start:y_end, x_start:x_end]
    sub_mask_2 = mask_2_full[z_start:z_end, y_start:y_end, x_start:x_end]

    # Pad if smaller than subvolume_size
    pad_z = subvolume_size - sub_raw.shape[0]
    pad_y = subvolume_size - sub_raw.shape[1]
    pad_x = subvolume_size - sub_raw.shape[2]

    if pad_z > 0 or pad_y > 0 or pad_x > 0:
        sub_raw = np.pad(sub_raw, ((0, pad_z), (0, pad_y), (0, pad_x)),
                         mode='constant', constant_values=0)
        sub_mask_1 = np.pad(sub_mask_1, ((0, pad_z), (0, pad_y), (0, pad_x)),
                            mode='constant', constant_values=False)
        sub_mask_2 = np.pad(sub_mask_2, ((0, pad_z), (0, pad_y), (0, pad_x)),
                            mode='constant', constant_values=False)

    # Slice to exact shape
    sub_raw = sub_raw[:subvolume_size, :subvolume_size, :subvolume_size]
    sub_mask_1 = sub_mask_1[:subvolume_size, :subvolume_size, :subvolume_size]
    sub_mask_2 = sub_mask_2[:subvolume_size, :subvolume_size, :subvolume_size]

    # ایجاد آرایه برای تصاویر سه‌کاناله
    # ابعاد: (ارتفاع، عرض، کانال‌ها، عمق)
    overlaid_cube = np.zeros((subvolume_size, subvolume_size, 3, subvolume_size), dtype=np.uint8)

    for z in range(subvolume_size):
        # نرمال‌سازی برش اصلی به محدوده [0, 1]
        raw_slice = sub_raw[z].astype(np.float32)
        mn, mx = raw_slice.min(), raw_slice.max()
        if mx > mn:
            raw_normalized = (raw_slice - mn) / (mx - mn)
        else:
            raw_normalized = raw_slice - mn  # در صورتی که همه مقادیر برابر باشند

        # تبدیل به مقیاس 0-255
        raw_scaled = (raw_normalized * 255).astype(np.uint8)

        # کانال میانی (کانال 1) عکس اصلی
        overlaid_cube[:, :, 1, z] = raw_scaled

        # استخراج ماسک‌ها
        mask1 = sub_mask_1[z].astype(np.uint8)
        mask2 = sub_mask_2[z].astype(np.uint8)

        # کانال 0: فقط قسمت سگمنت شده با mask_1
        overlaid_cube[:, :, 0, z] = raw_scaled * mask1

        # کانال 2: فقط قسمت سگمنت شده با mask_2
        overlaid_cube[:, :, 2, z] = raw_scaled * mask2
    return overlaid_cube


class VideoMAEDataset(Dataset):
    """
    Dataset class that uses segmented volumes (side1 & side2) for VideoMAE pre-training.
    """
    def __init__(self, vol_data_list, synapse_df, processor, subvol_size=80, num_frames=16, alpha=0.3):
        """
        Args:
            vol_data_list (List[Tuple[np.ndarray, np.ndarray]]): List of (raw_vol, seg_vol).
            synapse_df (pd.DataFrame): DataFrame with synapse coordinates (central, side1, side2).
            processor (VideoMAEImageProcessor): Processor for VideoMAE.
            subvol_size (int): Size of the sub-volume to extract.
            num_frames (int): Number of frames for the model.
            alpha (float): Blending alpha for segmentation.
        """
        self.vol_data_list = vol_data_list
        self.synapse_df = synapse_df.reset_index(drop=True)
        self.processor = processor
        self.subvol_size = subvol_size
        self.num_frames = num_frames
        self.alpha = alpha

    def __len__(self):
        return len(self.synapse_df)

    def __getitem__(self, idx):
        syn_info = self.synapse_df.iloc[idx]
        bbox_index = syn_info['bbox_index']

        raw_vol, seg_vol = self.vol_data_list[bbox_index]
        if raw_vol is None or seg_vol is None:
            # Return dummy data if volumes not found
            pixel_values = torch.zeros((self.num_frames, 3, self.subvol_size, self.subvol_size), dtype=torch.float32)
            return pixel_values, pixel_values

        # Coordinates
        central_coord = (
            int(syn_info['central_coord_1']),
            int(syn_info['central_coord_2']),
            int(syn_info['central_coord_3'])
        )
        side1_coord = (
            int(syn_info['side_1_coord_1']),
            int(syn_info['side_1_coord_2']),
            int(syn_info['side_1_coord_3'])
        )
        side2_coord = (
            int(syn_info['side_2_coord_1']),
            int(syn_info['side_2_coord_2']),
            int(syn_info['side_2_coord_3'])
        )

        # Create the overlaid segmented cube
        overlaid_cube = create_segmented_cube(
            raw_vol=raw_vol,
            seg_vol=seg_vol,
            central_coord=central_coord,
            side1_coord=side1_coord,
            side2_coord=side2_coord,
            subvolume_size=self.subvol_size,
            alpha=self.alpha
        )  # shape: (80, 80, 3, 80)

        # We interpret the last dimension (depth) as frames
        frames = []
        for z in range(overlaid_cube.shape[3]):  # 80 slices
            frame_rgb = overlaid_cube[..., z]  # (80, 80, 3)
            frames.append(frame_rgb)

        # Now reduce or expand to self.num_frames
        total_slices = len(frames)  # 80
        if total_slices < self.num_frames:
            while len(frames) < self.num_frames:
                frames.append(frames[-1])
        elif total_slices > self.num_frames:
            indices = np.linspace(0, total_slices - 1, self.num_frames, dtype=int)
            frames = [frames[i] for i in indices]

        # Process using the VideoMAEImageProcessor
        inputs = self.processor(frames, return_tensors="pt")
        pixel_values = inputs["pixel_values"].squeeze(0)  # (num_frames, 3, H, W)
        pixel_values = pixel_values.float()

        # For MAE, target is the same
        return pixel_values, pixel_values


def generate_masked_positions(batch_size, sequence_length, mask_ratio=0.75, device='cuda'):
    """
    Generate a boolean mask indicating which positions are masked.
    """
    masks = torch.zeros(batch_size, sequence_length, dtype=torch.bool, device=device)
    num_mask = int(mask_ratio * sequence_length)
    for i in range(batch_size):
        mask_indices = torch.randperm(sequence_length, device=device)[:num_mask]
        masks[i, mask_indices] = True
    return masks


def log_input_gifs(pixel_values, epoch, prefix="Training"):
    """
    Convert input pixel values to GIFs and log them to WandB.
    """
    # Ensure pixel_values is on CPU and convert to numpy
    pixel_values = pixel_values.cpu().numpy()

    gifs = []
    for i in range(min(2, pixel_values.shape[0])):  # just take up to 2 samples for logging
        frames = pixel_values[i]  # shape: (num_frames, 3, height, width)

        # Normalize frames to [0, 255]
        frames = frames - frames.min()
        if frames.max() != 0:
            frames = frames / frames.max()
        frames = (frames * 255).astype(np.uint8)

        # Rearrange to (num_frames, height, width, 3)
        frames = frames.transpose(0, 2, 3, 1)

        image_list = [f for f in frames]
        gif_buffer = io.BytesIO()
        imageio.mimsave(gif_buffer, image_list, format='GIF', fps=5)
        gif_buffer.seek(0)
        gifs.append(wandb.Video(gif_buffer, format="gif"))

    for idx, gif in enumerate(gifs):
        wandb.log({f"{prefix}_Input_Cube_Sample_{idx+1}_Epoch_{epoch}": gif})


def visualize_before_training(dataloader, epoch, prefix="Pre-Training Visualization"):
    """
    Visualize a few samples from the dataloader before training starts.
    """
    print(f"Starting {prefix}...")
    for batch_idx, (pixel_values, targets) in enumerate(dataloader):
        log_input_gifs(pixel_values, epoch=epoch, prefix=prefix)
        print(f"Logged {prefix} for batch {batch_idx + 1}")
        # Only log the first batch
        break
    print(f"{prefix} completed.")


def save_sample_gifs(dataloader, save_dir, num_gifs=2, prefix="Segmented"):
    """
    Save a specified number of sample GIFs from the dataloader to a directory.
    This will save the segmented cubes as GIFs before training.
    """
    os.makedirs(save_dir, exist_ok=True)
    print(f"Saving {num_gifs} sample GIFs to {save_dir}...")

    saved_gifs = 0
    for batch_idx, (pixel_values, targets) in enumerate(dataloader):
        # pixel_values shape: (batch_size, num_frames, 3, H, W)
        pixel_values = pixel_values.cpu().numpy()
        for i in range(pixel_values.shape[0]):
            if saved_gifs >= num_gifs:
                break
            frames = pixel_values[i]  # (num_frames, 3, H, W)

            # Normalize to [0, 255]
            frames = frames - frames.min()
            if frames.max() > 0:
                frames = frames / frames.max()
            frames = (frames * 255).astype(np.uint8)

            # Rearrange to (num_frames, H, W, 3)
            frames = frames.transpose(0, 2, 3, 1)

            image_list = [frame for frame in frames]

            gif_filename = f"{prefix}_Sample_{saved_gifs + 1}.gif"
            gif_path = os.path.join(save_dir, gif_filename)
            imageio.mimsave(gif_path, image_list, format='GIF', fps=5)
            print(f"Saved GIF: {gif_path}")

            saved_gifs += 1

        if saved_gifs >= num_gifs:
            break
    print(f"Successfully saved {saved_gifs} GIF(s) to {save_dir}.")


# Loss

In [None]:
# import torch
# import torch.nn as nn
# import torch.nn.functional as F

# class DiceLoss(nn.Module):
#     """
#     Dice Loss for binary segmentation tasks.
#     """
#     def __init__(self, smooth=1e-5):
#         super(DiceLoss, self).__init__()
#         self.smooth = smooth

#     def forward(self, preds, targets):
#         """
#         Compute the Dice Loss.

#         Args:
#             preds (torch.Tensor): Predicted probabilities (B, C, H, W) where C is number of channels.
#             targets (torch.Tensor): Ground truth (B, C, H, W) where C is number of channels.

#         Returns:
#             torch.Tensor: Scalar Dice Loss.
#         """
#         preds = torch.sigmoid(preds)  # Ensure predictions are in [0, 1]
#         intersection = (preds * targets).sum(dim=(2, 3))  # Sum over H and W
#         union = preds.sum(dim=(2, 3)) + targets.sum(dim=(2, 3))

#         dice_score = (2.0 * intersection + self.smooth) / (union + self.smooth)
#         dice_loss = 1.0 - dice_score.mean()  # Average across batch and channels
#         return dice_loss


## Main

In [None]:
# args = parse_args()

# wandb.init(
#     project=args.wandb_project,
#     entity=args.wandb_entity,
#     name=args.wandb_run_name,
#     config={
#         "raw_base_dir": args.raw_base_dir,
#         "seg_base_dir": args.seg_base_dir,
#         "csv_output_dir": args.csv_output_dir,
#         "checkpoint_dir": args.checkpoint_dir,
#         "log_dir": args.log_dir,
#         "batch_size": args.batch_size,
#         "num_epochs": args.num_epochs,
#         "learning_rate": args.learning_rate,
#         "weight_decay": args.weight_decay,
#         "subvol_size": args.subvol_size,
#         "num_frames": args.num_frames,
#         "mask_ratio": args.mask_ratio,
#         "resume_checkpoint": args.resume_checkpoint,
#     },
#     save_code=True,
# )

# os.makedirs(args.csv_output_dir, exist_ok=True)
# os.makedirs(args.checkpoint_dir, exist_ok=True)
# os.makedirs(args.log_dir, exist_ok=True)

# # Directory to save GIFs
# saved_gifs_dir = os.path.join(args.log_dir, 'saved_gifs')
# os.makedirs(saved_gifs_dir, exist_ok=True)

# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# print(f"Using device: {device}")
# os.makedirs(args.csv_output_dir, exist_ok=True)
# os.makedirs(args.checkpoint_dir, exist_ok=True)
# os.makedirs(args.log_dir, exist_ok=True)

# # Directory to save GIFs
# saved_gifs_dir = os.path.join(args.log_dir, 'saved_gifs')
# os.makedirs(saved_gifs_dir, exist_ok=True)

# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# print(f"Using device: {device}")

# model_name = "MCG-NJU/videomae-base"
# print("Initializing VideoMAE model and processor...")

# model_videomae = VideoMAEForPreTraining.from_pretrained(
#     model_name,
#     attn_implementation="sdpa",
#     torch_dtype=torch.float32
# ).to(device)

# processor_videomae = VideoMAEImageProcessor.from_pretrained(model_name)
# model_videomae.train()
# print("VideoMAE model and processor initialized.")

# # For demonstration, let's assume we have bounding box names and their corresponding Excel files
# bbox_names = [f'bbox{i}' for i in range(1, 2)]  # fewer bboxes for a short demo
# all_vol_data = []
# all_syn_df = []

# for bbox_index, bbox_name in enumerate(bbox_names):
#     print(f"Loading data for {bbox_name}...")
#     raw_vol, seg_vol = load_volumes(bbox_name, args.raw_base_dir, args.seg_base_dir)
#     if raw_vol is None or seg_vol is None:
#         print(f"Skipping {bbox_name} due to loading errors.")
#         continue

#     # Suppose we have an Excel file: bbox1.xlsx, bbox2.xlsx, etc.
#     excel_file = f"{bbox_name}.xlsx"
#     if not os.path.exists(excel_file):
#         print(f"Excel file {excel_file} not found. Skipping {bbox_name}.")
#         continue

#     syn_df = pd.read_excel(excel_file)

#     # We assume syn_df has columns: central_coord_1/2/3, side_1_coord_1/2/3, side_2_coord_1/2/3, etc.
#     # We'll just add the bbox_index:
#     syn_df['bbox_index'] = bbox_index

#     all_vol_data.append((raw_vol, seg_vol))
#     all_syn_df.append(syn_df)

# # if not all_syn_df:
# #     print("No synapse data loaded. Exiting.")
# #     wandb.finish()
# #     return

# combined_syn_df = pd.concat(all_syn_df, ignore_index=True)
# print(f"Total synapses loaded: {len(combined_syn_df)}")

# # Split into train/val
# train_syn_df, val_syn_df = train_test_split(combined_syn_df, test_size=0.2, random_state=42)
# print(f"Training synapses: {len(train_syn_df)}, Validation synapses: {len(val_syn_df)}")

# # Build Datasets
# dataset_videomae_train = VideoMAEDataset(
#     vol_data_list=all_vol_data,
#     synapse_df=train_syn_df,
#     processor=processor_videomae,
#     subvol_size=args.subvol_size,
#     num_frames=args.num_frames,
#     alpha=0.3
# )
# dataset_videomae_val = VideoMAEDataset(
#     vol_data_list=all_vol_data,
#     synapse_df=val_syn_df,
#     processor=processor_videomae,
#     subvol_size=args.subvol_size,
#     num_frames=args.num_frames,
#     alpha=0.3
# )

# num_workers = min(4, multiprocessing.cpu_count())
# print(f"Using {num_workers} workers for DataLoader.")

# dataloader_videomae_train = DataLoader(
#     dataset_videomae_train,
#     batch_size=args.batch_size,
#     shuffle=True,
#     num_workers=num_workers,
#     pin_memory=True
# )
# dataloader_videomae_val = DataLoader(
#     dataset_videomae_val,
#     batch_size=args.batch_size,
#     shuffle=False,
#     num_workers=num_workers,
#     pin_memory=True
# )

# print("Saving 2 sample segmented GIFs before training starts...")
# save_sample_gifs(
#     dataloader=dataloader_videomae_train,
#     save_dir=saved_gifs_dir,
#     num_gifs=2,
#     prefix="Segmented"
# )

# print("Visualizing sample inputs before training starts...")
# visualize_before_training(dataloader_videomae_train, epoch=0, prefix="Pre-Training")
# print("Visualization completed.")


[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33malim98barnet[0m ([33malim98barnet-university-of-tehran[0m). Use [1m`wandb login --relogin`[0m to force relogin


Using device: cuda
Using device: cuda
Initializing VideoMAE model and processor...


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


VideoMAE model and processor initialized.
Loading data for bbox1...
Total synapses loaded: 58
Training synapses: 46, Validation synapses: 12
Using 2 workers for DataLoader.
Saving 2 sample segmented GIFs before training starts...
Saving 2 sample GIFs to logs/saved_gifs...
Saved GIF: logs/saved_gifs/Segmented_Sample_1.gif
Saved GIF: logs/saved_gifs/Segmented_Sample_2.gif
Successfully saved 2 GIF(s) to logs/saved_gifs.
Visualizing sample inputs before training starts...
Starting Pre-Training...
Logged Pre-Training for batch 1
Pre-Training completed.
Visualization completed.


In [None]:
# def reshape_model_output(logits, num_frames, image_size, patch_size, device):
#     """
#     Reshape the output logits of the model to match the original input pixel values.
#     Note: VideoMAE outputs predictions only for masked patches.

#     Args:
#         logits (torch.Tensor): Model output logits
#         num_frames (int): Number of frames in the input video
#         image_size (int): Size of the input image (assumes square images)
#         patch_size (int): Size of each patch in the input image
#         device (torch.device): Device to use for processing

#     Returns:
#         torch.Tensor: Reconstructed pixel values of shape (batch_size, num_frames, 3, H, W)
#     """
#     batch_size = logits.shape[0]
#     n_patches = (image_size // patch_size) ** 2

#     # Get hidden dimension size from input
#     hidden_dim = logits.shape[-1] // 2  # Divide by 2 since we're working with masked predictions

#     # First reshape to batch and temporal dimensions
#     logits = logits.reshape(batch_size, -1, hidden_dim)

#     # Calculate patches per frame (accounting for masked ratio)
#     patches_per_frame = n_patches // 2  # Half the patches due to masking

#     # Reshape to separate temporal dimension
#     logits = logits.reshape(batch_size, num_frames, -1, hidden_dim)

#     # Reshape patches to spatial dimensions
#     h_patches = w_patches = int((image_size // patch_size) // np.sqrt(2))  # Adjust for masking

#     # Reshape to spatial layout
#     logits = logits.reshape(batch_size, num_frames, h_patches, w_patches, patch_size, patch_size, 3)

#     # Permute to get final image format
#     logits = logits.permute(0, 1, 6, 2, 4, 3, 5).contiguous()
#     logits = logits.reshape(batch_size, num_frames, 3, image_size // 2, image_size // 2)

#     # Upsample to full resolution using bilinear interpolation
#     logits = torch.nn.functional.interpolate(
#         logits.reshape(-1, 3, image_size // 2, image_size // 2),
#         size=(image_size, image_size),
#         mode='bilinear',
#         align_corners=False
#     )

#     # Reshape back to include temporal dimension
#     logits = logits.reshape(batch_size, num_frames, 3, image_size, image_size)

#     return logits.to(device)
# def compute_masked_loss(predictions, targets, bool_masked_pos):
#     """
#     Compute loss only on masked patches.

#     Args:
#         predictions (torch.Tensor): Model predictions [B, T, C, H, W]
#         targets (torch.Tensor): Target values [B, T, C, H, W]
#         bool_masked_pos (torch.Tensor): Boolean mask indicating masked positions

#     Returns:
#         torch.Tensor: Loss computed only on masked positions
#     """
#     B, T, C, H, W = predictions.shape
#     patch_size = 16  # VideoMAE default patch size
#     n_patches = H * W // (patch_size * patch_size)

#     # Reshape to patches
#     predictions = predictions.reshape(B * T, C, H // patch_size, patch_size, W // patch_size, patch_size)
#     targets = targets.reshape(B * T, C, H // patch_size, patch_size, W // patch_size, patch_size)

#     # Rearrange to [B*T, N, P] where N is number of patches and P is pixels per patch
#     predictions = predictions.permute(0, 2, 4, 1, 3, 5).reshape(B * T, -1, C * patch_size * patch_size)
#     targets = targets.permute(0, 2, 4, 1, 3, 5).reshape(B * T, -1, C * patch_size * patch_size)

#     # Reshape bool_masked_pos to match
#     bool_masked_pos = bool_masked_pos.reshape(B * T, -1)

#     # Get masked patches only
#     predictions_masked = predictions[bool_masked_pos]
#     targets_masked = targets[bool_masked_pos]

#     # Compute loss
#     loss = F.mse_loss(predictions_masked, targets_masked)

#     return loss

In [None]:
# def main():

#     optimizer = torch.optim.AdamW(
#         model_videomae.parameters(),
#         lr=args.learning_rate,
#         weight_decay=args.weight_decay
#     )

#     total_steps = len(dataloader_videomae_train) * args.num_epochs
#     scheduler = get_cosine_schedule_with_warmup(
#         optimizer,
#         num_warmup_steps=int(0.1 * total_steps),
#         num_training_steps=total_steps
#     )

#     # Removed EarlyStopping initialization
#     # early_stopping = EarlyStopping(
#     #     patience=args.patience,
#     #     verbose=True,
#     #     path=os.path.join(args.checkpoint_dir, 'best_model.pth')
#     # )


#     scaler = GradScaler()

#     start_epoch = 1
#     if args.resume_checkpoint:
#         if os.path.exists(args.resume_checkpoint):
#             checkpoint = torch.load(args.resume_checkpoint, map_location=device)
#             model_videomae.load_state_dict(checkpoint['model_state_dict'])
#             optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
#             scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
#             start_epoch = checkpoint['epoch'] + 1
#             print(f"Resumed training from checkpoint {args.resume_checkpoint} at epoch {start_epoch}")
#             wandb.run.summary["resumed_from_epoch"] = start_epoch
#         else:
#             print(f"Checkpoint {args.resume_checkpoint} not found. Starting from scratch.")
#     # Initialize the loss functions
#     mse_loss_fn = nn.MSELoss()
#     dice_loss_fn = DiceLoss()

#     # Training Loop
#     for epoch in range(start_epoch, args.num_epochs + 1):
#         model_videomae.train()
#         epoch_loss = 0.0

#         train_pbar = tqdm(dataloader_videomae_train, desc=f"Epoch {epoch}/{args.num_epochs} - Train")
#         for batch_idx, (pixel_values, targets) in enumerate(train_pbar):
#             pixel_values = pixel_values.to(device)
#             targets = targets.to(device)
#             optimizer.zero_grad()

#             tubelet_size = model_videomae.config.tubelet_size
#             image_size = model_videomae.config.image_size
#             patch_size = model_videomae.config.patch_size

#             num_patches_per_frame = (image_size // patch_size) ** 2
#             num_tubelets = pixel_values.shape[1] // tubelet_size
#             sequence_length = num_tubelets * num_patches_per_frame

#             bool_masked_pos = generate_masked_positions(
#                 pixel_values.shape[0],
#                 sequence_length,
#                 mask_ratio=args.mask_ratio,
#                 device=device
#             )

#             with autocast():
#                 outputs = model_videomae(
#                     pixel_values=pixel_values,
#                     bool_masked_pos=bool_masked_pos
#                 )

#                 # # Reshape logits to match target
#                 logits_reshaped = reshape_model_output(
#                     logits=outputs.logits,
#                     num_frames=args.num_frames,
#                     image_size=image_size,
#                     patch_size=patch_size,
#                     device=device
#                 )

#                 # # Compute losses
#                 # mse_loss = mse_loss_fn(logits_reshaped, targets)
#                 # dice_loss = dice_loss_fn(logits_reshaped, targets)
#                 # total_loss = mse_loss + dice_loss
#                 loss = compute_masked_loss(logits_reshaped, targets, bool_masked_pos)
#             scaler.scale(loss).backward()
#             scaler.unscale_(optimizer)
#             torch.nn.utils.clip_grad_norm_(model_videomae.parameters(), max_norm=1.0)
#             scaler.step(optimizer)
#             scaler.update()
#             scheduler.step()

#             epoch_loss += loss.item()
#             train_pbar.set_postfix({'loss': loss.item(), 'lr': scheduler.get_last_lr()[0]})

#     avg_epoch_loss = epoch_loss / len(dataloader_videomae_train)
#     print(f"Epoch {epoch} - Training Loss: {avg_epoch_loss:.4f}")
#     wandb.log({
#         'epoch': epoch,
#         'train_loss': avg_epoch_loss,
#         'learning_rate': scheduler.get_last_lr()[0]
#     })

#     # Validation Step (unchanged)
#     model_videomae.eval()
#     val_loss = 0.0
#     with torch.no_grad():
#         val_pbar = tqdm(dataloader_videomae_val, desc=f"Epoch {epoch}/{args.num_epochs} - Val")
#         for batch_idx, (pixel_values, targets) in enumerate(val_pbar):
#             pixel_values = pixel_values.to(device)
#             targets = targets.to(device)

#             with autocast():  # Updated for the future warning
#                 outputs = model_videomae(
#                     pixel_values=pixel_values,
#                     bool_masked_pos=bool_masked_pos
#                 )
#                 logits_reshaped = reshape_model_output(
#                     logits=outputs.logits,
#                     num_frames=args.num_frames,
#                     image_size=image_size,
#                     patch_size=patch_size,
#                     device=device
#                 )
#                 loss = compute_masked_loss(logits_reshaped, targets, bool_masked_pos)

#                 val_loss += loss.item()


#         avg_val_loss = val_loss / len(dataloader_videomae_val)
#         print(f"Epoch {epoch} - Validation Loss: {avg_val_loss:.4f}")
#         wandb.log({'val_loss': avg_val_loss, 'epoch': epoch})


#         checkpoint_path = os.path.join(args.checkpoint_dir, f'epoch_{epoch}.pth')
#         torch.save({
#             'epoch': epoch,
#             'model_state_dict': model_videomae.state_dict(),
#             'optimizer_state_dict': optimizer.state_dict(),
#             'scheduler_state_dict': scheduler.state_dict(),
#             'loss': avg_epoch_loss,
#         }, checkpoint_path)
#         print(f"Checkpoint saved at {checkpoint_path}")
#         wandb.save(checkpoint_path)

#     # Final save
#     final_model_path = os.path.join(args.checkpoint_dir, 'final_checkpoint.pth')
#     torch.save({
#         'epoch': epoch,
#         'model_state_dict': model_videomae.state_dict(),
#         'optimizer_state_dict': optimizer.state_dict(),
#         'scheduler_state_dict': scheduler.state_dict(),
#         'loss': avg_epoch_loss,
#     }, final_model_path)
#     print(f"Training completed. Final checkpoint saved at {final_model_path}")
#     wandb.save(final_model_path)

#     artifact = wandb.Artifact('final_model', type='model')
#     artifact.add_file(final_model_path)
#     wandb.log_artifact(artifact)

#     wandb.finish()


# if __name__ == "__main__":
#     main()


  scaler = GradScaler()
  with autocast():
Epoch 1/100 - Train:   0%|          | 0/23 [00:02<?, ?it/s]


RuntimeError: shape '[2, 16, 9, 9, 16, 16, 3]' is invalid for input of size 2408448

#run 3 channedl seg

In [None]:

# # - Channel 0: Segmented part with mask_1 in grayscale
# # - Channel 1: Original image in grayscale
# # - Channel 2: Segmented part with mask_2 in grayscale
# import os
# import glob
# import argparse
# import numpy as np
# import pandas as pd
# import torch
# from torch.utils.data import Dataset, DataLoader
# import imageio.v2 as iio
# from transformers import (
#     VideoMAEForPreTraining,
#     VideoMAEImageProcessor,
#     get_cosine_schedule_with_warmup,
# )
# from sklearn.model_selection import train_test_split
# import umap.umap_ as umap
# import plotly.express as px
# import plotly.graph_objects as go
# from plotly.subplots import make_subplots
# from tqdm import tqdm
# import warnings
# from torch.utils.tensorboard import SummaryWriter
# from torch.cuda.amp import GradScaler, autocast
# from collections import deque
# import time
# import multiprocessing
# import imageio
# import wandb
# import io  # For in-memory file handling
# import matplotlib.pyplot as plt

# warnings.filterwarnings("ignore", category=UserWarning, module="torch.utils.data.dataloader")


# def parse_args():
#     """
#     Parse command-line arguments for configurable paths and training parameters.
#     """
#     parser = argparse.ArgumentParser(description="VideoMAE Pre-training Script with Segmented Videos")
#     parser.add_argument('--raw_base_dir', type=str, default='./raw', help='Path to raw data directory')
#     parser.add_argument('--seg_base_dir', type=str, default='./seg', help='Path to segmentation data directory')
#     parser.add_argument('--csv_output_dir', type=str, default='csv_outputs', help='Directory to save CSV outputs')
#     parser.add_argument('--checkpoint_dir', type=str, default='checkpoints', help='Directory to save model checkpoints')
#     parser.add_argument('--log_dir', type=str, default='logs', help='Directory for TensorBoard logs')
#     parser.add_argument('--batch_size', type=int, default=2, help='Batch size for training')
#     parser.add_argument('--num_epochs', type=int, default=100, help='Number of training epochs')
#     parser.add_argument('--learning_rate', type=float, default=1e-4, help='Learning rate for optimizer')
#     parser.add_argument('--weight_decay', type=float, default=1e-2, help='Weight decay for optimizer')
#     parser.add_argument('--subvol_size', type=int, default=80, help='Size of the sub-volume to extract')
#     parser.add_argument('--num_frames', type=int, default=16, help='Number of frames per video clip')
#     parser.add_argument('--mask_ratio', type=float, default=0.50, help='Mask ratio for VideoMAE')
#     parser.add_argument('--patience', type=int, default=3, help='Patience for early stopping')
#     parser.add_argument('--resume_checkpoint', type=str, default=None, help='Path to resume checkpoint')

#     # Optionally add WandB specific arguments
#     parser.add_argument('--wandb_project', type=str, default='VideoMAE_PreTraining', help='WandB project name')
#     parser.add_argument('--wandb_entity', type=str, default=None, help='WandB entity/team name')
#     parser.add_argument('--wandb_run_name', type=str, default=None, help='WandB run name')

#     args, _ = parser.parse_known_args()
#     return args


# class EarlyStopping:
#     """
#     Early stopping utility to halt training when validation loss stops improving.
#     """
#     def __init__(self, patience=10, verbose=False, delta=0.0, path='checkpoint.pth'):
#         """
#         Args:
#             patience (int): How long to wait after last time validation loss improved.
#             verbose (bool): If True, prints a message for each validation loss improvement.
#             delta (float): Minimum change in the monitored quantity to qualify as an improvement.
#             path (str): Path for the checkpoint to be saved to.
#         """
#         self.patience = patience
#         self.verbose = verbose
#         self.delta = delta
#         self.path = path
#         self.counter = 0
#         self.best_loss = None
#         self.early_stop = False

#     def __call__(self, current_loss, model, optimizer, scheduler, epoch):
#         if self.best_loss is None:
#             self.best_loss = current_loss
#             self.save_checkpoint(model, optimizer, scheduler, epoch, current_loss)
#         elif current_loss > self.best_loss - self.delta:
#             self.counter += 1
#             if self.verbose:
#                 print(f"EarlyStopping counter: {self.counter} out of {self.patience}")
#             if self.counter >= self.patience:
#                 self.early_stop = True
#         else:
#             self.best_loss = current_loss
#             self.save_checkpoint(model, optimizer, scheduler, epoch, current_loss)
#             self.counter = 0

#     def save_checkpoint(self, model, optimizer, scheduler, epoch, loss):
#         """
#         Saves model when validation loss decreases.
#         """
#         torch.save({
#             'epoch': epoch,
#             'model_state_dict': model.state_dict(),
#             'optimizer_state_dict': optimizer.state_dict(),
#             'scheduler_state_dict': scheduler.state_dict(),
#             'loss': loss,
#         }, self.path)
#         if self.verbose:
#             print(f"Validation loss decreased. Saving model to {self.path}")


# def load_volumes(bbox_name, raw_base_dir, seg_base_dir):
#     """
#     Load raw volume and segmentation volume for a bounding box.

#     Args:
#         bbox_name (str): Name of the bounding box directory.
#         raw_base_dir (str): Base directory for raw data.
#         seg_base_dir (str): Base directory for segmentation data.

#     Returns:
#         tuple: (raw_vol, seg_vol) each as np.ndarray
#     """
#     raw_dir = os.path.join(raw_base_dir, bbox_name)
#     seg_dir = os.path.join(seg_base_dir, bbox_name)

#     raw_tif_files = sorted(glob.glob(os.path.join(raw_dir, 'slice_*.tif')))
#     seg_tif_files = sorted(glob.glob(os.path.join(seg_dir, 'slice_*.tif')))

#     if len(raw_tif_files) == 0:
#         print(f"No raw files found for {bbox_name} in {raw_dir}")
#         return None, None

#     if len(seg_tif_files) == 0:
#         print(f"No segmentation files found for {bbox_name} in {seg_dir}")
#         return None, None

#     if len(raw_tif_files) != len(seg_tif_files):
#         print(f"Mismatch in number of raw vs seg slices for {bbox_name}. Skipping.")
#         return None, None

#     try:
#         raw_vol = np.stack([iio.imread(f) for f in raw_tif_files], axis=0)  # shape: (Z, Y, X)
#         seg_vol = np.stack([iio.imread(f).astype(np.uint32) for f in seg_tif_files], axis=0)
#         return raw_vol, seg_vol
#     except Exception as e:
#         print(f"Error loading volumes for {bbox_name}: {e}")
#         return None, None


# def create_segmented_cube(
#     raw_vol,
#     seg_vol,
#     central_coord,
#     side1_coord,
#     side2_coord,
#     subvolume_size=80,
#     alpha=0.3
# ):
#     """
#     Constructs an 80x80x80 segmented 3D cube around the specified synapse coordinates
#     and overlays both segmentation masks (side1_coord, side2_coord) on the raw data
#     with specified transparency for each slice.

#     Returns:
#         np.ndarray: Overlaid cube of shape (height, width, 3, depth),
#                     i.e., (80, 80, 3, 80) if subvolume_size=80.
#     """

#     def create_segment_masks(segmentation_volume, s1_coord, s2_coord):
#         x1, y1, z1 = s1_coord
#         x2, y2, z2 = s2_coord
#         # Validate within volume
#         if not (0 <= z1 < segmentation_volume.shape[0] and
#                 0 <= y1 < segmentation_volume.shape[1] and
#                 0 <= x1 < segmentation_volume.shape[2]):
#             raise ValueError("Side1 coordinates are out of bounds.")

#         if not (0 <= z2 < segmentation_volume.shape[0] and
#                 0 <= y2 < segmentation_volume.shape[1] and
#                 0 <= x2 < segmentation_volume.shape[2]):
#             raise ValueError("Side2 coordinates are out of bounds.")

#         seg_id_1 = segmentation_volume[z1, y1, x1]
#         seg_id_2 = segmentation_volume[z2, y2, x2]

#         # If seg_id == 0, it means no segment at that voxel
#         if seg_id_1 == 0:
#             mask_1 = np.zeros_like(segmentation_volume, dtype=bool)
#         else:
#             mask_1 = (segmentation_volume == seg_id_1)

#         if seg_id_2 == 0:
#             mask_2 = np.zeros_like(segmentation_volume, dtype=bool)
#         else:
#             mask_2 = (segmentation_volume == seg_id_2)

#         return mask_1, mask_2

#     # Build masks
#     mask_1_full, mask_2_full = create_segment_masks(seg_vol, side1_coord, side2_coord)

#     # Define subvolume bounds
#     half_size = subvolume_size // 2
#     cx, cy, cz = central_coord

#     x_start, x_end = max(cx - half_size, 0), min(cx + half_size, raw_vol.shape[2])
#     y_start, y_end = max(cy - half_size, 0), min(cy + half_size, raw_vol.shape[1])
#     z_start, z_end = max(cz - half_size, 0), min(cz + half_size, raw_vol.shape[0])

#     # Extract subvolumes
#     sub_raw = raw_vol[z_start:z_end, y_start:y_end, x_start:x_end]
#     sub_mask_1 = mask_1_full[z_start:z_end, y_start:y_end, x_start:x_end]
#     sub_mask_2 = mask_2_full[z_start:z_end, y_start:y_end, x_start:x_end]

#     # Pad if smaller than subvolume_size
#     pad_z = subvolume_size - sub_raw.shape[0]
#     pad_y = subvolume_size - sub_raw.shape[1]
#     pad_x = subvolume_size - sub_raw.shape[2]

#     if pad_z > 0 or pad_y > 0 or pad_x > 0:
#         sub_raw = np.pad(sub_raw, ((0, pad_z), (0, pad_y), (0, pad_x)),
#                          mode='constant', constant_values=0)
#         sub_mask_1 = np.pad(sub_mask_1, ((0, pad_z), (0, pad_y), (0, pad_x)),
#                             mode='constant', constant_values=False)
#         sub_mask_2 = np.pad(sub_mask_2, ((0, pad_z), (0, pad_y), (0, pad_x)),
#                             mode='constant', constant_values=False)

#     # Slice to exact shape
#     sub_raw = sub_raw[:subvolume_size, :subvolume_size, :subvolume_size]
#     sub_mask_1 = sub_mask_1[:subvolume_size, :subvolume_size, :subvolume_size]
#     sub_mask_2 = sub_mask_2[:subvolume_size, :subvolume_size, :subvolume_size]

#     # ایجاد آرایه برای تصاویر سه‌کاناله
#     # ابعاد: (ارتفاع، عرض، کانال‌ها، عمق)
#     overlaid_cube = np.zeros((subvolume_size, subvolume_size, 3, subvolume_size), dtype=np.uint8)

#     for z in range(subvolume_size):
#         # نرمال‌سازی برش اصلی به محدوده [0, 1]
#         raw_slice = sub_raw[z].astype(np.float32)
#         mn, mx = raw_slice.min(), raw_slice.max()
#         if mx > mn:
#             raw_normalized = (raw_slice - mn) / (mx - mn)
#         else:
#             raw_normalized = raw_slice - mn  # در صورتی که همه مقادیر برابر باشند

#         # تبدیل به مقیاس 0-255
#         raw_scaled = (raw_normalized * 255).astype(np.uint8)

#         # کانال میانی (کانال 1) عکس اصلی
#         overlaid_cube[:, :, 1, z] = raw_scaled

#         # استخراج ماسک‌ها
#         mask1 = sub_mask_1[z].astype(np.uint8)
#         mask2 = sub_mask_2[z].astype(np.uint8)

#         # کانال 0: فقط قسمت سگمنت شده با mask_1
#         overlaid_cube[:, :, 0, z] = raw_scaled * mask1

#         # کانال 2: فقط قسمت سگمنت شده با mask_2
#         overlaid_cube[:, :, 2, z] = raw_scaled * mask2
#     return overlaid_cube


# class VideoMAEDataset(Dataset):
#     """
#     Dataset class that uses segmented volumes (side1 & side2) for VideoMAE pre-training.
#     """
#     def __init__(self, vol_data_list, synapse_df, processor, subvol_size=80, num_frames=16, alpha=0.3):
#         """
#         Args:
#             vol_data_list (List[Tuple[np.ndarray, np.ndarray]]): List of (raw_vol, seg_vol).
#             synapse_df (pd.DataFrame): DataFrame with synapse coordinates (central, side1, side2).
#             processor (VideoMAEImageProcessor): Processor for VideoMAE.
#             subvol_size (int): Size of the sub-volume to extract.
#             num_frames (int): Number of frames for the model.
#             alpha (float): Blending alpha for segmentation.
#         """
#         self.vol_data_list = vol_data_list
#         self.synapse_df = synapse_df.reset_index(drop=True)
#         self.processor = processor
#         self.subvol_size = subvol_size
#         self.num_frames = num_frames
#         self.alpha = alpha

#     def __len__(self):
#         return len(self.synapse_df)

#     def __getitem__(self, idx):
#         syn_info = self.synapse_df.iloc[idx]
#         bbox_index = syn_info['bbox_index']

#         raw_vol, seg_vol = self.vol_data_list[bbox_index]
#         if raw_vol is None or seg_vol is None:
#             # Return dummy data if volumes not found
#             pixel_values = torch.zeros((self.num_frames, 3, self.subvol_size, self.subvol_size), dtype=torch.float32)
#             return pixel_values, pixel_values

#         # Coordinates
#         central_coord = (
#             int(syn_info['central_coord_1']),
#             int(syn_info['central_coord_2']),
#             int(syn_info['central_coord_3'])
#         )
#         side1_coord = (
#             int(syn_info['side_1_coord_1']),
#             int(syn_info['side_1_coord_2']),
#             int(syn_info['side_1_coord_3'])
#         )
#         side2_coord = (
#             int(syn_info['side_2_coord_1']),
#             int(syn_info['side_2_coord_2']),
#             int(syn_info['side_2_coord_3'])
#         )

#         # Create the overlaid segmented cube
#         overlaid_cube = create_segmented_cube(
#             raw_vol=raw_vol,
#             seg_vol=seg_vol,
#             central_coord=central_coord,
#             side1_coord=side1_coord,
#             side2_coord=side2_coord,
#             subvolume_size=self.subvol_size,
#             alpha=self.alpha
#         )  # shape: (80, 80, 3, 80)

#         # We interpret the last dimension (depth) as frames
#         frames = []
#         for z in range(overlaid_cube.shape[3]):  # 80 slices
#             frame_rgb = overlaid_cube[..., z]  # (80, 80, 3)
#             frames.append(frame_rgb)

#         # Now reduce or expand to self.num_frames
#         total_slices = len(frames)  # 80
#         if total_slices < self.num_frames:
#             while len(frames) < self.num_frames:
#                 frames.append(frames[-1])
#         elif total_slices > self.num_frames:
#             indices = np.linspace(0, total_slices - 1, self.num_frames, dtype=int)
#             frames = [frames[i] for i in indices]

#         # Process using the VideoMAEImageProcessor
#         inputs = self.processor(frames, return_tensors="pt")
#         pixel_values = inputs["pixel_values"].squeeze(0)  # (num_frames, 3, H, W)
#         pixel_values = pixel_values.float()

#         # For MAE, target is the same
#         return pixel_values, pixel_values


# def generate_masked_positions(batch_size, sequence_length, mask_ratio=0.75, device='cuda'):
#     """
#     Generate a boolean mask indicating which positions are masked.
#     """
#     masks = torch.zeros(batch_size, sequence_length, dtype=torch.bool, device=device)
#     num_mask = int(mask_ratio * sequence_length)
#     for i in range(batch_size):
#         mask_indices = torch.randperm(sequence_length, device=device)[:num_mask]
#         masks[i, mask_indices] = True
#     return masks


# def log_input_gifs(pixel_values, epoch, prefix="Training"):
#     """
#     Convert input pixel values to GIFs and log them to WandB.
#     """
#     # Ensure pixel_values is on CPU and convert to numpy
#     pixel_values = pixel_values.cpu().numpy()

#     gifs = []
#     for i in range(min(2, pixel_values.shape[0])):  # just take up to 2 samples for logging
#         frames = pixel_values[i]  # shape: (num_frames, 3, height, width)

#         # Normalize frames to [0, 255]
#         frames = frames - frames.min()
#         if frames.max() != 0:
#             frames = frames / frames.max()
#         frames = (frames * 255).astype(np.uint8)

#         # Rearrange to (num_frames, height, width, 3)
#         frames = frames.transpose(0, 2, 3, 1)

#         image_list = [f for f in frames]
#         gif_buffer = io.BytesIO()
#         imageio.mimsave(gif_buffer, image_list, format='GIF', fps=5)
#         gif_buffer.seek(0)
#         gifs.append(wandb.Video(gif_buffer, format="gif"))

#     for idx, gif in enumerate(gifs):
#         wandb.log({f"{prefix}_Input_Cube_Sample_{idx+1}_Epoch_{epoch}": gif})


# def visualize_before_training(dataloader, epoch, prefix="Pre-Training Visualization"):
#     """
#     Visualize a few samples from the dataloader before training starts.
#     """
#     print(f"Starting {prefix}...")
#     for batch_idx, (pixel_values, targets) in enumerate(dataloader):
#         log_input_gifs(pixel_values, epoch=epoch, prefix=prefix)
#         print(f"Logged {prefix} for batch {batch_idx + 1}")
#         # Only log the first batch
#         break
#     print(f"{prefix} completed.")


# def save_sample_gifs(dataloader, save_dir, num_gifs=2, prefix="Segmented"):
#     """
#     Save a specified number of sample GIFs from the dataloader to a directory.
#     This will save the segmented cubes as GIFs before training.
#     """
#     os.makedirs(save_dir, exist_ok=True)
#     print(f"Saving {num_gifs} sample GIFs to {save_dir}...")

#     saved_gifs = 0
#     for batch_idx, (pixel_values, targets) in enumerate(dataloader):
#         # pixel_values shape: (batch_size, num_frames, 3, H, W)
#         pixel_values = pixel_values.cpu().numpy()
#         for i in range(pixel_values.shape[0]):
#             if saved_gifs >= num_gifs:
#                 break
#             frames = pixel_values[i]  # (num_frames, 3, H, W)

#             # Normalize to [0, 255]
#             frames = frames - frames.min()
#             if frames.max() > 0:
#                 frames = frames / frames.max()
#             frames = (frames * 255).astype(np.uint8)

#             # Rearrange to (num_frames, H, W, 3)
#             frames = frames.transpose(0, 2, 3, 1)

#             image_list = [frame for frame in frames]

#             gif_filename = f"{prefix}_Sample_{saved_gifs + 1}.gif"
#             gif_path = os.path.join(save_dir, gif_filename)
#             imageio.mimsave(gif_path, image_list, format='GIF', fps=5)
#             print(f"Saved GIF: {gif_path}")

#             saved_gifs += 1

#         if saved_gifs >= num_gifs:
#             break
#     print(f"Successfully saved {saved_gifs} GIF(s) to {save_dir}.")


# def main():
#     args = parse_args()

#     wandb.init(
#         project=args.wandb_project,
#         entity=args.wandb_entity,
#         name=args.wandb_run_name,
#         config={
#             "raw_base_dir": args.raw_base_dir,
#             "seg_base_dir": args.seg_base_dir,
#             "csv_output_dir": args.csv_output_dir,
#             "checkpoint_dir": args.checkpoint_dir,
#             "log_dir": args.log_dir,
#             "batch_size": args.batch_size,
#             "num_epochs": args.num_epochs,
#             "learning_rate": args.learning_rate,
#             "weight_decay": args.weight_decay,
#             "subvol_size": args.subvol_size,
#             "num_frames": args.num_frames,
#             "mask_ratio": args.mask_ratio,
#             "patience": args.patience,
#             "resume_checkpoint": args.resume_checkpoint,
#         },
#         save_code=True,
#     )

#     os.makedirs(args.csv_output_dir, exist_ok=True)
#     os.makedirs(args.checkpoint_dir, exist_ok=True)
#     os.makedirs(args.log_dir, exist_ok=True)

#     # Directory to save GIFs
#     saved_gifs_dir = os.path.join(args.log_dir, 'saved_gifs')
#     os.makedirs(saved_gifs_dir, exist_ok=True)

#     device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
#     print(f"Using device: {device}")

#     model_name = "MCG-NJU/videomae-base"
#     print("Initializing VideoMAE model and processor...")

#     model_videomae = VideoMAEForPreTraining.from_pretrained(
#         model_name,
#         attn_implementation="sdpa",
#         torch_dtype=torch.float32
#     ).to(device)

#     processor_videomae = VideoMAEImageProcessor.from_pretrained(model_name)
#     model_videomae.train()
#     print("VideoMAE model and processor initialized.")

#     # For demonstration, let's assume we have bounding box names and their corresponding Excel files
#     bbox_names = [f'bbox{i}' for i in range(1, 8)]  # fewer bboxes for a short demo
#     all_vol_data = []
#     all_syn_df = []

#     for bbox_index, bbox_name in enumerate(bbox_names):
#         print(f"Loading data for {bbox_name}...")
#         raw_vol, seg_vol = load_volumes(bbox_name, args.raw_base_dir, args.seg_base_dir)
#         if raw_vol is None or seg_vol is None:
#             print(f"Skipping {bbox_name} due to loading errors.")
#             continue

#         # Suppose we have an Excel file: bbox1.xlsx, bbox2.xlsx, etc.
#         excel_file = f"{bbox_name}.xlsx"
#         if not os.path.exists(excel_file):
#             print(f"Excel file {excel_file} not found. Skipping {bbox_name}.")
#             continue

#         syn_df = pd.read_excel(excel_file)

#         # We assume syn_df has columns: central_coord_1/2/3, side_1_coord_1/2/3, side_2_coord_1/2/3, etc.
#         # We'll just add the bbox_index:
#         syn_df['bbox_index'] = bbox_index

#         all_vol_data.append((raw_vol, seg_vol))
#         all_syn_df.append(syn_df)

#     if not all_syn_df:
#         print("No synapse data loaded. Exiting.")
#         wandb.finish()
#         return

#     combined_syn_df = pd.concat(all_syn_df, ignore_index=True)
#     print(f"Total synapses loaded: {len(combined_syn_df)}")

#     # Split into train/val
#     train_syn_df, val_syn_df = train_test_split(combined_syn_df, test_size=0.2, random_state=42)
#     print(f"Training synapses: {len(train_syn_df)}, Validation synapses: {len(val_syn_df)}")

#     # Build Datasets
#     dataset_videomae_train = VideoMAEDataset(
#         vol_data_list=all_vol_data,
#         synapse_df=train_syn_df,
#         processor=processor_videomae,
#         subvol_size=args.subvol_size,
#         num_frames=args.num_frames,
#         alpha=0.3
#     )
#     dataset_videomae_val = VideoMAEDataset(
#         vol_data_list=all_vol_data,
#         synapse_df=val_syn_df,
#         processor=processor_videomae,
#         subvol_size=args.subvol_size,
#         num_frames=args.num_frames,
#         alpha=0.3
#     )

#     num_workers = min(4, multiprocessing.cpu_count())
#     print(f"Using {num_workers} workers for DataLoader.")

#     dataloader_videomae_train = DataLoader(
#         dataset_videomae_train,
#         batch_size=args.batch_size,
#         shuffle=True,
#         num_workers=num_workers,
#         pin_memory=True
#     )
#     dataloader_videomae_val = DataLoader(
#         dataset_videomae_val,
#         batch_size=args.batch_size,
#         shuffle=False,
#         num_workers=num_workers,
#         pin_memory=True
#     )

#     print("Saving 2 sample segmented GIFs before training starts...")
#     save_sample_gifs(
#         dataloader=dataloader_videomae_train,
#         save_dir=saved_gifs_dir,
#         num_gifs=2,
#         prefix="Segmented"
#     )

#     print("Visualizing sample inputs before training starts...")
#     visualize_before_training(dataloader_videomae_train, epoch=0, prefix="Pre-Training")
#     print("Visualization completed.")

#     optimizer = torch.optim.AdamW(
#         model_videomae.parameters(),
#         lr=args.learning_rate,
#         weight_decay=args.weight_decay
#     )

#     total_steps = len(dataloader_videomae_train) * args.num_epochs
#     scheduler = get_cosine_schedule_with_warmup(
#         optimizer,
#         num_warmup_steps=int(0.1 * total_steps),
#         num_training_steps=total_steps
#     )

#     early_stopping = EarlyStopping(
#         patience=args.patience,
#         verbose=True,
#         path=os.path.join(args.checkpoint_dir, 'best_model.pth')
#     )
#     scaler = GradScaler()

#     start_epoch = 1
#     if args.resume_checkpoint:
#         if os.path.exists(args.resume_checkpoint):
#             checkpoint = torch.load(args.resume_checkpoint, map_location=device)
#             model_videomae.load_state_dict(checkpoint['model_state_dict'])
#             optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
#             scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
#             start_epoch = checkpoint['epoch'] + 1
#             print(f"Resumed training from checkpoint {args.resume_checkpoint} at epoch {start_epoch}")
#             wandb.run.summary["resumed_from_epoch"] = start_epoch
#         else:
#             print(f"Checkpoint {args.resume_checkpoint} not found. Starting from scratch.")

#     print("Starting training...")
#     for epoch in range(start_epoch, args.num_epochs + 1):
#         model_videomae.train()
#         epoch_loss = 0.0

#         train_pbar = tqdm(dataloader_videomae_train, desc=f"Epoch {epoch}/{args.num_epochs} - Train")
#         for batch_idx, (pixel_values, targets) in enumerate(train_pbar):
#             pixel_values = pixel_values.to(device)
#             optimizer.zero_grad()

#             tubelet_size = model_videomae.config.tubelet_size
#             image_size = model_videomae.config.image_size
#             patch_size = model_videomae.config.patch_size

#             num_patches_per_frame = (image_size // patch_size) ** 2
#             num_tubelets = pixel_values.shape[1] // tubelet_size
#             sequence_length = num_tubelets * num_patches_per_frame

#             bool_masked_pos = generate_masked_positions(
#                 pixel_values.shape[0],
#                 sequence_length,
#                 mask_ratio=args.mask_ratio,
#                 device=device
#             )

#             with autocast():
#                 outputs = model_videomae(
#                     pixel_values=pixel_values,
#                     bool_masked_pos=bool_masked_pos
#                 )
#                 loss = outputs.loss

#             scaler.scale(loss).backward()
#             scaler.unscale_(optimizer)
#             torch.nn.utils.clip_grad_norm_(model_videomae.parameters(), max_norm=1.0)
#             scaler.step(optimizer)
#             scaler.update()
#             scheduler.step()

#             epoch_loss += loss.item()
#             train_pbar.set_postfix({'loss': loss.item(), 'lr': scheduler.get_last_lr()[0]})

#         avg_epoch_loss = epoch_loss / len(dataloader_videomae_train)
#         print(f"Epoch {epoch} - Training Loss: {avg_epoch_loss:.4f}")
#         wandb.log({
#             'epoch': epoch,
#             'train_loss': avg_epoch_loss,
#             'learning_rate': scheduler.get_last_lr()[0]
#         })

#         # Validation
#         model_videomae.eval()
#         val_loss = 0.0
#         with torch.no_grad():
#             val_pbar = tqdm(dataloader_videomae_val, desc=f"Epoch {epoch}/{args.num_epochs} - Val")
#             for batch_idx, (pixel_values, targets) in enumerate(val_pbar):
#                 pixel_values = pixel_values.to(device)

#                 tubelet_size = model_videomae.config.tubelet_size
#                 image_size = model_videomae.config.image_size
#                 patch_size = model_videomae.config.patch_size

#                 num_patches_per_frame = (image_size // patch_size) ** 2
#                 num_tubelets = pixel_values.shape[1] // tubelet_size
#                 sequence_length = num_tubelets * num_patches_per_frame

#                 bool_masked_pos = generate_masked_positions(
#                     pixel_values.shape[0],
#                     sequence_length,
#                     mask_ratio=args.mask_ratio,
#                     device=device
#                 )

#                 outputs = model_videomae(
#                     pixel_values=pixel_values,
#                     bool_masked_pos=bool_masked_pos
#                 )
#                 loss = outputs.loss
#                 val_loss += loss.item()

#             avg_val_loss = val_loss / len(dataloader_videomae_val)
#             print(f"Epoch {epoch} - Validation Loss: {avg_val_loss:.4f}")
#             wandb.log({'val_loss': avg_val_loss, 'epoch': epoch})

#         # Early stopping
#         early_stopping(avg_val_loss, model_videomae, optimizer, scheduler, epoch)
#         if early_stopping.early_stop:
#             print("Early stopping triggered.")
#             wandb.run.summary["early_stopped_at_epoch"] = epoch
#             break

#         # Save checkpoint every epoch (you can adjust frequency)
#         checkpoint_path = os.path.join(args.checkpoint_dir, f'epoch_{epoch}.pth')
#         torch.save({
#             'epoch': epoch,
#             'model_state_dict': model_videomae.state_dict(),
#             'optimizer_state_dict': optimizer.state_dict(),
#             'scheduler_state_dict': scheduler.state_dict(),
#             'loss': avg_epoch_loss,
#         }, checkpoint_path)
#         print(f"Checkpoint saved at {checkpoint_path}")
#         wandb.save(checkpoint_path)

#     # Final save
#     final_model_path = os.path.join(args.checkpoint_dir, 'final_checkpoint.pth')
#     torch.save({
#         'epoch': epoch,
#         'model_state_dict': model_videomae.state_dict(),
#         'optimizer_state_dict': optimizer.state_dict(),
#         'scheduler_state_dict': scheduler.state_dict(),
#         'loss': avg_epoch_loss,
#     }, final_model_path)
#     print(f"Training completed. Final checkpoint saved at {final_model_path}")
#     wandb.save(final_model_path)

#     artifact = wandb.Artifact('final_model', type='model')
#     artifact.add_file(final_model_path)
#     wandb.log_artifact(artifact)

#     wandb.finish()


# if __name__ == "__main__":
#     main()


[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.


<IPython.core.display.Javascript object>

[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize
wandb: Paste an API key from your profile and hit enter, or press ctrl+c to quit:[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


Using device: cuda
Initializing VideoMAE model and processor...


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


config.json:   0%|          | 0.00/725 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/377M [00:00<?, ?B/s]

preprocessor_config.json:   0%|          | 0.00/271 [00:00<?, ?B/s]

VideoMAE model and processor initialized.
Loading data for bbox1...
Loading data for bbox2...
Loading data for bbox3...
Loading data for bbox4...
Loading data for bbox5...
Loading data for bbox6...
Loading data for bbox7...
Total synapses loaded: 509
Training synapses: 407, Validation synapses: 102
Using 4 workers for DataLoader.
Saving 2 sample segmented GIFs before training starts...
Saving 2 sample GIFs to logs/saved_gifs...
Saved GIF: logs/saved_gifs/Segmented_Sample_1.gif
Saved GIF: logs/saved_gifs/Segmented_Sample_2.gif
Successfully saved 2 GIF(s) to logs/saved_gifs.
Visualizing sample inputs before training starts...
Starting Pre-Training...
Logged Pre-Training for batch 1


  scaler = GradScaler()


Pre-Training completed.
Visualization completed.
Starting training...


  with autocast():
Epoch 1/100 - Train: 100%|██████████| 204/204 [00:37<00:00,  5.39it/s, loss=0.295, lr=1e-5]


Epoch 1 - Training Loss: 0.2774


Epoch 1/100 - Val: 100%|██████████| 51/51 [00:08<00:00,  5.80it/s]


Epoch 1 - Validation Loss: 0.2716
Validation loss decreased. Saving model to checkpoints/best_model.pth
Checkpoint saved at checkpoints/epoch_1.pth


Epoch 2/100 - Train: 100%|██████████| 204/204 [00:36<00:00,  5.53it/s, loss=0.263, lr=2e-5]


Epoch 2 - Training Loss: 0.2672


Epoch 2/100 - Val: 100%|██████████| 51/51 [00:09<00:00,  5.34it/s]


Epoch 2 - Validation Loss: 0.2651
Validation loss decreased. Saving model to checkpoints/best_model.pth
Checkpoint saved at checkpoints/epoch_2.pth


Epoch 3/100 - Train: 100%|██████████| 204/204 [00:36<00:00,  5.62it/s, loss=0.297, lr=3e-5]


Epoch 3 - Training Loss: 0.2629


Epoch 3/100 - Val: 100%|██████████| 51/51 [00:08<00:00,  5.80it/s]


Epoch 3 - Validation Loss: 0.2639
Validation loss decreased. Saving model to checkpoints/best_model.pth
Checkpoint saved at checkpoints/epoch_3.pth


Epoch 4/100 - Train: 100%|██████████| 204/204 [00:36<00:00,  5.57it/s, loss=0.248, lr=4e-5]


Epoch 4 - Training Loss: 0.2600


Epoch 4/100 - Val: 100%|██████████| 51/51 [00:09<00:00,  5.41it/s]


Epoch 4 - Validation Loss: 0.2594
Validation loss decreased. Saving model to checkpoints/best_model.pth
Checkpoint saved at checkpoints/epoch_4.pth


Epoch 5/100 - Train: 100%|██████████| 204/204 [00:36<00:00,  5.61it/s, loss=0.251, lr=5e-5]


Epoch 5 - Training Loss: 0.2581


Epoch 5/100 - Val: 100%|██████████| 51/51 [00:08<00:00,  5.81it/s]


Epoch 5 - Validation Loss: 0.2593
Validation loss decreased. Saving model to checkpoints/best_model.pth
Checkpoint saved at checkpoints/epoch_5.pth


Epoch 6/100 - Train: 100%|██████████| 204/204 [00:36<00:00,  5.57it/s, loss=0.245, lr=6e-5]


Epoch 6 - Training Loss: 0.2564


Epoch 6/100 - Val: 100%|██████████| 51/51 [00:09<00:00,  5.49it/s]


Epoch 6 - Validation Loss: 0.2575
Validation loss decreased. Saving model to checkpoints/best_model.pth
Checkpoint saved at checkpoints/epoch_6.pth


Epoch 7/100 - Train: 100%|██████████| 204/204 [00:36<00:00,  5.55it/s, loss=0.294, lr=7e-5]


Epoch 7 - Training Loss: 0.2551


Epoch 7/100 - Val: 100%|██████████| 51/51 [00:08<00:00,  5.94it/s]


Epoch 7 - Validation Loss: 0.2565
Validation loss decreased. Saving model to checkpoints/best_model.pth
Checkpoint saved at checkpoints/epoch_7.pth


Epoch 8/100 - Train: 100%|██████████| 204/204 [00:37<00:00,  5.49it/s, loss=0.276, lr=8e-5]


Epoch 8 - Training Loss: 0.2534


Epoch 8/100 - Val: 100%|██████████| 51/51 [00:09<00:00,  5.66it/s]


Epoch 8 - Validation Loss: 0.2561
Validation loss decreased. Saving model to checkpoints/best_model.pth
Checkpoint saved at checkpoints/epoch_8.pth


Epoch 9/100 - Train: 100%|██████████| 204/204 [00:37<00:00,  5.47it/s, loss=0.204, lr=9e-5]


Epoch 9 - Training Loss: 0.2523


Epoch 9/100 - Val: 100%|██████████| 51/51 [00:08<00:00,  5.75it/s]


Epoch 9 - Validation Loss: 0.2548
Validation loss decreased. Saving model to checkpoints/best_model.pth
Checkpoint saved at checkpoints/epoch_9.pth


Epoch 10/100 - Train: 100%|██████████| 204/204 [00:37<00:00,  5.48it/s, loss=0.262, lr=0.0001]


Epoch 10 - Training Loss: 0.2513


Epoch 10/100 - Val: 100%|██████████| 51/51 [00:08<00:00,  5.71it/s]


Epoch 10 - Validation Loss: 0.2555
EarlyStopping counter: 1 out of 3
Checkpoint saved at checkpoints/epoch_10.pth


Epoch 11/100 - Train: 100%|██████████| 204/204 [00:36<00:00,  5.55it/s, loss=0.228, lr=0.0001]


Epoch 11 - Training Loss: 0.2499


Epoch 11/100 - Val: 100%|██████████| 51/51 [00:08<00:00,  5.90it/s]


Epoch 11 - Validation Loss: 0.2545
Validation loss decreased. Saving model to checkpoints/best_model.pth
Checkpoint saved at checkpoints/epoch_11.pth


Epoch 12/100 - Train: 100%|██████████| 204/204 [00:36<00:00,  5.55it/s, loss=0.204, lr=9.99e-5]


Epoch 12 - Training Loss: 0.2483


Epoch 12/100 - Val: 100%|██████████| 51/51 [00:09<00:00,  5.61it/s]


Epoch 12 - Validation Loss: 0.2550
EarlyStopping counter: 1 out of 3
Checkpoint saved at checkpoints/epoch_12.pth


Epoch 13/100 - Train: 100%|██████████| 204/204 [00:37<00:00,  5.42it/s, loss=0.246, lr=9.97e-5]


Epoch 13 - Training Loss: 0.2468


Epoch 13/100 - Val: 100%|██████████| 51/51 [00:08<00:00,  5.84it/s]


Epoch 13 - Validation Loss: 0.2519
Validation loss decreased. Saving model to checkpoints/best_model.pth
Checkpoint saved at checkpoints/epoch_13.pth


Epoch 14/100 - Train: 100%|██████████| 204/204 [00:36<00:00,  5.51it/s, loss=0.206, lr=9.95e-5]


Epoch 14 - Training Loss: 0.2451


Epoch 14/100 - Val: 100%|██████████| 51/51 [00:09<00:00,  5.59it/s]


Epoch 14 - Validation Loss: 0.2517
Validation loss decreased. Saving model to checkpoints/best_model.pth
Checkpoint saved at checkpoints/epoch_14.pth


Epoch 15/100 - Train: 100%|██████████| 204/204 [00:36<00:00,  5.55it/s, loss=0.219, lr=9.92e-5]


Epoch 15 - Training Loss: 0.2438


Epoch 15/100 - Val: 100%|██████████| 51/51 [00:09<00:00,  5.48it/s]


Epoch 15 - Validation Loss: 0.2490
Validation loss decreased. Saving model to checkpoints/best_model.pth
Checkpoint saved at checkpoints/epoch_15.pth


Epoch 16/100 - Train: 100%|██████████| 204/204 [00:36<00:00,  5.56it/s, loss=0.262, lr=9.89e-5]


Epoch 16 - Training Loss: 0.2429


Epoch 16/100 - Val: 100%|██████████| 51/51 [00:08<00:00,  5.85it/s]


Epoch 16 - Validation Loss: 0.2497
EarlyStopping counter: 1 out of 3
Checkpoint saved at checkpoints/epoch_16.pth


Epoch 17/100 - Train: 100%|██████████| 204/204 [00:36<00:00,  5.53it/s, loss=0.221, lr=9.85e-5]


Epoch 17 - Training Loss: 0.2417


Epoch 17/100 - Val: 100%|██████████| 51/51 [00:09<00:00,  5.64it/s]


Epoch 17 - Validation Loss: 0.2494
EarlyStopping counter: 2 out of 3
Checkpoint saved at checkpoints/epoch_17.pth


Epoch 18/100 - Train: 100%|██████████| 204/204 [00:37<00:00,  5.50it/s, loss=0.214, lr=9.81e-5]


Epoch 18 - Training Loss: 0.2401


Epoch 18/100 - Val: 100%|██████████| 51/51 [00:09<00:00,  5.51it/s]


Epoch 18 - Validation Loss: 0.2492
EarlyStopping counter: 3 out of 3
Early stopping triggered.
Training completed. Final checkpoint saved at checkpoints/final_checkpoint.pth


0,1
epoch,▁▁▁▁▂▂▂▂▃▃▃▃▃▃▄▄▄▄▅▅▅▅▆▆▆▆▆▆▇▇▇▇████
learning_rate,▁▂▃▃▄▅▆▆▇█████████
train_loss,█▆▅▅▄▄▄▄▃▃▃▃▂▂▂▂▁▁
val_loss,█▆▆▄▄▄▃▃▃▃▃▃▂▂▁▁▁▁

0,1
early_stopped_at_epoch,18.0
epoch,18.0
learning_rate,0.0001
train_loss,0.24009
val_loss,0.24924


# Run colored Seg



> to do
>
>
>1.   add Lr-scheduler
>2.   use large model





*   vis samples
*  added wandb



In [None]:
import os
import glob
import argparse
import numpy as np
import pandas as pd
import torch
from torch.utils.data import Dataset, DataLoader
import imageio.v2 as iio
from transformers import (
    VideoMAEForPreTraining,
    VideoMAEImageProcessor,
    get_cosine_schedule_with_warmup,
)
from sklearn.model_selection import train_test_split
import umap.umap_ as umap  # Not necessarily used, but left for completeness
import plotly.express as px  # Not necessarily used, but left for completeness
import plotly.graph_objects as go  # Not necessarily used, but left for completeness
from plotly.subplots import make_subplots  # Not necessarily used, but left for completeness
from tqdm import tqdm
import warnings
from torch.utils.tensorboard import SummaryWriter
from torch.cuda.amp import GradScaler, autocast
from collections import deque
import time
import multiprocessing
import imageio
import wandb
import io  # For in-memory file handling
import matplotlib.pyplot as plt

warnings.filterwarnings("ignore", category=UserWarning, module="torch.utils.data.dataloader")


def parse_args():
    """
    Parse command-line arguments for configurable paths and training parameters.
    """
    parser = argparse.ArgumentParser(description="VideoMAE Pre-training Script with Segmented Videos")
    parser.add_argument('--raw_base_dir', type=str, default='./raw', help='Path to raw data directory')
    parser.add_argument('--seg_base_dir', type=str, default='./seg', help='Path to segmentation data directory')
    parser.add_argument('--csv_output_dir', type=str, default='csv_outputs', help='Directory to save CSV outputs')
    parser.add_argument('--checkpoint_dir', type=str, default='checkpoints', help='Directory to save model checkpoints')
    parser.add_argument('--log_dir', type=str, default='logs', help='Directory for TensorBoard logs')
    parser.add_argument('--batch_size', type=int, default=2, help='Batch size for training')
    parser.add_argument('--num_epochs', type=int, default=5, help='Number of training epochs')
    parser.add_argument('--learning_rate', type=float, default=1e-4, help='Learning rate for optimizer')
    parser.add_argument('--weight_decay', type=float, default=1e-2, help='Weight decay for optimizer')
    parser.add_argument('--subvol_size', type=int, default=80, help='Size of the sub-volume to extract')
    parser.add_argument('--num_frames', type=int, default=16, help='Number of frames per video clip')
    parser.add_argument('--mask_ratio', type=float, default=0.75, help='Mask ratio for VideoMAE')
    parser.add_argument('--patience', type=int, default=3, help='Patience for early stopping')
    parser.add_argument('--resume_checkpoint', type=str, default=None, help='Path to resume checkpoint')

    # Optionally add WandB specific arguments
    parser.add_argument('--wandb_project', type=str, default='VideoMAE_PreTraining', help='WandB project name')
    parser.add_argument('--wandb_entity', type=str, default=None, help='WandB entity/team name')
    parser.add_argument('--wandb_run_name', type=str, default=None, help='WandB run name')

    args, _ = parser.parse_known_args()
    return args


class EarlyStopping:
    """
    Early stopping utility to halt training when validation loss stops improving.
    """
    def __init__(self, patience=10, verbose=False, delta=0.0, path='checkpoint.pth'):
        """
        Args:
            patience (int): How long to wait after last time validation loss improved.
            verbose (bool): If True, prints a message for each validation loss improvement.
            delta (float): Minimum change in the monitored quantity to qualify as an improvement.
            path (str): Path for the checkpoint to be saved to.
        """
        self.patience = patience
        self.verbose = verbose
        self.delta = delta
        self.path = path
        self.counter = 0
        self.best_loss = None
        self.early_stop = False

    def __call__(self, current_loss, model, optimizer, scheduler, epoch):
        if self.best_loss is None:
            self.best_loss = current_loss
            self.save_checkpoint(model, optimizer, scheduler, epoch, current_loss)
        elif current_loss > self.best_loss - self.delta:
            self.counter += 1
            if self.verbose:
                print(f"EarlyStopping counter: {self.counter} out of {self.patience}")
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_loss = current_loss
            self.save_checkpoint(model, optimizer, scheduler, epoch, current_loss)
            self.counter = 0

    def save_checkpoint(self, model, optimizer, scheduler, epoch, loss):
        """
        Saves model when validation loss decreases.
        """
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.state_dict(),
            'loss': loss,
        }, self.path)
        if self.verbose:
            print(f"Validation loss decreased. Saving model to {self.path}")


def load_volumes(bbox_name, raw_base_dir, seg_base_dir):
    """
    Load raw volume and segmentation volume for a bounding box.

    Args:
        bbox_name (str): Name of the bounding box directory.
        raw_base_dir (str): Base directory for raw data.
        seg_base_dir (str): Base directory for segmentation data.

    Returns:
        tuple: (raw_vol, seg_vol) each as np.ndarray
    """
    raw_dir = os.path.join(raw_base_dir, bbox_name)
    seg_dir = os.path.join(seg_base_dir, bbox_name)

    raw_tif_files = sorted(glob.glob(os.path.join(raw_dir, 'slice_*.tif')))
    seg_tif_files = sorted(glob.glob(os.path.join(seg_dir, 'slice_*.tif')))

    if len(raw_tif_files) == 0:
        print(f"No raw files found for {bbox_name} in {raw_dir}")
        return None, None

    if len(seg_tif_files) == 0:
        print(f"No segmentation files found for {bbox_name} in {seg_dir}")
        return None, None

    if len(raw_tif_files) != len(seg_tif_files):
        print(f"Mismatch in number of raw vs seg slices for {bbox_name}. Skipping.")
        return None, None

    try:
        raw_vol = np.stack([iio.imread(f) for f in raw_tif_files], axis=0)  # shape: (Z, Y, X)
        seg_vol = np.stack([iio.imread(f).astype(np.uint32) for f in seg_tif_files], axis=0)
        return raw_vol, seg_vol
    except Exception as e:
        print(f"Error loading volumes for {bbox_name}: {e}")
        return None, None


def create_segmented_cube(
    raw_vol,
    seg_vol,
    central_coord,
    side1_coord,
    side2_coord,
    subvolume_size=80,
    alpha=0.3
):
    """
    Constructs an 80x80x80 segmented 3D cube around the specified synapse coordinates
    and overlays both segmentation masks (side1_coord, side2_coord) on the raw data
    with specified transparency for each slice.

    Returns:
        np.ndarray: Overlaid cube of shape (height, width, 3, depth),
                    i.e., (80, 80, 3, 80) if subvolume_size=80.
    """

    def create_segment_masks(segmentation_volume, s1_coord, s2_coord):
        x1, y1, z1 = s1_coord
        x2, y2, z2 = s2_coord
        # Validate within volume
        if not (0 <= z1 < segmentation_volume.shape[0] and
                0 <= y1 < segmentation_volume.shape[1] and
                0 <= x1 < segmentation_volume.shape[2]):
            raise ValueError("Side1 coordinates are out of bounds.")

        if not (0 <= z2 < segmentation_volume.shape[0] and
                0 <= y2 < segmentation_volume.shape[1] and
                0 <= x2 < segmentation_volume.shape[2]):
            raise ValueError("Side2 coordinates are out of bounds.")

        seg_id_1 = segmentation_volume[z1, y1, x1]
        seg_id_2 = segmentation_volume[z2, y2, x2]

        # If seg_id == 0, it means no segment at that voxel
        if seg_id_1 == 0:
            mask_1 = np.zeros_like(segmentation_volume, dtype=bool)
        else:
            mask_1 = (segmentation_volume == seg_id_1)

        if seg_id_2 == 0:
            mask_2 = np.zeros_like(segmentation_volume, dtype=bool)
        else:
            mask_2 = (segmentation_volume == seg_id_2)

        return mask_1, mask_2

    # Build masks
    mask_1_full, mask_2_full = create_segment_masks(seg_vol, side1_coord, side2_coord)

    # Define subvolume bounds
    half_size = subvolume_size // 2
    cx, cy, cz = central_coord

    x_start, x_end = max(cx - half_size, 0), min(cx + half_size, raw_vol.shape[2])
    y_start, y_end = max(cy - half_size, 0), min(cy + half_size, raw_vol.shape[1])
    z_start, z_end = max(cz - half_size, 0), min(cz + half_size, raw_vol.shape[0])

    # Extract subvolumes
    sub_raw = raw_vol[z_start:z_end, y_start:y_end, x_start:x_end]
    sub_mask_1 = mask_1_full[z_start:z_end, y_start:y_end, x_start:x_end]
    sub_mask_2 = mask_2_full[z_start:z_end, y_start:y_end, x_start:x_end]

    # Pad if smaller than subvolume_size
    pad_z = subvolume_size - sub_raw.shape[0]
    pad_y = subvolume_size - sub_raw.shape[1]
    pad_x = subvolume_size - sub_raw.shape[2]

    if pad_z > 0 or pad_y > 0 or pad_x > 0:
        sub_raw = np.pad(sub_raw, ((0, pad_z), (0, pad_y), (0, pad_x)),
                         mode='constant', constant_values=0)
        sub_mask_1 = np.pad(sub_mask_1, ((0, pad_z), (0, pad_y), (0, pad_x)),
                            mode='constant', constant_values=False)
        sub_mask_2 = np.pad(sub_mask_2, ((0, pad_z), (0, pad_y), (0, pad_x)),
                            mode='constant', constant_values=False)

    # Slice to exact shape
    sub_raw = sub_raw[:subvolume_size, :subvolume_size, :subvolume_size]
    sub_mask_1 = sub_mask_1[:subvolume_size, :subvolume_size, :subvolume_size]
    sub_mask_2 = sub_mask_2[:subvolume_size, :subvolume_size, :subvolume_size]

    # We'll build an overlaid cube: shape => (H, W, 3, D)
    overlaid_cube = np.zeros((subvolume_size, subvolume_size, 3, subvolume_size), dtype=np.uint8)

    # Colors
    side1_color = np.array([0, 0, 0], dtype=np.float32)  # Red
    side2_color = np.array([0, 0, 1], dtype=np.float32)  # Blue

    for z in range(subvolume_size):
        # Normalize raw slice to [0, 1]
        raw_slice = sub_raw[z].astype(np.float32)
        mn, mx = raw_slice.min(), raw_slice.max()
        if mx > mn:
            raw_slice = (raw_slice - mn) / (mx - mn)
        else:
            raw_slice = raw_slice - mn  # all zeros if mn=mx

        raw_rgb = np.stack([raw_slice]*3, axis=-1)  # shape (H, W, 3)

        # Build colored masks
        mask1_rgb = np.zeros_like(raw_rgb)
        mask1_rgb[sub_mask_1[z]] = side1_color

        mask2_rgb = np.zeros_like(raw_rgb)
        mask2_rgb[sub_mask_2[z]] = side2_color

        # Blend
        overlaid_image = (1 - alpha) * raw_rgb + alpha * (mask1_rgb + mask2_rgb)
        overlaid_image = np.clip(overlaid_image, 0, 1)

        overlaid_image = (overlaid_image * 255).astype(np.uint8)
        overlaid_cube[:, :, :, z] = overlaid_image

    return overlaid_cube


class VideoMAEDataset(Dataset):
    """
    Dataset class that uses segmented volumes (side1 & side2) for VideoMAE pre-training.
    """
    def __init__(self, vol_data_list, synapse_df, processor, subvol_size=80, num_frames=16, alpha=0.3):
        """
        Args:
            vol_data_list (List[Tuple[np.ndarray, np.ndarray]]): List of (raw_vol, seg_vol).
            synapse_df (pd.DataFrame): DataFrame with synapse coordinates (central, side1, side2).
            processor (VideoMAEImageProcessor): Processor for VideoMAE.
            subvol_size (int): Size of the sub-volume to extract.
            num_frames (int): Number of frames for the model.
            alpha (float): Blending alpha for segmentation.
        """
        self.vol_data_list = vol_data_list
        self.synapse_df = synapse_df.reset_index(drop=True)
        self.processor = processor
        self.subvol_size = subvol_size
        self.num_frames = num_frames
        self.alpha = alpha

    def __len__(self):
        return len(self.synapse_df)

    def __getitem__(self, idx):
        syn_info = self.synapse_df.iloc[idx]
        bbox_index = syn_info['bbox_index']

        raw_vol, seg_vol = self.vol_data_list[bbox_index]
        if raw_vol is None or seg_vol is None:
            # Return dummy data if volumes not found
            pixel_values = torch.zeros((self.num_frames, 3, self.subvol_size, self.subvol_size), dtype=torch.float32)
            return pixel_values, pixel_values

        # Coordinates
        central_coord = (
            int(syn_info['central_coord_1']),
            int(syn_info['central_coord_2']),
            int(syn_info['central_coord_3'])
        )
        side1_coord = (
            int(syn_info['side_1_coord_1']),
            int(syn_info['side_1_coord_2']),
            int(syn_info['side_1_coord_3'])
        )
        side2_coord = (
            int(syn_info['side_2_coord_1']),
            int(syn_info['side_2_coord_2']),
            int(syn_info['side_2_coord_3'])
        )

        # Create the overlaid segmented cube
        overlaid_cube = create_segmented_cube(
            raw_vol=raw_vol,
            seg_vol=seg_vol,
            central_coord=central_coord,
            side1_coord=side1_coord,
            side2_coord=side2_coord,
            subvolume_size=self.subvol_size,
            alpha=self.alpha
        )  # shape: (80, 80, 3, 80)

        # We interpret the last dimension (depth) as frames
        frames = []
        for z in range(overlaid_cube.shape[3]):  # 80 slices
            frame_rgb = overlaid_cube[..., z]  # (80, 80, 3)
            frames.append(frame_rgb)

        # Now reduce or expand to self.num_frames
        total_slices = len(frames)  # 80
        if total_slices < self.num_frames:
            while len(frames) < self.num_frames:
                frames.append(frames[-1])
        elif total_slices > self.num_frames:
            indices = np.linspace(0, total_slices - 1, self.num_frames, dtype=int)
            frames = [frames[i] for i in indices]

        # Process using the VideoMAEImageProcessor
        inputs = self.processor(frames, return_tensors="pt")
        pixel_values = inputs["pixel_values"].squeeze(0)  # (num_frames, 3, H, W)
        pixel_values = pixel_values.float()

        # For MAE, target is the same
        return pixel_values, pixel_values


def generate_masked_positions(batch_size, sequence_length, mask_ratio=0.75, device='cuda'):
    """
    Generate a boolean mask indicating which positions are masked.
    """
    masks = torch.zeros(batch_size, sequence_length, dtype=torch.bool, device=device)
    num_mask = int(mask_ratio * sequence_length)
    for i in range(batch_size):
        mask_indices = torch.randperm(sequence_length, device=device)[:num_mask]
        masks[i, mask_indices] = True
    return masks


def log_input_gifs(pixel_values, epoch, prefix="Training"):
    """
    Convert input pixel values to GIFs and log them to WandB.
    """
    # Ensure pixel_values is on CPU and convert to numpy
    pixel_values = pixel_values.cpu().numpy()

    gifs = []
    for i in range(min(2, pixel_values.shape[0])):  # just take up to 2 samples for logging
        frames = pixel_values[i]  # shape: (num_frames, 3, height, width)

        # Normalize frames to [0, 255]
        frames = frames - frames.min()
        if frames.max() != 0:
            frames = frames / frames.max()
        frames = (frames * 255).astype(np.uint8)

        # Rearrange to (num_frames, height, width, 3)
        frames = frames.transpose(0, 2, 3, 1)

        image_list = [f for f in frames]
        gif_buffer = io.BytesIO()
        imageio.mimsave(gif_buffer, image_list, format='GIF', fps=5)
        gif_buffer.seek(0)
        gifs.append(wandb.Video(gif_buffer, format="gif"))

    for idx, gif in enumerate(gifs):
        wandb.log({f"{prefix}_Input_Cube_Sample_{idx+1}_Epoch_{epoch}": gif})


def visualize_before_training(dataloader, epoch, prefix="Pre-Training Visualization"):
    """
    Visualize a few samples from the dataloader before training starts.
    """
    print(f"Starting {prefix}...")
    for batch_idx, (pixel_values, targets) in enumerate(dataloader):
        log_input_gifs(pixel_values, epoch=epoch, prefix=prefix)
        print(f"Logged {prefix} for batch {batch_idx + 1}")
        # Only log the first batch
        break
    print(f"{prefix} completed.")


def save_sample_gifs(dataloader, save_dir, num_gifs=2, prefix="Segmented"):
    """
    Save a specified number of sample GIFs from the dataloader to a directory.
    This will save the segmented cubes as GIFs before training.
    """
    os.makedirs(save_dir, exist_ok=True)
    print(f"Saving {num_gifs} sample GIFs to {save_dir}...")

    saved_gifs = 0
    for batch_idx, (pixel_values, targets) in enumerate(dataloader):
        # pixel_values shape: (batch_size, num_frames, 3, H, W)
        pixel_values = pixel_values.cpu().numpy()
        for i in range(pixel_values.shape[0]):
            if saved_gifs >= num_gifs:
                break
            frames = pixel_values[i]  # (num_frames, 3, H, W)

            # Normalize to [0, 255]
            frames = frames - frames.min()
            if frames.max() > 0:
                frames = frames / frames.max()
            frames = (frames * 255).astype(np.uint8)

            # Rearrange to (num_frames, H, W, 3)
            frames = frames.transpose(0, 2, 3, 1)

            image_list = [frame for frame in frames]

            gif_filename = f"{prefix}_Sample_{saved_gifs + 1}.gif"
            gif_path = os.path.join(save_dir, gif_filename)
            imageio.mimsave(gif_path, image_list, format='GIF', fps=5)
            print(f"Saved GIF: {gif_path}")

            saved_gifs += 1

        if saved_gifs >= num_gifs:
            break
    print(f"Successfully saved {saved_gifs} GIF(s) to {save_dir}.")


def main():
    args = parse_args()

    wandb.init(
        project=args.wandb_project,
        entity=args.wandb_entity,
        name=args.wandb_run_name,
        config={
            "raw_base_dir": args.raw_base_dir,
            "seg_base_dir": args.seg_base_dir,
            "csv_output_dir": args.csv_output_dir,
            "checkpoint_dir": args.checkpoint_dir,
            "log_dir": args.log_dir,
            "batch_size": args.batch_size,
            "num_epochs": args.num_epochs,
            "learning_rate": args.learning_rate,
            "weight_decay": args.weight_decay,
            "subvol_size": args.subvol_size,
            "num_frames": args.num_frames,
            "mask_ratio": args.mask_ratio,
            "patience": args.patience,
            "resume_checkpoint": args.resume_checkpoint,
        },
        save_code=True,
    )

    os.makedirs(args.csv_output_dir, exist_ok=True)
    os.makedirs(args.checkpoint_dir, exist_ok=True)
    os.makedirs(args.log_dir, exist_ok=True)

    # Directory to save GIFs
    saved_gifs_dir = os.path.join(args.log_dir, 'saved_gifs')
    os.makedirs(saved_gifs_dir, exist_ok=True)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    model_name = "MCG-NJU/videomae-base"
    print("Initializing VideoMAE model and processor...")

    model_videomae = VideoMAEForPreTraining.from_pretrained(
        model_name,
        attn_implementation="sdpa",
        torch_dtype=torch.float32
    ).to(device)

    processor_videomae = VideoMAEImageProcessor.from_pretrained(model_name)
    model_videomae.train()
    print("VideoMAE model and processor initialized.")

    # For demonstration, let's assume we have bounding box names and their corresponding Excel files
    bbox_names = [f'bbox{i}' for i in range(1, 3)]  # fewer bboxes for a short demo
    all_vol_data = []
    all_syn_df = []

    for bbox_index, bbox_name in enumerate(bbox_names):
        print(f"Loading data for {bbox_name}...")
        raw_vol, seg_vol = load_volumes(bbox_name, args.raw_base_dir, args.seg_base_dir)
        if raw_vol is None or seg_vol is None:
            print(f"Skipping {bbox_name} due to loading errors.")
            continue

        # Suppose we have an Excel file: bbox1.xlsx, bbox2.xlsx, etc.
        excel_file = f"{bbox_name}.xlsx"
        if not os.path.exists(excel_file):
            print(f"Excel file {excel_file} not found. Skipping {bbox_name}.")
            continue

        syn_df = pd.read_excel(excel_file)

        # We assume syn_df has columns: central_coord_1/2/3, side_1_coord_1/2/3, side_2_coord_1/2/3, etc.
        # We'll just add the bbox_index:
        syn_df['bbox_index'] = bbox_index

        all_vol_data.append((raw_vol, seg_vol))
        all_syn_df.append(syn_df)

    if not all_syn_df:
        print("No synapse data loaded. Exiting.")
        wandb.finish()
        return

    combined_syn_df = pd.concat(all_syn_df, ignore_index=True)
    print(f"Total synapses loaded: {len(combined_syn_df)}")

    # Split into train/val
    train_syn_df, val_syn_df = train_test_split(combined_syn_df, test_size=0.2, random_state=42)
    print(f"Training synapses: {len(train_syn_df)}, Validation synapses: {len(val_syn_df)}")

    # Build Datasets
    dataset_videomae_train = VideoMAEDataset(
        vol_data_list=all_vol_data,
        synapse_df=train_syn_df,
        processor=processor_videomae,
        subvol_size=args.subvol_size,
        num_frames=args.num_frames,
        alpha=0.3
    )
    dataset_videomae_val = VideoMAEDataset(
        vol_data_list=all_vol_data,
        synapse_df=val_syn_df,
        processor=processor_videomae,
        subvol_size=args.subvol_size,
        num_frames=args.num_frames,
        alpha=0.3
    )

    num_workers = min(4, multiprocessing.cpu_count())
    print(f"Using {num_workers} workers for DataLoader.")

    dataloader_videomae_train = DataLoader(
        dataset_videomae_train,
        batch_size=args.batch_size,
        shuffle=True,
        num_workers=num_workers,
        pin_memory=True
    )
    dataloader_videomae_val = DataLoader(
        dataset_videomae_val,
        batch_size=args.batch_size,
        shuffle=False,
        num_workers=num_workers,
        pin_memory=True
    )

    print("Saving 2 sample segmented GIFs before training starts...")
    save_sample_gifs(
        dataloader=dataloader_videomae_train,
        save_dir=saved_gifs_dir,
        num_gifs=2,
        prefix="Segmented"
    )

    print("Visualizing sample inputs before training starts...")
    visualize_before_training(dataloader_videomae_train, epoch=0, prefix="Pre-Training")
    print("Visualization completed.")

    optimizer = torch.optim.AdamW(
        model_videomae.parameters(),
        lr=args.learning_rate,
        weight_decay=args.weight_decay
    )

    total_steps = len(dataloader_videomae_train) * args.num_epochs
    scheduler = get_cosine_schedule_with_warmup(
        optimizer,
        num_warmup_steps=int(0.1 * total_steps),
        num_training_steps=total_steps
    )

    early_stopping = EarlyStopping(
        patience=args.patience,
        verbose=True,
        path=os.path.join(args.checkpoint_dir, 'best_model.pth')
    )
    scaler = GradScaler()

    start_epoch = 1
    if args.resume_checkpoint:
        if os.path.exists(args.resume_checkpoint):
            checkpoint = torch.load(args.resume_checkpoint, map_location=device)
            model_videomae.load_state_dict(checkpoint['model_state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
            scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
            start_epoch = checkpoint['epoch'] + 1
            print(f"Resumed training from checkpoint {args.resume_checkpoint} at epoch {start_epoch}")
            wandb.run.summary["resumed_from_epoch"] = start_epoch
        else:
            print(f"Checkpoint {args.resume_checkpoint} not found. Starting from scratch.")

    print("Starting training...")
    for epoch in range(start_epoch, args.num_epochs + 1):
        model_videomae.train()
        epoch_loss = 0.0

        train_pbar = tqdm(dataloader_videomae_train, desc=f"Epoch {epoch}/{args.num_epochs} - Train")
        for batch_idx, (pixel_values, targets) in enumerate(train_pbar):
            pixel_values = pixel_values.to(device)
            optimizer.zero_grad()

            tubelet_size = model_videomae.config.tubelet_size
            image_size = model_videomae.config.image_size
            patch_size = model_videomae.config.patch_size

            num_patches_per_frame = (image_size // patch_size) ** 2
            num_tubelets = pixel_values.shape[1] // tubelet_size
            sequence_length = num_tubelets * num_patches_per_frame

            bool_masked_pos = generate_masked_positions(
                pixel_values.shape[0],
                sequence_length,
                mask_ratio=args.mask_ratio,
                device=device
            )

            with autocast():
                outputs = model_videomae(
                    pixel_values=pixel_values,
                    bool_masked_pos=bool_masked_pos
                )
                loss = outputs.loss

            scaler.scale(loss).backward()
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model_videomae.parameters(), max_norm=1.0)
            scaler.step(optimizer)
            scaler.update()
            scheduler.step()

            epoch_loss += loss.item()
            train_pbar.set_postfix({'loss': loss.item(), 'lr': scheduler.get_last_lr()[0]})

        avg_epoch_loss = epoch_loss / len(dataloader_videomae_train)
        print(f"Epoch {epoch} - Training Loss: {avg_epoch_loss:.4f}")
        wandb.log({
            'epoch': epoch,
            'train_loss': avg_epoch_loss,
            'learning_rate': scheduler.get_last_lr()[0]
        })

        # Validation
        model_videomae.eval()
        val_loss = 0.0
        with torch.no_grad():
            val_pbar = tqdm(dataloader_videomae_val, desc=f"Epoch {epoch}/{args.num_epochs} - Val")
            for batch_idx, (pixel_values, targets) in enumerate(val_pbar):
                pixel_values = pixel_values.to(device)

                tubelet_size = model_videomae.config.tubelet_size
                image_size = model_videomae.config.image_size
                patch_size = model_videomae.config.patch_size

                num_patches_per_frame = (image_size // patch_size) ** 2
                num_tubelets = pixel_values.shape[1] // tubelet_size
                sequence_length = num_tubelets * num_patches_per_frame

                bool_masked_pos = generate_masked_positions(
                    pixel_values.shape[0],
                    sequence_length,
                    mask_ratio=args.mask_ratio,
                    device=device
                )

                outputs = model_videomae(
                    pixel_values=pixel_values,
                    bool_masked_pos=bool_masked_pos
                )
                loss = outputs.loss
                val_loss += loss.item()

            avg_val_loss = val_loss / len(dataloader_videomae_val)
            print(f"Epoch {epoch} - Validation Loss: {avg_val_loss:.4f}")
            wandb.log({'val_loss': avg_val_loss, 'epoch': epoch})

        # Early stopping
        early_stopping(avg_val_loss, model_videomae, optimizer, scheduler, epoch)
        if early_stopping.early_stop:
            print("Early stopping triggered.")
            wandb.run.summary["early_stopped_at_epoch"] = epoch
            break

        # Save checkpoint every epoch (you can adjust frequency)
        checkpoint_path = os.path.join(args.checkpoint_dir, f'epoch_{epoch}.pth')
        torch.save({
            'epoch': epoch,
            'model_state_dict': model_videomae.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.state_dict(),
            'loss': avg_epoch_loss,
        }, checkpoint_path)
        print(f"Checkpoint saved at {checkpoint_path}")
        wandb.save(checkpoint_path)

    # Final save
    final_model_path = os.path.join(args.checkpoint_dir, 'final_checkpoint.pth')
    torch.save({
        'epoch': epoch,
        'model_state_dict': model_videomae.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'scheduler_state_dict': scheduler.state_dict(),
        'loss': avg_epoch_loss,
    }, final_model_path)
    print(f"Training completed. Final checkpoint saved at {final_model_path}")
    wandb.save(final_model_path)

    artifact = wandb.Artifact('final_model', type='model')
    artifact.add_file(final_model_path)
    wandb.log_artifact(artifact)

    wandb.finish()


if __name__ == "__main__":
    main()


Using device: cpu
Initializing VideoMAE model and processor...
VideoMAE model and processor initialized.
Loading data for bbox1...
Loading data for bbox2...
Total synapses loaded: 158
Training synapses: 126, Validation synapses: 32
Using 2 workers for DataLoader.
Saving 2 sample segmented GIFs before training starts...
Saving 2 sample GIFs to logs/saved_gifs...
Saved GIF: logs/saved_gifs/Segmented_Sample_1.gif
Saved GIF: logs/saved_gifs/Segmented_Sample_2.gif
Successfully saved 2 GIF(s) to logs/saved_gifs.
Visualizing sample inputs before training starts...
Starting Pre-Training...


  scaler = GradScaler()


Logged Pre-Training for batch 1
Pre-Training completed.
Visualization completed.
Starting training...


  with autocast():
Epoch 1/5 - Train:   5%|▍         | 3/63 [01:15<25:10, 25.18s/it, loss=0.615, lr=9.68e-6]


KeyboardInterrupt: 

# More masks on the segmented areas

In [None]:
!cp /content/checkpoints/final_model.pth /content/drive/MyDrive

In [None]:
import os
import glob
import numpy as np
import pandas as pd
import torch
from torch.utils.data import Dataset, DataLoader
import imageio.v2 as iio
from transformers import (
    VideoMAEForPreTraining,
    VideoMAEImageProcessor,
    VideoMAEModel,
)
from sklearn.decomposition import PCA
import umap.umap_ as umap
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots
from tqdm import tqdm
import warnings
from torch.utils.tensorboard import SummaryWriter
import time
import multiprocessing
from collections import deque

warnings.filterwarnings("ignore", category=UserWarning, module="torch.utils.data.dataloader")

# Directories and configurations
raw_base_dir = '/content/raw'
seg_base_dir = '/content/seg'
bbox_names = [f'bbox{i}' for i in range(1, 8)]

os.makedirs('csv_outputs', exist_ok=True)
os.makedirs('checkpoints', exist_ok=True)  # Directory for saving checkpoints

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

def load_bbox_data(bbox_name, max_slices=None):
    """
    Load raw and segmentation volumes for a bounding box.
    Returns (raw_vol, seg_vol) each shape (Z, Y, X).
    """
    raw_dir = os.path.join(raw_base_dir, bbox_name)
    seg_dir = os.path.join(seg_base_dir, bbox_name)

    raw_tif_files = sorted(glob.glob(os.path.join(raw_dir, 'slice_*.tif')))
    seg_tif_files = sorted(glob.glob(os.path.join(seg_dir, 'slice_*.tif')))

    if max_slices is not None:
        raw_tif_files = raw_tif_files[:max_slices]
        seg_tif_files = seg_tif_files[:max_slices]

    assert len(raw_tif_files) == len(seg_tif_files), f"Raw/Seg mismatch in {bbox_name}"

    raw_slices = [iio.imread(f) for f in raw_tif_files]
    seg_slices = [iio.imread(f).astype(np.uint32) for f in seg_tif_files]

    raw_vol = np.stack(raw_slices, axis=0)  # shape: (Z, Y, X)
    seg_vol = np.stack(seg_slices, axis=0)  # shape: (Z, Y, X)
    return raw_vol, seg_vol

def create_segment_masks(seg_vol, side1_coord, side2_coord):
    """
    Creates boolean masks for side_1 and side_2 coords in the segmentation volume.
    """
    x1, y1, z1 = [int(c) for c in side1_coord]
    x2, y2, z2 = [int(c) for c in side2_coord]

    seg_id_1 = seg_vol[z1, y1, x1]
    seg_id_2 = seg_vol[z2, y2, x2]

    mask_1 = (seg_vol == seg_id_1) if seg_id_1 != 0 else np.zeros_like(seg_vol, dtype=bool)
    mask_2 = (seg_vol == seg_id_2) if seg_id_2 != 0 else np.zeros_like(seg_vol, dtype=bool)
    return mask_1, mask_2

class FeatureExtractionDataset(Dataset):
    """
    Dataset class for feature extraction using the trained VideoMAE encoder.
    Each item consists of a video clip extracted from the sub-volume around a central coordinate.
    """
    def __init__(self, vol_data_list, synapse_df, processor, subvol_size=80, num_frames=16):
        """
        Args:
            vol_data_list (List[Tuple[np.ndarray, np.ndarray]]): List containing tuples of (raw_vol, seg_vol).
            synapse_df (pd.DataFrame): DataFrame containing synapse information.
            processor (VideoMAEImageProcessor): Processor for VideoMAE.
            subvol_size (int, optional): Size of the sub-volume to extract around the central coordinate. Defaults to 80.
            num_frames (int, optional): Number of frames per video clip for VideoMAE. Defaults to 16.
        """
        self.vol_data_list = vol_data_list
        self.synapse_df = synapse_df.reset_index(drop=True)
        self.subvol_size = subvol_size
        self.half_size = subvol_size // 2
        self.num_frames = num_frames
        self.processor = processor

    def __len__(self):
        return len(self.synapse_df)

    def __getitem__(self, idx):
        syn_info = self.synapse_df.iloc[idx]
        bbox_index = syn_info['bbox_index']
        raw_vol, seg_vol = self.vol_data_list[bbox_index]

        # Coordinates
        central_coord = (
            int(syn_info['central_coord_1']),
            int(syn_info['central_coord_2']),
            int(syn_info['central_coord_3'])
        )
        side1_coord = (
            int(syn_info['side_1_coord_1']),
            int(syn_info['side_1_coord_2']),
            int(syn_info['side_1_coord_3'])
        )
        side2_coord = (
            int(syn_info['side_2_coord_1']),
            int(syn_info['side_2_coord_2']),
            int(syn_info['side_2_coord_3'])
        )

        # Create side1 and side2 masks
        mask_1_full, mask_2_full = create_segment_masks(seg_vol, side1_coord, side2_coord)

        # Determine sub-volume bounds
        cx, cy, cz = central_coord
        x_start = max(cx - self.half_size, 0)
        x_end   = min(cx + self.half_size, raw_vol.shape[2])
        y_start = max(cy - self.half_size, 0)
        y_end   = min(cy + self.half_size, raw_vol.shape[1])
        z_start = max(cz - self.half_size, 0)
        z_end   = min(cz + self.half_size, raw_vol.shape[0])

        sub_raw    = raw_vol[z_start:z_end, y_start:y_end, x_start:x_end]
        sub_mask_1 = mask_1_full[z_start:z_end, y_start:y_end, x_start:x_end]
        sub_mask_2 = mask_2_full[z_start:z_end, y_start:y_end, x_start:x_end]

        # Pad sub-volumes to (subvol_size, subvol_size, subvol_size) if near edges
        desired_shape = (self.subvol_size, self.subvol_size, self.subvol_size)
        dz, dy, dx = sub_raw.shape

        padded_sub_raw    = np.zeros(desired_shape, dtype=sub_raw.dtype)
        padded_sub_mask1  = np.zeros(desired_shape, dtype=np.uint8)
        padded_sub_mask2  = np.zeros(desired_shape, dtype=np.uint8)

        padded_sub_raw[:dz, :dy, :dx] = sub_raw
        padded_sub_mask1[:dz, :dy, :dx] = sub_mask_1
        padded_sub_mask2[:dz, :dy, :dx] = sub_mask_2

        # Create RGB-like frames: R = side1 mask, G = raw intensity, B = side2 mask
        frames = []
        for z in range(self.subvol_size):
            frame_raw = padded_sub_raw[z]
            frame_mask1 = padded_sub_mask1[z]
            frame_mask2 = padded_sub_mask2[z]

            # Normalize raw intensity to [0, 1]
            if frame_raw.max() > frame_raw.min():
                frame_raw_norm = (frame_raw - frame_raw.min()) / (frame_raw.max() - frame_raw.min())
            else:
                frame_raw_norm = np.zeros_like(frame_raw)

            # Stack into 3 channels
            frame_rgb = np.stack([frame_mask1, frame_raw_norm, frame_mask2], axis=-1)  # Shape: (Y, X, 3)
            frames.append(frame_rgb)

        if len(frames) < self.num_frames:
            while len(frames) < self.num_frames:
                frames.append(frames[-1])
        elif len(frames) > self.num_frames:
            indices = np.linspace(0, len(frames)-1, self.num_frames, dtype=int)
            frames = [frames[i] for i in indices]

        frames = [ (frame * 255).astype(np.uint8) for frame in frames ]

        inputs = self.processor(
            frames,
            return_tensors="pt"
        )
        pixel_values = inputs["pixel_values"].squeeze(0)  # Shape: (num_frames, num_channels, height, width)

        # Convert to float32 to match the model's dtype
        pixel_values = pixel_values.float()

        syn_info_dict = syn_info.to_dict()

        return pixel_values, syn_info_dict

# Initialize TensorBoard writer for feature extraction
feature_log_dir = os.path.join('logs', 'feature_extraction', time.strftime("%Y%m%d-%H%M%S"))
feature_writer = SummaryWriter(log_dir=feature_log_dir)
print(f"TensorBoard logging initialized at {feature_log_dir}")

# Load the pre-trained VideoMAE model for feature extraction
model_name = "MCG-NJU/videomae-base"
model_save_path = 'checkpoints/best_model.pth'  # Path to your saved checkpoint

print("Initializing VideoMAEModel for feature extraction...")
model_videomae_feature = VideoMAEModel.from_pretrained(model_name).to(device)

# Load pre-trained weights
pretrained_dict = torch.load(model_save_path, map_location=device)

filtered_dict = {}
for key, value in pretrained_dict['model_state_dict'].items():
    if key.startswith('encoder.'):
        new_key = key.replace('encoder.', '')
        filtered_dict[new_key] = value

model_videomae_feature.load_state_dict(filtered_dict, strict=False)

model_videomae_feature.eval()

print("Pre-trained weights loaded into VideoMAEModel for feature extraction")

# Initialize the processor
processor_videomae = VideoMAEImageProcessor.from_pretrained(model_name)

# Load data
all_vol_data = []
all_syn_df = []

for bbox_index, bbox_name in enumerate(bbox_names):
    print(f"Loading data for {bbox_name}...")
    raw_vol, seg_vol = load_bbox_data(bbox_name)
    excel_file = f'/content/{bbox_name}.xlsx'
    syn_df = pd.read_excel(excel_file)

    syn_df['bbox_index'] = bbox_index
    syn_df['bbox_name']  = bbox_name

    # Append to the lists
    all_vol_data.append( (raw_vol, seg_vol) )
    all_syn_df.append(syn_df)

combined_syn_df = pd.concat(all_syn_df, ignore_index=True)
print(f"Total synapses loaded: {len(combined_syn_df)}")

subvol_size = 80
num_frames = 16   # Number of frames VideoMAE expects

# Initialize Feature Extraction Dataset
feature_dataset = FeatureExtractionDataset(
    vol_data_list=all_vol_data,
    synapse_df=combined_syn_df,
    processor=processor_videomae,
    subvol_size=subvol_size,
    num_frames=num_frames
)

# Determine optimal number of workers
num_workers = min(8, multiprocessing.cpu_count())  # Adjust based on your system
print(f"Using {num_workers} workers for FeatureExtraction DataLoader.")

feature_dataloader = DataLoader(
    feature_dataset,
    batch_size=2,        # Adjust based on your GPU memory
    shuffle=False,
    num_workers=num_workers,
    pin_memory=True,
    persistent_workers=True  # Keep workers alive between epochs
)

print("Feature Extraction DataLoader created.")

# Function to extract features with VideoMAE
def extract_features_with_videomae(video_batch):
    """
    Extract features using the trained VideoMAE model.

    Args:
        video_batch (torch.Tensor): Tensor of shape [B, num_frames, 3, H, W].

    Returns:
        np.ndarray: Array of extracted features with shape [B, hidden_size].
    """
    with torch.no_grad():
        outputs = model_videomae_feature(
            pixel_values=video_batch.to(device),
            return_dict=True
        )
        last_hidden_states = outputs.last_hidden_state  # Shape: [B, sequence_length, hidden_size]
        # Aggregate features (e.g., mean pooling)
        pooled_features = last_hidden_states.mean(dim=1)  # Shape: [B, hidden_size]

    return pooled_features.cpu().numpy()

# Initialize variables for feature extraction
all_csv_paths = []
start_time_total = time.time()

print("Starting feature extraction with VideoMAE encoder...")
for bbox_idx, bbox_name in enumerate(bbox_names):
    print(f"Processing {bbox_name}...")
    raw_vol, seg_vol = load_bbox_data(bbox_name)
    excel_file = f'/content/{bbox_name}.xlsx'
    syn_df = pd.read_excel(excel_file)

    syn_df['bbox_index'] = bbox_idx
    syn_df['bbox_name']  = bbox_name

    dataset_bbox = FeatureExtractionDataset(
        vol_data_list=all_vol_data,
        synapse_df=syn_df,
        processor=processor_videomae,
        subvol_size=subvol_size,
        num_frames=num_frames
    )
    dataloader_bbox = DataLoader(
        dataset_bbox,
        batch_size=2,
        shuffle=False,
        num_workers=num_workers,
        pin_memory=True,
        persistent_workers=True
    )

    bbox_features = []
    bbox_syn_info = []

    for batch_idx, (video_batch, syn_infos) in enumerate(tqdm(dataloader_bbox, desc=f"Extracting features for {bbox_name}")):
        feats = extract_features_with_videomae(video_batch)
        bbox_features.append(feats)

        syn_infos_df = pd.DataFrame(syn_infos)
        bbox_syn_info.append(syn_infos_df)

        # Optional: Log progress to TensorBoard
        feature_writer.add_scalar('Features/Processed Batches', batch_idx+1, bbox_idx * len(dataloader_bbox) + batch_idx + 1)

    bbox_features = np.concatenate(bbox_features, axis=0)
    bbox_syn_info = pd.concat(bbox_syn_info, axis=0).reset_index(drop=True)

    feature_cols = [f'feat_{j}' for j in range(bbox_features.shape[1])]
    features_df = pd.DataFrame(bbox_features, columns=feature_cols)

    output_df = pd.concat([bbox_syn_info, features_df], axis=1)

    output_csv_name = f'csv_outputs/{bbox_name}_videomae_features.csv'
    output_df.to_csv(output_csv_name, index=False)
    all_csv_paths.append(output_csv_name)
    print(f"Saved VideoMAE features for {bbox_name} -> {output_csv_name}")

    # Checkpoint: Save after each bbox
    checkpoint_path = f'checkpoints/{bbox_name}_features.pth'
    torch.save({
        'bbox_name': bbox_name,
        'features': bbox_features,
        'syn_info': syn_infos_df
    }, checkpoint_path)
    print(f"Checkpoint saved at {checkpoint_path}")

    # Log to TensorBoard
    feature_writer.add_scalar('Features/BBoxes Processed', bbox_idx + 1, bbox_idx + 1)

print("Feature extraction completed.")

print("Merging all CSV files...")
merged_df = pd.concat([pd.read_csv(p) for p in all_csv_paths], ignore_index=True)
print(f"Merged {len(all_csv_paths)} CSVs into one DataFrame with {len(merged_df)} rows.")

merged_csv = 'csv_outputs/all_features_merged_videomae.csv'
merged_df.to_csv(merged_csv, index=False)
print(f"Final merged CSV: {merged_csv}")

# Close the TensorBoard writer for feature extraction
feature_writer.close()

print("Starting PCA and UMAP dimensionality reduction...")
df = pd.read_csv(merged_csv)

feat_cols = [c for c in df.columns if c.startswith('feat_')]
X = df[feat_cols].values
print(f"Feature matrix shape: {X.shape}")

# Apply PCA
print("Applying PCA...")
pca = PCA(n_components=50, random_state=42)
X_pca = pca.fit_transform(X)
print(f"PCA transformed shape: {X_pca.shape}")

# Apply UMAP for 3D visualization
print("Applying UMAP for 3D dimensionality reduction...")
umap_3d = umap.UMAP(
    n_components=3,
    n_neighbors=15,
    min_dist=0.1,
    random_state=42
)
X_umap3 = umap_3d.fit_transform(X_pca)
df['umap_x'] = X_umap3[:,0]
df['umap_y'] = X_umap3[:,1]
df['umap_z'] = X_umap3[:,2]

print("Creating 3D UMAP visualization...")
fig = px.scatter_3d(
    df,
    x='umap_x',
    y='umap_y',
    z='umap_z',
    color='bbox_name',
    hover_data=['central_coord_1', 'central_coord_2', 'central_coord_3']
)
fig.update_traces(marker=dict(size=3))
fig.update_layout(width=800, height=600)
fig.write_html("videomae_umap3d.html")
fig.show()

print("Creating 2D UMAP projections...")
fig_xy = px.scatter(
    df,
    x="umap_x",
    y="umap_y",
    color="bbox_name",
    title="UMAP (x vs y)",
    hover_data=["umap_x", "umap_y", "bbox_name"]
)
fig_xy.write_html("videomae_umap_x_vs_y.html")
fig_xy.show()

fig_xz = px.scatter(
    df,
    x="umap_x",
    y="umap_z",
    color="bbox_name",
    title="UMAP (x vs z)",
    hover_data=["umap_x", "umap_z", "bbox_name"]
)
fig_xz.write_html("videomae_umap_x_vs_z.html")
fig_xz.show()

fig_yz = px.scatter(
    df,
    x="umap_y",
    y="umap_z",
    color="bbox_name",
    title="UMAP (y vs z)",
    hover_data=["umap_y", "umap_z", "bbox_name"]
)
fig_yz.write_html("videomae_umap_y_vs_z.html")
fig_yz.show()

print("Creating combined 2D UMAP projections...")
fig_combined = make_subplots(
    rows=1, cols=3,
    subplot_titles=[
        "UMAP (x vs y)",
        "UMAP (x vs z)",
        "UMAP (y vs z)"
    ]
)

cat_codes = df["bbox_name"].astype("category").cat.codes

trace_xy = go.Scatter(
    x=df["umap_x"],
    y=df["umap_y"],
    mode="markers",
    name="(x vs y)",
    marker=dict(
        color=cat_codes,
        colorscale="Viridis",
        showscale=True,
        size=5
    ),
    text=df["bbox_name"],    # Hover text
    hovertemplate="bbox_name:%{text}<br>umap_x=%{x}<br>umap_y=%{y}"
)
fig_combined.add_trace(trace_xy, row=1, col=1)

# Trace for (x vs z)
trace_xz = go.Scatter(
    x=df["umap_x"],
    y=df["umap_z"],
    mode="markers",
    name="(x vs z)",
    marker=dict(
        color=cat_codes,
        colorscale="Viridis",
        showscale=False,  # Colorbar already shown in first subplot
        size=5
    ),
    text=df["bbox_name"],
    hovertemplate="bbox_name:%{text}<br>umap_x=%{x}<br>umap_z=%{y}"
)
fig_combined.add_trace(trace_xz, row=1, col=2)

trace_yz = go.Scatter(
    x=df["umap_y"],
    y=df["umap_z"],
    mode="markers",
    name="(y vs z)",
    marker=dict(
        color=cat_codes,
        colorscale="Viridis",
        showscale=False,
        size=5
    ),
    text=df["bbox_name"],
    hovertemplate="bbox_name:%{text}<br>umap_y=%{x}<br>umap_z=%{y}"
)
fig_combined.add_trace(trace_yz, row=1, col=3)

fig_combined.update_layout(
    title="2D UMAP Projections (All Pairwise Components)",
    width=1800,   # Wide figure
    height=600,
    showlegend=False
)

fig_combined.write_html("videomae_combined_umap_projections.html")
fig_combined.show()
print("Dimensionality reduction and visualization completed.")


Using device: cuda
TensorBoard logging initialized at logs/feature_extraction/20250109-103645
Initializing VideoMAEModel for feature extraction...


  pretrained_dict = torch.load(model_save_path, map_location=device)


Pre-trained weights loaded into VideoMAEModel for feature extraction
Loading data for bbox1...
Loading data for bbox2...
Loading data for bbox3...
Loading data for bbox4...
Loading data for bbox5...
Loading data for bbox6...
Loading data for bbox7...
Total synapses loaded: 509
Using 8 workers for FeatureExtraction DataLoader.
Feature Extraction DataLoader created.
Starting feature extraction with VideoMAE encoder...
Processing bbox1...


Extracting features for bbox1: 100%|██████████| 29/29 [00:05<00:00,  5.09it/s]


Saved VideoMAE features for bbox1 -> csv_outputs/bbox1_videomae_features.csv
Checkpoint saved at checkpoints/bbox1_features.pth
Processing bbox2...


Extracting features for bbox2: 100%|██████████| 50/50 [00:07<00:00,  6.63it/s]


Saved VideoMAE features for bbox2 -> csv_outputs/bbox2_videomae_features.csv
Checkpoint saved at checkpoints/bbox2_features.pth
Processing bbox3...


Extracting features for bbox3: 100%|██████████| 31/31 [00:05<00:00,  5.60it/s]


Saved VideoMAE features for bbox3 -> csv_outputs/bbox3_videomae_features.csv
Checkpoint saved at checkpoints/bbox3_features.pth
Processing bbox4...


Extracting features for bbox4: 100%|██████████| 20/20 [00:03<00:00,  5.05it/s]


Saved VideoMAE features for bbox4 -> csv_outputs/bbox4_videomae_features.csv
Checkpoint saved at checkpoints/bbox4_features.pth
Processing bbox5...


Extracting features for bbox5: 100%|██████████| 43/43 [00:06<00:00,  6.54it/s]


Saved VideoMAE features for bbox5 -> csv_outputs/bbox5_videomae_features.csv
Checkpoint saved at checkpoints/bbox5_features.pth
Processing bbox6...


Extracting features for bbox6: 100%|██████████| 49/49 [00:07<00:00,  6.68it/s]


Saved VideoMAE features for bbox6 -> csv_outputs/bbox6_videomae_features.csv
Checkpoint saved at checkpoints/bbox6_features.pth
Processing bbox7...


Extracting features for bbox7: 100%|██████████| 33/33 [00:05<00:00,  5.90it/s]


Saved VideoMAE features for bbox7 -> csv_outputs/bbox7_videomae_features.csv
Checkpoint saved at checkpoints/bbox7_features.pth
Feature extraction completed.
Merging all CSV files...
Merged 7 CSVs into one DataFrame with 509 rows.
Final merged CSV: csv_outputs/all_features_merged_videomae.csv
Starting PCA and UMAP dimensionality reduction...
Feature matrix shape: (509, 768)
Applying PCA...
PCA transformed shape: (509, 50)
Applying UMAP for 3D dimensionality reduction...


  warn(


Creating 3D UMAP visualization...


Creating 2D UMAP projections...


Creating combined 2D UMAP projections...


Dimensionality reduction and visualization completed.


# Gradcam


In [None]:
!mkdir videomae_gradcam


In [None]:
dataset_videomae.synapse_df.head()

Unnamed: 0,Var1,central_coord_1,central_coord_2,central_coord_3,side_1_coord_1,side_1_coord_2,side_1_coord_3,side_2_coord_1,side_2_coord_2,side_2_coord_3,bbox_index,bbox_name
0,non_spine_synapsed_056,171,260,350,171,268,359,171,260,340,0,bbox1
1,non_spine_synapse_057,223,113,425,223,112,438,223,114,407,0,bbox1
2,non_spine_synapse_058,280,102,377,280,94,400,280,108,364,0,bbox1
3,non_spine_synapse_063,455,131,162,455,134,181,455,127,145,0,bbox1
4,non_spine_synapse_062,138,121,302,135,113,298,140,127,312,0,bbox1


In [None]:
# import torch
# import numpy as np
# from PIL import Image
# import cv2
# from typing import List, Optional

# class VideoMAEGradCAM:
#     """
#     GradCAM implementation for VideoMAE model.
#     Generates attention maps for video input.
#     """
#     def __init__(self, model: torch.nn.Module, layer_idx: int = 11):
#         torch.backends.cudnn.enabled = False  # Temporary fix for some CUDA issues
#         self.model = model
#         self.device = next(model.parameters()).device
#         self.gradients = None
#         self.activations = None

#         # Register hooks for the attention output
#         target_layer = self.model.encoder.layer[layer_idx].attention.output.dense
#         self.forward_hook = target_layer.register_forward_hook(self._save_activation)
#         self.backward_hook = target_layer.register_backward_hook(self._save_gradient)

#     def _save_activation(self, module, input, output):
#         self.activations = output

#     def __del__(self):
#         # Remove hooks when the object is deleted
#         self.forward_hook.remove()
#         self.backward_hook.remove()

#     def _save_gradient(self, module, grad_input, grad_output):
#         self.gradients = grad_output[0]

#     def generate_cam(self, video_input: torch.Tensor) -> np.ndarray:
#         """
#         Generate attention map for video input.
#         """
#         # Ensure input is on the same device as model
#         video_input = video_input.to(self.device)

#         self.model.zero_grad()

#         # Forward pass
#         outputs = self.model(pixel_values=video_input)

#         # Use mean of output features as target for visualization
#         target = outputs.last_hidden_state.mean()
#         target.backward()

#         with torch.no_grad():
#             # Get gradients and activations
#             gradients = self.gradients.detach()
#             activations = self.activations.detach()

#             # Calculate importance weights
#             weights = torch.mean(gradients, dim=(0, 1))

#             # Weight the activations by the gradients
#             weighted_activations = torch.einsum('ntd,d->nt', activations, weights)

#             # Get video dimensions
#             num_frames = video_input.size(1)  # Number of input frames
#             patch_size = self.model.config.patch_size
#             tubelet_size = self.model.config.tubelet_size
#             image_size = video_input.size(-1)  # Height/width of input frames

#             # Calculate patches
#             patches_per_frame = (image_size // patch_size) ** 2  # Spatial patches per frame
#             num_total_patches = weighted_activations.size(1)  # Total patches in sequence
#             temporal_patches = num_frames // tubelet_size  # Number of temporal patches

#             # First reshape to (temporal_patches, spatial_patches)
#             attention_map = weighted_activations.view(temporal_patches, patches_per_frame)

#             # Then reshape spatial dimension to square grid
#             spatial_size = int(np.sqrt(patches_per_frame))
#             attention_map = attention_map.view(temporal_patches, spatial_size, spatial_size)

#             # Upsample temporal dimension to match number of frames
#             attention_map = attention_map.unsqueeze(1)
#             attention_map = attention_map.repeat(1, tubelet_size, 1, 1)
#             attention_map = attention_map.reshape(num_frames, spatial_size, spatial_size)

#             # Apply ReLU and normalize
#             attention_map = torch.relu(attention_map)
#             if attention_map.max() > attention_map.min():
#                 attention_map = (attention_map - attention_map.min()) / (attention_map.max() - attention_map.min())

#             return attention_map.cpu().numpy()

#     def apply_attention_map(self,
#                           video_frames: List[np.ndarray],
#                           attention_map: np.ndarray,
#                           alpha: float = 0.6) -> List[np.ndarray]:
#         """
#         Apply attention map to original video frames.
#         """
#         result_frames = []

#         for frame_idx, frame in enumerate(video_frames):
#             # Resize attention map to match frame size
#             attention = cv2.resize(attention_map[frame_idx],
#                                  (frame.shape[1], frame.shape[0]),
#                                  interpolation=cv2.INTER_LINEAR)

#             # Create heatmap
#             heatmap = cv2.applyColorMap(
#                 (attention * 255).astype(np.uint8),
#                 cv2.COLORMAP_JET
#             )
#             heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB)

#             # Overlay heatmap on frame
#             superimposed = cv2.addWeighted(frame, 1-alpha, heatmap, alpha, 0)
#             result_frames.append(superimposed)

#         return result_frames

#     def visualize(self,
#                  video_input: torch.Tensor,
#                  original_frames: List[np.ndarray],
#                  output_path: Optional[str] = None,
#                  alpha: float = 0.6,
#                  fps: int = 4) -> List[np.ndarray]:
#         """
#         Generate and save complete visualization.
#         """
#         # Generate attention map
#         attention_map = self.generate_cam(video_input)

#         # Apply attention map to frames
#         visualization_frames = self.apply_attention_map(
#             original_frames,
#             attention_map,
#             alpha
#         )

#         # Save if output path provided
#         if output_path is not None:
#             if output_path.endswith('.gif'):
#                 # Save as GIF
#                 pil_frames = [Image.fromarray(frame) for frame in visualization_frames]
#                 pil_frames[0].save(
#                     output_path,
#                     save_all=True,
#                     append_images=pil_frames[1:],
#                     duration=int(1000/fps),
#                     loop=0
#                 )
#             elif output_path.endswith('.mp4'):
#                 # Save as MP4
#                 height, width = visualization_frames[0].shape[:2]
#                 fourcc = cv2.VideoWriter_fourcc(*'mp4v')
#                 out = cv2.VideoWriter(output_path, fourcc, fps, (width, height))
#                 for frame in visualization_frames:
#                     out.write(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR))
#                 out.release()

#         return visualization_frames

# def visualize_synapse_attention(model, dataset, idx: int,
#                               layer_idx: int = 11,
#                               output_path: Optional[str] = None):
#     """
#     Generate attention visualization for a specific synapse.
#     """
#     # Get video data and move to model device
#     pixel_values, synapse_info = dataset[idx]
#     pixel_values = pixel_values.unsqueeze(0)  # Add batch dimension

#     # Get original frames (on CPU)
#     original_frames = []
#     for frame_idx in range(pixel_values.size(1)):
#         frame = pixel_values[0, frame_idx].cpu().permute(1, 2, 0).numpy()
#         frame = (frame * 255).astype(np.uint8)
#         original_frames.append(frame)

#     # Initialize GradCAM
#     gradcam = VideoMAEGradCAM(model, layer_idx=layer_idx)

#     # Generate visualization
#     vis_frames = gradcam.visualize(
#         pixel_values,
#         original_frames,
#         output_path=output_path
#     )

#     return vis_frames, synapse_info


# import random
# import os

# # Create output directory
# output_dir = "gradcam_outpu2ts2"
# os.makedirs(output_dir, exist_ok=True)

# # Get 20 random indices
# n_samples = 2
# total_samples = len(dataset_videomae)
# random_indices = random.sample(range(total_samples), min(n_samples, total_samples))

# # Get the synapse DataFrame
# synapse_df = dataset_videomae.synapse_df

# # Generate visualizations
# for i, idx in enumerate(random_indices):
#     # Get Var1 value for this synapse
#     var1_value = synapse_df.iloc[idx]['Var1']

#     # Create output filename using Var1
#     output_path = os.path.join(output_dir, f"{var1_value}_attention.gif")

#     print(f"Generating visualization {i+1}/20 for synapse {var1_value}")

#     vis_frames = visualize_synapse_attention(
#         model_videomae_feature,
#         dataset_videomae,
#         idx,
#         layer_idx=11,
#         output_path=output_path
#     )

#     print(f"Saved visualization to {output_path}")

Generating visualization 1/20 for synapse explorative_2024-08-03_Ali_Karimi_025
Forward hook activated: Activations captured.
Forward hook activated: Activations captured.
Forward hook activated: Activations captured.
Forward hook activated: Activations captured.
Forward hook activated: Activations captured.
Forward hook activated: Activations captured.
Forward hook activated: Activations captured.
Forward hook activated: Activations captured.
Forward hook activated: Activations captured.
Forward hook activated: Activations captured.
Forward hook activated: Activations captured.
Forward hook activated: Activations captured.
Forward hook activated: Activations captured.
Forward hook activated: Activations captured.
Forward hook activated: Activations captured.
Forward hook activated: Activations captured.
Forward hook activated: Activations captured.
Forward hook activated: Activations captured.
Forward hook activated: Activations captured.
Forward hook activated: Activations captured.


In [None]:
# import shutil

# shutil.make_archive("/content/gradcam_outputs_archive", 'zip', "/content/gradcam_outputs")

# print("Folder has been zipped to: /content/gradcam_outputs_archive.zip")

Folder has been zipped to: /content/gradcam_outputs_archive.zip
