# Get image embeddings using SAM

<div style="display: inline-flex; align-items: center; gap: 10px;">
        <a href="https://colab.research.google.com/github/3lc-ai/3lc-examples/blob/main/tutorials/sam-embeddings.ipynb"
        target="_blank"
            style="background-color: transparent; text-decoration: none; display: inline-flex; align-items: center;
            padding: 5px 10px; font-family: Arial, sans-serif;"> <img
            src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab" style="height: 30px;
            vertical-align: middle;box-shadow: none;"/>
        </a> <a href="https://github.com/3lc-ai/3lc-examples/blob/main/tutorials/sam-embeddings.ipynb"
            style="text-decoration: none; display: inline-flex; align-items: center; background-color: #ffffff; border:
            1px solid #d1d5da; border-radius: 8px; padding: 2px 10px; color: #333; font-family: Arial, sans-serif;">
            <svg aria-hidden="true" focusable="false" role="img" class="octicon octicon-mark-github" viewBox="0 0 16 16"
            width="20" height="20" fill="#333"
            style="display:inline-block;user-select:none;vertical-align:text-bottom;overflow:visible; margin-right:
            8px;">
                <path d="M8 0c4.42 0 8 3.58 8 8a8.013 8.013 0 0 1-5.45 7.59c-.4.08-.55-.17-.55-.38 0-.27.01-1.13.01-2.2
                0-.75-.25-1.23-.54-1.48 1.78-.2 3.65-.88 3.65-3.95 0-.88-.31-1.59-.82-2.15.08-.2.36-1.02-.08-2.12 0
                0-.67-.22-2.2.82-.64-.18-1.32-.27-2-.27-.68 0-1.36.09-2 .27-1.53-1.03-2.2-.82-2.2-.82-.44 1.1-.16
                1.92-.08 2.12-.51.56-.82 1.28-.82 2.15 0 3.06 1.86 3.75 3.64 3.95-.23.2-.44.55-.51
                1.07-.46.21-1.61.55-2.33-.66-.15-.24-.6-.83-1.23-.82-.67.01-.27.38.01.53.34.19.73.9.82 1.13.16.45.68
                1.31 2.69.94 0 .67.01 1.3.01 1.49 0 .21-.15.45-.55.38A7.995 7.995 0 0 1 0 8c0-4.42 3.58-8 8-8Z"></path>
            </svg> <span style="vertical-align: middle; color: #333;">Open in GitHub</span>
        </a>
</div>

In this example, we will show how to create a Run containing embeddings extracted from SAM for a set of images.

## Constants

In [14]:
ABSOLUTE_PATH_TO_IMAGES = "/Users/markus/Code/3lc-examples/data/coco128/images"
MODEL_TYPE = "vit_b"
MODEL_URL = "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth"
CHECKPOINT = "sam_vit_b_01ec64.pth"
DEVICE = "mps"
EMBEDDING_DIM = 3
REDUCTION_METHOD = "pacmap"
BATCH_SIZE = 4
INSTALL_DEPENDENCIES = False

## Install dependencies

In [15]:
%%capture
if INSTALL_DEPENDENCIES:
    %pip --quiet install 3lc segment_anything pacmap opencv-python

## Imports

In [16]:
from pathlib import Path
import tlc
import torch
from segment_anything import sam_model_registry
from segment_anything.utils.transforms import ResizeLongestSide
import cv2
from tqdm import tqdm
from torch.utils.data import DataLoader

## Download model weights

In [17]:
if not Path(CHECKPOINT).exists():
    torch.hub.download_url_to_file(MODEL_URL, CHECKPOINT)

## Set up model and preprocessing

In [18]:
def create_model():
    sam_model = sam_model_registry[MODEL_TYPE](checkpoint=CHECKPOINT)
    sam_model.to(DEVICE)
    sam_model.eval()
    return sam_model

In [19]:
sam_model = create_model()
RESIZE_TRANSFORM = ResizeLongestSide(sam_model.image_encoder.img_size)
PREPROCESS_TRANSFORM = sam_model.preprocess

def transform_to_sam_format(sample):
    image = cv2.cvtColor(cv2.imread(sample["image"]), cv2.COLOR_BGR2RGB)
    image = RESIZE_TRANSFORM.apply_image(image)
    image = torch.as_tensor(image, device=DEVICE).permute(2, 0, 1).contiguous()
    image = PREPROCESS_TRANSFORM(image)

    return {"image": image}

## Create 3LC Table and Run

In [20]:
# The paths of all the images in the specified directory
image_paths = sorted([str(p) for p in Path(ABSOLUTE_PATH_TO_IMAGES).iterdir()])

# Create a table with the image paths and apply the transformation defined above
table = tlc.Table.from_dict(
    data={"image": image_paths},
    structure={"image": tlc.ImagePath},
    project_name="SAM-embeddings",
    dataset_name="Name of my dataset",
    table_name="initial",
).map(transform_to_sam_format)

# Initialize a 3LC Run
run = tlc.init(project_name="SAM-embeddings", run_name="Collect embeddings")

# Add our Table to the Run
run.add_input_table(table)

## Collect embeddings using SAM

In [21]:
# Define the Schema for our MetricsTableWriter
embedding_size = tlc.DimensionNumericValue(
    value_min=256,
    value_max=256,
    enforce_min=True,
    enforce_max=True,
)
image_embedding_schema = tlc.Schema(
    "Image Embedding",
    writable=False,
    computable=False,
    value=tlc.Float32Value(number_role=tlc.NUMBER_ROLE_NN_EMBEDDING),
    size0=embedding_size,
)

# Create a MetricsTableWriter
metrics_writer = tlc.MetricsTableWriter(
    run.url,
    table.url,
    column_schemas={"embedding": image_embedding_schema},
)


# Iterate over the Table and collect the embeddings
for i, sample in enumerate(tqdm(DataLoader(table, batch_size=BATCH_SIZE))):
    with torch.no_grad():
        embedding = sam_model.image_encoder(sample["image"])
    metrics_writer.add_batch(
        {
            "embedding": embedding.mean(dim=[2, 3]).cpu().numpy(),
            "example_id": list(range(i * BATCH_SIZE, (i + 1) * BATCH_SIZE)),
        }
    )

# Finalize the MetricsTableWriter and update the Run with the collected metrics
metrics_writer.finalize()
metrics_infos = metrics_writer.get_written_metrics_infos()
run.update_metrics(metrics_infos)

# Reduce the embeddings using the specified method and number of dimensions
run.reduce_embeddings_by_foreign_table_url(
    table.url, method=REDUCTION_METHOD, n_components=EMBEDDING_DIM
)

100%|██████████| 32/32 [01:18<00:00,  2.44s/it]


{Url('/Users/markus/Library/Application Support/3LC/projects/SAM-embeddings/runs/Collect embeddings_0004/metrics_0000'): Url('/Users/markus/Library/Application Support/3LC/projects/SAM-embeddings/runs/Collect embeddings_0004/metrics_0000_reduced_embedding_00')}