In [1]:
from neuprint import Client, skeleton
from neuprint import fetch_synapses, NeuronCriteria as NC, SynapseCriteria as SC
import numpy as np
from matplotlib.colors import ListedColormap
import importlib
import random
import pickle
import os
import seaborn as sns
import matplotlib.pyplot as plt
from sklearn.preprocessing import StandardScaler
import time
import tensorstore as ts
from skimage import measure, morphology
import umap
from scipy import stats
from functools import partial
from concurrent.futures import ThreadPoolExecutor, as_completed
import json
import nibabel as nib  # for NIFTI format
import pandas as pd
from skimage import filters
from skimage import io
from matplotlib.animation import FuncAnimation
from IPython.display import HTML
import matplotlib.colors as mcolors
import glob
import re
from scipy.ndimage import gaussian_filter
from mpl_toolkits.mplot3d.art3d import Poly3DCollection
from matplotlib.colors import LightSource
from scipy import ndimage
import plotly.graph_objects as go
import plotly.io as pio
import plotly
from time import time
import ast
from skimage.morphology import binary_closing
import napari
from scipy.ndimage import distance_transform_edt
from scipy.spatial.distance import cdist

np.set_printoptions(precision=5, suppress=True)  # suppress scientific float notation

def import_module(module_name, file_path):
    spec = importlib.util.spec_from_file_location(module_name, file_path)
    module = importlib.util.module_from_spec(spec)
    spec.loader.exec_module(module)
    return module

home_dir = '/Users/aatmikmallya/Desktop/research/fly/segmentation'
utils = import_module('utils', f'{home_dir}/util_files/utils.py')
config = import_module('config', f'{home_dir}/util_files/config.py')
voxel_utils = import_module('voxel_utils', f'{home_dir}/util_files/voxel_utils.py')
segmentation = import_module('segmentation', f'{home_dir}/util_files/segmentation.py')
analysis = import_module('analysis', f'{home_dir}/util_files/analysis.py')



In [3]:
def create_distance_transform(points, grid_size=100):
    """
    Create a signed distance transform using tanh activation.
    
    Args:
        points: nx3 numpy array of coordinates, each in range 0-99
        grid_size: size of cubic output array (default 100)
        
    Returns:
        grid_size x grid_size x grid_size numpy array with tanh distance transform
    """
    # Create meshgrid of all coordinates
    x, y, z = np.meshgrid(
        np.arange(grid_size),
        np.arange(grid_size),
        np.arange(grid_size),
        indexing='ij'
    )
    
    # Reshape coordinates into nx3 array
    coords = np.stack([x.ravel(), y.ravel(), z.ravel()], axis=1)
    
    # Calculate distances from each voxel to each point
    distances = cdist(coords, points)
    
    # Get distance to nearest point for each voxel
    min_distances = np.min(distances, axis=1)
    
    # Apply tanh activation and reshape to grid
    transform = np.tanh(min_distances).reshape(grid_size, grid_size, grid_size)
    
    
    return transform

In [13]:
def create_cube_distance_transform(points, cube_size, grid_size=100):
    """
    Create a signed distance transform using tanh activation with cubes centered on points.
    
    Args:
        points: nx3 numpy array of coordinates, each in range 0-99
        cube_size: length of cube sides (must be odd number)
        grid_size: size of cubic output array (default 100)
        
    Returns:
        grid_size x grid_size x grid_size numpy array with tanh distance transform
    """
    if cube_size % 2 == 0:
        raise ValueError("cube_size must be odd")
    
    half_size = cube_size // 2
    
    # Initialize binary grid
    grid = np.zeros((grid_size, grid_size, grid_size), dtype=bool)
    
    # Convert points to integer indices and create slice objects for cube placement
    points_idx = points.astype(int)
    cube_ranges = []
    for point in points_idx:
        # Calculate cube bounds with clamping
        ranges = []
        for coord in point:
            start = max(0, coord - half_size)
            end = min(grid_size, coord + half_size + 1)
            ranges.append(slice(start, end))
        cube_ranges.append(ranges)
    
    # Place cubes in grid using broadcasting
    for ranges in cube_ranges:
        grid[ranges[0], ranges[1], ranges[2]] = True
    
    # Calculate distance transforms for both inside and outside
    # Outside distance transform
    pos_dist = distance_transform_edt(~grid)
    
    # Inside distance transform
    neg_dist = distance_transform_edt(grid)
    
    # Combine distances: positive outside cubes, negative inside
    distances = pos_dist - neg_dist * (grid)
    
    # Apply tanh activation
    transform = np.tanh(distances)

    ####### SCALE FACTOR: Tanh(d/lambda)
    
    return transform, grid

In [15]:
centerpoint_dir = 'training/labeled_centerpoints'
sdt_dir = 'training/labeled_sdt'

for file in os.listdir(centerpoint_dir):
    if file.endswith('.npy') and file not in os.listdir(sdt_dir):
        print(file)
        points = np.load(f'{centerpoint_dir}/{file}')
        sdt, grid = create_cube_distance_transform(points, 3)
        sdt = np.swapaxes(sdt, 1, 2)
        np.save(f'{sdt_dir}/{file}', sdt)
    

2974912.npy


In [18]:
df = pd.read_csv('training/swc_LC.csv')
for file in os.listdir(centerpoint_dir):
    if file.endswith('.npy') and file not in os.listdir('training/subvols/image/'):
        print(file)
        name = int(file.split('.')[0])
        row = df.loc[name]
    
        x,y,z = row[['x','y','z']]
        bodyId = row.bodyId
        
        size = 100  # Size of the subvolume in voxels
        half_size = size // 2
        
        x_min = x - half_size
        x_max = x + half_size
        y_min = y - half_size
        y_max = y + half_size
        z_min = z - half_size
        z_max = z + half_size
        
        # Create the bounding box
        bbox_xyz = np.array([
            [x_min, y_min, z_min],
            [x_max, y_max, z_max]
        ])
        
        # Find axis which best aligns with MTs
        axis = utils.find_closest_direction(row.theta, row.phi)[0]
        bbox_zyx = np.flip(bbox_xyz, axis=1).astype(int)
        
        gray_subvol, mask, mito_mask = utils.get_slice_from_box(bbox_zyx, bodyId)
        
        np.save(f'training/subvols/image/{name}.npy', gray_subvol)
        np.save(f'training/subvols/cell_mask/{name}.npy', mask)
        np.save(f'training/subvols/mito_mask/{name}.npy', mito_mask)
        

In [29]:
image = np.load('training/subvols/image/2045541.npy') / 255

In [31]:

utils.create_3d_slices_animation([sdt, image],['', ''], 'y')

In [None]:
import os
import numpy as np
from pathlib import Path

def create_instance_masks(centerpoints_dir, output_dir, shape=(100, 100, 100)):
    """
    Create directory of instance masks from centerpoints
    
    Args:
        centerpoints_dir (str): Directory containing .npy files with centerpoints
        output_dir (str): Directory where mask files will be saved
        shape (tuple): Shape of the output mask (default: (100, 100, 100))
    """
    # Create output directory
    os.makedirs(output_dir, exist_ok=True)
    
    # Get all .npy files in the input directory
    input_files = list(Path(centerpoints_dir).glob('*.npy'))
    
    for input_file in input_files:
        # Load centerpoints
        points = np.load(input_file)
        
        # Create instance mask
        mask = np.zeros(shape, dtype=np.uint16)
        
        # Group points by z-coordinate
        z_coords = np.unique(points[:, 2])
        instance_id = 1
        
        for z in z_coords:
            slice_points = points[points[:, 2] == z]
            if len(slice_points) == 0:
                continue
            
            # Create mask for this slice
            slice_mask = np.zeros(shape[1:], dtype=bool)
            for point in slice_points:
                y, x = int(point[0]), int(point[1])
                if 0 <= y < shape[1] and 0 <= x < shape[2]:
                    # Create circle with radius 3 (for 6-pixel diameter)
                    y_indices, x_indices = np.ogrid[-3:4, -3:4]
                    circle = x_indices**2 + y_indices**2 <= 3**2
                    
                    # Add circle to mask
                    y_start = max(0, y-3)
                    y_end = min(shape[1], y+4)
                    x_start = max(0, x-3)
                    x_end = min(shape[2], x+4)
                    
                    circle_y_start = max(0, -(y-3))
                    circle_y_end = circle.shape[0] - max(0, (y+4)-shape[1])
                    circle_x_start = max(0, -(x-3))
                    circle_x_end = circle.shape[1] - max(0, (x+4)-shape[2])
                    
                    slice_mask[y_start:y_end, x_start:x_end] |= circle[circle_y_start:circle_y_end, 
                                                                      circle_x_start:circle_x_end]
            
            # Add to instance mask
            mask[int(z)][slice_mask] = instance_id
            instance_id += 1

        mask = np.rollaxis(mask, 0, 2)
        
        # Save mask using the same filename as input but in output directory
        output_file = os.path.join(output_dir, input_file.name)
        np.save(output_file, mask)
    
    return output_dir
    
# First create the instance masks
create_instance_masks(
    centerpoints_dir='training/labeled_centerpoints',
    output_dir='training/instance_masks'
)


In [165]:
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
from pathlib import Path
from scipy.ndimage import distance_transform_edt
from skimage.segmentation import find_boundaries


class EMDataset(Dataset):
    def __init__(self, root_dir='training', split='train', transform=None):
        """
        Dataset for EM microtubule segmentation
        
        Args:
            root_dir (str): Path to training directory containing subvols/image and instance_masks
            split (str): Either 'train' or 'val'
            transform (callable, optional): Optional transform to be applied on images
        """
        self.root_dir = Path(root_dir)
        self.transform = transform
        
        # Get all image and mask paths
        self.image_paths = sorted(list((self.root_dir / 'subvols' / 'image').glob('*.npy')))
        self.mask_paths = sorted(list((self.root_dir / 'instance_masks').glob('*.npy')))
        
        # Verify matching files
        assert len(self.image_paths) == len(self.mask_paths), \
            f"Found {len(self.image_paths)} images but {len(self.mask_paths)} masks"
        
        # Split into train/val
        split_idx = int(0.9 * len(self.image_paths))
        if split == 'train':
            self.image_paths = self.image_paths[:split_idx]
            self.mask_paths = self.mask_paths[:split_idx]
        else:
            self.image_paths = self.image_paths[split_idx:]
            self.mask_paths = self.mask_paths[split_idx:]
            
        print(f"{split} set: Found {len(self.image_paths)} volumes")
        
    def create_sam_label(self, instance_mask):
        """
        Convert instance mask to SAM's 4-channel format, accounting for Y-axis alignment
        """
        # Channel 1: Foreground mask
        foreground = (instance_mask > 0).astype(np.float32)
        
        # Initialize output channels
        center_distances = np.zeros_like(foreground, dtype=np.float32)
        boundary_distances = np.zeros_like(foreground, dtype=np.float32)
        boundaries = np.zeros_like(foreground, dtype=np.float32)
        
        # Process each Y slice
        for y in range(instance_mask.shape[1]):
            # Get XZ slice
            slice_mask = instance_mask[:, y, :]
            
            # Process each instance in this slice
            for instance_id in np.unique(slice_mask):
                if instance_id == 0:  # Skip background
                    continue
                    
                # Create binary mask for this instance in this slice
                instance_binary = (slice_mask == instance_id)
                
                # Compute 2D distance transform for this slice
                dist = distance_transform_edt(instance_binary)
                if dist.max() > 0:
                    dist = dist / dist.max()
                
                # Update distance maps for this slice
                center_distances[:, y, :] = np.maximum(center_distances[:, y, :], dist)
                boundary_distances[:, y, :] = np.maximum(boundary_distances[:, y, :], 1 - dist)
                
                # Compute boundaries for this slice
                slice_boundaries = find_boundaries(instance_binary, mode='thick')
                boundaries[:, y, :] = np.maximum(boundaries[:, y, :], slice_boundaries)
        
        # Stack all channels
        return np.stack([
            foreground,
            center_distances,
            boundary_distances,
            boundaries
        ])
    
    def __len__(self):
        return len(self.image_paths)
    
    def __getitem__(self, idx):
        # Load volume and mask
        volume = np.load(self.image_paths[idx])
        mask = np.load(self.mask_paths[idx])
        
        # Normalize volume to [0, 1]
        volume = volume.astype(np.float32) / 255.0
        
        # Create SAM label format
        label = self.create_sam_label(mask)
        
        # Add channel dimension to volume if needed
        if volume.ndim == 3:
            volume = volume[None]
            
        return volume, label, self.image_paths[idx].stem

def create_dataloaders(root_dir='training', batch_size=1, num_workers=0):
    """Create train and validation dataloaders"""
    train_dataset = EMDataset(root_dir, split='train')
    val_dataset = EMDataset(root_dir, split='val')
    
    train_loader = DataLoader(
        train_dataset, 
        batch_size=batch_size,
        shuffle=True,
        num_workers=num_workers,
        pin_memory=True
    )
    
    val_loader = DataLoader(
        val_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers,
        pin_memory=True
    )
    
    return train_loader, val_loader

def visualize_sample(volume, label, save_path=None):
    """Visualize a sample and its labels"""
    fig, axes = plt.subplots(2, 3, figsize=(15, 10))
    
    # Take middle slice along Y axis
    y_slice = volume.shape[1] // 2
    
    # Plot original volume
    axes[0, 0].imshow(volume[0, y_slice, :, :], cmap='gray')
    axes[0, 0].set_title('Original Volume')
    
    # Plot each label channel
    channel_names = ['Foreground', 'Center Distance', 
                    'Boundary Distance', 'Boundaries']
    
    for i, (ax, name) in enumerate(zip(axes.flat[1:], channel_names)):
        ax.imshow(label[i, :, y_slice, :], cmap='viridis')
        ax.set_title(name)
    
    # Remove empty subplot
    axes[1, 2].remove()
    
    # Add overall title
    plt.suptitle(f'Y-axis Slice {y_slice}', y=1.02)
    
    for ax in axes.flat:
        if ax.axes is not None:
            ax.axis('off')
    
    plt.tight_layout()
    
    if save_path:
        plt.savefig(save_path, bbox_inches='tight', dpi=150)
        plt.close()
    else:
        plt.show()



def save_processed_dataset(root_dir='training', output_dir='processed_dataset'):
    """
    Process and save the entire dataset in SAM format
    """
    # Create output directories
    output_dir = Path(output_dir)
    for split in ['train', 'val']:
        (output_dir / split / 'volumes').mkdir(parents=True, exist_ok=True)
        (output_dir / split / 'labels').mkdir(parents=True, exist_ok=True)
        (output_dir / split / 'visualizations').mkdir(parents=True, exist_ok=True)
    
    # Process each split
    for split in ['train', 'val']:
        print(f"\nProcessing {split} split...")
        dataset = EMDataset(root_dir=root_dir, split=split)
        
        for idx in range(len(dataset)):
            volume, label, name = dataset[idx]
            
            # Save volume and label
            np.save(output_dir / split / 'volumes' / f'{name}.npy', volume)
            np.save(output_dir / split / 'labels' / f'{name}.npy', label)
            
            # Create and save visualization
            visualize_sample(
                volume, label,
                save_path=output_dir / split / 'visualizations' / f'{name}.png'
            )
            
            if idx % 10 == 0:
                print(f"Processed {idx + 1}/{len(dataset)} samples")
        
        print(f"Completed {split} split: {len(dataset)} samples")
    
    # Save dataset metadata
    metadata = {
        'n_train': len(EMDataset(root_dir, 'train')),
        'n_val': len(EMDataset(root_dir, 'val')),
        'volume_shape': volume.shape,
        'label_shape': label.shape,
    }
    
    np.save(output_dir / 'metadata.npy', metadata)
    print(f"\nDataset saved to {output_dir}")
    print(f"Metadata: {metadata}")

    # Verify a few random samples
    verify_processed_dataset(output_dir)

def verify_processed_dataset(processed_dir):
    """
    Verify the saved dataset by loading a few random samples
    """
    processed_dir = Path(processed_dir)
    metadata = np.load(processed_dir / 'metadata.npy', allow_pickle=True).item()
    
    print("\nVerifying processed dataset...")
    print(f"Expected shapes - Volume: {metadata['volume_shape']}, Label: {metadata['label_shape']}")
    
    for split in ['train', 'val']:
        volume_dir = processed_dir / split / 'volumes'
        label_dir = processed_dir / split / 'labels'
        
        # Get all files
        volume_files = list(volume_dir.glob('*.npy'))
        
        # Check a random sample
        if volume_files:
            sample_file = np.random.choice(volume_files)
            name = sample_file.stem
            
            # Load volume and label
            volume = np.load(volume_dir / f'{name}.npy')
            label = np.load(label_dir / f'{name}.npy')
            
            print(f"\nChecking random {split} sample: {name}")
            print(f"Volume shape: {volume.shape}")
            print(f"Label shape: {label.shape}")
            print(f"Value ranges - Volume: [{volume.min():.3f}, {volume.max():.3f}]")
            print(f"             Label: [{label.min():.3f}, {label.max():.3f}]")

if __name__ == "__main__":
    # Save the processed dataset
    save_processed_dataset(root_dir='training', output_dir='processed_dataset')

# Example usage
# if __name__ == "__main__":
#     # Create dataloaders
#     train_loader, val_loader = create_dataloaders()
    
#     # Get a sample batch
#     volume_batch, label_batch = next(iter(train_loader))
    
#     # Print shapes
#     print(f"Volume batch shape: {volume_batch.shape}")  # Should be [B, 1, 100, 100, 100]
#     print(f"Label batch shape: {label_batch.shape}")    # Should be [B, 4, 100, 100, 100]
    
#     # Visualize a sample
#     visualize_sample(volume_batch[0], label_batch[0])


Processing train split...
train set: Found 2 volumes
Processed 1/2 samples
Completed train split: 2 samples

Processing val split...
val set: Found 1 volumes
Processed 1/1 samples
Completed val split: 1 samples
train set: Found 2 volumes
val set: Found 1 volumes

Dataset saved to processed_dataset
Metadata: {'n_train': 2, 'n_val': 1, 'volume_shape': (1, 100, 100, 100), 'label_shape': (4, 100, 100, 100)}

Verifying processed dataset...
Expected shapes - Volume: (1, 100, 100, 100), Label: (4, 100, 100, 100)

Checking random train sample: 1216641
Volume shape: (1, 100, 100, 100)
Label shape: (4, 100, 100, 100)
Value ranges - Volume: [0.000, 1.000]
             Label: [0.000, 1.000]

Checking random val sample: 6102185
Volume shape: (1, 100, 100, 100)
Label shape: (4, 100, 100, 100)
Value ranges - Volume: [0.000, 1.000]
             Label: [0.000, 1.000]


In [15]:

points = np.load('training/labeled_centerpoints/1216641.npy')
image = np.load('training/subvols/image/1216641.npy') / 255

In [11]:
# sdt = create_distance_transform(points)
sdt, grid = create_cube_distance_transform(points, 3)


In [20]:
utils.create_3d_slices_animation([sdt, np.swapaxes(sdt, 1, 2), image],['', '', ''], 'y')

In [107]:
np.save('training/labeled_sdt/2045541_y.npy', sdt)