### Requirements

- numpy
- scikit-learn
- pytorch
- torchvision
- opencv-python
- h5py
- tqdm


In [None]:
import sys
sys.path.append("./SAM")
sys.path.append("./utils")

In [None]:
import pickle
from pathlib import Path

import h5py
import numpy as np
import torch
from PIL import Image, ImageSequence
from sklearn.ensemble import RandomForestClassifier

from tqdm.notebook import trange, tqdm

import SAM
from SAM.models import LightHQSAM
from utils.data import (
    get_stack_sizes,
    get_num_target_patches
)
from utils.extract import (
    get_patch_sizes,
    get_sam_embeddings_for_slice
)
from utils.postprocess import postprocess_segmentation
from utils.postprocess_with_sam import postprocess_segmentations_with_sam

### Set the Input, RF Model and the result directory paths

In [None]:
# input image
data_path = "../data/Stack02_819_3598_cor_TM1corb_cr/Substack.tif"
data_path = Path(data_path)
print(f"data_path exists: {data_path.exists()}")

# random forest model
rf_model_path = "../data/Stack02_819_3598_cor_TM1corb_cr/rf_model_1.bin"
rf_model_path = Path(rf_model_path)
print(f"rf_model_path exists: {rf_model_path.exists()}")

# result folder
segmentation_dir = data_path.joinpath("segmentation_results")
segmentation_dir.mkdir(parents=True, exist_ok=True)

# temporary storage path for saving extracted embeddings patches
storage_path = "./temp_storage.hdf5"

### Initializing the SAM Model

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"running on {device}")

In [None]:
# sam model (light hq sam)
sam_model = LightHQSAM.setup().to(device)
# load weights
weights = torch.load(
    "./SAM/models/weights/sam_hq_vit_tiny.pth",
    map_location=device
)
sam_model.load_state_dict(weights, strict=True)
sam_model.eval()

sam_encoder = sam_model.image_encoder

print(sam_encoder)

### Utility functions

In [None]:
def predict_slice(rf_model, patch_dataset, img_height, img_width, patch_size, target_patch_size):
    """Predict a slice patch by patch"""
    segmentation_image = []
    # shape: N x target_size x target_size x C
    feature_patches = patch_dataset[:]
    num_patches = feature_patches.shape[0]
    total_channels = SAM.ENCODER_OUT_CHANNELS + SAM.EMBED_PATCH_CHANNELS

    for i in tqdm(
        range(num_patches), desc="Predicting slice patches", position=1, leave=True
    ):
        input_data = feature_patches[i].reshape(-1, total_channels)
        predictions = rf_model.predict(input_data).astype(np.uint8)
        segmentation_image.append(predictions)

    segmentation_image = np.vstack(segmentation_image)
    # reshape into the image size + padding
    patch_rows, patch_cols = get_num_target_patches(
        img_height, img_width, patch_size, target_patch_size
    )
    segmentation_image = segmentation_image.reshape(
        patch_rows, patch_cols, target_patch_size, target_patch_size
    )
    segmentation_image = np.moveaxis(segmentation_image, 1, 2).reshape(
        patch_rows * target_patch_size,
        patch_cols * target_patch_size
    )
    # skip paddings
    segmentation_image = segmentation_image[:img_height, :img_width]

    return segmentation_image


def postprocess(segmentation_image, area_threshold, use_sam=False, sam_model=None):
    area_threshold = area_threshold / 100
    if use_sam:
        post_image = postprocess_segmentations_with_sam(
            sam_model, segmentation_image, area_threshold
        )
    else:
        post_image = postprocess_segmentation(
            segmentation_image, area_threshold
        )

    return post_image

### Prepare the Input and Temporary Storage

In [None]:
# get patch sizes
input_stack = Image.open(data_path)

num_slices = input_stack.n_frames
img_height = input_stack.height
img_width = input_stack.width

patch_size, target_patch_size = get_patch_sizes(img_height, img_width)

print(num_slices, img_height, img_width)
print(patch_size, target_patch_size)

In [None]:
with open(rf_model_path, mode="rb") as f:
    rf_model = pickle.load(f)
    rf_model.set_params(verbose=0)

rf_model

In [None]:
storage = h5py.File(storage_path, "w")
storage_group = storage.create_group("slice")

In [None]:
# post-processing parameters
do_postprocess = True
post_use_sam = False
area_threshold = 25

### Prediction

In [None]:
tiff_img = Image.open(data_path)
for i, page in tqdm(
    enumerate(ImageSequence.Iterator(tiff_img)),
    desc="Slices", total=num_slices, position=0
):
    # print(f"slice {i + 1}", end="\n")
    slice_img = np.array(page.convert("L"))

    get_sam_embeddings_for_slice(
        slice_img, patch_size, target_patch_size,
        sam_encoder, device, storage_group
    )

    segmentation_image = predict_slice(
        rf_model, storage_group["sam"],
        img_height, img_width,
        patch_size, target_patch_size
    )

    if do_postprocess:
        segmentation_image = postprocess(
            segmentation_image, area_threshold,
            post_use_sam, sam_model
        )

    # save result
    img = Image.fromarray(segmentation_image)
    img.save(segmentation_dir.joinpath(f"slice_{i}.tiff"))


if storage is not None:
    storage.close()
Path(storage_path).unlink()

In [None]:
if storage is not None:
    storage.close()
Path(storage_path).unlink()