In [1]:
import numpy as np
import torch
import cv2
import os
from sam2.build_sam import build_sam2
from sam2.sam2_image_predictor import SAM2ImagePredictor
import zarr
from pathlib import Path
import numpy as np

BASE_PATH = Path("/scratch/")

mask_path = BASE_PATH.joinpath("693196-VS_73_sasha_smoothed.zarr")
img_path = BASE_PATH.joinpath("scaled_693196.zarr")
# points_path = BASE_PATH.joinpath("sam2_points_693196.npy")
# bboxs_path = BASE_PATH.joinpath("sam2_bboxs_693196.npy")

mask = zarr.load(mask_path).astype(np.uint8)
img = zarr.load(img_path)
# points = np.load(points_path)
# bboxs = np.load(bboxs_path)

print(img.shape, mask.shape, img.dtype, mask.dtype)#, points.dtype, points.shape)

  OLD_GPU, USE_FLASH_ATTN, MATH_KERNEL_ON = get_sdpa_settings()


(458, 1282, 929) (458, 1282, 929) float16 uint8


In [2]:
from scipy.ndimage import zoom, distance_transform_edt
import matplotlib.pyplot as plt
from pathlib import Path
from skimage.measure import regionprops, label
import heapq
from scipy.spatial.distance import cdist

def compute_dt_per_slice(mask):
    distance_transform = mask.copy().astype(np.float32)

    for i, curr_msk in enumerate(mask):
        if np.count_nonzero(curr_msk):
            distance_transform[i] = distance_transform_edt(curr_msk).astype(np.float32)

    return distance_transform

%matplotlib inline

def get_points_in_distance(points, d):
    """
    Sample points from the given set such that each selected point is at least distance `d` away
    from all previously selected points.

    Parameters:
    - points: A NumPy array of shape (n, m), where n is the number of points and m is the dimensionality.
    - d: Minimum distance between selected points.

    Returns:
    - A NumPy array of the selected points.
    """
    # Convert points to a NumPy array if not already
    points = np.array(points)

    # List to store the selected points
    selected_points = [points[0]]

    # Iterate over all points
    for idx, point in enumerate(points[1:]):
        # If no points have been selected yet, select the first one
        if not selected_points:
            selected_points.append(point)
        else:
            # Compute pairwise distances between the new point and the already selected points
            distances = cdist([point], selected_points)

            # If the new point is at least distance `d` away from all selected points, select it
            if np.all(distances >= d):
                selected_points.append(point)

    return np.array(selected_points)

def sort_points_by_heatmap(heatmap, points):
    """
    Sort points based on the values in the heatmap.

    Parameters:
    - heatmap: A 2D NumPy array representing the heatmap.
    - points: A NumPy array of shape (n, 2), where each row is a point with (row_index, column_index).

    Returns:
    - A sorted NumPy array of points based on their values in the heatmap.
    """
    # print(heatmap.shape, points)
    # Extract heatmap values for each point
    heatmap_values = heatmap[points[:, 0], points[:, 1]]

    # Sort points based on the heatmap values
    sorted_indices = np.argsort(heatmap_values)[::-1]
    sorted_points = points[sorted_indices]

    return sorted_points, heatmap_values[sorted_indices]

def get_sparsed_points(mask, distance_transform, n_points=5, min_point_distance=10.0):
    points = []

    for mask_index, slice_mask in enumerate(mask):
        counted_non_zero = np.count_nonzero(slice_mask)
        if counted_non_zero > 80:

            labeled_slice_mask, num_regions = label(slice_mask, return_num=True)
            
            for region_label in range(1, num_regions + 1):
                region_mask = labeled_slice_mask == region_label
                
                region_coords = np.column_stack(np.where(region_mask))

                region_coords, distance_transform_sorted = sort_points_by_heatmap(
                    distance_transform[mask_index], 
                    region_coords
                )
                
                picked_points = get_points_in_distance(region_coords, d=min_point_distance)

                picked_points = np.insert(picked_points, 0, mask_index, axis=1)
                points.append(picked_points[:n_points].copy())

    points_filtered = np.unique(np.concatenate(points, axis=0), axis=0)
    return points_filtered

In [3]:
def create_slices(
    brain_id,
    image_data,
    mask_data,
    output_images,
    output_masks,
    output_points,
    n_points=5,
    min_point_distance=10.0,
):
    """Gotta refactor, too many params"""
    output_images = Path(output_images)
    output_masks = Path(output_masks)
    output_points = Path(output_points)
    points = []
    
    if not output_images.exists():
        output_images.mkdir(parents=True, exist_ok=True)

    if not output_masks.exists():
        output_masks.mkdir(parents=True, exist_ok=True)

    if not output_points.exists():
        output_points.mkdir(parents=True, exist_ok=True)

    idx_range = range(mask_data.shape[0])
    generated_idxs = []

    for mask_idx in idx_range:

        curr_mask = mask_data[mask_idx]
        non_zero = np.count_nonzero(curr_mask)

        if non_zero > 100:
            print(f"Generating slice for slice: {mask_idx}")
            curr_slice = image_data[mask_idx].astype(np.float32)
            scaling_factor = (1024 / curr_slice.shape[0], 1024 / curr_slice.shape[1])

            curr_slice = zoom(curr_slice, scaling_factor, order=3)
            min_val = np.min(curr_slice)
            max_val = np.max(curr_slice)
            
            curr_slice = (curr_slice - min_val) / (max_val - min_val)
            curr_slice = (curr_slice * 255).astype(np.float16)

            # Creating mask
            curr_mask = zoom(curr_mask, scaling_factor, order=0).astype(np.uint8)

            # Labeling binary areas
            labeled_curr_mask, num_regions = label(curr_mask, return_num=True)
            msk_dt = distance_transform_edt(curr_mask).astype(np.float32)

            mask_points = []
            # Going through regions to get the points based on edt

            for region_label in range(1, num_regions + 1):
                region_mask = labeled_curr_mask == region_label

                region_coords = np.column_stack(np.where(region_mask))

                region_coords, distance_transform_sorted = sort_points_by_heatmap(
                    msk_dt, region_coords
                )

                picked_points = get_points_in_distance(
                    region_coords, d=min_point_distance
                )
                mask_points.append(picked_points[:n_points].copy())

            mask_points = np.concatenate(mask_points, axis=0).astype(np.uint32)
            mask_points = np.insert(mask_points, 0, mask_idx, axis=1)

            if mask_points.shape[0]:
                np.save(
                    str(
                        output_images.joinpath(
                            f"smartspim_{brain_id}_vs_{mask_idx}.npy"
                        )
                    ),
                    curr_slice,
                )
                np.save(
                    str(
                        output_masks.joinpath(f"smartspim_{brain_id}_vs_{mask_idx}.npy")
                    ),
                    curr_mask,
                )
                points.append(mask_points.copy())
                generated_idxs.append(mask_idx)
                
                # np.save(
                #     str(
                #         output_points.joinpath(
                #             f"smartspim_{brain_id}_vs_{mask_idx}.npy"
                #         )
                #     ),
                #     mask_points,
                # )

            else:
                print(
                    f"Not saving slice {mask_idx} since there are no points: {mask_points.shape}"
                )
                raise ValueError

    points = np.concatenate(points, axis=0).astype(np.uint16)
    np.save(
        str(
            output_points.joinpath(
                f"smartspim_{brain_id}_vs_pts.npy"
            )
        ),
        points,
    )
    return generated_idxs
    


In [4]:
dataset_path = Path("/scratch/ventricle_dataset")

idxs = create_slices(
    brain_id="693196",
    image_data=img,
    mask_data=mask,
    output_images=dataset_path.joinpath('train/images'),
    output_masks=dataset_path.joinpath('train/labels'),
    output_points=dataset_path.joinpath('train/points'),
    n_points=1,
    min_point_distance=10.0
)

Generating slice for slice: 73
Generating slice for slice: 74
Generating slice for slice: 75
Generating slice for slice: 76
Generating slice for slice: 77
Generating slice for slice: 78
Generating slice for slice: 79
Generating slice for slice: 80
Generating slice for slice: 81
Generating slice for slice: 82
Generating slice for slice: 83
Generating slice for slice: 84
Generating slice for slice: 85
Generating slice for slice: 86
Generating slice for slice: 87
Generating slice for slice: 88
Generating slice for slice: 89
Generating slice for slice: 90
Generating slice for slice: 91
Generating slice for slice: 92
Generating slice for slice: 93
Generating slice for slice: 94
Generating slice for slice: 95
Generating slice for slice: 96
Generating slice for slice: 97
Generating slice for slice: 98
Generating slice for slice: 99
Generating slice for slice: 100
Generating slice for slice: 101
Generating slice for slice: 102
Generating slice for slice: 103
Generating slice for slice: 104
Gen

In [5]:
import numpy as np
# BASE_PATH_POINTS = Path("/scratch/ventricle_dataset/train/points")
points = np.load("/scratch/ventricle_dataset/train/points/smartspim_693196_vs_pts.npy")

for i in idxs:
    c = points[points[:, 0]==i]
    if not c.shape[0]:
        print(f"Problem with {i}")
    else:
        print(f"Slice {i} has {c.shape}")
    

Slice 73 has (1, 3)
Slice 74 has (1, 3)
Slice 75 has (1, 3)
Slice 76 has (1, 3)
Slice 77 has (1, 3)
Slice 78 has (1, 3)
Slice 79 has (1, 3)
Slice 80 has (1, 3)
Slice 81 has (1, 3)
Slice 82 has (1, 3)
Slice 83 has (1, 3)
Slice 84 has (1, 3)
Slice 85 has (1, 3)
Slice 86 has (1, 3)
Slice 87 has (1, 3)
Slice 88 has (1, 3)
Slice 89 has (1, 3)
Slice 90 has (1, 3)
Slice 91 has (1, 3)
Slice 92 has (1, 3)
Slice 93 has (1, 3)
Slice 94 has (1, 3)
Slice 95 has (1, 3)
Slice 96 has (1, 3)
Slice 97 has (1, 3)
Slice 98 has (1, 3)
Slice 99 has (1, 3)
Slice 100 has (1, 3)
Slice 101 has (1, 3)
Slice 102 has (3, 3)
Slice 103 has (2, 3)
Slice 104 has (2, 3)
Slice 105 has (1, 3)
Slice 106 has (2, 3)
Slice 107 has (2, 3)
Slice 108 has (2, 3)
Slice 109 has (2, 3)
Slice 110 has (2, 3)
Slice 111 has (2, 3)
Slice 112 has (2, 3)
Slice 113 has (3, 3)
Slice 114 has (2, 3)
Slice 115 has (4, 3)
Slice 116 has (5, 3)
Slice 117 has (5, 3)
Slice 118 has (3, 3)
Slice 119 has (3, 3)
Slice 120 has (3, 3)
Slice 121 has (3, 3

In [6]:
points.shape

(2596, 3)

array([], shape=(0, 3), dtype=uint16)