![TomoSAM logo](https://github.com/fsemerar/SlicerTomoSAM/raw/main/TomoSAM/Resources/Media/tomosam_logo.png)

This notebook helps with the generation of the image embeddings for all the slices of your tiff stack along the three Cartesian directions. You can create the embeddings by running this notebook either locally or on Colab. A GPU is recommended for this step to speed up the process; in Colab, make sure to select `Runtime`→`Change runtime type` and set the `Hardware accelerator` to GPU. Locally, you will first need to create the conda environment, as shown in the README.

In [None]:
from google.colab import output
output.enable_custom_widget_manager()
!pip install segment-anything

In [1]:
# Download weights for SAM
![ ! -f "sam_vit_h_4b8939.pth" ] && wget https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth

In [None]:
import cv2
import numpy as np
from segment_anything import sam_model_registry, SamPredictor
import torch
import sys, os
import pickle
def create_embeddings(img_input_filepath, output_filepath, sam_checkpoint_path):

    check, img = cv2.imreadmulti(img_input_filepath)
    img = np.array(img)
    if not check:
        raise Exception("Image file not found.")
    elif img.ndim > 3 or img.ndim < 2:
        raise Exception("Unsupported image type.")
    elif img.ndim == 2:
        img = img[:, :, np.newaxis]

    print(f"Image dimensions: {img.shape}")

    sam = sam_model_registry["vit_h"](checkpoint=sam_checkpoint_path)
    if torch.cuda.is_available():
        sam.to(device="cuda")
    predictor = SamPredictor(sam)

    embeddings = [[], [], []]
    slice_direction = ['x', 'y', 'z']
    for i, d in enumerate(slice_direction):
        print(f"\nSlicing along {d} direction")
        for k in range(img.shape[i]):
            if i == 0:
                img_slice = img[k]
            elif i == 1:
                img_slice = img[:, k]
            else:
                img_slice = img[:, :, k]
            sys.stdout.write(f"\rCreating embedding for {k + 1}/{img.shape[i]} image")
            predictor.reset_image()
            predictor.set_image(np.repeat(img_slice[:, :, np.newaxis], 3, axis=2))
            embeddings[i].append({'original_size': predictor.original_size,
                                  'input_size': predictor.input_size,
                                  'features': predictor.features.to('cpu')})

    with open(output_filepath + ".pkl", 'wb') as f:
        pickle.dump(embeddings, f)
        print(f"Saved {output_filepath}.pkl")

In [None]:
from google.colab import files
img_filename = list(files.upload().keys())[0]

In [None]:
# for local use
img_filename = ""

In [None]:
create_embeddings(img_filename, os.path.splitext(img_filename)[0], "sam_vit_h_4b8939.pth")

In [None]:
# Download from Colab
files.download(img_filename + ".pkl")