<a href="https://colab.research.google.com/github/alim98/ControlVAE/blob/main/MPI_Data_Preprocessing.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [4]:
!wget -O downloaded_file.zip "https://drive.usercontent.google.com/download?id=1iHPBdBOPEagvPTHZmrN__LD49emXwReY&export=download&authuser=0&confirm=t&uuid=631d60dd-569c-4bb1-a9e8-d681f0ed3d43&at=APvzH3r4me8x_LwP3n8O7lgPo8oK%3A1733988188000"

!unzip downloaded_file.zip

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
  inflating: seg/bbox3/slice_301.tif  
  inflating: seg/bbox3/slice_329.tif  
  inflating: seg/bbox3/slice_498.tif  
  inflating: seg/bbox3/slice_117.tif  
  inflating: seg/bbox3/slice_103.tif  
  inflating: seg/bbox3/slice_063.tif  
  inflating: seg/bbox3/slice_077.tif  
  inflating: seg/bbox3/slice_088.tif  
  inflating: seg/bbox3/slice_261.tif  
  inflating: seg/bbox3/slice_507.tif  
  inflating: seg/bbox3/slice_513.tif  
  inflating: seg/bbox3/slice_275.tif  
  inflating: seg/bbox3/slice_249.tif  
  inflating: seg/bbox3/slice_248.tif  
  inflating: seg/bbox3/slice_512.tif  
  inflating: seg/bbox3/slice_274.tif  
  inflating: seg/bbox3/slice_260.tif  
  inflating: seg/bbox3/slice_506.tif  
  inflating: seg/bbox3/slice_089.tif  
  inflating: seg/bbox3/slice_076.tif  
  inflating: seg/bbox3/slice_062.tif  
  inflating: seg/bbox3/slice_102.tif  
  inflating: seg/bbox3/slice_116.tif  
  inflating: seg/bbox3/slice_499.tif  

In [5]:
import os
import glob
import imageio
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from matplotlib.patches import Patch
from IPython.display import HTML
from google.colab import files

class VolumeProcessor:
    def __init__(self, raw_base_dir, seg_base_dir, bbox_name, excel_file):
        self.raw_base_dir = raw_base_dir
        self.seg_base_dir = seg_base_dir
        self.bbox_name = bbox_name
        self.excel_file = excel_file
        self.raw_vol = None
        self.seg_vol = None
        self.synapse_data = None

    def load_volume_data(self, max_slices=None):
        raw_dir = os.path.join(self.raw_base_dir, self.bbox_name)
        seg_dir = os.path.join(self.seg_base_dir, self.bbox_name)

        raw_files = sorted(glob.glob(os.path.join(raw_dir, 'slice_*.tif')))
        seg_files = sorted(glob.glob(os.path.join(seg_dir, 'slice_*.tif')))

        if max_slices:
            raw_files = raw_files[:max_slices]
            seg_files = seg_files[:max_slices]

        assert len(raw_files) == len(seg_files), "Number of raw and segmentation files do not match."

        self.raw_vol = np.stack([imageio.imread(f) for f in raw_files], axis=0)
        self.seg_vol = np.stack([imageio.imread(f).astype(np.uint32) for f in seg_files], axis=0)

    def load_synapse_data(self):
        self.synapse_data = pd.read_excel(self.excel_file)

    def create_segment_masks(self, side1_coord, side2_coord):
        x1, y1, z1 = map(int, side1_coord)
        x2, y2, z2 = map(int, side2_coord)

        seg_id_1 = self.seg_vol[z1, y1, x1]
        seg_id_2 = self.seg_vol[z2, y2, x2]

        mask_1 = (self.seg_vol == seg_id_1) if seg_id_1 != 0 else np.zeros_like(self.seg_vol, dtype=bool)
        mask_2 = (self.seg_vol == seg_id_2) if seg_id_2 != 0 else np.zeros_like(self.seg_vol, dtype=bool)
        return mask_1, mask_2

    def process_synapses(self, start_idx, end_idx, subvolume_size=80):
        half_size = subvolume_size // 2

        for idx in range(start_idx, end_idx):
            syn_info = self.synapse_data.iloc[idx]

            central_coord = (
                int(syn_info['central_coord_1']),
                int(syn_info['central_coord_2']),
                int(syn_info['central_coord_3'])
            )

            side1_coord = (
                int(syn_info['side_1_coord_1']),
                int(syn_info['side_1_coord_2']),
                int(syn_info['side_1_coord_3'])
            )

            side2_coord = (
                int(syn_info['side_2_coord_1']),
                int(syn_info['side_2_coord_2']),
                int(syn_info['side_2_coord_3'])
            )

            mask_1_full, mask_2_full = self.create_segment_masks(side1_coord, side2_coord)

            cx, cy, cz = central_coord

            x_start, x_end = max(cx - half_size, 0), min(cx + half_size, self.raw_vol.shape[2])
            y_start, y_end = max(cy - half_size, 0), min(cy + half_size, self.raw_vol.shape[1])
            z_start, z_end = max(cz - half_size, 0), min(cz + half_size, self.raw_vol.shape[0])

            sub_raw = self.raw_vol[z_start:z_end, y_start:y_end, x_start:x_end]
            sub_mask_1 = mask_1_full[z_start:z_end, y_start:y_end, x_start:x_end]
            sub_mask_2 = mask_2_full[z_start:z_end, y_start:y_end, x_start:x_end]

            self._generate_gif(sub_raw, sub_mask_1, sub_mask_2, syn_info, z_start)

    def _generate_gif(self, sub_raw, sub_mask_1, sub_mask_2, syn_info, z_start):
        fig, ax = plt.subplots(figsize=(5, 5))
        im = ax.imshow(sub_raw[0], cmap='gray', interpolation='nearest')
        ax.axis('off')

        legend_patches = [
            Patch(facecolor='red', edgecolor='red', label='Side 1'),
            Patch(facecolor='blue', edgecolor='blue', label='Side 2')
        ]
        ax.legend(handles=legend_patches, loc='upper right', frameon=True)

        def update(frame):
            overlay = np.stack([sub_raw[frame]] * 3, axis=-1).astype(float)

            overlay[sub_mask_1[frame]] = overlay[sub_mask_1[frame]] * 0.7 + np.array([255, 0, 0]) * 0.3
            overlay[sub_mask_2[frame]] = overlay[sub_mask_2[frame]] * 0.7 + np.array([0, 0, 255]) * 0.3

            overlay = np.clip(overlay, 0, 255).astype(np.uint8)
            ax.set_title(f"{syn_info['Var1']} - Z: {z_start + frame}")
            im.set_data(overlay)
            return [im]

        ani = animation.FuncAnimation(fig, update, frames=sub_raw.shape[0], interval=100, blit=True)

        gif_filename = f"{syn_info['Var1']}.gif"
        ani.save(gif_filename, writer='pillow', fps=10)
        plt.close(fig)
        print(f"Saved: {gif_filename}")
        # files.download(gif_filename)

raw_base_dir = '/content/raw'
seg_base_dir = '/content/seg'
bbox_name = 'bbox1'
excel_file = f'{bbox_name}.xlsx'

processor = VolumeProcessor(raw_base_dir, seg_base_dir, bbox_name, excel_file)
processor.load_volume_data()
processor.load_synapse_data()
processor.process_synapses(start_idx=15, end_idx=26)


  self.raw_vol = np.stack([imageio.imread(f) for f in raw_files], axis=0)
  self.seg_vol = np.stack([imageio.imread(f).astype(np.uint32) for f in seg_files], axis=0)


Saved: non_spine_synapse_047.gif
Saved: non_spine_synapse_046.gif
Saved: non_spine_synapse_045.gif
Saved: non_spine_synapse_044.gif
Saved: non_spine_synapse_040.gif
Saved: non_spine_synapse_038.gif
Saved: non_spine_synapse_037.gif
Saved: non_spine_synapse_036.gif
Saved: non_spine_synapse_035.gif
Saved: non_spine_synapse_034.gif
Saved: non_spine_synapse_033.gif
