In [2]:
import numpy as np
import matplotlib.pyplot as plt
import tifffile
import os
from patchify import patchify  #Only to handle large images
import random
from scipy import ndimage
from pathlib import Path
import zarr
from pathlib import Path
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
from natsort import natsorted
import matplotlib.pyplot as plt
import torch
import cv2
import os
from skimage.measure import label, regionprops

%matplotlib inline

BASE_PATH = Path("/scratch/ventricle_dataset/train")

def compute_iou(boxA, boxB):
    # determine the (x, y)-coordinates of the intersection rectangle
    xA = max(boxA[0], boxB[0])
    yA = max(boxA[1], boxB[1])
    xB = min(boxA[2], boxB[2])
    yB = min(boxA[3], boxB[3])

    # compute the area of intersection rectangle
    interArea = abs(max((xB - xA, 0)) * max((yB - yA), 0))
    if interArea == 0:
        return 0
    # compute the area of both the prediction and ground-truth
    # rectangles
    boxAArea = abs((boxA[2] - boxA[0]) * (boxA[3] - boxA[1]))
    boxBArea = abs((boxB[2] - boxB[0]) * (boxB[3] - boxB[1]))
    # print(boxAArea, boxBArea)
    # compute the intersection over union by taking the intersection
    # area and dividing it by the sum of prediction + ground-truth
    # areas - the interesection area
    iou = interArea / float(boxAArea + boxBArea - interArea)

    # return the intersection over union value
    return iou


def filter_bboxes_by_area(bboxes, iou_threshold=0.5):
    # Calculate the area of each bounding box
    areas = [(box[2] - box[0]) * (box[3] - box[1]) for box in bboxes]
    
    # Sort bounding boxes by area in descending order
    sorted_indices = np.argsort(-np.array(areas))
    bboxes = [bboxes[i] for i in sorted_indices]
    # print("Sorted boxes: ", bboxes)
    selected_bboxes = []

    while bboxes:
        # Choose the bounding box with the largest area
        chosen_box = bboxes.pop(0)
        selected_bboxes.append(chosen_box)

        # for box in bboxes:
        #     iou = compute_iou(chosen_box, box)
        #     print(f"IOU between {chosen_box} and {box}: {iou}")
        #     if iou > iou_threshold:
        #         print("Removing ", box)

            # else:
                
        # Remove boxes that overlap with the chosen box
        bboxes = [box for box in bboxes if compute_iou(chosen_box, box) < iou_threshold]

    return selected_bboxes

def get_bounding_box(ground_truth_map, iou_threshold=0.5):
    ground_truth_map = np.squeeze(ground_truth_map)
    H, W = ground_truth_map.shape
    labeled_mask = label(ground_truth_map)
    regions = regionprops(labeled_mask)
    bboxes = []
    
    for r in regions:
        min_row, min_col, max_row, max_col = r.bbox
        min_col = max(0, min_col - np.random.randint(0, 20))
        max_col = min(W, max_col + np.random.randint(0, 20))
        min_row = max(0, min_row - np.random.randint(0, 20))
        max_row = min(H, max_row + np.random.randint(0, 20))
        
        bboxes.append([
            float(min_col),
            float(min_row),
            float(max_col),
            float(max_row),
        ])
        bboxes = filter_bboxes_by_area(bboxes, iou_threshold=iou_threshold)
        bboxes = [list(b) for b in bboxes]
        # print("Filtered bboxs: ", bboxes)
    
    return bboxes

class NumpyDataset(Dataset):
    def __init__(self, image_path, labels_path, points, processor, transform=None):
        """
        Args:
            images (list or numpy array): List of images as NumPy arrays.
            labels (list or numpy array): List of labels as NumPy arrays or other formats.
            transform (callable, optional): Optional transform to be applied on a sample.
        """
        self.labels_paths = natsorted(
            [str(p.name) for p in list(Path(image_path).glob("*.npy"))]
        )
        self.image_base_path = Path(image_path)
        self.labels_base_path = Path(labels_path)
        self.points = points.copy()
        self.processor = processor
        self.transform = transform

    def __len__(self):
        return len(self.labels_paths)

    def __getitem__(self, idx):
        path = self.labels_paths[idx]
        slice_idx = int(path.split('_')[-1].replace('.npy', ''))
        slice_points = self.points[self.points[:, 0] == slice_idx]
        if not slice_points.shape[0]:
            raise ValueError(f"Problem getting points from image {path} {slice_idx}")

        # slice_points = slice_points[:, 1:]
        image = np.load(self.image_base_path.joinpath(self.labels_paths[idx])).astype(np.float32)
        label = np.load(self.labels_base_path.joinpath(self.labels_paths[idx])).astype(np.uint8)
        bbox_prompt = get_bounding_box(label)
        
        if self.transform:
            bbox_prompt = [ b + ['ventricle'] for b in bbox_prompt]
            augmented_data = self.transform(
                image=image,
                mask=label,
                bboxes=bbox_prompt,
            )
            image = augmented_data['image'].detach().cpu().numpy()
            label = augmented_data['mask'].detach().cpu().numpy()
            bbox_prompt = augmented_data['bboxes']
            bbox_prompt = [list(b[:-1]) for b in bbox_prompt]

        else:
            image = np.expand_dims(image, axis=0)
        # if not isinstance(image, torch.Tensor):
        #     image = torch.tensor(image[np.newaxis, ...], dtype=torch.float32)

        # if not isinstance(label, torch.Tensor):
        #     label = torch.tensor(label, dtype=torch.long)

        image = image.transpose( (1,2,0) ) / 255
        image = np.concatenate([image]*3, axis=-1)
        # print(bbox_prompt)

        # print(image.shape, label.shape,image.dtype, label.dtype, len(slice_points), slice_points)
        # slice_points.tolist()
        inputs = self.processor(image, input_boxes=[bbox_prompt], return_tensors="pt", input_data_format="channels_last", do_rescale=False)
    
        # remove batch dimension which the processor adds by default
        inputs = {k:v.squeeze(0) for k,v in inputs.items()}
    
        # add ground truth segmentation
        inputs["ground_truth_mask"] = label
        
        return inputs

In [3]:
from transformers import SamProcessor

def custom_collate(batch):
    # print(batch[0].keys())
    pixel_values = []
    original_sizes = []
    reshaped_input_sizes = []
    input_boxes = []
    ground_truth_mask = []

    for b in batch:
        pixel_values.append(b['pixel_values'])
        original_sizes.append(b['original_sizes'])
        reshaped_input_sizes.append(b['reshaped_input_sizes'])
        input_boxes.append(b['input_boxes'])
        ground_truth_mask.append(b['ground_truth_mask'])

    return pixel_values, original_sizes, reshaped_input_sizes, input_boxes, ground_truth_mask


processor = SamProcessor.from_pretrained("facebook/sam-vit-base")

BASE_PATH = Path("/scratch/ventricle_dataset/train")

points = np.load(BASE_PATH.joinpath('points/smartspim_693196_vs_pts.npy'))
ventricle_dataset = NumpyDataset(
    image_path=BASE_PATH.joinpath('images'),
    labels_path=BASE_PATH.joinpath('labels'),
    points=points,
    processor=processor,
    transform=None, #augmentations
)
dataloader = DataLoader(ventricle_dataset, batch_size=1, shuffle=False, collate_fn=custom_collate)

In [6]:
from transformers import SamModel
from tqdm import tqdm
import torch

device = torch.device(0)
model = SamModel.from_pretrained("facebook/sam-vit-base")#, low_cpu_mem_usage=True, torch_dtype=torch.float16)

In [7]:
ventricle_weights = torch.load("sam_ventricle_2_0.7978167533874512.pt")#"only_decoder/sam_ventricle_40_0.7864770293235779.pt")
model.load_state_dict(ventricle_weights)
model.to(device)
model.eval()

  ventricle_weights = torch.load("sam_ventricle_2_0.7978167533874512.pt")#"only_decoder/sam_ventricle_40_0.7864770293235779.pt")


SamModel(
  (shared_image_embedding): SamPositionalEmbedding()
  (vision_encoder): SamVisionEncoder(
    (patch_embed): SamPatchEmbeddings(
      (projection): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
    )
    (layers): ModuleList(
      (0-11): 12 x SamVisionLayer(
        (layer_norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
        (attn): SamVisionAttention(
          (qkv): Linear(in_features=768, out_features=2304, bias=True)
          (proj): Linear(in_features=768, out_features=768, bias=True)
        )
        (layer_norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
        (mlp): SamMLPBlock(
          (lin1): Linear(in_features=768, out_features=3072, bias=True)
          (lin2): Linear(in_features=3072, out_features=768, bias=True)
          (act): GELUActivation()
        )
      )
    )
    (neck): SamVisionNeck(
      (conv1): Conv2d(768, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (layer_norm1): SamLayerNorm()
     

In [8]:
pbar = tqdm(dataloader)

volume_slices = []
volume_masks = []
volume_pred = []
volume_scores = []

with torch.no_grad():

    for batch in pbar:
        # Getting data from batch
        pixel_values, original_sizes, reshaped_input_sizes, input_boxes, ground_truth_masks = batch
    
        pixel_values = torch.stack(pixel_values, dim=0).to(device)#, torch.float16)
        ground_truth_masks = torch.stack(
            [torch.tensor(np_arr, dtype=torch.float32)
             for np_arr in ground_truth_masks],
            dim=0
        )
        original_sizes = torch.stack(original_sizes, dim=0)
        reshaped_input_sizes = torch.stack(reshaped_input_sizes, dim=0)
    
        pred_masks = []
        batch_iou = []
        for img_idx in range(len(input_boxes)):
            pix_val = pixel_values[img_idx][None, ...]
            inp_box = input_boxes[img_idx][None, ...]
            
            sam_pred = model(
                pixel_values=pix_val,
                input_boxes=inp_box.to(device),
                multimask_output=False
            )
            
            scores = sam_pred.iou_scores
    
            pred_mask = sam_pred.pred_masks.cpu()
            orig_size = original_sizes[img_idx][None, ...].cpu()
            reshaped_size = reshaped_input_sizes[img_idx][None, ...].cpu()
    
            up_pred_mask = processor.image_processor.post_process_masks(
                pred_mask,
                orig_size,
                reshaped_size
            )
    
            # SAM generates a single mask per box, so I'm concatenating them
            if inp_box.shape[1] != 1:
                up_pred_mask, _ = torch.max(up_pred_mask[0].float(), dim=0)
                scores = scores.detach().cpu().numpy().squeeze().mean()
    
            else:
                up_pred_mask = up_pred_mask[0].float()
                scores = scores.detach().cpu().numpy().squeeze()
                
            pred_masks.append(
                up_pred_mask.squeeze()
            )
            batch_iou.append(
                scores
            )
    
        # GT
        ground_truth_masks = ground_truth_masks.float()
        up_pred_masks = torch.stack(pred_masks, dim=0).detach().cpu().numpy().astype(np.uint8)
        batch_iou = np.array(batch_iou).mean()
        # print(f"GT {ground_truth_masks.shape} Pred: {up_pred_masks.shape} Scores: {batch_iou}")
    
        volume_slices.append(pixel_values.detach().cpu().numpy())
        volume_masks.append(ground_truth_masks.detach().cpu().numpy().astype(np.uint8))
        volume_pred.append(up_pred_masks.copy())
        volume_scores.append(batch_iou)


100%|██████████| 385/385 [03:20<00:00,  1.92it/s]


In [9]:
torch.cuda.empty_cache()

In [10]:
print(len(volume_slices), volume_slices[0].shape)

385 (1, 3, 1024, 1024)


In [11]:
import zarr
from pathlib import Path
from natsort import natsorted

zarr_brain_data = zarr.load("/scratch/scaled_693196.zarr")
mask_indices = np.array(
    natsorted(
        [
            int(str(p).split("_")[-1].replace('.npy', ''))
            for p in Path("/scratch/ventricle_dataset/train/images").glob("*.npy")
        ]
    )
)
print(zarr_brain_data.shape, mask_indices.shape)

(458, 1282, 929) (385,)


In [12]:
mask_indices

array([ 73,  74,  75,  76,  77,  78,  79,  80,  81,  82,  83,  84,  85,
        86,  87,  88,  89,  90,  91,  92,  93,  94,  95,  96,  97,  98,
        99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111,
       112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124,
       125, 126, 127, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137,
       138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, 149, 150,
       151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163,
       164, 165, 166, 167, 168, 169, 170, 171, 172, 173, 174, 175, 176,
       177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189,
       190, 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202,
       203, 204, 205, 206, 207, 208, 209, 210, 211, 212, 213, 214, 215,
       216, 217, 218, 219, 220, 221, 222, 223, 224, 225, 226, 227, 228,
       229, 230, 231, 232, 233, 234, 235, 236, 237, 238, 239, 240, 241,
       242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 25

In [13]:
volume_masks_plot = np.concatenate(
    volume_masks, axis=0
)[:, :, ...]
print(volume_masks_plot.shape)

(385, 1024, 1024)


In [14]:
from scipy.ndimage import zoom

scaling_factor = (
    1,
    zarr_brain_data.shape[-2] / volume_masks_plot.shape[-2],
    zarr_brain_data.shape[-1] / volume_masks_plot.shape[-1]
)

volume_masks_plot = zoom(volume_masks_plot, scaling_factor, order=0)
print("Mask reshaped: ", volume_masks_plot.shape)

Mask reshaped:  (385, 1282, 929)


In [15]:
volume_pred_plot = np.concatenate(
    volume_pred, axis=0
)[:, :, ...]
print(volume_pred_plot.shape)

(385, 1024, 1024)


In [16]:
volume_pred_plot = zoom(volume_pred_plot, scaling_factor, order=0)
print("pred reshaped: ", volume_pred_plot.shape)

pred reshaped:  (385, 1282, 929)


In [17]:
combined_mask = np.zeros(volume_pred_plot.shape + (3, ))
print(combined_mask.shape)
combined_mask[..., 0] = volume_masks_plot # RED
combined_mask[..., 1] = volume_pred_plot # GREEN

(385, 1282, 929, 3)


In [18]:
# volume_slices_plot = np.expand_dims(
#     np.concatenate(
#         volume_slices, axis=0
#     )[:, 0, :, ...],
#     axis=-1
# )
# # volume_slices_plot = (
# #     volume_slices_plot - np.min(volume_slices_plot)
# # ) / (np.max(volume_slices_plot) - np.min(volume_slices_plot))

# # volume_slices_plot *= 255
# print(volume_slices_plot.shape)

In [20]:
volume_scores_plot = np.array(volume_scores)
print(volume_scores_plot.shape)

(385,)


In [21]:
import numpy as np
import matplotlib.pyplot as plt
import ipywidgets as widgets
from ipywidgets import interact

def show_mask(mask, ax, obj_id=None, random_color=False):
    if random_color:
        color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
    else:
        cmap = plt.get_cmap("tab10")
        cmap_idx = 0 if obj_id is None else obj_id
        color = np.array([*cmap(cmap_idx)[:3], 0.6])
    h, w = mask.shape[-2:]
    mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
    ax.imshow(mask_image)

def show_points(coords, labels, ax, marker_size=200):
    pos_points = coords[labels==1]
    neg_points = coords[labels==0]
    ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
    ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)

def show_box(box, ax):
    x0, y0 = box[0], box[1]
    w, h = box[2] - box[0], box[3] - box[1]
    ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0, 0, 0, 0), lw=2))

# Function to plot a specific slice
def plot_slice(volume, combined_mask, scores, mask_indices, slice_idx, cmap='gray'):
    plt.figure(figsize=(10, 10))
    
    slice_ = volume[slice_idx, :, :]

    plt.imshow(slice_, cmap=cmap, vmin=0, vmax=5)

    msg = f"Slice {slice_idx}"
    # Plotting mask if it exists in that slice
    eval_slice =  np.where(mask_indices == slice_idx)[0]
    if eval_slice.shape[0]:
        eval_slice = eval_slice[0]
    
        # slice_msk = mask[eval_slice, :, :]
        # slice_pred = mask[eval_slice, :, :]
        slice_score = scores[eval_slice]

        plt.imshow(combined_mask[eval_slice], alpha=0.5)#, cmap='Greens')
        
        # plt.imshow(slice_pred, alpha=0.5, cmap='Blues')
        # plt.imshow(slice_msk, alpha=0.5, cmap='Oranges')
        # plt.scatter(slice_points[:, 1], slice_points[:, 0], c='red', s=5)
        msg += f" - IoU: {slice_score}"
        
    plt.title(msg)
    plt.axis('off')
    plt.show()

# Interactive function for controlling slice selection
def interactive_plot(volume, combined_mask, scores, mask_indices, axis=0, cmap='gray'):
    n_slices = volume.shape[axis]
    
    # Slider for selecting slices
    slice_slider = widgets.IntSlider(min=0, max=n_slices-1, step=1, value=0, description='Slice')
    
    # Update function for slider
    def update(slice_idx):
        plot_slice(volume, combined_mask, scores, mask_indices, slice_idx, cmap)

    # Interactive display with slider
    interact(update, slice_idx=slice_slider)

interactive_plot(
    zarr_brain_data,
    # volume_masks_plot,
    # volume_pred_plot,
    combined_mask,
    volume_scores_plot,
    mask_indices
)

interactive(children=(IntSlider(value=0, description='Slice', max=457), Output()), _dom_classes=('widget-inter…