<a href="https://colab.research.google.com/github/alim98/MPI/blob/main/MAE/MPI_video_MAE_f.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Essential downloads

In [1]:
!wget -O downloaded_file.zip "https://drive.usercontent.google.com/download?id=1iHPBdBOPEagvPTHZmrN__LD49emXwReY&export=download&authuser=0&confirm=t&uuid=631d60dd-569c-4bb1-a9e8-d681f0ed3d43&at=APvzH3r4me8x_LwP3n8O7lgPo8oK%3A1733988188000"

!unzip -q downloaded_file.zip

--2025-01-09 09:15:38--  https://drive.usercontent.google.com/download?id=1iHPBdBOPEagvPTHZmrN__LD49emXwReY&export=download&authuser=0&confirm=t&uuid=631d60dd-569c-4bb1-a9e8-d681f0ed3d43&at=APvzH3r4me8x_LwP3n8O7lgPo8oK%3A1733988188000
Resolving drive.usercontent.google.com (drive.usercontent.google.com)... 142.251.2.132, 2607:f8b0:4023:c0d::84
Connecting to drive.usercontent.google.com (drive.usercontent.google.com)|142.251.2.132|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 1264688649 (1.2G) [application/octet-stream]
Saving to: ‘downloaded_file.zip’


2025-01-09 09:15:56 (76.4 MB/s) - ‘downloaded_file.zip’ saved [1264688649/1264688649]



In [2]:

!pip install transformers scikit-learn matplotlib seaborn torch torchvision umap-learn
!pip install openpyxl


Collecting umap-learn
  Downloading umap_learn-0.5.7-py3-none-any.whl.metadata (21 kB)
Collecting pynndescent>=0.5 (from umap-learn)
  Downloading pynndescent-0.5.13-py3-none-any.whl.metadata (6.8 kB)
Downloading umap_learn-0.5.7-py3-none-any.whl (88 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m88.8/88.8 kB[0m [31m4.7 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading pynndescent-0.5.13-py3-none-any.whl (56 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m56.9/56.9 kB[0m [31m6.0 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: pynndescent, umap-learn
Successfully installed pynndescent-0.5.13 umap-learn-0.5.7


# Run



> to do
>
>
>1.   add Lr-scheduler
>2.   use large model



In [6]:
import os
import glob
import numpy as np
import pandas as pd
import torch
from torch.utils.data import Dataset, DataLoader
import imageio.v2 as iio
from transformers import (
    VideoMAEForPreTraining,
    VideoMAEImageProcessor,
    get_linear_schedule_with_warmup,
)
from transformers import get_cosine_schedule_with_warmup  # Added for cosine scheduler
from sklearn.decomposition import PCA
import umap.umap_ as umap
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots
from tqdm import tqdm
import warnings
from torch.utils.tensorboard import SummaryWriter  # Added for TensorBoard
import time

warnings.filterwarnings("ignore", category=UserWarning, module="torch.utils.data.dataloader")

# Directories and configurations
raw_base_dir = '/content/raw'
seg_base_dir = '/content/seg'
bbox_names = [f'bbox{i}' for i in range(1, 8)]

os.makedirs('csv_outputs', exist_ok=True)
os.makedirs('checkpoints', exist_ok=True)  # Directory for saving checkpoints

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

def load_bbox_data(bbox_name, max_slices=None):
    """
    Load raw and segmentation volumes for a bounding box.
    Returns (raw_vol, seg_vol) each shape (Z, Y, X).
    """
    raw_dir = os.path.join(raw_base_dir, bbox_name)
    seg_dir = os.path.join(seg_base_dir, bbox_name)

    raw_tif_files = sorted(glob.glob(os.path.join(raw_dir, 'slice_*.tif')))
    seg_tif_files = sorted(glob.glob(os.path.join(seg_dir, 'slice_*.tif')))

    if max_slices is not None:
        raw_tif_files = raw_tif_files[:max_slices]
        seg_tif_files = seg_tif_files[:max_slices]

    assert len(raw_tif_files) == len(seg_tif_files), f"Raw/Seg mismatch in {bbox_name}"

    raw_slices = [iio.imread(f) for f in raw_tif_files]
    seg_slices = [iio.imread(f).astype(np.uint32) for f in seg_tif_files]

    raw_vol = np.stack(raw_slices, axis=0)  # shape: (Z, Y, X)
    seg_vol = np.stack(seg_slices, axis=0)  # shape: (Z, Y, X)
    return raw_vol, seg_vol

def create_segment_masks(seg_vol, side1_coord, side2_coord):
    """
    Creates boolean masks for side_1 and side_2 coords in the segmentation volume.
    """
    x1, y1, z1 = [int(c) for c in side1_coord]
    x2, y2, z2 = [int(c) for c in side2_coord]

    seg_id_1 = seg_vol[z1, y1, x1]
    seg_id_2 = seg_vol[z2, y2, x2]

    mask_1 = (seg_vol == seg_id_1) if seg_id_1 != 0 else np.zeros_like(seg_vol, dtype=bool)
    mask_2 = (seg_vol == seg_id_2) if seg_id_2 != 0 else np.zeros_like(seg_vol, dtype=bool)
    return mask_1, mask_2

class VideoMAEDataset(Dataset):
    """
    Dataset class tailored for VideoMAE pre-training.
    Each item consists of a video clip extracted from the sub-volume around a central coordinate.
    """
    def __init__(self, vol_data_list, synapse_df, subvol_size=80, num_frames=16):
        """
        Args:
            vol_data_list (List[Tuple[np.ndarray, np.ndarray]]): List containing tuples of (raw_vol, seg_vol).
            synapse_df (pd.DataFrame): DataFrame containing synapse information.
            subvol_size (int, optional): Size of the sub-volume to extract around the central coordinate. Defaults to 80.
            num_frames (int, optional): Number of frames per video clip for VideoMAE. Defaults to 16.
        """
        self.vol_data_list = vol_data_list
        self.synapse_df = synapse_df.reset_index(drop=True)
        self.subvol_size = subvol_size
        self.half_size = subvol_size // 2
        self.num_frames = num_frames

    def __len__(self):
        return len(self.synapse_df)

    def __getitem__(self, idx):
        syn_info = self.synapse_df.iloc[idx]
        bbox_index = syn_info['bbox_index']
        raw_vol, seg_vol = self.vol_data_list[bbox_index]

        # Coordinates
        central_coord = (
            int(syn_info['central_coord_1']),
            int(syn_info['central_coord_2']),
            int(syn_info['central_coord_3'])
        )
        side1_coord = (
            int(syn_info['side_1_coord_1']),
            int(syn_info['side_1_coord_2']),
            int(syn_info['side_1_coord_3'])
        )
        side2_coord = (
            int(syn_info['side_2_coord_1']),
            int(syn_info['side_2_coord_2']),
            int(syn_info['side_2_coord_3'])
        )

        # Create side1 and side2 masks
        mask_1_full, mask_2_full = create_segment_masks(seg_vol, side1_coord, side2_coord)

        # Determine sub-volume bounds
        cx, cy, cz = central_coord
        x_start = max(cx - self.half_size, 0)
        x_end   = min(cx + self.half_size, raw_vol.shape[2])
        y_start = max(cy - self.half_size, 0)
        y_end   = min(cy + self.half_size, raw_vol.shape[1])
        z_start = max(cz - self.half_size, 0)
        z_end   = min(cz + self.half_size, raw_vol.shape[0])

        sub_raw    = raw_vol[z_start:z_end, y_start:y_end, x_start:x_end]
        sub_mask_1 = mask_1_full[z_start:z_end, y_start:y_end, x_start:x_end]
        sub_mask_2 = mask_2_full[z_start:z_end, y_start:y_end, x_start:x_end]

        # Pad sub-volumes to (subvol_size, subvol_size, subvol_size) if near edges
        desired_shape = (self.subvol_size, self.subvol_size, self.subvol_size)
        dz, dy, dx = sub_raw.shape

        padded_sub_raw    = np.zeros(desired_shape, dtype=sub_raw.dtype)
        padded_sub_mask1  = np.zeros(desired_shape, dtype=np.uint8)
        padded_sub_mask2  = np.zeros(desired_shape, dtype=np.uint8)

        padded_sub_raw[:dz, :dy, :dx] = sub_raw
        padded_sub_mask1[:dz, :dy, :dx] = sub_mask_1
        padded_sub_mask2[:dz, :dy, :dx] = sub_mask_2

        # Create RGB-like frames: R = side1 mask, G = raw intensity, B = side2 mask
        frames = []
        for z in range(self.subvol_size):
            frame_raw = padded_sub_raw[z]
            frame_mask1 = padded_sub_mask1[z]
            frame_mask2 = padded_sub_mask2[z]

            # Normalize raw intensity to [0, 1]
            if frame_raw.max() > frame_raw.min():
                frame_raw_norm = (frame_raw - frame_raw.min()) / (frame_raw.max() - frame_raw.min())
            else:
                frame_raw_norm = np.zeros_like(frame_raw)

            # Stack into 3 channels
            frame_rgb = np.stack([frame_mask1, frame_raw_norm, frame_mask2], axis=-1)  # Shape: (Y, X, 3)
            frames.append(frame_rgb)

        if len(frames) < self.num_frames:
            while len(frames) < self.num_frames:
                frames.append(frames[-1])
        elif len(frames) > self.num_frames:
            indices = np.linspace(0, len(frames)-1, self.num_frames, dtype=int)
            frames = [frames[i] for i in indices]

        frames = [ (frame * 255).astype(np.uint8) for frame in frames ]

        inputs = processor_videomae(
            frames,
            return_tensors="pt"
        )
        pixel_values = inputs["pixel_values"].squeeze(0)  # Shape: (num_frames, num_channels, height, width)

        # Convert to float32 to match the model's dtype
        pixel_values = pixel_values.float()

        # For VideoMAE pre-training, the target is the same as input (autoencoding)
        return pixel_values, pixel_values

# Initialize TensorBoard writer
log_dir = os.path.join('logs', time.strftime("%Y%m%d-%H%M%S"))
writer = SummaryWriter(log_dir=log_dir)

model_name = "MCG-NJU/videomae-base"
print("Initializing VideoMAE model and processor for pre-training...")

# Load the VideoMAE model
model_videomae = VideoMAEForPreTraining.from_pretrained(
    model_name,
    attn_implementation="sdpa",
    torch_dtype=torch.float32
).to(device)

processor_videomae = VideoMAEImageProcessor.from_pretrained(model_name)
model_videomae.train()  # Set to training mode
print("VideoMAE model and processor initialized.")

# Load data
all_vol_data = []
all_syn_df = []

for bbox_index, bbox_name in enumerate(bbox_names):
    print(f"Loading data for {bbox_name}...")
    raw_vol, seg_vol = load_bbox_data(bbox_name)
    excel_file = f'/content/{bbox_name}.xlsx'
    syn_df = pd.read_excel(excel_file)

    syn_df['bbox_index'] = bbox_index
    syn_df['bbox_name']  = bbox_name

    # Append to the lists
    all_vol_data.append( (raw_vol, seg_vol) )
    all_syn_df.append(syn_df)

combined_syn_df = pd.concat(all_syn_df, ignore_index=True)
print(f"Total synapses loaded: {len(combined_syn_df)}")

subvol_size = 80
num_frames = 16   # Number of frames VideoMAE expects

dataset_videomae = VideoMAEDataset(
    vol_data_list=all_vol_data,
    synapse_df=combined_syn_df,
    subvol_size=subvol_size,
    num_frames=num_frames
)

# Determine optimal number of workers
import multiprocessing
num_workers = min(8, multiprocessing.cpu_count())  # Adjust based on your system
print(f"Using {num_workers} workers for DataLoader.")

dataloader_videomae = DataLoader(
    dataset_videomae,
    batch_size=2,
    shuffle=True,
    num_workers=num_workers,
    pin_memory=True,
    persistent_workers=True  # Keep workers alive between epochs
)

print("VideoMAE DataLoader created.")

optimizer = torch.optim.AdamW(
    model_videomae.parameters(),
    lr=1e-4,
    weight_decay=0.01
)

num_epochs = 80
total_steps = len(dataloader_videomae) * num_epochs

# Initialize scheduler (using Cosine Annealing with Warm Restarts)
scheduler = get_cosine_schedule_with_warmup(
    optimizer,
    num_warmup_steps=int(0.1 * total_steps),
    num_training_steps=total_steps
)

# Optionally, implement early stopping
from collections import deque

class EarlyStopping:
    def __init__(self, patience=10, verbose=False, delta=0.0, path='checkpoint.pth'):
        """
        Args:
            patience (int): How long to wait after last time validation loss improved.
            verbose (bool): If True, prints a message for each validation loss improvement.
            delta (float): Minimum change in the monitored quantity to qualify as an improvement.
            path (str): Path for the checkpoint to be saved to.
        """
        self.patience = patience
        self.verbose = verbose
        self.delta = delta
        self.path = path
        self.counter = 0
        self.best_loss = None
        self.early_stop = False

    def __call__(self, current_loss, model, optimizer, epoch):
        if self.best_loss is None:
            self.best_loss = current_loss
            self.save_checkpoint(model, optimizer, epoch, current_loss)
        elif current_loss > self.best_loss - self.delta:
            self.counter += 1
            if self.verbose:
                print(f"EarlyStopping counter: {self.counter} out of {self.patience}")
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_loss = current_loss
            self.save_checkpoint(model, optimizer, epoch, current_loss)
            self.counter = 0

    def save_checkpoint(self, model, optimizer, epoch, loss):
        '''Saves model when validation loss decrease.'''
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': loss,
        }, self.path)
        if self.verbose:
            print(f"Validation loss decreased. Saving model to {self.path}")

# Initialize early stopping (optional)
early_stopping = EarlyStopping(patience=15, verbose=True, path='checkpoints/best_model.pth')

def generate_masked_positions(batch_size, sequence_length, mask_ratio=0.9):
    """
    Generate a boolean mask indicating which positions are masked.

    Args:
        batch_size (int): Number of samples in the batch.
        sequence_length (int): Total number of patches.
        mask_ratio (float): Proportion of patches to mask.

    Returns:
        torch.BoolTensor: Mask of shape [batch_size, sequence_length].
    """
    masks = torch.zeros(batch_size, sequence_length, dtype=torch.bool, device=device)
    num_mask = int(mask_ratio * sequence_length)
    for i in range(batch_size):
        mask_indices = torch.randperm(sequence_length, device=device)[:num_mask]
        masks[i, mask_indices] = True
    return masks

print("Starting VideoMAE pre-training...")

start_time = time.time()
for epoch in range(num_epochs):
    model_videomae.train()
    epoch_loss = 0.0
    progress_bar = tqdm(dataloader_videomae, desc=f"Epoch {epoch+1}/{num_epochs}")

    for batch_idx, (pixel_values, targets) in enumerate(progress_bar):
        pixel_values = pixel_values.to(device)
        targets = targets.to(device)

        optimizer.zero_grad()

        tubelet_size = model_videomae.config.tubelet_size
        image_size = model_videomae.config.image_size
        patch_size = model_videomae.config.patch_size
        num_patches_per_frame = (image_size // patch_size) ** 2
        sequence_length = (pixel_values.shape[1] // tubelet_size) * num_patches_per_frame

        batch_size_current = pixel_values.shape[0]
        bool_masked_pos = generate_masked_positions(batch_size_current, sequence_length, mask_ratio=0.9).to(device)

        outputs = model_videomae(pixel_values=pixel_values, bool_masked_pos=bool_masked_pos)

        loss = outputs.loss
        loss.backward()

        torch.nn.utils.clip_grad_norm_(model_videomae.parameters(), max_norm=1.0)

        optimizer.step()
        scheduler.step()

        epoch_loss += loss.item()
        progress_bar.set_postfix({'loss': loss.item()})

    avg_epoch_loss = epoch_loss / len(dataloader_videomae)
    elapsed = time.time() - start_time
    print(f"Epoch {epoch+1} completed. Average Loss: {avg_epoch_loss:.4f}. Time Elapsed: {elapsed/60:.2f} mins")

    # Log to TensorBoard
    writer.add_scalar('Loss/Train', avg_epoch_loss, epoch+1)

    # Save checkpoint every 10 epochs
    if (epoch + 1) % 10 == 0:
        checkpoint_path = f'checkpoints/epoch_{epoch+1}.pth'
        torch.save({
            'epoch': epoch+1,
            'model_state_dict': model_videomae.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.state_dict(),
            'loss': avg_epoch_loss,
        }, checkpoint_path)
        print(f"Checkpoint saved at {checkpoint_path}")

    # Early stopping check
    early_stopping(avg_epoch_loss, model_videomae, optimizer, epoch+1)
    if early_stopping.early_stop:
        print("Early stopping triggered.")
        break

    # Optional: Reset start_time for next epoch
    start_time = time.time()

# Save the final model
final_model_path = 'checkpoints/final_model.pth'
torch.save(model_videomae.state_dict(), final_model_path)
print(f"Training completed. Final model saved at {final_model_path}")

# Close the TensorBoard writer
writer.close()


Using device: cuda
Initializing VideoMAE model and processor for pre-training...
VideoMAE model and processor initialized.
Loading data for bbox1...
Loading data for bbox2...
Loading data for bbox3...
Loading data for bbox4...
Loading data for bbox5...
Loading data for bbox6...
Loading data for bbox7...
Total synapses loaded: 509
Using 8 workers for DataLoader.
VideoMAE DataLoader created.
Starting VideoMAE pre-training...


Epoch 1/80: 100%|██████████| 255/255 [00:33<00:00,  7.56it/s, loss=0.316]


Epoch 1 completed. Average Loss: 0.3294. Time Elapsed: 0.56 mins
Validation loss decreased. Saving model to checkpoints/best_model.pth


Epoch 2/80: 100%|██████████| 255/255 [00:32<00:00,  7.82it/s, loss=0.288]


Epoch 2 completed. Average Loss: 0.3223. Time Elapsed: 0.54 mins
Validation loss decreased. Saving model to checkpoints/best_model.pth


Epoch 3/80: 100%|██████████| 255/255 [00:32<00:00,  7.78it/s, loss=0.314]


Epoch 3 completed. Average Loss: 0.3194. Time Elapsed: 0.55 mins
Validation loss decreased. Saving model to checkpoints/best_model.pth


Epoch 4/80: 100%|██████████| 255/255 [00:32<00:00,  7.83it/s, loss=0.347]


Epoch 4 completed. Average Loss: 0.3190. Time Elapsed: 0.54 mins
Validation loss decreased. Saving model to checkpoints/best_model.pth


Epoch 5/80: 100%|██████████| 255/255 [00:32<00:00,  7.83it/s, loss=0.309]


Epoch 5 completed. Average Loss: 0.3182. Time Elapsed: 0.54 mins
Validation loss decreased. Saving model to checkpoints/best_model.pth


Epoch 6/80: 100%|██████████| 255/255 [00:32<00:00,  7.82it/s, loss=0.311]


Epoch 6 completed. Average Loss: 0.3177. Time Elapsed: 0.54 mins
Validation loss decreased. Saving model to checkpoints/best_model.pth


Epoch 7/80: 100%|██████████| 255/255 [00:32<00:00,  7.82it/s, loss=0.295]


Epoch 7 completed. Average Loss: 0.3160. Time Elapsed: 0.54 mins
Validation loss decreased. Saving model to checkpoints/best_model.pth


Epoch 8/80: 100%|██████████| 255/255 [00:32<00:00,  7.77it/s, loss=0.299]


Epoch 8 completed. Average Loss: 0.3162. Time Elapsed: 0.55 mins
EarlyStopping counter: 1 out of 15


Epoch 9/80: 100%|██████████| 255/255 [00:32<00:00,  7.75it/s, loss=0.299]


Epoch 9 completed. Average Loss: 0.3147. Time Elapsed: 0.55 mins
Validation loss decreased. Saving model to checkpoints/best_model.pth


Epoch 10/80: 100%|██████████| 255/255 [00:32<00:00,  7.76it/s, loss=0.3]


Epoch 10 completed. Average Loss: 0.3134. Time Elapsed: 0.55 mins
Checkpoint saved at checkpoints/epoch_10.pth
Validation loss decreased. Saving model to checkpoints/best_model.pth


Epoch 11/80: 100%|██████████| 255/255 [00:32<00:00,  7.79it/s, loss=0.319]


Epoch 11 completed. Average Loss: 0.3124. Time Elapsed: 0.55 mins
Validation loss decreased. Saving model to checkpoints/best_model.pth


Epoch 12/80: 100%|██████████| 255/255 [00:32<00:00,  7.81it/s, loss=0.364]


Epoch 12 completed. Average Loss: 0.3116. Time Elapsed: 0.54 mins
Validation loss decreased. Saving model to checkpoints/best_model.pth


Epoch 13/80: 100%|██████████| 255/255 [00:32<00:00,  7.79it/s, loss=0.296]


Epoch 13 completed. Average Loss: 0.3100. Time Elapsed: 0.55 mins
Validation loss decreased. Saving model to checkpoints/best_model.pth


Epoch 14/80: 100%|██████████| 255/255 [00:32<00:00,  7.79it/s, loss=0.301]


Epoch 14 completed. Average Loss: 0.3097. Time Elapsed: 0.55 mins
Validation loss decreased. Saving model to checkpoints/best_model.pth


Epoch 15/80: 100%|██████████| 255/255 [00:32<00:00,  7.75it/s, loss=0.309]


Epoch 15 completed. Average Loss: 0.3084. Time Elapsed: 0.55 mins
Validation loss decreased. Saving model to checkpoints/best_model.pth


Epoch 16/80: 100%|██████████| 255/255 [00:32<00:00,  7.76it/s, loss=0.317]


Epoch 16 completed. Average Loss: 0.3079. Time Elapsed: 0.55 mins
Validation loss decreased. Saving model to checkpoints/best_model.pth


Epoch 17/80: 100%|██████████| 255/255 [00:32<00:00,  7.77it/s, loss=0.338]


Epoch 17 completed. Average Loss: 0.3075. Time Elapsed: 0.55 mins
Validation loss decreased. Saving model to checkpoints/best_model.pth


Epoch 18/80: 100%|██████████| 255/255 [00:32<00:00,  7.80it/s, loss=0.317]


Epoch 18 completed. Average Loss: 0.3056. Time Elapsed: 0.54 mins
Validation loss decreased. Saving model to checkpoints/best_model.pth


Epoch 19/80: 100%|██████████| 255/255 [00:32<00:00,  7.80it/s, loss=0.309]


Epoch 19 completed. Average Loss: 0.3061. Time Elapsed: 0.54 mins
EarlyStopping counter: 1 out of 15


Epoch 20/80: 100%|██████████| 255/255 [00:32<00:00,  7.79it/s, loss=0.296]


Epoch 20 completed. Average Loss: 0.3049. Time Elapsed: 0.55 mins
Checkpoint saved at checkpoints/epoch_20.pth
Validation loss decreased. Saving model to checkpoints/best_model.pth


Epoch 21/80: 100%|██████████| 255/255 [00:33<00:00,  7.72it/s, loss=0.291]


Epoch 21 completed. Average Loss: 0.3050. Time Elapsed: 0.55 mins
EarlyStopping counter: 1 out of 15


Epoch 22/80: 100%|██████████| 255/255 [00:32<00:00,  7.83it/s, loss=0.29]


Epoch 22 completed. Average Loss: 0.3033. Time Elapsed: 0.54 mins
Validation loss decreased. Saving model to checkpoints/best_model.pth


Epoch 23/80: 100%|██████████| 255/255 [00:32<00:00,  7.79it/s, loss=0.345]


Epoch 23 completed. Average Loss: 0.3035. Time Elapsed: 0.55 mins
EarlyStopping counter: 1 out of 15


Epoch 24/80: 100%|██████████| 255/255 [00:32<00:00,  7.76it/s, loss=0.325]


Epoch 24 completed. Average Loss: 0.3023. Time Elapsed: 0.55 mins
Validation loss decreased. Saving model to checkpoints/best_model.pth


Epoch 25/80: 100%|██████████| 255/255 [00:32<00:00,  7.78it/s, loss=0.324]


Epoch 25 completed. Average Loss: 0.3020. Time Elapsed: 0.55 mins
Validation loss decreased. Saving model to checkpoints/best_model.pth


Epoch 26/80: 100%|██████████| 255/255 [00:33<00:00,  7.69it/s, loss=0.315]


Epoch 26 completed. Average Loss: 0.3017. Time Elapsed: 0.55 mins
Validation loss decreased. Saving model to checkpoints/best_model.pth


Epoch 27/80: 100%|██████████| 255/255 [00:32<00:00,  7.74it/s, loss=0.305]


Epoch 27 completed. Average Loss: 0.3002. Time Elapsed: 0.55 mins
Validation loss decreased. Saving model to checkpoints/best_model.pth


Epoch 28/80: 100%|██████████| 255/255 [00:32<00:00,  7.81it/s, loss=0.314]


Epoch 28 completed. Average Loss: 0.3003. Time Elapsed: 0.54 mins
EarlyStopping counter: 1 out of 15


Epoch 29/80: 100%|██████████| 255/255 [00:32<00:00,  7.82it/s, loss=0.297]


Epoch 29 completed. Average Loss: 0.2992. Time Elapsed: 0.54 mins
Validation loss decreased. Saving model to checkpoints/best_model.pth


Epoch 30/80: 100%|██████████| 255/255 [00:32<00:00,  7.73it/s, loss=0.305]


Epoch 30 completed. Average Loss: 0.2990. Time Elapsed: 0.55 mins
Checkpoint saved at checkpoints/epoch_30.pth
Validation loss decreased. Saving model to checkpoints/best_model.pth


Epoch 31/80: 100%|██████████| 255/255 [00:32<00:00,  7.76it/s, loss=0.329]


Epoch 31 completed. Average Loss: 0.2983. Time Elapsed: 0.55 mins
Validation loss decreased. Saving model to checkpoints/best_model.pth


Epoch 32/80: 100%|██████████| 255/255 [00:32<00:00,  7.78it/s, loss=0.327]


Epoch 32 completed. Average Loss: 0.2974. Time Elapsed: 0.55 mins
Validation loss decreased. Saving model to checkpoints/best_model.pth


Epoch 33/80: 100%|██████████| 255/255 [00:33<00:00,  7.72it/s, loss=0.302]


Epoch 33 completed. Average Loss: 0.2967. Time Elapsed: 0.55 mins
Validation loss decreased. Saving model to checkpoints/best_model.pth


Epoch 34/80: 100%|██████████| 255/255 [00:32<00:00,  7.81it/s, loss=0.306]


Epoch 34 completed. Average Loss: 0.2965. Time Elapsed: 0.54 mins
Validation loss decreased. Saving model to checkpoints/best_model.pth


Epoch 35/80: 100%|██████████| 255/255 [00:32<00:00,  7.80it/s, loss=0.282]


Epoch 35 completed. Average Loss: 0.2960. Time Elapsed: 0.55 mins
Validation loss decreased. Saving model to checkpoints/best_model.pth


Epoch 36/80: 100%|██████████| 255/255 [00:32<00:00,  7.76it/s, loss=0.28]


Epoch 36 completed. Average Loss: 0.2955. Time Elapsed: 0.55 mins
Validation loss decreased. Saving model to checkpoints/best_model.pth


Epoch 37/80: 100%|██████████| 255/255 [00:32<00:00,  7.79it/s, loss=0.285]


Epoch 37 completed. Average Loss: 0.2946. Time Elapsed: 0.55 mins
Validation loss decreased. Saving model to checkpoints/best_model.pth


Epoch 38/80: 100%|██████████| 255/255 [00:32<00:00,  7.86it/s, loss=0.288]


Epoch 38 completed. Average Loss: 0.2940. Time Elapsed: 0.54 mins
Validation loss decreased. Saving model to checkpoints/best_model.pth


Epoch 39/80: 100%|██████████| 255/255 [00:33<00:00,  7.72it/s, loss=0.325]


Epoch 39 completed. Average Loss: 0.2932. Time Elapsed: 0.55 mins
Validation loss decreased. Saving model to checkpoints/best_model.pth


Epoch 40/80: 100%|██████████| 255/255 [00:32<00:00,  7.82it/s, loss=0.282]


Epoch 40 completed. Average Loss: 0.2931. Time Elapsed: 0.54 mins
Checkpoint saved at checkpoints/epoch_40.pth
Validation loss decreased. Saving model to checkpoints/best_model.pth


Epoch 41/80: 100%|██████████| 255/255 [00:33<00:00,  7.66it/s, loss=0.309]


Epoch 41 completed. Average Loss: 0.2928. Time Elapsed: 0.55 mins
Validation loss decreased. Saving model to checkpoints/best_model.pth


Epoch 42/80: 100%|██████████| 255/255 [00:32<00:00,  7.79it/s, loss=0.276]


Epoch 42 completed. Average Loss: 0.2916. Time Elapsed: 0.55 mins
Validation loss decreased. Saving model to checkpoints/best_model.pth


Epoch 43/80: 100%|██████████| 255/255 [00:32<00:00,  7.77it/s, loss=0.277]


Epoch 43 completed. Average Loss: 0.2912. Time Elapsed: 0.55 mins
Validation loss decreased. Saving model to checkpoints/best_model.pth


Epoch 44/80: 100%|██████████| 255/255 [00:32<00:00,  7.80it/s, loss=0.312]


Epoch 44 completed. Average Loss: 0.2906. Time Elapsed: 0.55 mins
Validation loss decreased. Saving model to checkpoints/best_model.pth


Epoch 45/80: 100%|██████████| 255/255 [00:32<00:00,  7.73it/s, loss=0.277]


Epoch 45 completed. Average Loss: 0.2893. Time Elapsed: 0.55 mins
Validation loss decreased. Saving model to checkpoints/best_model.pth


Epoch 46/80: 100%|██████████| 255/255 [00:32<00:00,  7.76it/s, loss=0.315]


Epoch 46 completed. Average Loss: 0.2891. Time Elapsed: 0.55 mins
Validation loss decreased. Saving model to checkpoints/best_model.pth


Epoch 47/80: 100%|██████████| 255/255 [00:32<00:00,  7.77it/s, loss=0.292]


Epoch 47 completed. Average Loss: 0.2890. Time Elapsed: 0.55 mins
Validation loss decreased. Saving model to checkpoints/best_model.pth


Epoch 48/80: 100%|██████████| 255/255 [00:33<00:00,  7.69it/s, loss=0.269]


Epoch 48 completed. Average Loss: 0.2884. Time Elapsed: 0.55 mins
Validation loss decreased. Saving model to checkpoints/best_model.pth


Epoch 49/80: 100%|██████████| 255/255 [00:32<00:00,  7.80it/s, loss=0.277]


Epoch 49 completed. Average Loss: 0.2874. Time Elapsed: 0.55 mins
Validation loss decreased. Saving model to checkpoints/best_model.pth


Epoch 50/80: 100%|██████████| 255/255 [00:32<00:00,  7.73it/s, loss=0.317]


Epoch 50 completed. Average Loss: 0.2872. Time Elapsed: 0.55 mins
Checkpoint saved at checkpoints/epoch_50.pth
Validation loss decreased. Saving model to checkpoints/best_model.pth


Epoch 51/80: 100%|██████████| 255/255 [00:32<00:00,  7.76it/s, loss=0.27]


Epoch 51 completed. Average Loss: 0.2864. Time Elapsed: 0.55 mins
Validation loss decreased. Saving model to checkpoints/best_model.pth


Epoch 52/80: 100%|██████████| 255/255 [00:32<00:00,  7.79it/s, loss=0.266]


Epoch 52 completed. Average Loss: 0.2863. Time Elapsed: 0.55 mins
Validation loss decreased. Saving model to checkpoints/best_model.pth


Epoch 53/80: 100%|██████████| 255/255 [00:33<00:00,  7.72it/s, loss=0.289]


Epoch 53 completed. Average Loss: 0.2851. Time Elapsed: 0.55 mins
Validation loss decreased. Saving model to checkpoints/best_model.pth


Epoch 54/80: 100%|██████████| 255/255 [00:32<00:00,  7.75it/s, loss=0.29]


Epoch 54 completed. Average Loss: 0.2850. Time Elapsed: 0.55 mins
Validation loss decreased. Saving model to checkpoints/best_model.pth


Epoch 55/80: 100%|██████████| 255/255 [00:32<00:00,  7.77it/s, loss=0.27]


Epoch 55 completed. Average Loss: 0.2845. Time Elapsed: 0.55 mins
Validation loss decreased. Saving model to checkpoints/best_model.pth


Epoch 56/80: 100%|██████████| 255/255 [00:33<00:00,  7.72it/s, loss=0.274]


Epoch 56 completed. Average Loss: 0.2841. Time Elapsed: 0.55 mins
Validation loss decreased. Saving model to checkpoints/best_model.pth


Epoch 57/80: 100%|██████████| 255/255 [00:33<00:00,  7.72it/s, loss=0.299]


Epoch 57 completed. Average Loss: 0.2834. Time Elapsed: 0.55 mins
Validation loss decreased. Saving model to checkpoints/best_model.pth


Epoch 58/80: 100%|██████████| 255/255 [00:32<00:00,  7.75it/s, loss=0.286]


Epoch 58 completed. Average Loss: 0.2832. Time Elapsed: 0.55 mins
Validation loss decreased. Saving model to checkpoints/best_model.pth


Epoch 59/80: 100%|██████████| 255/255 [00:32<00:00,  7.79it/s, loss=0.285]


Epoch 59 completed. Average Loss: 0.2828. Time Elapsed: 0.55 mins
Validation loss decreased. Saving model to checkpoints/best_model.pth


Epoch 60/80: 100%|██████████| 255/255 [00:32<00:00,  7.76it/s, loss=0.286]


Epoch 60 completed. Average Loss: 0.2817. Time Elapsed: 0.55 mins
Checkpoint saved at checkpoints/epoch_60.pth
Validation loss decreased. Saving model to checkpoints/best_model.pth


Epoch 61/80: 100%|██████████| 255/255 [00:33<00:00,  7.69it/s, loss=0.281]


Epoch 61 completed. Average Loss: 0.2817. Time Elapsed: 0.55 mins
EarlyStopping counter: 1 out of 15


Epoch 62/80: 100%|██████████| 255/255 [00:32<00:00,  7.82it/s, loss=0.284]


Epoch 62 completed. Average Loss: 0.2815. Time Elapsed: 0.54 mins
Validation loss decreased. Saving model to checkpoints/best_model.pth


Epoch 63/80: 100%|██████████| 255/255 [00:33<00:00,  7.73it/s, loss=0.284]


Epoch 63 completed. Average Loss: 0.2810. Time Elapsed: 0.55 mins
Validation loss decreased. Saving model to checkpoints/best_model.pth


Epoch 64/80: 100%|██████████| 255/255 [00:32<00:00,  7.79it/s, loss=0.255]


Epoch 64 completed. Average Loss: 0.2806. Time Elapsed: 0.55 mins
Validation loss decreased. Saving model to checkpoints/best_model.pth


Epoch 65/80: 100%|██████████| 255/255 [00:32<00:00,  7.79it/s, loss=0.282]


Epoch 65 completed. Average Loss: 0.2806. Time Elapsed: 0.55 mins
Validation loss decreased. Saving model to checkpoints/best_model.pth


Epoch 66/80: 100%|██████████| 255/255 [00:32<00:00,  7.75it/s, loss=0.275]


Epoch 66 completed. Average Loss: 0.2794. Time Elapsed: 0.55 mins
Validation loss decreased. Saving model to checkpoints/best_model.pth


Epoch 67/80: 100%|██████████| 255/255 [00:32<00:00,  7.78it/s, loss=0.289]


Epoch 67 completed. Average Loss: 0.2796. Time Elapsed: 0.55 mins
EarlyStopping counter: 1 out of 15


Epoch 68/80: 100%|██████████| 255/255 [00:32<00:00,  7.75it/s, loss=0.296]


Epoch 68 completed. Average Loss: 0.2791. Time Elapsed: 0.55 mins
Validation loss decreased. Saving model to checkpoints/best_model.pth


Epoch 69/80: 100%|██████████| 255/255 [00:32<00:00,  7.76it/s, loss=0.265]


Epoch 69 completed. Average Loss: 0.2788. Time Elapsed: 0.55 mins
Validation loss decreased. Saving model to checkpoints/best_model.pth


Epoch 70/80: 100%|██████████| 255/255 [00:32<00:00,  7.81it/s, loss=0.303]


Epoch 70 completed. Average Loss: 0.2786. Time Elapsed: 0.54 mins
Checkpoint saved at checkpoints/epoch_70.pth
Validation loss decreased. Saving model to checkpoints/best_model.pth


Epoch 71/80: 100%|██████████| 255/255 [00:33<00:00,  7.71it/s, loss=0.265]


Epoch 71 completed. Average Loss: 0.2789. Time Elapsed: 0.55 mins
EarlyStopping counter: 1 out of 15


Epoch 72/80: 100%|██████████| 255/255 [00:33<00:00,  7.72it/s, loss=0.258]


Epoch 72 completed. Average Loss: 0.2779. Time Elapsed: 0.55 mins
Validation loss decreased. Saving model to checkpoints/best_model.pth


Epoch 73/80: 100%|██████████| 255/255 [00:32<00:00,  7.79it/s, loss=0.288]


Epoch 73 completed. Average Loss: 0.2783. Time Elapsed: 0.55 mins
EarlyStopping counter: 1 out of 15


Epoch 74/80: 100%|██████████| 255/255 [00:32<00:00,  7.84it/s, loss=0.267]


Epoch 74 completed. Average Loss: 0.2782. Time Elapsed: 0.54 mins
EarlyStopping counter: 2 out of 15


Epoch 75/80: 100%|██████████| 255/255 [00:32<00:00,  7.80it/s, loss=0.298]


Epoch 75 completed. Average Loss: 0.2784. Time Elapsed: 0.54 mins
EarlyStopping counter: 3 out of 15


Epoch 76/80: 100%|██████████| 255/255 [00:33<00:00,  7.71it/s, loss=0.292]


Epoch 76 completed. Average Loss: 0.2783. Time Elapsed: 0.55 mins
EarlyStopping counter: 4 out of 15


Epoch 77/80: 100%|██████████| 255/255 [00:32<00:00,  7.76it/s, loss=0.256]


Epoch 77 completed. Average Loss: 0.2775. Time Elapsed: 0.55 mins
Validation loss decreased. Saving model to checkpoints/best_model.pth


Epoch 78/80: 100%|██████████| 255/255 [00:32<00:00,  7.80it/s, loss=0.277]


Epoch 78 completed. Average Loss: 0.2774. Time Elapsed: 0.54 mins
Validation loss decreased. Saving model to checkpoints/best_model.pth


Epoch 79/80: 100%|██████████| 255/255 [00:32<00:00,  7.76it/s, loss=0.294]


Epoch 79 completed. Average Loss: 0.2781. Time Elapsed: 0.55 mins
EarlyStopping counter: 1 out of 15


Epoch 80/80: 100%|██████████| 255/255 [00:32<00:00,  7.76it/s, loss=0.335]


Epoch 80 completed. Average Loss: 0.2782. Time Elapsed: 0.55 mins
Checkpoint saved at checkpoints/epoch_80.pth
EarlyStopping counter: 2 out of 15
Training completed. Final model saved at checkpoints/final_model.pth


In [7]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [8]:
!cp /content/checkpoints/final_model.pth /content/drive/MyDrive

In [9]:
import os
import glob
import numpy as np
import pandas as pd
import torch
from torch.utils.data import Dataset, DataLoader
import imageio.v2 as iio
from transformers import (
    VideoMAEForPreTraining,
    VideoMAEImageProcessor,
    VideoMAEModel,
)
from sklearn.decomposition import PCA
import umap.umap_ as umap
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots
from tqdm import tqdm
import warnings
from torch.utils.tensorboard import SummaryWriter
import time
import multiprocessing
from collections import deque

warnings.filterwarnings("ignore", category=UserWarning, module="torch.utils.data.dataloader")

# Directories and configurations
raw_base_dir = '/content/raw'
seg_base_dir = '/content/seg'
bbox_names = [f'bbox{i}' for i in range(1, 8)]

os.makedirs('csv_outputs', exist_ok=True)
os.makedirs('checkpoints', exist_ok=True)  # Directory for saving checkpoints

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

def load_bbox_data(bbox_name, max_slices=None):
    """
    Load raw and segmentation volumes for a bounding box.
    Returns (raw_vol, seg_vol) each shape (Z, Y, X).
    """
    raw_dir = os.path.join(raw_base_dir, bbox_name)
    seg_dir = os.path.join(seg_base_dir, bbox_name)

    raw_tif_files = sorted(glob.glob(os.path.join(raw_dir, 'slice_*.tif')))
    seg_tif_files = sorted(glob.glob(os.path.join(seg_dir, 'slice_*.tif')))

    if max_slices is not None:
        raw_tif_files = raw_tif_files[:max_slices]
        seg_tif_files = seg_tif_files[:max_slices]

    assert len(raw_tif_files) == len(seg_tif_files), f"Raw/Seg mismatch in {bbox_name}"

    raw_slices = [iio.imread(f) for f in raw_tif_files]
    seg_slices = [iio.imread(f).astype(np.uint32) for f in seg_tif_files]

    raw_vol = np.stack(raw_slices, axis=0)  # shape: (Z, Y, X)
    seg_vol = np.stack(seg_slices, axis=0)  # shape: (Z, Y, X)
    return raw_vol, seg_vol

def create_segment_masks(seg_vol, side1_coord, side2_coord):
    """
    Creates boolean masks for side_1 and side_2 coords in the segmentation volume.
    """
    x1, y1, z1 = [int(c) for c in side1_coord]
    x2, y2, z2 = [int(c) for c in side2_coord]

    seg_id_1 = seg_vol[z1, y1, x1]
    seg_id_2 = seg_vol[z2, y2, x2]

    mask_1 = (seg_vol == seg_id_1) if seg_id_1 != 0 else np.zeros_like(seg_vol, dtype=bool)
    mask_2 = (seg_vol == seg_id_2) if seg_id_2 != 0 else np.zeros_like(seg_vol, dtype=bool)
    return mask_1, mask_2

class FeatureExtractionDataset(Dataset):
    """
    Dataset class for feature extraction using the trained VideoMAE encoder.
    Each item consists of a video clip extracted from the sub-volume around a central coordinate.
    """
    def __init__(self, vol_data_list, synapse_df, processor, subvol_size=80, num_frames=16):
        """
        Args:
            vol_data_list (List[Tuple[np.ndarray, np.ndarray]]): List containing tuples of (raw_vol, seg_vol).
            synapse_df (pd.DataFrame): DataFrame containing synapse information.
            processor (VideoMAEImageProcessor): Processor for VideoMAE.
            subvol_size (int, optional): Size of the sub-volume to extract around the central coordinate. Defaults to 80.
            num_frames (int, optional): Number of frames per video clip for VideoMAE. Defaults to 16.
        """
        self.vol_data_list = vol_data_list
        self.synapse_df = synapse_df.reset_index(drop=True)
        self.subvol_size = subvol_size
        self.half_size = subvol_size // 2
        self.num_frames = num_frames
        self.processor = processor

    def __len__(self):
        return len(self.synapse_df)

    def __getitem__(self, idx):
        syn_info = self.synapse_df.iloc[idx]
        bbox_index = syn_info['bbox_index']
        raw_vol, seg_vol = self.vol_data_list[bbox_index]

        # Coordinates
        central_coord = (
            int(syn_info['central_coord_1']),
            int(syn_info['central_coord_2']),
            int(syn_info['central_coord_3'])
        )
        side1_coord = (
            int(syn_info['side_1_coord_1']),
            int(syn_info['side_1_coord_2']),
            int(syn_info['side_1_coord_3'])
        )
        side2_coord = (
            int(syn_info['side_2_coord_1']),
            int(syn_info['side_2_coord_2']),
            int(syn_info['side_2_coord_3'])
        )

        # Create side1 and side2 masks
        mask_1_full, mask_2_full = create_segment_masks(seg_vol, side1_coord, side2_coord)

        # Determine sub-volume bounds
        cx, cy, cz = central_coord
        x_start = max(cx - self.half_size, 0)
        x_end   = min(cx + self.half_size, raw_vol.shape[2])
        y_start = max(cy - self.half_size, 0)
        y_end   = min(cy + self.half_size, raw_vol.shape[1])
        z_start = max(cz - self.half_size, 0)
        z_end   = min(cz + self.half_size, raw_vol.shape[0])

        sub_raw    = raw_vol[z_start:z_end, y_start:y_end, x_start:x_end]
        sub_mask_1 = mask_1_full[z_start:z_end, y_start:y_end, x_start:x_end]
        sub_mask_2 = mask_2_full[z_start:z_end, y_start:y_end, x_start:x_end]

        # Pad sub-volumes to (subvol_size, subvol_size, subvol_size) if near edges
        desired_shape = (self.subvol_size, self.subvol_size, self.subvol_size)
        dz, dy, dx = sub_raw.shape

        padded_sub_raw    = np.zeros(desired_shape, dtype=sub_raw.dtype)
        padded_sub_mask1  = np.zeros(desired_shape, dtype=np.uint8)
        padded_sub_mask2  = np.zeros(desired_shape, dtype=np.uint8)

        padded_sub_raw[:dz, :dy, :dx] = sub_raw
        padded_sub_mask1[:dz, :dy, :dx] = sub_mask_1
        padded_sub_mask2[:dz, :dy, :dx] = sub_mask_2

        # Create RGB-like frames: R = side1 mask, G = raw intensity, B = side2 mask
        frames = []
        for z in range(self.subvol_size):
            frame_raw = padded_sub_raw[z]
            frame_mask1 = padded_sub_mask1[z]
            frame_mask2 = padded_sub_mask2[z]

            # Normalize raw intensity to [0, 1]
            if frame_raw.max() > frame_raw.min():
                frame_raw_norm = (frame_raw - frame_raw.min()) / (frame_raw.max() - frame_raw.min())
            else:
                frame_raw_norm = np.zeros_like(frame_raw)

            # Stack into 3 channels
            frame_rgb = np.stack([frame_mask1, frame_raw_norm, frame_mask2], axis=-1)  # Shape: (Y, X, 3)
            frames.append(frame_rgb)

        if len(frames) < self.num_frames:
            while len(frames) < self.num_frames:
                frames.append(frames[-1])
        elif len(frames) > self.num_frames:
            indices = np.linspace(0, len(frames)-1, self.num_frames, dtype=int)
            frames = [frames[i] for i in indices]

        frames = [ (frame * 255).astype(np.uint8) for frame in frames ]

        inputs = self.processor(
            frames,
            return_tensors="pt"
        )
        pixel_values = inputs["pixel_values"].squeeze(0)  # Shape: (num_frames, num_channels, height, width)

        # Convert to float32 to match the model's dtype
        pixel_values = pixel_values.float()

        syn_info_dict = syn_info.to_dict()

        return pixel_values, syn_info_dict

# Initialize TensorBoard writer for feature extraction
feature_log_dir = os.path.join('logs', 'feature_extraction', time.strftime("%Y%m%d-%H%M%S"))
feature_writer = SummaryWriter(log_dir=feature_log_dir)
print(f"TensorBoard logging initialized at {feature_log_dir}")

# Load the pre-trained VideoMAE model for feature extraction
model_name = "MCG-NJU/videomae-base"
model_save_path = 'checkpoints/best_model.pth'  # Path to your saved checkpoint

print("Initializing VideoMAEModel for feature extraction...")
model_videomae_feature = VideoMAEModel.from_pretrained(model_name).to(device)

# Load pre-trained weights
pretrained_dict = torch.load(model_save_path, map_location=device)

filtered_dict = {}
for key, value in pretrained_dict['model_state_dict'].items():
    if key.startswith('encoder.'):
        new_key = key.replace('encoder.', '')
        filtered_dict[new_key] = value

model_videomae_feature.load_state_dict(filtered_dict, strict=False)

model_videomae_feature.eval()

print("Pre-trained weights loaded into VideoMAEModel for feature extraction")

# Initialize the processor
processor_videomae = VideoMAEImageProcessor.from_pretrained(model_name)

# Load data
all_vol_data = []
all_syn_df = []

for bbox_index, bbox_name in enumerate(bbox_names):
    print(f"Loading data for {bbox_name}...")
    raw_vol, seg_vol = load_bbox_data(bbox_name)
    excel_file = f'/content/{bbox_name}.xlsx'
    syn_df = pd.read_excel(excel_file)

    syn_df['bbox_index'] = bbox_index
    syn_df['bbox_name']  = bbox_name

    # Append to the lists
    all_vol_data.append( (raw_vol, seg_vol) )
    all_syn_df.append(syn_df)

combined_syn_df = pd.concat(all_syn_df, ignore_index=True)
print(f"Total synapses loaded: {len(combined_syn_df)}")

subvol_size = 80
num_frames = 16   # Number of frames VideoMAE expects

# Initialize Feature Extraction Dataset
feature_dataset = FeatureExtractionDataset(
    vol_data_list=all_vol_data,
    synapse_df=combined_syn_df,
    processor=processor_videomae,
    subvol_size=subvol_size,
    num_frames=num_frames
)

# Determine optimal number of workers
num_workers = min(8, multiprocessing.cpu_count())  # Adjust based on your system
print(f"Using {num_workers} workers for FeatureExtraction DataLoader.")

feature_dataloader = DataLoader(
    feature_dataset,
    batch_size=2,        # Adjust based on your GPU memory
    shuffle=False,
    num_workers=num_workers,
    pin_memory=True,
    persistent_workers=True  # Keep workers alive between epochs
)

print("Feature Extraction DataLoader created.")

# Function to extract features with VideoMAE
def extract_features_with_videomae(video_batch):
    """
    Extract features using the trained VideoMAE model.

    Args:
        video_batch (torch.Tensor): Tensor of shape [B, num_frames, 3, H, W].

    Returns:
        np.ndarray: Array of extracted features with shape [B, hidden_size].
    """
    with torch.no_grad():
        outputs = model_videomae_feature(
            pixel_values=video_batch.to(device),
            return_dict=True
        )
        last_hidden_states = outputs.last_hidden_state  # Shape: [B, sequence_length, hidden_size]
        # Aggregate features (e.g., mean pooling)
        pooled_features = last_hidden_states.mean(dim=1)  # Shape: [B, hidden_size]

    return pooled_features.cpu().numpy()

# Initialize variables for feature extraction
all_csv_paths = []
start_time_total = time.time()

print("Starting feature extraction with VideoMAE encoder...")
for bbox_idx, bbox_name in enumerate(bbox_names):
    print(f"Processing {bbox_name}...")
    raw_vol, seg_vol = load_bbox_data(bbox_name)
    excel_file = f'/content/{bbox_name}.xlsx'
    syn_df = pd.read_excel(excel_file)

    syn_df['bbox_index'] = bbox_idx
    syn_df['bbox_name']  = bbox_name

    dataset_bbox = FeatureExtractionDataset(
        vol_data_list=all_vol_data,
        synapse_df=syn_df,
        processor=processor_videomae,
        subvol_size=subvol_size,
        num_frames=num_frames
    )
    dataloader_bbox = DataLoader(
        dataset_bbox,
        batch_size=2,
        shuffle=False,
        num_workers=num_workers,
        pin_memory=True,
        persistent_workers=True
    )

    bbox_features = []
    bbox_syn_info = []

    for batch_idx, (video_batch, syn_infos) in enumerate(tqdm(dataloader_bbox, desc=f"Extracting features for {bbox_name}")):
        feats = extract_features_with_videomae(video_batch)
        bbox_features.append(feats)

        syn_infos_df = pd.DataFrame(syn_infos)
        bbox_syn_info.append(syn_infos_df)

        # Optional: Log progress to TensorBoard
        feature_writer.add_scalar('Features/Processed Batches', batch_idx+1, bbox_idx * len(dataloader_bbox) + batch_idx + 1)

    bbox_features = np.concatenate(bbox_features, axis=0)
    bbox_syn_info = pd.concat(bbox_syn_info, axis=0).reset_index(drop=True)

    feature_cols = [f'feat_{j}' for j in range(bbox_features.shape[1])]
    features_df = pd.DataFrame(bbox_features, columns=feature_cols)

    output_df = pd.concat([bbox_syn_info, features_df], axis=1)

    output_csv_name = f'csv_outputs/{bbox_name}_videomae_features.csv'
    output_df.to_csv(output_csv_name, index=False)
    all_csv_paths.append(output_csv_name)
    print(f"Saved VideoMAE features for {bbox_name} -> {output_csv_name}")

    # Checkpoint: Save after each bbox
    checkpoint_path = f'checkpoints/{bbox_name}_features.pth'
    torch.save({
        'bbox_name': bbox_name,
        'features': bbox_features,
        'syn_info': syn_infos_df
    }, checkpoint_path)
    print(f"Checkpoint saved at {checkpoint_path}")

    # Log to TensorBoard
    feature_writer.add_scalar('Features/BBoxes Processed', bbox_idx + 1, bbox_idx + 1)

print("Feature extraction completed.")

print("Merging all CSV files...")
merged_df = pd.concat([pd.read_csv(p) for p in all_csv_paths], ignore_index=True)
print(f"Merged {len(all_csv_paths)} CSVs into one DataFrame with {len(merged_df)} rows.")

merged_csv = 'csv_outputs/all_features_merged_videomae.csv'
merged_df.to_csv(merged_csv, index=False)
print(f"Final merged CSV: {merged_csv}")

# Close the TensorBoard writer for feature extraction
feature_writer.close()

print("Starting PCA and UMAP dimensionality reduction...")
df = pd.read_csv(merged_csv)

feat_cols = [c for c in df.columns if c.startswith('feat_')]
X = df[feat_cols].values
print(f"Feature matrix shape: {X.shape}")

# Apply PCA
print("Applying PCA...")
pca = PCA(n_components=50, random_state=42)
X_pca = pca.fit_transform(X)
print(f"PCA transformed shape: {X_pca.shape}")

# Apply UMAP for 3D visualization
print("Applying UMAP for 3D dimensionality reduction...")
umap_3d = umap.UMAP(
    n_components=3,
    n_neighbors=15,
    min_dist=0.1,
    random_state=42
)
X_umap3 = umap_3d.fit_transform(X_pca)
df['umap_x'] = X_umap3[:,0]
df['umap_y'] = X_umap3[:,1]
df['umap_z'] = X_umap3[:,2]

print("Creating 3D UMAP visualization...")
fig = px.scatter_3d(
    df,
    x='umap_x',
    y='umap_y',
    z='umap_z',
    color='bbox_name',
    hover_data=['central_coord_1', 'central_coord_2', 'central_coord_3']
)
fig.update_traces(marker=dict(size=3))
fig.update_layout(width=800, height=600)
fig.write_html("videomae_umap3d.html")
fig.show()

print("Creating 2D UMAP projections...")
fig_xy = px.scatter(
    df,
    x="umap_x",
    y="umap_y",
    color="bbox_name",
    title="UMAP (x vs y)",
    hover_data=["umap_x", "umap_y", "bbox_name"]
)
fig_xy.write_html("videomae_umap_x_vs_y.html")
fig_xy.show()

fig_xz = px.scatter(
    df,
    x="umap_x",
    y="umap_z",
    color="bbox_name",
    title="UMAP (x vs z)",
    hover_data=["umap_x", "umap_z", "bbox_name"]
)
fig_xz.write_html("videomae_umap_x_vs_z.html")
fig_xz.show()

fig_yz = px.scatter(
    df,
    x="umap_y",
    y="umap_z",
    color="bbox_name",
    title="UMAP (y vs z)",
    hover_data=["umap_y", "umap_z", "bbox_name"]
)
fig_yz.write_html("videomae_umap_y_vs_z.html")
fig_yz.show()

print("Creating combined 2D UMAP projections...")
fig_combined = make_subplots(
    rows=1, cols=3,
    subplot_titles=[
        "UMAP (x vs y)",
        "UMAP (x vs z)",
        "UMAP (y vs z)"
    ]
)

cat_codes = df["bbox_name"].astype("category").cat.codes

trace_xy = go.Scatter(
    x=df["umap_x"],
    y=df["umap_y"],
    mode="markers",
    name="(x vs y)",
    marker=dict(
        color=cat_codes,
        colorscale="Viridis",
        showscale=True,
        size=5
    ),
    text=df["bbox_name"],    # Hover text
    hovertemplate="bbox_name:%{text}<br>umap_x=%{x}<br>umap_y=%{y}"
)
fig_combined.add_trace(trace_xy, row=1, col=1)

# Trace for (x vs z)
trace_xz = go.Scatter(
    x=df["umap_x"],
    y=df["umap_z"],
    mode="markers",
    name="(x vs z)",
    marker=dict(
        color=cat_codes,
        colorscale="Viridis",
        showscale=False,  # Colorbar already shown in first subplot
        size=5
    ),
    text=df["bbox_name"],
    hovertemplate="bbox_name:%{text}<br>umap_x=%{x}<br>umap_z=%{y}"
)
fig_combined.add_trace(trace_xz, row=1, col=2)

trace_yz = go.Scatter(
    x=df["umap_y"],
    y=df["umap_z"],
    mode="markers",
    name="(y vs z)",
    marker=dict(
        color=cat_codes,
        colorscale="Viridis",
        showscale=False,
        size=5
    ),
    text=df["bbox_name"],
    hovertemplate="bbox_name:%{text}<br>umap_y=%{x}<br>umap_z=%{y}"
)
fig_combined.add_trace(trace_yz, row=1, col=3)

fig_combined.update_layout(
    title="2D UMAP Projections (All Pairwise Components)",
    width=1800,   # Wide figure
    height=600,
    showlegend=False
)

fig_combined.write_html("videomae_combined_umap_projections.html")
fig_combined.show()
print("Dimensionality reduction and visualization completed.")


Using device: cuda
TensorBoard logging initialized at logs/feature_extraction/20250109-103645
Initializing VideoMAEModel for feature extraction...


  pretrained_dict = torch.load(model_save_path, map_location=device)


Pre-trained weights loaded into VideoMAEModel for feature extraction
Loading data for bbox1...
Loading data for bbox2...
Loading data for bbox3...
Loading data for bbox4...
Loading data for bbox5...
Loading data for bbox6...
Loading data for bbox7...
Total synapses loaded: 509
Using 8 workers for FeatureExtraction DataLoader.
Feature Extraction DataLoader created.
Starting feature extraction with VideoMAE encoder...
Processing bbox1...


Extracting features for bbox1: 100%|██████████| 29/29 [00:05<00:00,  5.09it/s]


Saved VideoMAE features for bbox1 -> csv_outputs/bbox1_videomae_features.csv
Checkpoint saved at checkpoints/bbox1_features.pth
Processing bbox2...


Extracting features for bbox2: 100%|██████████| 50/50 [00:07<00:00,  6.63it/s]


Saved VideoMAE features for bbox2 -> csv_outputs/bbox2_videomae_features.csv
Checkpoint saved at checkpoints/bbox2_features.pth
Processing bbox3...


Extracting features for bbox3: 100%|██████████| 31/31 [00:05<00:00,  5.60it/s]


Saved VideoMAE features for bbox3 -> csv_outputs/bbox3_videomae_features.csv
Checkpoint saved at checkpoints/bbox3_features.pth
Processing bbox4...


Extracting features for bbox4: 100%|██████████| 20/20 [00:03<00:00,  5.05it/s]


Saved VideoMAE features for bbox4 -> csv_outputs/bbox4_videomae_features.csv
Checkpoint saved at checkpoints/bbox4_features.pth
Processing bbox5...


Extracting features for bbox5: 100%|██████████| 43/43 [00:06<00:00,  6.54it/s]


Saved VideoMAE features for bbox5 -> csv_outputs/bbox5_videomae_features.csv
Checkpoint saved at checkpoints/bbox5_features.pth
Processing bbox6...


Extracting features for bbox6: 100%|██████████| 49/49 [00:07<00:00,  6.68it/s]


Saved VideoMAE features for bbox6 -> csv_outputs/bbox6_videomae_features.csv
Checkpoint saved at checkpoints/bbox6_features.pth
Processing bbox7...


Extracting features for bbox7: 100%|██████████| 33/33 [00:05<00:00,  5.90it/s]


Saved VideoMAE features for bbox7 -> csv_outputs/bbox7_videomae_features.csv
Checkpoint saved at checkpoints/bbox7_features.pth
Feature extraction completed.
Merging all CSV files...
Merged 7 CSVs into one DataFrame with 509 rows.
Final merged CSV: csv_outputs/all_features_merged_videomae.csv
Starting PCA and UMAP dimensionality reduction...
Feature matrix shape: (509, 768)
Applying PCA...
PCA transformed shape: (509, 50)
Applying UMAP for 3D dimensionality reduction...


  warn(


Creating 3D UMAP visualization...


Creating 2D UMAP projections...


Creating combined 2D UMAP projections...


Dimensionality reduction and visualization completed.


# Gradcam


In [None]:
!mkdir videomae_gradcam


In [None]:
dataset_videomae.synapse_df.head()

Unnamed: 0,Var1,central_coord_1,central_coord_2,central_coord_3,side_1_coord_1,side_1_coord_2,side_1_coord_3,side_2_coord_1,side_2_coord_2,side_2_coord_3,bbox_index,bbox_name
0,non_spine_synapsed_056,171,260,350,171,268,359,171,260,340,0,bbox1
1,non_spine_synapse_057,223,113,425,223,112,438,223,114,407,0,bbox1
2,non_spine_synapse_058,280,102,377,280,94,400,280,108,364,0,bbox1
3,non_spine_synapse_063,455,131,162,455,134,181,455,127,145,0,bbox1
4,non_spine_synapse_062,138,121,302,135,113,298,140,127,312,0,bbox1


In [59]:
import torch
import numpy as np
from PIL import Image
import cv2
from typing import List, Optional

class VideoMAEGradCAM:
    """
    GradCAM implementation for VideoMAE model.
    Generates attention maps for video input.
    """
    def __init__(self, model: torch.nn.Module, layer_idx: int = 11):
        torch.backends.cudnn.enabled = False  # Temporary fix for some CUDA issues
        self.model = model
        self.device = next(model.parameters()).device
        self.gradients = None
        self.activations = None

        # Register hooks for the attention output
        target_layer = self.model.encoder.layer[layer_idx].attention.output.dense
        self.forward_hook = target_layer.register_forward_hook(self._save_activation)
        self.backward_hook = target_layer.register_backward_hook(self._save_gradient)

    def _save_activation(self, module, input, output):
        self.activations = output

    def __del__(self):
        # Remove hooks when the object is deleted
        self.forward_hook.remove()
        self.backward_hook.remove()

    def _save_gradient(self, module, grad_input, grad_output):
        self.gradients = grad_output[0]

    def generate_cam(self, video_input: torch.Tensor) -> np.ndarray:
        """
        Generate attention map for video input.
        """
        # Ensure input is on the same device as model
        video_input = video_input.to(self.device)

        self.model.zero_grad()

        # Forward pass
        outputs = self.model(pixel_values=video_input)

        # Use mean of output features as target for visualization
        target = outputs.last_hidden_state.mean()
        target.backward()

        with torch.no_grad():
            # Get gradients and activations
            gradients = self.gradients.detach()
            activations = self.activations.detach()

            # Calculate importance weights
            weights = torch.mean(gradients, dim=(0, 1))

            # Weight the activations by the gradients
            weighted_activations = torch.einsum('ntd,d->nt', activations, weights)

            # Get video dimensions
            num_frames = video_input.size(1)  # Number of input frames
            patch_size = self.model.config.patch_size
            tubelet_size = self.model.config.tubelet_size
            image_size = video_input.size(-1)  # Height/width of input frames

            # Calculate patches
            patches_per_frame = (image_size // patch_size) ** 2  # Spatial patches per frame
            num_total_patches = weighted_activations.size(1)  # Total patches in sequence
            temporal_patches = num_frames // tubelet_size  # Number of temporal patches

            # First reshape to (temporal_patches, spatial_patches)
            attention_map = weighted_activations.view(temporal_patches, patches_per_frame)

            # Then reshape spatial dimension to square grid
            spatial_size = int(np.sqrt(patches_per_frame))
            attention_map = attention_map.view(temporal_patches, spatial_size, spatial_size)

            # Upsample temporal dimension to match number of frames
            attention_map = attention_map.unsqueeze(1)
            attention_map = attention_map.repeat(1, tubelet_size, 1, 1)
            attention_map = attention_map.reshape(num_frames, spatial_size, spatial_size)

            # Apply ReLU and normalize
            attention_map = torch.relu(attention_map)
            if attention_map.max() > attention_map.min():
                attention_map = (attention_map - attention_map.min()) / (attention_map.max() - attention_map.min())

            return attention_map.cpu().numpy()

    def apply_attention_map(self,
                          video_frames: List[np.ndarray],
                          attention_map: np.ndarray,
                          alpha: float = 0.6) -> List[np.ndarray]:
        """
        Apply attention map to original video frames.
        """
        result_frames = []

        for frame_idx, frame in enumerate(video_frames):
            # Resize attention map to match frame size
            attention = cv2.resize(attention_map[frame_idx],
                                 (frame.shape[1], frame.shape[0]),
                                 interpolation=cv2.INTER_LINEAR)

            # Create heatmap
            heatmap = cv2.applyColorMap(
                (attention * 255).astype(np.uint8),
                cv2.COLORMAP_JET
            )
            heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB)

            # Overlay heatmap on frame
            superimposed = cv2.addWeighted(frame, 1-alpha, heatmap, alpha, 0)
            result_frames.append(superimposed)

        return result_frames

    def visualize(self,
                 video_input: torch.Tensor,
                 original_frames: List[np.ndarray],
                 output_path: Optional[str] = None,
                 alpha: float = 0.6,
                 fps: int = 4) -> List[np.ndarray]:
        """
        Generate and save complete visualization.
        """
        # Generate attention map
        attention_map = self.generate_cam(video_input)

        # Apply attention map to frames
        visualization_frames = self.apply_attention_map(
            original_frames,
            attention_map,
            alpha
        )

        # Save if output path provided
        if output_path is not None:
            if output_path.endswith('.gif'):
                # Save as GIF
                pil_frames = [Image.fromarray(frame) for frame in visualization_frames]
                pil_frames[0].save(
                    output_path,
                    save_all=True,
                    append_images=pil_frames[1:],
                    duration=int(1000/fps),
                    loop=0
                )
            elif output_path.endswith('.mp4'):
                # Save as MP4
                height, width = visualization_frames[0].shape[:2]
                fourcc = cv2.VideoWriter_fourcc(*'mp4v')
                out = cv2.VideoWriter(output_path, fourcc, fps, (width, height))
                for frame in visualization_frames:
                    out.write(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR))
                out.release()

        return visualization_frames

def visualize_synapse_attention(model, dataset, idx: int,
                              layer_idx: int = 11,
                              output_path: Optional[str] = None):
    """
    Generate attention visualization for a specific synapse.
    """
    # Get video data and move to model device
    pixel_values, synapse_info = dataset[idx]
    pixel_values = pixel_values.unsqueeze(0)  # Add batch dimension

    # Get original frames (on CPU)
    original_frames = []
    for frame_idx in range(pixel_values.size(1)):
        frame = pixel_values[0, frame_idx].cpu().permute(1, 2, 0).numpy()
        frame = (frame * 255).astype(np.uint8)
        original_frames.append(frame)

    # Initialize GradCAM
    gradcam = VideoMAEGradCAM(model, layer_idx=layer_idx)

    # Generate visualization
    vis_frames = gradcam.visualize(
        pixel_values,
        original_frames,
        output_path=output_path
    )

    return vis_frames, synapse_info


import random
import os

# Create output directory
output_dir = "gradcam_outpu2ts2"
os.makedirs(output_dir, exist_ok=True)

# Get 20 random indices
n_samples = 2
total_samples = len(dataset_videomae)
random_indices = random.sample(range(total_samples), min(n_samples, total_samples))

# Get the synapse DataFrame
synapse_df = dataset_videomae.synapse_df

# Generate visualizations
for i, idx in enumerate(random_indices):
    # Get Var1 value for this synapse
    var1_value = synapse_df.iloc[idx]['Var1']

    # Create output filename using Var1
    output_path = os.path.join(output_dir, f"{var1_value}_attention.gif")

    print(f"Generating visualization {i+1}/20 for synapse {var1_value}")

    vis_frames = visualize_synapse_attention(
        model_videomae_feature,
        dataset_videomae,
        idx,
        layer_idx=11,
        output_path=output_path
    )

    print(f"Saved visualization to {output_path}")

Generating visualization 1/20 for synapse explorative_2024-08-03_Ali_Karimi_025
Forward hook activated: Activations captured.
Forward hook activated: Activations captured.
Forward hook activated: Activations captured.
Forward hook activated: Activations captured.
Forward hook activated: Activations captured.
Forward hook activated: Activations captured.
Forward hook activated: Activations captured.
Forward hook activated: Activations captured.
Forward hook activated: Activations captured.
Forward hook activated: Activations captured.
Forward hook activated: Activations captured.
Forward hook activated: Activations captured.
Forward hook activated: Activations captured.
Forward hook activated: Activations captured.
Forward hook activated: Activations captured.
Forward hook activated: Activations captured.
Forward hook activated: Activations captured.
Forward hook activated: Activations captured.
Forward hook activated: Activations captured.
Forward hook activated: Activations captured.


In [None]:
import shutil

shutil.make_archive("/content/gradcam_outputs_archive", 'zip', "/content/gradcam_outputs")

print("Folder has been zipped to: /content/gradcam_outputs_archive.zip")

Folder has been zipped to: /content/gradcam_outputs_archive.zip
