<a href="https://colab.research.google.com/github/alim98/MPI/blob/main/Helper_Functions.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import os
import glob
import numpy as np
import imageio.v2 as imageio
import pandas as pd
import matplotlib.pyplot as plt

def get_segmented_images(raw_base_dir, seg_base_dir, bbox_name, excel_file, var1, subvolume_size=80, alpha=0.3):
    """
    Extracts two 80x80 segmented 2D masks (Side 1 and Side 2) around the specified synapse
    using Var1 from the Excel file and overlays both segmentation masks on a single raw image
    with specified transparency.

    Parameters:
        raw_base_dir (str): Directory containing raw image slices.
        seg_base_dir (str): Directory containing segmentation image slices.
        bbox_name (str): Name of the bounding box to process.
        excel_file (str): Path to the Excel file containing synapse data.
        var1 (str): Var1 identifier of the synapse to process.
        subvolume_size (int): Size of the subvolume to extract around the central coordinate.
        alpha (float): Transparency level for the masks (0.0 transparent, 1.0 opaque).

    Returns:
        np.ndarray: RGB numpy array with both masks overlaid on the raw image.
    """
    # Load raw and segmentation volumes
    raw_dir = os.path.join(raw_base_dir, bbox_name)
    seg_dir = os.path.join(seg_base_dir, bbox_name)

    raw_files = sorted(glob.glob(os.path.join(raw_dir, 'slice_*.tif')))
    seg_files = sorted(glob.glob(os.path.join(seg_dir, 'slice_*.tif')))

    if len(raw_files) == 0:
        raise FileNotFoundError(f"No raw .tif files found in directory: {raw_dir}")
    if len(seg_files) == 0:
        raise FileNotFoundError(f"No segmentation .tif files found in directory: {seg_dir}")
    if len(raw_files) != len(seg_files):
        raise ValueError("Number of raw and segmentation files do not match.")

    # Load volumes
    try:
        raw_vol = np.stack([imageio.imread(f) for f in raw_files], axis=0)
        seg_vol = np.stack([imageio.imread(f).astype(np.uint32) for f in seg_files], axis=0)
    except Exception as e:
        raise IOError(f"Error loading image files: {e}")

    # Load synapse data
    try:
        synapse_data = pd.read_excel(excel_file)
    except Exception as e:
        raise IOError(f"Error reading Excel file '{excel_file}': {e}")

    # Find the synapse row based on Var1
    synapse_row = synapse_data[synapse_data['Var1'] == var1]
    if synapse_row.empty:
        raise ValueError(f"No synapse found with Var1: '{var1}' in Excel file '{excel_file}'.")

    syn_info = synapse_row.iloc[0]

    # Extract coordinates
    try:
        central_coord = (
            int(syn_info['central_coord_1']),
            int(syn_info['central_coord_2']),
            int(syn_info['central_coord_3'])
        )
        side1_coord = (
            int(syn_info['side_1_coord_1']),
            int(syn_info['side_1_coord_2']),
            int(syn_info['side_1_coord_3'])
        )
        side2_coord = (
            int(syn_info['side_2_coord_1']),
            int(syn_info['side_2_coord_2']),
            int(syn_info['side_2_coord_3'])
        )
    except KeyError as e:
        raise KeyError(f"Missing column in Excel file: {e}")
    except ValueError as e:
        raise ValueError(f"Invalid coordinate value: {e}")

    # Create segmentation masks based on coordinates
    def create_segment_masks(seg_vol, side1_coord, side2_coord):
        x1, y1, z1 = side1_coord
        x2, y2, z2 = side2_coord

        # Validate coordinates
        if not (0 <= z1 < seg_vol.shape[0] and 0 <= y1 < seg_vol.shape[1] and 0 <= x1 < seg_vol.shape[2]):
            raise ValueError("Side1 coordinates are out of bounds.")
        if not (0 <= z2 < seg_vol.shape[0] and 0 <= y2 < seg_vol.shape[1] and 0 <= x2 < seg_vol.shape[2]):
            raise ValueError("Side2 coordinates are out of bounds.")

        seg_id_1 = seg_vol[z1, y1, x1]
        seg_id_2 = seg_vol[z2, y2, x2]

        if seg_id_1 == 0:
            mask_1 = np.zeros_like(seg_vol, dtype=bool)
        else:
            mask_1 = (seg_vol == seg_id_1)

        if seg_id_2 == 0:
            mask_2 = np.zeros_like(seg_vol, dtype=bool)
        else:
            mask_2 = (seg_vol == seg_id_2)

        return mask_1, mask_2

    try:
        mask_1_full, mask_2_full = create_segment_masks(seg_vol, side1_coord, side2_coord)
    except Exception as e:
        raise ValueError(f"Error creating segmentation masks: {e}")

    # Extract subvolume around central coordinate
    half_size = subvolume_size // 2
    cx, cy, cz = central_coord

    x_start, x_end = max(cx - half_size, 0), min(cx + half_size, raw_vol.shape[2])
    y_start, y_end = max(cy - half_size, 0), min(cy + half_size, raw_vol.shape[1])
    z_slice = cz  # Assuming a single central slice in Z-axis

    if z_slice < 0 or z_slice >= raw_vol.shape[0]:
        raise ValueError("Central Z-coordinate is out of bounds.")

    sub_raw = raw_vol[z_slice, y_start:y_end, x_start:x_end]
    sub_mask_1 = mask_1_full[z_slice, y_start:y_end, x_start:x_end]
    sub_mask_2 = mask_2_full[z_slice, y_start:y_end, x_start:x_end]

    # Ensure the extracted subvolume is exactly subvolume_size x subvolume_size
    if sub_raw.shape[0] != subvolume_size or sub_raw.shape[1] != subvolume_size:
        pad_y = subvolume_size - sub_raw.shape[0]
        pad_x = subvolume_size - sub_raw.shape[1]
        sub_raw = np.pad(sub_raw, ((0, pad_y), (0, pad_x)), mode='constant', constant_values=0)
        sub_mask_1 = np.pad(sub_mask_1, ((0, pad_y), (0, pad_x)), mode='constant', constant_values=False)
        sub_mask_2 = np.pad(sub_mask_2, ((0, pad_y), (0, pad_x)), mode='constant', constant_values=False)

    # Normalize raw image to [0, 1] for blending
    raw_normalized = sub_raw.astype(np.float32)
    raw_min, raw_max = raw_normalized.min(), raw_normalized.max()
    if raw_max != raw_min:
        raw_normalized = (raw_normalized - raw_min) / (raw_max - raw_min)
    else:
        raw_normalized = raw_normalized - raw_min  # All zeros

    raw_rgb = np.stack([raw_normalized]*3, axis=-1)  # Convert to RGB

    # Create colored masks
    side1_color = np.array([1, 0, 0])  # Red
    side2_color = np.array([0, 0, 1])  # Blue

    mask1_rgb = np.zeros_like(raw_rgb)
    mask1_rgb[sub_mask_1] = side1_color

    mask2_rgb = np.zeros_like(raw_rgb)
    mask2_rgb[sub_mask_2] = side2_color

    # Blend both masks with raw image
    overlaid_image = np.clip((1 - alpha) * raw_rgb + alpha * (mask1_rgb + mask2_rgb), 0, 1)

    # Ensure that overlapping masks don't exceed the alpha blending
    overlaid_image = np.clip(overlaid_image, 0, 1)

    # Convert back to uint8
    overlaid_image = (overlaid_image * 255).astype(np.uint8)

    return overlaid_image

def create_segmented_cube(raw_base_dir, seg_base_dir, bbox_name, excel_file, var1, subvolume_size=80, alpha=0.3):
    """
    Constructs an 80x80x80 segmented 3D cube around the specified synapse using Var1 from the Excel file
    and overlays both segmentation masks on the raw data with specified transparency for each slice.

    Parameters:
        raw_base_dir (str): Directory containing raw image slices.
        seg_base_dir (str): Directory containing segmentation image slices.
        bbox_name (str): Name of the bounding box to process.
        excel_file (str): Path to the Excel file containing synapse data.
        var1 (str): Var1 identifier of the synapse to process.
        subvolume_size (int): Size of the subvolume to extract around the central coordinate.
        alpha (float): Transparency level for the masks (0.0 transparent, 1.0 opaque).

    Returns:
        np.ndarray: RGB numpy array with both masks overlaid on each slice of the raw cube.
                   Shape: (height, width, channels, depth)
    """
    # Load raw and segmentation volumes
    raw_dir = os.path.join(raw_base_dir, bbox_name)
    seg_dir = os.path.join(seg_base_dir, bbox_name)

    raw_files = sorted(glob.glob(os.path.join(raw_dir, 'slice_*.tif')))
    seg_files = sorted(glob.glob(os.path.join(seg_dir, 'slice_*.tif')))

    if len(raw_files) == 0:
        raise FileNotFoundError(f"No raw .tif files found in directory: {raw_dir}")
    if len(seg_files) == 0:
        raise FileNotFoundError(f"No segmentation .tif files found in directory: {seg_dir}")
    if len(raw_files) != len(seg_files):
        raise ValueError("Number of raw and segmentation files do not match.")

    # Load volumes
    try:
        raw_vol = np.stack([imageio.imread(f) for f in raw_files], axis=0)
        seg_vol = np.stack([imageio.imread(f).astype(np.uint32) for f in seg_files], axis=0)
    except Exception as e:
        raise IOError(f"Error loading image files: {e}")

    # Load synapse data
    try:
        synapse_data = pd.read_excel(excel_file)
    except Exception as e:
        raise IOError(f"Error reading Excel file '{excel_file}': {e}")

    # Find the synapse row based on Var1
    synapse_row = synapse_data[synapse_data['Var1'] == var1]
    if synapse_row.empty:
        raise ValueError(f"No synapse found with Var1: '{var1}' in Excel file '{excel_file}'.")

    syn_info = synapse_row.iloc[0]

    # Extract coordinates
    try:
        central_coord = (
            int(syn_info['central_coord_1']),
            int(syn_info['central_coord_2']),
            int(syn_info['central_coord_3'])
        )
        side1_coord = (
            int(syn_info['side_1_coord_1']),
            int(syn_info['side_1_coord_2']),
            int(syn_info['side_1_coord_3'])
        )
        side2_coord = (
            int(syn_info['side_2_coord_1']),
            int(syn_info['side_2_coord_2']),
            int(syn_info['side_2_coord_3'])
        )
    except KeyError as e:
        raise KeyError(f"Missing column in Excel file: {e}")
    except ValueError as e:
        raise ValueError(f"Invalid coordinate value: {e}")

    # Create segmentation masks based on coordinates
    def create_segment_masks(seg_vol, side1_coord, side2_coord):
        x1, y1, z1 = side1_coord
        x2, y2, z2 = side2_coord

        # Validate coordinates
        if not (0 <= z1 < seg_vol.shape[0] and 0 <= y1 < seg_vol.shape[1] and 0 <= x1 < seg_vol.shape[2]):
            raise ValueError("Side1 coordinates are out of bounds.")
        if not (0 <= z2 < seg_vol.shape[0] and 0 <= y2 < seg_vol.shape[1] and 0 <= x2 < seg_vol.shape[2]):
            raise ValueError("Side2 coordinates are out of bounds.")

        seg_id_1 = seg_vol[z1, y1, x1]
        seg_id_2 = seg_vol[z2, y2, x2]

        if seg_id_1 == 0:
            mask_1 = np.zeros_like(seg_vol, dtype=bool)
        else:
            mask_1 = (seg_vol == seg_id_1)

        if seg_id_2 == 0:
            mask_2 = np.zeros_like(seg_vol, dtype=bool)
        else:
            mask_2 = (seg_vol == seg_id_2)

        return mask_1, mask_2

    try:
        mask_1_full, mask_2_full = create_segment_masks(seg_vol, side1_coord, side2_coord)
    except Exception as e:
        raise ValueError(f"Error creating segmentation masks: {e}")

    # Define the range for the cube
    half_size = subvolume_size // 2
    cx, cy, cz = central_coord

    x_start, x_end = max(cx - half_size, 0), min(cx + half_size, raw_vol.shape[2])
    y_start, y_end = max(cy - half_size, 0), min(cy + half_size, raw_vol.shape[1])
    z_start, z_end = max(cz - half_size, 0), min(cz + half_size, raw_vol.shape[0])

    sub_raw = raw_vol[z_start:z_end, y_start:y_end, x_start:x_end]
    sub_mask_1 = mask_1_full[z_start:z_end, y_start:y_end, x_start:x_end]
    sub_mask_2 = mask_2_full[z_start:z_end, y_start:y_end, x_start:x_end]

    # If the extracted subvolume is smaller than desired, pad it
    pad_z = subvolume_size - sub_raw.shape[0]
    pad_y = subvolume_size - sub_raw.shape[1]
    pad_x = subvolume_size - sub_raw.shape[2]

    if pad_z > 0 or pad_y > 0 or pad_x > 0:
        sub_raw = np.pad(sub_raw, ((0, pad_z), (0, pad_y), (0, pad_x)), mode='constant', constant_values=0)
        sub_mask_1 = np.pad(sub_mask_1, ((0, pad_z), (0, pad_y), (0, pad_x)), mode='constant', constant_values=False)
        sub_mask_2 = np.pad(sub_mask_2, ((0, pad_z), (0, pad_y), (0, pad_x)), mode='constant', constant_values=False)

    # Ensure the subvolume is exactly subvolume_size x subvolume_size x subvolume_size
    sub_raw = sub_raw[:subvolume_size, :subvolume_size, :subvolume_size]
    sub_mask_1 = sub_mask_1[:subvolume_size, :subvolume_size, :subvolume_size]
    sub_mask_2 = sub_mask_2[:subvolume_size, :subvolume_size, :subvolume_size]

    # Initialize cube for overlaid images
    overlaid_cube = np.zeros((subvolume_size, subvolume_size, 3, subvolume_size), dtype=np.uint8)

    # Define colors
    side1_color = np.array([1, 0, 0])  # Red
    side2_color = np.array([0, 0, 1])  # Blue

    # Process each slice
    for z in range(subvolume_size):
        # Normalize raw image to [0, 1] for blending
        raw_normalized = sub_raw[z].astype(np.float32)
        raw_min, raw_max = raw_normalized.min(), raw_normalized.max()
        if raw_max != raw_min:
            raw_normalized = (raw_normalized - raw_min) / (raw_max - raw_min)
        else:
            raw_normalized = raw_normalized - raw_min  # All zeros

        raw_rgb = np.stack([raw_normalized]*3, axis=-1)  # Convert to RGB

        # Create colored masks
        mask1_rgb = np.zeros_like(raw_rgb)
        mask1_rgb[sub_mask_1[z]] = side1_color

        mask2_rgb = np.zeros_like(raw_rgb)
        mask2_rgb[sub_mask_2[z]] = side2_color

        # Blend both masks with raw image
        overlaid_image = np.clip((1 - alpha) * raw_rgb + alpha * (mask1_rgb + mask2_rgb), 0, 1)

        # Ensure that overlapping masks don't exceed the alpha blending
        overlaid_image = np.clip(overlaid_image, 0, 1)

        # Convert back to uint8
        overlaid_image = (overlaid_image * 255).astype(np.uint8)

        # Assign to cube (axes: height, width, channels, depth)
        overlaid_cube[:, :, :, z] = overlaid_image

    return overlaid_cube