<a href="https://colab.research.google.com/github/alim98/MPI/blob/main/Vit_MAE.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Essential downloads

In [None]:
!wget -O downloaded_file.zip "https://drive.usercontent.google.com/download?id=1iHPBdBOPEagvPTHZmrN__LD49emXwReY&export=download&authuser=0&confirm=t&uuid=631d60dd-569c-4bb1-a9e8-d681f0ed3d43&at=APvzH3r4me8x_LwP3n8O7lgPo8oK%3A1733988188000"

!unzip downloaded_file.zip

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
  inflating: seg/bbox3/slice_301.tif  
  inflating: seg/bbox3/slice_329.tif  
  inflating: seg/bbox3/slice_498.tif  
  inflating: seg/bbox3/slice_117.tif  
  inflating: seg/bbox3/slice_103.tif  
  inflating: seg/bbox3/slice_063.tif  
  inflating: seg/bbox3/slice_077.tif  
  inflating: seg/bbox3/slice_088.tif  
  inflating: seg/bbox3/slice_261.tif  
  inflating: seg/bbox3/slice_507.tif  
  inflating: seg/bbox3/slice_513.tif  
  inflating: seg/bbox3/slice_275.tif  
  inflating: seg/bbox3/slice_249.tif  
  inflating: seg/bbox3/slice_248.tif  
  inflating: seg/bbox3/slice_512.tif  
  inflating: seg/bbox3/slice_274.tif  
  inflating: seg/bbox3/slice_260.tif  
  inflating: seg/bbox3/slice_506.tif  
  inflating: seg/bbox3/slice_089.tif  
  inflating: seg/bbox3/slice_076.tif  
  inflating: seg/bbox3/slice_062.tif  
  inflating: seg/bbox3/slice_102.tif  
  inflating: seg/bbox3/slice_116.tif  
  inflating: seg/bbox3/slice_499.tif  

In [None]:

!pip install transformers scikit-learn matplotlib seaborn torch torchvision umap-learn git+https://github.com/funkelab/funlib.learn.torch.git
!pip install openpyxl


Collecting git+https://github.com/funkelab/funlib.learn.torch.git
  Cloning https://github.com/funkelab/funlib.learn.torch.git to /tmp/pip-req-build-nh9iwyur
  Running command git clone --filter=blob:none --quiet https://github.com/funkelab/funlib.learn.torch.git /tmp/pip-req-build-nh9iwyur
  Resolved https://github.com/funkelab/funlib.learn.torch.git to commit 049729151c7a2c0320a446dc9d3244ac830f7ea8
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
Collecting umap-learn
  Downloading umap_learn-0.5.7-py3-none-any.whl.metadata (21 kB)
Collecting pynndescent>=0.5 (from umap-learn)
  Downloading pynndescent-0.5.13-py3-none-any.whl.metadata (6.8 kB)
Downloading umap_learn-0.5.7-py3-none-any.whl (88 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m88.8/88.8 kB[0m [31m4.9 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading pynndescent-0.5.13-py3-none-any.w

In [None]:

import os
import glob
import imageio.v2 as iio
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib.patches import Patch, Rectangle
from torch.utils.data import Dataset, DataLoader
from transformers import ViTImageProcessor, ViTModel
import torch
import torch.nn as nn
from sklearn.preprocessing import StandardScaler
from sklearn.decomposition import PCA
import seaborn as sns
from umap import UMAP
import torch.nn.functional as F
from funlib.learn.torch.models import Vgg3D

The cache for model files in Transformers v4.22.0 has been updated. Migrating your old cache. This is a one-time only operation. You can interrupt this and resume the migration later on by calling `transformers.utils.move_cache()`.


0it [00:00, ?it/s]

# Vit-MAE (Working :) )

In [None]:
import os
import glob
import numpy as np
import pandas as pd
import imageio.v3 as iio
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from transformers import ViTMAEForPreTraining, ViTImageProcessor
from torchvision import transforms
from sklearn.preprocessing import StandardScaler
import umap
import plotly.express as px
from tqdm import tqdm

# ------------------------------
# 1. Device Configuration
# ------------------------------
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')

# ------------------------------
# 2. Data Loading and Preparation
# ------------------------------

# Define base directories
raw_base_dir = '/content/raw'  # Replace with your actual path
seg_base_dir = '/content/seg'  # Replace with your actual path
bbox_names = [f'bbox{i}' for i in range(1, 4)]  # ['bbox1', 'bbox2', ..., 'bbox7']

def load_bbox_data(bbox_name, max_slices=None):
    """
    Load raw and segmentation volumes for a given bounding box.
    """
    raw_dir = os.path.join(raw_base_dir, bbox_name)
    seg_dir = os.path.join(seg_base_dir, bbox_name)

    raw_tif_files = sorted(glob.glob(os.path.join(raw_dir, 'slice_*.tif')))
    seg_tif_files = sorted(glob.glob(os.path.join(seg_dir, 'slice_*.tif')))

    if max_slices is not None:
        raw_tif_files = raw_tif_files[:max_slices]
        seg_tif_files = seg_tif_files[:max_slices]

    assert len(raw_tif_files) == len(seg_tif_files), f"Raw/Seg mismatch in {bbox_name}"

    raw_slices = [iio.imread(f) for f in raw_tif_files]
    seg_slices = [iio.imread(f).astype(np.uint32) for f in seg_tif_files]

    raw_vol = np.stack(raw_slices, axis=0)  # Shape: (Z, Y, X)
    seg_vol = np.stack(seg_slices, axis=0)  # Shape: (Z, Y, X)
    return raw_vol, seg_vol

def create_segment_masks(seg_vol, side1_coord, side2_coord):
    """
    Create binary masks for two segments based on provided coordinates.
    """
    x1, y1, z1 = [int(c) for c in side1_coord]
    x2, y2, z2 = [int(c) for c in side2_coord]

    seg_id_1 = seg_vol[z1, y1, x1]
    seg_id_2 = seg_vol[z2, y2, x2]

    mask_1 = (seg_vol == seg_id_1) if seg_id_1 != 0 else np.zeros_like(seg_vol, dtype=bool)
    mask_2 = (seg_vol == seg_id_2) if seg_id_2 != 0 else np.zeros_like(seg_vol, dtype=bool)
    return mask_1, mask_2

class SynapseDataset(Dataset):
    def __init__(self, vol_data_list, synapse_df_list, subvol_size=80, transform=None):
        """
        Initialize the dataset with volumes and corresponding synapse annotations.
        """
        self.vol_data_list = vol_data_list  # List of tuples: [(raw_vol1, seg_vol1), ..., (raw_vol7, seg_vol7)]
        self.synapse_entries = []  # List to hold all synapse annotations across bboxes

        # Iterate through each bounding box's synapse dataframe and append to the list
        for bbox_idx, synapse_df in enumerate(synapse_df_list):
            for _, row in synapse_df.iterrows():
                entry = {
                    'Var1': row['Var1'],  # Include Var1 from Excel
                    'bbox_name': bbox_names[bbox_idx],  # Store bounding box name
                    'bbox_index': bbox_idx,
                    'central_coord': (
                        int(row['central_coord_1']),
                        int(row['central_coord_2']),
                        int(row['central_coord_3'])
                    ),
                    'side1_coord': (
                        int(row['side_1_coord_1']),
                        int(row['side_1_coord_2']),
                        int(row['side_1_coord_3'])
                    ),
                    'side2_coord': (
                        int(row['side_2_coord_1']),
                        int(row['side_2_coord_2']),
                        int(row['side_2_coord_3'])
                    )
                }
                self.synapse_entries.append(entry)

        self.subvol_size = subvol_size
        self.half_size = subvol_size // 2
        self.transform = transform

    def __len__(self):
        return len(self.synapse_entries)

    def __getitem__(self, idx):
        """
        Retrieve the subvolume around the synapse and extract axial, coronal, sagittal slices.
        Returns:
            - Stacked slices (3-channel image)
            - Var1 value
            - Bounding box name
        """
        syn_info = self.synapse_entries[idx]
        bbox_index = syn_info['bbox_index']
        raw_vol, seg_vol = self.vol_data_list[bbox_index]

        central_coord = syn_info['central_coord']
        side1_coord = syn_info['side1_coord']
        side2_coord = syn_info['side2_coord']

        mask_1_full, mask_2_full = create_segment_masks(seg_vol, side1_coord, side2_coord)

        cx, cy, cz = central_coord
        x_start = max(cx - self.half_size, 0)
        x_end = min(cx + self.half_size, raw_vol.shape[2])
        y_start = max(cy - self.half_size, 0)
        y_end = min(cy + self.half_size, raw_vol.shape[1])
        z_start = max(cz - self.half_size, 0)
        z_end = min(cz + self.half_size, raw_vol.shape[0])

        sub_raw = raw_vol[z_start:z_end, y_start:y_end, x_start:x_end]

        # Pad if necessary to ensure the subvolume has the desired shape
        desired_shape = (self.subvol_size, self.subvol_size, self.subvol_size)
        padded_sub_raw = np.zeros(desired_shape, dtype=sub_raw.dtype)
        dz, dy, dx = sub_raw.shape
        padded_sub_raw[:dz, :dy, :dx] = sub_raw

        # Extract slices
        axial_slice = padded_sub_raw[self.half_size, :, :]      # XY plane
        coronal_slice = padded_sub_raw[:, self.half_size, :]    # XZ plane
        sagittal_slice = padded_sub_raw[:, :, self.half_size]   # YZ plane

        # Stack slices to form a 3-channel image with shape (H, W, C)
        stacked_slices = np.stack([axial_slice, coronal_slice, sagittal_slice], axis=-1)  # Shape: (80,80,3)

        # Add assertion to ensure correct shape
        assert stacked_slices.shape == (self.subvol_size, self.subvol_size, 3), f"Incorrect stacked_slices shape: {stacked_slices.shape}"

        # Ensure the data is in float format
        if stacked_slices.dtype != np.float32 and stacked_slices.dtype != np.uint8:
            stacked_slices = stacked_slices.astype(np.float32)

        # Apply transformations if any
        if self.transform:
            try:
                stacked_slices = self.transform(stacked_slices)
            except Exception as e:
                print(f"Error applying transform on idx {idx}: {e}")
                raise e

        return stacked_slices, syn_info['Var1'], syn_info['bbox_name']

# ------------------------------
# 3. Masked Autoencoder (MAE) Fine-tuning
# ------------------------------

# Initialize the Image Processor
image_processor = ViTImageProcessor.from_pretrained("facebook/vit-mae-base")

# Initialize the MAE model
try:
    mae_model = ViTMAEForPreTraining.from_pretrained("facebook/vit-mae-base")
except Exception as e:
    print(f"Error loading ViTMAE model: {e}")
    raise e

mae_model.to(device)
mae_model.train()

# Define a custom transform to include the 'size' parameter in image_processor.resize
class ResizeWithProcessor:
    def __init__(self, image_processor, size):
        self.image_processor = image_processor
        self.size = size

    def __call__(self, image):
        return self.image_processor.resize(image, size=self.size)

# # Define transformations including masking (handled by the MAE model)
# mae_transforms = transforms.Compose([
#     transforms.ToPILImage(),  # Converts numpy array (H, W, C) to PIL Image
#     ResizeWithProcessor(image_processor, size=(224, 224)),
#     transforms.ToTensor(),    # Converts PIL Image to Tensor (C, H, W) and scales to [0,1]
#     transforms.Normalize(mean=image_processor.image_mean, std=image_processor.image_std),  # Normalization
# ])


Using device: cuda


In [None]:
# Modified transform pipeline with proper PIL Image handling
import numpy as np
from PIL import Image
import torch
import torchvision.transforms as transforms
from transformers import ViTImageProcessor

def normalize_array(array):
    """Normalize array to 0-255 range for PIL Image conversion"""
    min_val = array.min()
    max_val = array.max()
    if max_val == min_val:
        return np.zeros_like(array)
    normalized = ((array - min_val) / (max_val - min_val) * 255).astype(np.uint8)
    return normalized

class CustomImageTransform:
    def __init__(self, image_processor, size=(224, 224)):
        self.image_processor = image_processor
        self.size = size
        self.to_tensor = transforms.ToTensor()

    def __call__(self, image):
        # Ensure image is numpy array
        if not isinstance(image, np.ndarray):
            raise ValueError("Input must be a numpy array")

        # Normalize each channel separately
        normalized_channels = []
        for i in range(image.shape[-1]):  # Iterate over the channels
            channel = image[..., i]
            normalized_channel = normalize_array(channel)
            normalized_channels.append(normalized_channel)

        # Stack normalized channels
        normalized_image = np.stack(normalized_channels, axis=-1)

        # Convert to PIL Image
        pil_image = Image.fromarray(normalized_image)

        # Resize using standard PIL resize
        pil_image = pil_image.resize(self.size, Image.Resampling.BILINEAR)

        # Convert to tensor
        tensor_image = self.to_tensor(pil_image)

        # Normalize using the image processor's mean and std
        normalized_tensor = transforms.Normalize(
            mean=self.image_processor.image_mean,
            std=self.image_processor.image_std
        )(tensor_image)

        return normalized_tensor

# Usage in your main code:
image_processor = ViTImageProcessor.from_pretrained("facebook/vit-mae-base")

# Replace your existing mae_transforms with:
mae_transforms = CustomImageTransform(image_processor)

# Load synapse annotations
def load_synapse_annotations(bbox_name):
    """
    Load synapse annotations from an Excel file for a given bounding box.
    """
    excel_path = os.path.join('/content', f'{bbox_name}.xlsx')  # Replace with your actual path
    if not os.path.exists(excel_path):
        raise FileNotFoundError(f"Excel file for {bbox_name} not found at {excel_path}")
    synapse_df = pd.read_excel(excel_path)
    # Validate necessary columns
    required_columns = [
        'Var1', 'central_coord_1', 'central_coord_2', 'central_coord_3',
        'side_1_coord_1', 'side_1_coord_2', 'side_1_coord_3',
        'side_2_coord_1', 'side_2_coord_2', 'side_2_coord_3'
    ]
    for col in required_columns:
        if col not in synapse_df.columns:
            raise ValueError(f"Column '{col}' not found in {excel_path}")
    return synapse_df

# Load all synapse annotations
try:
    synapse_df_list = [load_synapse_annotations(bbox) for bbox in bbox_names]
except Exception as e:
    print(f"Error loading synapse annotations: {e}")
    raise e

# Load all volumes
try:
    vol_data_list = [load_bbox_data(bbox) for bbox in bbox_names]
except Exception as e:
    print(f"Error loading bounding box data: {e}")
    raise e

# Initialize the dataset for MAE
mae_dataset = SynapseDataset(vol_data_list, synapse_df_list, subvol_size=80, transform=mae_transforms)

# Create DataLoader
batch_size = 32
num_epochs = 10  # Adjust based on your computational resources
mae_loader = DataLoader(
    mae_dataset,
    batch_size=batch_size,
    shuffle=True,
    num_workers=2,  # Adjusted to match system recommendations
    pin_memory=True if device.type == 'cuda' else False
)

# Define optimizer and loss function
optimizer = torch.optim.AdamW(mae_model.parameters(), lr=1e-4)
criterion = nn.MSELoss()

# Training Loop for MAE
print("Starting MAE Fine-tuning...")
for epoch in range(num_epochs):
    mae_model.train()
    total_loss = 0
    for batch_idx, batch in enumerate(tqdm(mae_loader, desc=f"Epoch {epoch+1}/{num_epochs}")):
        images, _, _ = batch  # We only need images for MAE
        pixel_values = images.to(device)  # Shape: (B, 3, 224, 224)

        optimizer.zero_grad()
        try:
            outputs = mae_model(pixel_values=pixel_values)
            loss = outputs.loss
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        except Exception as e:
            print(f"Error during MAE training at epoch {epoch+1}, batch {batch_idx+1}: {e}")
            continue  # Skip to the next batch

    avg_loss = total_loss / len(mae_loader)
    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {avg_loss:.4f}")

# Save the fine-tuned MAE model
mae_model_save_path = 'fine_tuned_mae_vit_base_patch16.pth'
torch.save(mae_model.state_dict(), mae_model_save_path)
print(f"Fine-tuned MAE model saved to '{mae_model_save_path}'.")

# ------------------------------
# 4. Feature Extraction
# ------------------------------

# Switch MAE model to evaluation mode
mae_model.eval()

# # Initialize the Feature Extractor (Encoder Only)
# class MAEFeatureExtractor(nn.Module):
#     def __init__(self, mae_model):
#         super(MAEFeatureExtractor, self).__init__()
#         # Extract the encoder from MAE
#         self.encoder = mae_model.encoder  # ViT encoder

#     def forward(self, x):
#         with torch.no_grad():
#             # Pass through encoder
#             outputs = self.encoder(x)
#             # Extract the [CLS] token
#             cls_token = outputs.last_hidden_state[:, 0, :]  # Shape: (B, hidden_size)
#         return cls_token  # Shape: (B, hidden_size)

# # Initialize the feature extractor
# feature_extractor = MAEFeatureExtractor(mae_model)
# feature_extractor.to(device)
# feature_extractor.eval()

# # Initialize the dataset for feature extraction (using the same dataset)
# feature_dataset = SynapseDataset(vol_data_list, synapse_df_list, subvol_size=80, transform=mae_transforms)

# # Create DataLoader for feature extraction
# feature_loader = DataLoader(
#     feature_dataset,
#     batch_size=batch_size,
#     shuffle=False,
#     num_workers=2,  # Adjusted to match system recommendations
#     pin_memory=True if device.type == 'cuda' else False
# )

# # Extract features
# features = []
# var1_list = []
# bbox_name_list = []
# print("Extracting features using MAE encoder...")
# with torch.no_grad():
#     for batch_idx, batch in enumerate(tqdm(feature_loader, desc="Feature Extraction")):
#         images, var1, bbox_name = batch
#         pixel_values = images.to(device)  # Shape: (B, 3, 224, 224)
#         try:
#             cls_tokens = feature_extractor(pixel_values)  # Shape: (B, hidden_size)
#             features.append(cls_tokens.cpu().numpy())
#             var1_list.extend(var1)
#             bbox_name_list.extend(bbox_name)
#         except Exception as e:
#             print(f"Error during feature extraction at batch {batch_idx+1}: {e}")
#             continue  # Skip to the next batch

# # Concatenate all features
# features = np.concatenate(features, axis=0)  # Shape: (Total_Samples, hidden_size)
# print(f"Extracted features shape: {features.shape}")

# ------------------------------
# 5. Dimensionality Reduction and Visualization
# ------------------------------



Starting MAE Fine-tuning...


Epoch 1/10: 100%|██████████| 7/7 [00:46<00:00,  6.70s/it]


Epoch [1/10], Loss: 0.2599


Epoch 2/10: 100%|██████████| 7/7 [00:45<00:00,  6.56s/it]


Epoch [2/10], Loss: 0.2430


Epoch 3/10: 100%|██████████| 7/7 [00:45<00:00,  6.48s/it]


Epoch [3/10], Loss: 0.2354


Epoch 4/10: 100%|██████████| 7/7 [00:45<00:00,  6.44s/it]


Epoch [4/10], Loss: 0.2298


Epoch 5/10: 100%|██████████| 7/7 [00:44<00:00,  6.40s/it]


Epoch [5/10], Loss: 0.2300


Epoch 6/10: 100%|██████████| 7/7 [00:44<00:00,  6.37s/it]


Epoch [6/10], Loss: 0.2272


Epoch 7/10: 100%|██████████| 7/7 [00:44<00:00,  6.40s/it]


Epoch [7/10], Loss: 0.2268


Epoch 8/10: 100%|██████████| 7/7 [00:45<00:00,  6.49s/it]


Epoch [8/10], Loss: 0.2249


Epoch 9/10: 100%|██████████| 7/7 [00:45<00:00,  6.56s/it]


Epoch [9/10], Loss: 0.2248


Epoch 10/10: 100%|██████████| 7/7 [00:45<00:00,  6.48s/it]


Epoch [10/10], Loss: 0.2251
Fine-tuned MAE model saved to 'fine_tuned_mae_vit_base_patch16.pth'.


AttributeError: 'ViTMAEForPreTraining' object has no attribute 'encoder'

In [None]:
# Modified Feature Extractor class
class MAEFeatureExtractor(nn.Module):
    def __init__(self, mae_model):
        super(MAEFeatureExtractor, self).__init__()
        # The encoder is part of the vit model
        self.encoder = mae_model.vit

    def forward(self, x):
        with torch.no_grad():
            # Get the outputs from the encoder
            outputs = self.encoder(x, interpolate_pos_encoding=True)
            # Extract the [CLS] token
            cls_token = outputs.last_hidden_state[:, 0, :]
            return cls_token

# Initialize the feature extractor with the fine-tuned model
feature_extractor = MAEFeatureExtractor(mae_model)
feature_extractor.to(device)
feature_extractor.eval()

# Initialize the dataset for feature extraction
feature_dataset = SynapseDataset(vol_data_list, synapse_df_list, subvol_size=80, transform=mae_transforms)

# Create DataLoader for feature extraction
feature_loader = DataLoader(
    feature_dataset,
    batch_size=batch_size,
    shuffle=False,
    num_workers=2,
    pin_memory=True if device.type == 'cuda' else False
)

# Extract features
print("Extracting features using MAE encoder...")
features = []
var1_list = []
bbox_name_list = []

with torch.no_grad():
    for batch_idx, batch in enumerate(tqdm(feature_loader, desc="Feature Extraction")):
        try:
            images, var1, bbox_name = batch
            pixel_values = images.to(device)

            # Extract features using the feature extractor
            cls_tokens = feature_extractor(pixel_values)

            # Store results
            features.append(cls_tokens.cpu().numpy())
            var1_list.extend(var1)
            bbox_name_list.extend(bbox_name)

        except Exception as e:
            print(f"Error during feature extraction at batch {batch_idx+1}: {e}")
            continue

# Concatenate all features
features = np.concatenate(features, axis=0)
print(f"Extracted features shape: {features.shape}")

# Continue with the rest of your visualization code...

# Normalize features
scaler = StandardScaler()
features_scaled = scaler.fit_transform(features)

# Perform UMAP dimensionality reduction
print("Performing UMAP dimensionality reduction...")
umap_reducer = umap.UMAP(n_components=2, random_state=42)
features_umap = umap_reducer.fit_transform(features_scaled)
print("UMAP completed.")

# Create a DataFrame for visualization
visualization_df = pd.DataFrame({
    'UMAP_1': features_umap[:, 0],
    'UMAP_2': features_umap[:, 1],
    'Var1': var1_list,
    'BoundingBox': bbox_name_list
})

# Plot with Plotly Express, coloring by BoundingBox and showing Var1 on hover
print("Generating interactive Plotly scatter plot...")
fig = px.scatter(
    visualization_df,
    x='UMAP_1',
    y='UMAP_2',
    color='BoundingBox',
    title='UMAP Visualization of Synapse Features',
    labels={
        'UMAP_1': 'UMAP Component 1',
        'UMAP_2': 'UMAP Component 2',
        'BoundingBox': 'Bounding Box'
    },
    hover_data=['Var1'],  # Include Var1 in hover
    color_discrete_sequence=px.colors.qualitative.Dark24  # A palette with many distinct colors
)

# Enhance the layout
fig.update_layout(
    legend_title_text='Bounding Box',
    template='plotly_white',  # Clean white background
    width=900,
    height=700
)

# Display the plot
fig.show()

# ------------------------------
# 6. Saving Results (Optional)
# ------------------------------

# Save the visualization as an HTML file
# Uncomment the following lines if you wish to save the plot
# output_html_path = 'synapse_umap_visualization.html'
# fig.write_html(output_html_path)
# print(f"Plotly figure saved to '{output_html_path}'.")

# Save the features and metadata to a CSV file
output_csv_path = 'synapse_features_umap.csv'
visualization_df.to_csv(output_csv_path, index=False)
print(f"Feature extraction results saved to '{output_csv_path}'.")

# ------------------------------
# 7. Summary and Next Steps
# ------------------------------

"""
Summary:
1. **Data Loading:** Loaded 7 bounding boxes' raw and segmentation data along with their synapse annotations from Excel files, including the `Var1` feature.
2. **Dataset Preparation:** Created a custom PyTorch Dataset that extracts subvolumes around each synapse and stacks axial, coronal, and sagittal slices to form a 3-channel image suitable for the Masked Autoencoder (MAE).
3. **MAE Fine-tuning:** Implemented and fine-tuned a Masked Autoencoder (MAE) using the Vision Transformer (ViT) architecture on the synapse images.
4. **Feature Extraction:** Utilized the fine-tuned MAE encoder to extract meaningful feature vectors from the synapse images.
5. **Dimensionality Reduction:** Applied UMAP to reduce the high-dimensional feature vectors to 2D for visualization purposes.
6. **Visualization:** Created an interactive Plotly scatter plot displaying the UMAP components, colored by bounding box names, and including the `Var1` feature in the hover information.
7. **Results Saving:** Saved the feature extraction results to a CSV file for further analysis.

Next Steps:
- **Fine-tuning MAE Further:** Depending on the initial results, consider fine-tuning the MAE for more epochs or adjusting hyperparameters to enhance feature quality.
- **Advanced Visualization:** Incorporate additional metadata or explore different dimensionality reduction techniques like t-SNE for comparative analysis.
- **Further Analysis:** Analyze the distribution of `Var1` across different bounding boxes to uncover underlying patterns or relationships.
- **Model Deployment:** Integrate the MAE and feature extraction pipeline into a larger framework for automated analysis of new 3D-EM data.
- **Performance Optimization:** If dealing with larger datasets, consider optimizing data loading and processing steps to improve computational efficiency.
- **Error Handling:** Implement more robust error handling mechanisms to manage potential issues during data loading, transformation, or model training phases.
"""

# ------------------------------
# End of Script
# ------------------------------

Extracting features using MAE encoder...


Feature Extraction: 100%|██████████| 7/7 [00:47<00:00,  6.80s/it]
  warn(


Extracted features shape: (220, 768)
Performing UMAP dimensionality reduction...
UMAP completed.
Generating interactive Plotly scatter plot...


Feature extraction results saved to 'synapse_features_umap.csv'.


"\nSummary:\n1. **Data Loading:** Loaded 7 bounding boxes' raw and segmentation data along with their synapse annotations from Excel files, including the `Var1` feature.\n2. **Dataset Preparation:** Created a custom PyTorch Dataset that extracts subvolumes around each synapse and stacks axial, coronal, and sagittal slices to form a 3-channel image suitable for the Masked Autoencoder (MAE).\n3. **MAE Fine-tuning:** Implemented and fine-tuned a Masked Autoencoder (MAE) using the Vision Transformer (ViT) architecture on the synapse images.\n4. **Feature Extraction:** Utilized the fine-tuned MAE encoder to extract meaningful feature vectors from the synapse images.\n5. **Dimensionality Reduction:** Applied UMAP to reduce the high-dimensional feature vectors to 2D for visualization purposes.\n6. **Visualization:** Created an interactive Plotly scatter plot displaying the UMAP components, colored by bounding box names, and including the `Var1` feature in the hover information.\n7. **Results 