<a href="https://colab.research.google.com/github/alim98/MPI/blob/main/SynapseSegmentationDataset.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Essential downloads

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
!pip uninstall ploty
!pip install plotly==5.3.1

[0mCollecting plotly==5.3.1
  Downloading plotly-5.3.1-py2.py3-none-any.whl.metadata (7.4 kB)
Downloading plotly-5.3.1-py2.py3-none-any.whl (23.9 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m23.9/23.9 MB[0m [31m72.1 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: plotly
  Attempting uninstall: plotly
    Found existing installation: plotly 5.24.1
    Uninstalling plotly-5.24.1:
      Successfully uninstalled plotly-5.24.1
Successfully installed plotly-5.3.1


In [None]:
!wget -O downloaded_file.zip "https://drive.usercontent.google.com/download?id=1iHPBdBOPEagvPTHZmrN__LD49emXwReY&export=download&authuser=0&confirm=t&uuid=631d60dd-569c-4bb1-a9e8-d681f0ed3d43&at=APvzH3r4me8x_LwP3n8O7lgPo8oK%3A1733988188000"
!wget -O vesicle_cloud__syn_interface__mitochondria_annotation.zip "https://drive.usercontent.google.com/download?id=1qRibZL3kr7MQJQRgDFRquHMQlIGCN4XP&export=download&authuser=0&confirm=t&uuid=631d60dd-569c-4bb1-a9e8-d681f0ed3d43&at=APvzH3r4me8x_LwP3n8O7lgPo8oK%3A1733988188000"

!unzip -q downloaded_file.zip
!unzip -q vesicle_cloud__syn_interface__mitochondria_annotation.zip

!pip install transformers scikit-learn matplotlib seaborn torch torchvision umap-learn git+https://github.com/funkelab/funlib.learn.torch.git
!pip install openpyxl


--2025-02-18 10:18:18--  https://drive.usercontent.google.com/download?id=1iHPBdBOPEagvPTHZmrN__LD49emXwReY&export=download&authuser=0&confirm=t&uuid=631d60dd-569c-4bb1-a9e8-d681f0ed3d43&at=APvzH3r4me8x_LwP3n8O7lgPo8oK%3A1733988188000
Resolving drive.usercontent.google.com (drive.usercontent.google.com)... 74.125.130.132, 2404:6800:4003:c01::84
Connecting to drive.usercontent.google.com (drive.usercontent.google.com)|74.125.130.132|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 1264688649 (1.2G) [application/octet-stream]
Saving to: ‘downloaded_file.zip’


2025-02-18 10:18:40 (64.8 MB/s) - ‘downloaded_file.zip’ saved [1264688649/1264688649]

--2025-02-18 10:18:40--  https://drive.usercontent.google.com/download?id=1qRibZL3kr7MQJQRgDFRquHMQlIGCN4XP&export=download&authuser=0&confirm=t&uuid=631d60dd-569c-4bb1-a9e8-d681f0ed3d43&at=APvzH3r4me8x_LwP3n8O7lgPo8oK%3A1733988188000
Resolving drive.usercontent.google.com (drive.usercontent.google.com)... 74.125.130.132,


---

# Synapse Dataset Processing

This repository provides a tool for processing 3D volume data of synapse structures for segmentation, visualization, and analysis. The tool allows users to load, segment, and process raw and segmented image data and generate 3D visualizations of synapse structures.

The code processes multiple bounding boxes, each containing raw, segmentation, and additional mask data. Users can customize the segmentation overlay and generate visualizations in the form of segmented 3D cubes.

## Features

- Load raw, segmentation, and additional mask data from directories.
- Customize segmentation overlays for different synapse structures (e.g., vesicles, clefts, mitochondria).
- Process multiple bounding boxes in parallel.
- Generate 3D cubes with customizable segmentation.
- Alpha blending for better visualization of the data.
- Ability to save generated GIFs for visual inspection.
- Efficient processing pipelines for handling large synapse datasets.

## Requirements

Before using the tool, make sure you have the following Python packages installed:

```bash
pip install numpy pandas imageio tqdm torch torchvision scipy
```

## Arguments Overview

To run the script, use the following arguments to configure the dataset processing:

### `--raw_base_dir` (Required)
- **Description**: Directory containing the raw image data files (e.g., `.tif` slices).
- **Type**: `str`
- **Default**: `'raw'`
- **Example**:
    ```bash
    --raw_base_dir /path/to/raw/data
    ```

### `--seg_base_dir` (Required)
- **Description**: Directory containing segmentation data files for pre and post-synaptic structures (e.g., `.tif` slices).
- **Type**: `str`
- **Default**: `'seg'`
- **Example**:
    ```bash
    --seg_base_dir /path/to/segmentation/data
    ```

### `--add_mask_base_dir` (Optional)
- **Description**: Directory containing additional mask files for vesicles, clefts, and mitochondria (e.g., `.tif` slices).
- **Type**: `str`
- **Default**: `''` (empty string, optional)
- **Example**:
    ```bash
    --add_mask_base_dir /path/to/additional/masks
    ```

### `--bbox_name` (Required)
- **Description**: List of bounding box names to process. Each bounding box corresponds to a set of data files (raw, segmentation, and masks).
- **Type**: `list[str]`
- **Default**: `['bbox1']`
- **Example**:
    ```bash
    --bbox_name bbox1 bbox2 bbox3
    ```

### `--excel_file` (Required)
- **Description**: Path to the directory containing Excel files with synapse information. The data from these Excel files will be used for synapse annotations.
- **Type**: `str`
- **Default**: `''` (required path to a directory)
- **Example**:
    ```bash
    --excel_file /path/to/excel/files
    ```

### `--csv_output_dir` (Optional)
- **Description**: Directory to save CSV outputs, such as processed data summaries.
- **Type**: `str`
- **Default**: `'csv_outputs'`
- **Example**:
    ```bash
    --csv_output_dir /path/to/csv/outputs
    ```

### `--size` (Optional)
- **Description**: Target size for the frames. This will resize the frames to this size before processing.
- **Type**: `tuple[int, int]`
- **Default**: `(80, 80)`
- **Example**:
    ```bash
    --size 128 128
    ```

### `--subvol_size` (Optional)
- **Description**: Subvolume size for extracting regions from the full volume. This size determines the 3D crop of the data.
- **Type**: `int`
- **Default**: `80`
- **Example**:
    ```bash
    --subvol_size 128
    ```

### `--num_frames` (Optional)
- **Description**: Number of frames to extract from the data.
- **Type**: `int`
- **Default**: `80`
- **Example**:
    ```bash
    --num_frames 16
    ```

### `--save_gifs_dir` (Optional)
- **Description**: Directory to save generated GIFs for each segmentation type.
- **Type**: `str`
- **Default**: `'gifs'`
- **Example**:
    ```bash
    --save_gifs_dir /path/to/save/gifs
    ```

### `--alpha` (Optional)
- **Description**: Alpha blending factor for combining the raw image and the mask. This controls how much the unmasked areas are blended with a black overlay.
- **Type**: `float`
- **Default**: `0.5`
- **Example**:
    ```bash
    --alpha 0.7
    ```

### `--segmentation_type` (Required)
- **Description**: Defines which type of segmentation overlay to apply to the raw data. This option determines which mask type will be used for overlaying the raw image.
- **Type**: `int`
- **Choices**:
    - `0`: Raw image only (no overlay).
    - `1`: Presynapse region.
    - `2`: Postsynapse region.
    - `3`: Both presynapse and postsynapse.
    - `4`: Vesicles + Cleft (closest only).
    - `5`: All structures (vesicles, clefts, mitochondria, and both synaptic sides).
    - `6`: Vesicle cloud (closest).
    - `7`: Cleft regions only.
    - `8`: Mitochondria regions only.
    - `9`: Vesicle + Cleft combined (closest).
- **Default**: `6`
- **Example**:
    ```bash
    --segmentation_type 3
    ```

## Example Usage

Here is an example of how to run the script with the necessary arguments:

```bash
python data_loader.py \
    --raw_base_dir /path/to/raw/data \
    --seg_base_dir /path/to/segmentation/data \
    --add_mask_base_dir /path/to/additional/masks \
    --bbox_name bbox1 bbox2 \
    --excel_file /path/to/excel/files \
    --csv_output_dir /path/to/csv/outputs \
    --save_gifs_dir /path/to/save/gifs \
    --segmentation_type 2 \
    --alpha 0.5
```

## Segmentation Type Handling

The segmentation logic is based on the `segmentation_type` argument. It determines how to combine masks and create the desired visualization.

### Segmentation Logic:

- **`0`**: Raw image (no overlay)
- **`1`**: Presynapse region (based on overlap of vesicles with side1 or side2).
- **`2`**: Postsynapse region (based on overlap of vesicles with side1 or side2).
- **`3`**: Both presynapssde and postsynapse regions (overlay of both).
- **`4`**: Vesicles and clefts (closest components).
- **`5`**: All structures (vesicles, clefts, and both synaptic sides).
- **`6`**: Vesicle cloud (closest to target).
- **`7`**: Cleft regions (closest to target).
- **`8`**: Mitochondria regions (closest to target).
- **`9`**: Combined vesicle + cleft (closest to target).
- **`10`**: presynapse + cleft (closest to target).


In [None]:
import os
import glob
import io
import argparse
import multiprocessing
from typing import List, Tuple
import imageio
import numpy as np
import pandas as pd
import imageio.v3 as iio
import torch
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
from torchvision import transforms
from scipy.ndimage import label, center_of_mass
from tqdm import tqdm
class Synapse3DProcessor:
    def __init__(self, size=(80, 80), mean=(0.485,), std=(0.229,)):
        # Use Grayscale without converting to 3-channel RGB
        self.transform = transforms.Compose([
            transforms.ToPILImage(),
            transforms.Resize(size),
            transforms.Grayscale(num_output_channels=1),  # Changed to 1 channel (grayscale)
            transforms.ToTensor(),
            # transforms.Normalize(mean=mean, std=std),
        ])


    def __call__(self, frames, return_tensors=None):
        processed_frames = [self.transform(frame) for frame in frames]
        pixel_values = torch.stack(processed_frames)
        if return_tensors == "pt":
            return {"pixel_values": pixel_values}
        else:
            return pixel_values

def load_volumes(bbox_name: str, raw_base_dir: str, seg_base_dir: str, add_mask_base_dir: str) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
    raw_dir = os.path.join(raw_base_dir, bbox_name)
    seg_dir = os.path.join(seg_base_dir, bbox_name)
    if bbox_name.startswith("bbox"):
        bbox_num = bbox_name.replace("bbox", "")
        add_mask_dir = os.path.join(add_mask_base_dir, f"bbox_{bbox_num}")
    else:
        add_mask_dir = os.path.join(add_mask_base_dir, bbox_name)
    raw_tif_files = sorted(glob.glob(os.path.join(raw_dir, 'slice_*.tif')))
    seg_tif_files = sorted(glob.glob(os.path.join(seg_dir, 'slice_*.tif')))
    add_mask_tif_files = sorted(glob.glob(os.path.join(add_mask_dir, 'slice_*.tif')))
    if not (len(raw_tif_files) == len(seg_tif_files) == len(add_mask_tif_files)):
        return None, None, None
    try:
        raw_vol = np.stack([iio.imread(f) for f in raw_tif_files], axis=0)
        seg_vol = np.stack([iio.imread(f).astype(np.uint32) for f in seg_tif_files], axis=0)
        add_mask_vol = np.stack([iio.imread(f).astype(np.uint32) for f in add_mask_tif_files], axis=0)
        return raw_vol, seg_vol, add_mask_vol
    except Exception as e:
        return None, None, None

def parse_args():
    parser = argparse.ArgumentParser(description="Synapse Dataset")
    parser.add_argument('--raw_base_dir', type=str, default='raw')
    parser.add_argument('--seg_base_dir', type=str, default='seg')
    parser.add_argument('--add_mask_base_dir', type=str, default='')
    parser.add_argument('--bbox_name', type=str, default=['bbox1'], nargs='+')
    parser.add_argument('--excel_file', type=str, default='')
    parser.add_argument('--csv_output_dir', type=str, default='csv_outputs')
    parser.add_argument('--size', type=tuple, default=(80, 80))
    parser.add_argument('--subvol_size', type=int, default=80)
    parser.add_argument('--num_frames', type=int, default=80)
    parser.add_argument('--save_gifs_dir', type=str, default='gifs')
    parser.add_argument('--alpha', type=float, default=0.5)
    parser.add_argument('--segmentation_type', type=int, default=6, choices=range(0, 13),
                        help='Type of segmentation overlay')
    args, _ = parser.parse_known_args()
    return args
def get_closest_component_mask(full_mask, z_start, z_end, y_start, y_end, x_start, x_end, target_coord):
    sub_mask = full_mask[z_start:z_end, y_start:y_end, x_start:x_end]
    labeled_sub_mask, num_features = label(sub_mask)
    if num_features == 0:
        return np.zeros_like(full_mask, dtype=bool)
    else:
        # For each label (vesicle cloud), find the nearest pixel to the target_coord
        cx, cy, cz = target_coord
        min_distance = float('inf')  # Initialize minimum distance as infinity
        closest_label = None

        for label_num in range(1, num_features + 1):  # labels are 1-based, not 0
            # Get the coordinates of all pixels that belong to this label (vesicle cloud)
            vesicle_coords = np.column_stack(np.where(labeled_sub_mask == label_num))

            # Compute the distance of each pixel in the vesicle cloud to the target coordinate
            distances = np.sqrt(
                (vesicle_coords[:, 0] + z_start - cz) ** 2 +
                (vesicle_coords[:, 1] + y_start - cy) ** 2 +
                (vesicle_coords[:, 2] + x_start - cx) ** 2
            )

            # Find the pixel with the minimum distance
            min_dist_for_vesicle = np.min(distances)
            if min_dist_for_vesicle < min_distance:
                min_distance = min_dist_for_vesicle
                closest_label = label_num

        # Now, create a mask for the closest vesicle cloud
        if closest_label is not None:
            filtered_sub_mask = (labeled_sub_mask == closest_label)
            combined_mask = np.zeros_like(full_mask, dtype=bool)
            combined_mask[z_start:z_end, y_start:y_end, x_start:x_end] = filtered_sub_mask
            return combined_mask
        else:
            return np.zeros_like(full_mask, dtype=bool)

def create_segmented_cube(
    raw_vol: np.ndarray,
    seg_vol: np.ndarray,
    add_mask_vol: np.ndarray,
    central_coord: Tuple[int, int, int],
    side1_coord: Tuple[int, int, int],
    side2_coord: Tuple[int, int, int],
    segmentation_type: int,
    subvolume_size: int = 80,
    alpha: float = 0.3,
    bbox_name: str = "",
) -> np.ndarray:
    bbox_num = bbox_name.replace("bbox", "").strip()
    if bbox_num in {'2', '5',}:
        mito_label = 1
        vesicle_label = 3
        cleft_label2 = 4
        cleft_label = 2
    elif bbox_num == '7':
        mito_label = 1
        vesicle_label = 2
        cleft_label2 = 3
        cleft_label = 4
    elif bbox_num == '4':
        mito_label = 3
        vesicle_label = 2
        cleft_label2 = 4
        cleft_label = 1
    elif bbox_num == '3':
        # print("bbox_num3")
        mito_label = 6
        vesicle_label = 7
        cleft_label2 = 8
        cleft_label = 9
    else:  # For bbox1, 3, 6, etc.
        mito_label = 5
        vesicle_label = 6
        cleft_label = 7
        cleft_label2 = 7

    # --- Always calculate subvolume bounds FIRST ---
    half_size = subvolume_size // 2
    cx, cy, cz = central_coord
    x_start = max(cx - half_size, 0)
    x_end = min(cx + half_size, raw_vol.shape[2])
    y_start = max(cy - half_size, 0)
    y_end = min(cy + half_size, raw_vol.shape[1])
    z_start = max(cz - half_size, 0)
    z_end = min(cz + half_size, raw_vol.shape[0])

    # --- Vesicle filtering (critical for presynapse determination) ---
    vesicle_full_mask = (add_mask_vol == vesicle_label)
    vesicle_mask = get_closest_component_mask(
        vesicle_full_mask,
        z_start, z_end,
        y_start, y_end,
        x_start, x_end,
        (cx, cy, cz)
    )

    # --- Side masks ---
    def create_segment_masks(segmentation_volume, s1_coord, s2_coord):
        x1, y1, z1 = s1_coord
        x2, y2, z2 = s2_coord
        seg_id_1 = segmentation_volume[z1, y1, x1]
        seg_id_2 = segmentation_volume[z2, y2, x2]
        mask_1 = (segmentation_volume == seg_id_1) if seg_id_1 != 0 else np.zeros_like(segmentation_volume, dtype=bool)
        mask_2 = (segmentation_volume == seg_id_2) if seg_id_2 != 0 else np.zeros_like(segmentation_volume, dtype=bool)
        return mask_1, mask_2

    mask_1_full, mask_2_full = create_segment_masks(seg_vol, side1_coord, side2_coord)

    # --- Determine pre-synapse side using filtered vesicles ---
    overlap_side1 = np.sum(np.logical_and(mask_1_full, vesicle_mask))
    overlap_side2 = np.sum(np.logical_and(mask_2_full, vesicle_mask))
    presynapse_side = 1 if overlap_side1 > overlap_side2 else 2
    if segmentation_type == 0: # Raw data
        combined_mask_full = np.ones_like(add_mask_vol, dtype=bool)
    elif segmentation_type == 1:  # Presynapse
        combined_mask_full = mask_1_full if presynapse_side == 1 else mask_2_full
    elif segmentation_type == 2:  # Postsynapse
        combined_mask_full = mask_2_full if presynapse_side == 1 else mask_1_full
    elif segmentation_type == 3:  # Both sides
        combined_mask_full = np.logical_or(mask_1_full, mask_2_full)
    elif segmentation_type == 4:  # Vesicles + Cleft (closest only)
        vesicle_closest = get_closest_component_mask(
            (add_mask_vol == vesicle_label), z_start, z_end, y_start, y_end, x_start, x_end, (cx, cy, cz)
        )
        cleft_closest = get_closest_component_mask(
            ((add_mask_vol == cleft_label)), z_start, z_end, y_start, y_end, x_start, x_end, (cx, cy, cz)
        )
        cleft_closest2 = get_closest_component_mask(
            ((add_mask_vol == cleft_label2)), z_start, z_end, y_start, y_end, x_start, x_end, (cx, cy, cz)
        )
        combined_mask_full = np.logical_or(vesicle_closest, np.logical_or(cleft_closest,cleft_closest2))
    elif segmentation_type == 5:  # (closest vesicles/cleft + sides)
        vesicle_closest = get_closest_component_mask(
            (add_mask_vol == vesicle_label), z_start, z_end, y_start, y_end, x_start, x_end, (cx, cy, cz)
        )
        cleft_closest = get_closest_component_mask(
            (add_mask_vol == cleft_label), z_start, z_end, y_start, y_end, x_start, x_end, (cx, cy, cz)
        )
        combined_mask_extra = np.logical_or(vesicle_closest, cleft_closest)
        combined_mask_full = np.logical_or(mask_1_full, np.logical_or(mask_2_full, combined_mask_extra))
    elif segmentation_type == 6:  # Vesicle cloud (closest)
        combined_mask_full = get_closest_component_mask(
            (add_mask_vol == vesicle_label), z_start, z_end, y_start, y_end, x_start, x_end, (cx, cy, cz)
        )
    elif segmentation_type == 7:  # Cleft (closest)
        cleft_closest = get_closest_component_mask(
            ((add_mask_vol == cleft_label)), z_start, z_end, y_start, y_end, x_start, x_end, (cx, cy, cz)
        )
        cleft_closest2 = get_closest_component_mask(
            ((add_mask_vol == cleft_label2)), z_start, z_end, y_start, y_end, x_start, x_end, (cx, cy, cz)
        )
        combined_mask_full =  np.logical_or(cleft_closest,cleft_closest2)
    elif segmentation_type == 8:  # Mitochondria (closest)
        combined_mask_full = get_closest_component_mask(
            (add_mask_vol == mito_label), z_start, z_end, y_start, y_end, x_start, x_end, (cx, cy, cz)
        )
    elif segmentation_type == 10:  #  +Cleft +pre
        cleft_closest = get_closest_component_mask(
            (add_mask_vol == cleft_label), z_start, z_end, y_start, y_end, x_start, x_end, (cx, cy, cz)
        )
        pre_mask_full = mask_1_full if presynapse_side == 1 else mask_2_full

        combined_mask_full = np.logical_or(cleft_closest,pre_mask_full)

    elif segmentation_type == 9:  # cleft+vesicle
        vesicle_closest = get_closest_component_mask(
            (add_mask_vol == vesicle_label), z_start, z_end, y_start, y_end, x_start, x_end, (cx, cy, cz)
        )
        cleft_closest = get_closest_component_mask(
            (add_mask_vol == cleft_label), z_start, z_end, y_start, y_end, x_start, x_end, (cx, cy, cz)
        )

        combined_mask_full = np.logical_or(cleft_closest,vesicle_closest)

    else:
        raise ValueError(f"Unsupported segmentation type: {segmentation_type}")

    # --- Subvolume extraction and processing ---
    sub_raw = raw_vol[z_start:z_end, y_start:y_end, x_start:x_end]
    sub_combined_mask = combined_mask_full[z_start:z_end, y_start:y_end, x_start:x_end]

    # Padding if needed
    pad_z = subvolume_size - sub_raw.shape[0]
    pad_y = subvolume_size - sub_raw.shape[1]
    pad_x = subvolume_size - sub_raw.shape[2]
    if pad_z > 0 or pad_y > 0 or pad_x > 0:
        sub_raw = np.pad(sub_raw, ((0, pad_z), (0, pad_y), (0, pad_x)), mode='constant', constant_values=0)
        sub_combined_mask = np.pad(sub_combined_mask, ((0, pad_z), (0, pad_y), (0, pad_x)), mode='constant', constant_values=False)

    sub_raw = sub_raw[:subvolume_size, :subvolume_size, :subvolume_size]
    sub_combined_mask = sub_combined_mask[:subvolume_size, :subvolume_size, :subvolume_size]

    # Vectorized normalization
    sub_raw = sub_raw.astype(np.float32)
    mins = np.min(sub_raw, axis=(1, 2), keepdims=True)
    maxs = np.max(sub_raw, axis=(1, 2), keepdims=True)
    ranges = np.where(maxs > mins, maxs - mins, 1.0)
    normalized = (sub_raw - mins) / ranges

    # Define gray color (0.5 for grayscale)
    gray_color = 0.6  # For grayscale

    # Vectorized blending with gray color
    raw_rgb = np.repeat(normalized[..., np.newaxis], 3, axis=-1)  # Convert to RGB
    mask_factor = sub_combined_mask[..., np.newaxis]  # Adding an extra dimension to make it (80, 80, 80, 1)

    if alpha < 1:
        blended_part = alpha * gray_color + (1 - alpha) * raw_rgb  # Blend with gray
    else:
        # When alpha is 1, apply gray only to unmasked areas (grayscale), keep raw_rgb in masked areas
        blended_part = gray_color * (1 - mask_factor) + raw_rgb * mask_factor

    # Now, overlaid_image will be computed as follows
    overlaid_image = raw_rgb * mask_factor + (1 - mask_factor) * blended_part

    # Convert to uint8 and transpose dimensions
    overlaid_cube = np.transpose(overlaid_image, (1, 2, 3, 0))  # Keep it grayscale

    return overlaid_cube

class SynapseDataset(Dataset):
    def __init__(self, vol_data_dict: dict, synapse_df: pd.DataFrame, processor,
                 segmentation_type: int, subvol_size: int = 80, num_frames: int = 16,
                 alpha: float = 0.3):
        self.vol_data_dict = vol_data_dict
        self.synapse_df = synapse_df.reset_index(drop=True)
        self.processor = processor
        self.segmentation_type = segmentation_type
        self.subvol_size = subvol_size
        self.num_frames = num_frames
        self.alpha = alpha

    def __len__(self):
        return len(self.synapse_df)

    def __getitem__(self, idx):
        syn_info = self.synapse_df.iloc[idx]
        bbox_name = syn_info['bbox_name']
        raw_vol, seg_vol, add_mask_vol = self.vol_data_dict.get(bbox_name, (None, None, None))
        if raw_vol is None:
            return torch.zeros((self.num_frames, 1, self.subvol_size, self.subvol_size), dtype=torch.float32), syn_info, bbox_name

        central_coord = (int(syn_info['central_coord_1']), int(syn_info['central_coord_2']), int(syn_info['central_coord_3']))
        side1_coord = (int(syn_info['side_1_coord_1']), int(syn_info['side_1_coord_2']), int(syn_info['side_1_coord_3']))
        side2_coord = (int(syn_info['side_2_coord_1']), int(syn_info['side_2_coord_2']), int(syn_info['side_2_coord_3']))

        overlaid_cube = create_segmented_cube(
            raw_vol=raw_vol,
            seg_vol=seg_vol,
            add_mask_vol=add_mask_vol,
            central_coord=central_coord,
            side1_coord=side1_coord,
            side2_coord=side2_coord,
            segmentation_type=self.segmentation_type,
            subvolume_size=self.subvol_size,
            alpha=self.alpha,
            bbox_name=bbox_name,  # Pass bbox_name here
        )
        frames = [overlaid_cube[..., z] for z in range(overlaid_cube.shape[3])]
        if len(frames) < self.num_frames:
            frames += [frames[-1]] * (self.num_frames - len(frames))
        elif len(frames) > self.num_frames:
            indices = np.linspace(0, len(frames)-1, self.num_frames, dtype=int)
            frames = [frames[i] for i in indices]

        inputs = self.processor(frames, return_tensors="pt")
        return inputs["pixel_values"].squeeze(0).float(), syn_info, bbox_name

# Add unique IDs to fixed_samples
fixed_samples = [
    {"id": 1, "bbox_name": "bbox1", "Var1": "non_spine_synapse_004", "slice_number": 25},
    {"id": 2, "bbox_name": "bbox1", "Var1": "non_spine_synapse_006", "slice_number": 40},
    {"id": 4, "bbox_name": "bbox2", "Var1": "explorative_2024-08-28_Cora_Wolter_031", "slice_number": 43},
    {"id": 3, "bbox_name": "bbox2", "Var1": "explorative_2024-08-28_Cora_Wolter_051", "slice_number": 28},
    {"id": 5, "bbox_name": "bbox3", "Var1": "non_spine_synapse_036", "slice_number": 41},
    {"id": 6, "bbox_name": "bbox3", "Var1": "non_spine_synapse_018", "slice_number": 41},
    {"id": 7, "bbox_name": "bbox4", "Var1": "explorative_2024-08-03_Ali_Karimi_023", "slice_number": 28},
    {"id": 8, "bbox_name": "bbox5", "Var1": "non_spine_synapse_033", "slice_number": 48},
    {"id": 9, "bbox_name": "bbox5", "Var1": "non_spine_synapse_045", "slice_number": 40},
    {"id": 10, "bbox_name": "bbox6", "Var1": "spine_synapse_070", "slice_number": 37},
    {"id": 11, "bbox_name": "bbox6", "Var1": "spine_synapse_021", "slice_number": 30},
    {"id": 12, "bbox_name": "bbox7", "Var1": "non_spine_synapse_013", "slice_number": 25},
]


args = parse_args()
args.bbox_name=['bbox1','bbox2','bbox3','bbox4','bbox5','bbox6','bbox7',]
# args.bbox_name=['bbox4']

# Load volumes
vol_data_dict = {}
for bbox_name in args.bbox_name:
    raw_vol, seg_vol, add_mask_vol = load_volumes(
        bbox_name=bbox_name,
        raw_base_dir=args.raw_base_dir,
        seg_base_dir=args.seg_base_dir,
        add_mask_base_dir=args.add_mask_base_dir
    )
    if raw_vol is not None:
        vol_data_dict[bbox_name] = (raw_vol, seg_vol, add_mask_vol)

# Load synapse data
syn_df = pd.concat([
    pd.read_excel(os.path.join(args.excel_file, f"{bbox}.xlsx")).assign(bbox_name=bbox)
    for bbox in args.bbox_name if os.path.exists(os.path.join(args.excel_file, f"{bbox}.xlsx"))
])

# Initialize model
processor = Synapse3DProcessor(size=args.size)



In [None]:

import os
import numpy as np
import torch
import imageio
bbox_name=['bbox1','bbox2','bbox3','bbox4','bbox5','bbox6','bbox7',]

def parse_args():
    parser = argparse.ArgumentParser(description="Synapse Dataset")
    parser.add_argument('--raw_base_dir', type=str, default='raw')
    parser.add_argument('--seg_base_dir', type=str, default='seg')
    parser.add_argument('--add_mask_base_dir', type=str, default='')
    parser.add_argument('--bbox_name', type=str, default=['bbox4'], nargs='+')
    parser.add_argument('--excel_file', type=str, default='')
    parser.add_argument('--csv_output_dir', type=str, default='csv_outputs')
    parser.add_argument('--size', type=tuple, default=(80, 80))
    parser.add_argument('--subvol_size', type=int, default=80)
    parser.add_argument('--num_frames', type=int, default=80)
    parser.add_argument('--save_gifs_dir', type=str, default='gifs')
    parser.add_argument('--alpha', type=float, default=0.7)
    parser.add_argument('--segmentation_type', type=int, default=4, choices=range(0, 13),
                        help='Type of segmentation overlay')
    args, _ = parser.parse_known_args()
    return args
args = parse_args()
args.bbox_name=bbox_name
args.segmentation_type=1
vol_data_dict = {}
for bbox_name in args.bbox_name:
    raw_vol, seg_vol, add_mask_vol = load_volumes(
        bbox_name=bbox_name,
        raw_base_dir=args.raw_base_dir,
        seg_base_dir=args.seg_base_dir,
        add_mask_base_dir=args.add_mask_base_dir
    )
    if raw_vol is not None:
        vol_data_dict[bbox_name] = (raw_vol, seg_vol, add_mask_vol)

syn_df = pd.concat([
    pd.read_excel(os.path.join(args.excel_file, f"{bbox}.xlsx")).assign(bbox_name=bbox)
    for bbox in args.bbox_name if os.path.exists(os.path.join(args.excel_file, f"{bbox}.xlsx"))
])

processor = Synapse3DProcessor(size=args.size)


dataset = SynapseDataset(
    vol_data_dict=vol_data_dict,
    synapse_df=syn_df,
    processor=processor,
    segmentation_type=args.segmentation_type,
    subvol_size=args.subvol_size,
    num_frames=args.num_frames,
    alpha=0.5
)
