# Fine-tune SAM using a Table and collect custom metrics and embeddings

<div style="display: inline-flex; align-items: center; gap: 10px;">
        <a href="https://colab.research.google.com/github/3lc-ai/notebook-examples/blob/main/fine-tune-sam.ipynb"
        target="_blank"
            style="background-color: transparent; text-decoration: none; display: inline-flex; align-items: center;
            padding: 5px 10px; font-family: Arial, sans-serif;"> <img
            src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab" style="height: 30px;
            vertical-align: middle;box-shadow: none;"/>
        </a> <a href="https://github.com/3lc-ai/notebook-examples/blob/main/fine-tune-sam.ipynb"
            style="text-decoration: none; display: inline-flex; align-items: center; background-color: #ffffff; border:
            1px solid #d1d5da; border-radius: 8px; padding: 2px 10px; color: #333; font-family: Arial, sans-serif;">
            <svg aria-hidden="true" focusable="false" role="img" class="octicon octicon-mark-github" viewBox="0 0 16 16"
            width="20" height="20" fill="#333"
            style="display:inline-block;user-select:none;vertical-align:text-bottom;overflow:visible; margin-right:
            8px;">
                <path d="M8 0c4.42 0 8 3.58 8 8a8.013 8.013 0 0 1-5.45 7.59c-.4.08-.55-.17-.55-.38 0-.27.01-1.13.01-2.2
                0-.75-.25-1.23-.54-1.48 1.78-.2 3.65-.88 3.65-3.95 0-.88-.31-1.59-.82-2.15.08-.2.36-1.02-.08-2.12 0
                0-.67-.22-2.2.82-.64-.18-1.32-.27-2-.27-.68 0-1.36.09-2 .27-1.53-1.03-2.2-.82-2.2-.82-.44 1.1-.16
                1.92-.08 2.12-.51.56-.82 1.28-.82 2.15 0 3.06 1.86 3.75 3.64 3.95-.23.2-.44.55-.51
                1.07-.46.21-1.61.55-2.33-.66-.15-.24-.6-.83-1.23-.82-.67.01-.27.38.01.53.34.19.73.9.82 1.13.16.45.68
                1.31 2.69.94 0 .67.01 1.3.01 1.49 0 .21-.15.45-.55.38A7.995 7.995 0 0 1 0 8c0-4.42 3.58-8 8-8Z"></path>
            </svg> <span style="vertical-align: middle; color: #333;">Open in GitHub</span>
        </a>
</div>

This notebook is a modified version of the official colab tutorial from Encord which can be found [here](https://colab.research.google.com/drive/1F6uRommb3GswcRlPZWpkAQRMVNdVH7Ww).

It demonstrates how you can use a 3LC Table to fine-tune Segment Anything Model (SAM). It also demonstrates how 3LC
can collect custom metrics and embeddings within a training loop.

In order to run this notebook, you must first have run the `create_sam_dataset.ipynb` notebook to create the
Table. 

In [None]:
# Parameters
DATASET_NAME = "staver-dataset"  # Need to match create_sam_dataset.ipnyb DATASET_NAME
RUN_NAME = "staver-run"

# Training parameters
MODEL_TYPE = "vit_b"
CHECKPOINT = "sam_vit_b_01ec64.pth"
DEVICE = "cuda:0"
LEARNING_RATE = 1e-4
WEIGHT_DECAY = 0
NUM_EPOCHS = 10

# Embedding parameters
EMBEDDING_DIM = 2
REDUCTION_METHOD = "pacmap"

INSTALL_DEPENDENCIES = False

In [None]:
%%capture
if INSTALL_DEPENDENCIES:
    %pip --quiet install torch --index-url https://download.pytorch.org/whl/cu118
    %pip --quiet install torchvision --index-url https://download.pytorch.org/whl/cu118
    %pip --quiet install git+https://github.com/facebookresearch/segment-anything.git
    %pip --quiet install opencv-python
    %pip --quiet install tlc[pacmap]

In [None]:
### HIDDEN CELL ###

# Reloads all modules every time before executing the Python code.
%load_ext autoreload
%autoreload 2

# Ensure notebook_tests on PATH
import os
import sys

sys.path.append('..')
import notebook_tests

# Optionally override the default test data path
if path := os.getenv("TLC_PUBLIC_EXAMPLES_TEST_DATA_PATH"):
    print(f"Using test data path: {path}")
    TEST_DATA_PATH = path

# Prints the current 3lc configuration
!3lc config --list

In [None]:
import tlc
from typing import List, Union
from segment_anything import sam_model_registry
from segment_anything.utils.transforms import ResizeLongestSide
from torchvision.transforms import ToPILImage
import torch
import cv2
import numpy as np
import os

## Downloading weights

In [None]:
if not os.path.exists(CHECKPOINT):
    torch.hub.download_url_to_file("https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth", CHECKPOINT)

## Finetuning

In [None]:
# Constants

# 3LC parameters
TABLE_URL = tlc.Url.create_table_url(project_name="SAM_EXAMPLE", dataset_name=DATASET_NAME)

RUN_URL = tlc.Url.create_run_url(run_name=RUN_NAME)


# Derived Constants
def create_model():
    sam_model = sam_model_registry[MODEL_TYPE](checkpoint=CHECKPOINT)
    sam_model.to(DEVICE)
    sam_model.train()
    return sam_model


sam_model = create_model()


run = tlc.Run(url=RUN_URL)


run.write_to_url()
run.add_input_table(TABLE_URL)


resize_transform = ResizeLongestSide(sam_model.image_encoder.img_size)


optimizer = torch.optim.Adam(sam_model.mask_decoder.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)
loss_fn = torch.nn.MSELoss()

In [None]:
def transform_to_sam_format(sample):
    image = cv2.imread(sample["image"])
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    transform = ResizeLongestSide(sam_model.image_encoder.img_size)
    input_image = transform.apply_image(image)
    input_image_torch = torch.as_tensor(input_image, device=DEVICE)
    transformed_image = input_image_torch.permute(2, 0, 1).contiguous()[None, :, :, :]
    input_image = sam_model.preprocess(transformed_image)

    original_image_size = image.shape[:2]

    input_size = tuple(transformed_image.shape[-2:])

    ground_truth_masks = cv2.imread(sample["mask"], cv2.IMREAD_GRAYSCALE)
    ground_truth_masks = ground_truth_masks == 0

    prompt_box = np.array(
        [
            sample["prompt box"]["bb_list"][0]["x0"],
            sample["prompt box"]["bb_list"][0]["y0"],
            sample["prompt box"]["bb_list"][0]["x1"],
            sample["prompt box"]["bb_list"][0]["y1"],
        ]
    )

    transformed_sample = {
        "image": input_image,
        "input_size": input_size,
        "original_image_size": original_image_size,
        "ground_truth_masks": ground_truth_masks,
        "prompt_box": prompt_box,
    }

    return transformed_sample


def pre_reduce_embedding(embedding: torch.Tensor) -> np.ndarray:
    """Prepare a batch of embeddings for writing to a metrics table by reducing its
    size, moving it to CPU, and converting it to a numpy array.
    """
    return embedding.mean(dim=[2, 3]).cpu().numpy()


def create_metrics_writer() -> tlc.MetricsTableWriter:
    size0 = tlc.DimensionNumericValue(
        value_min=256,
        value_max=256,
        enforce_min=True,
        enforce_max=True,
    )

    image_embedding_schema = tlc.Schema(
        "Image Embedding",
        writable=False,
        computable=False,
        value=tlc.Float32Value(number_role=tlc.NUMBER_ROLE_NN_EMBEDDING),
        size0=size0,
    )

    prompt_embedding_schema = tlc.Schema(
        "Prompt Embedding",
        writable=False,
        computable=False,
        value=tlc.Float32Value(number_role=tlc.NUMBER_ROLE_NN_EMBEDDING),
        size0=size0,
    )

    loss_schema = tlc.Schema(writable=False, value=tlc.Float32Value())
    prediction_schema = tlc.Schema(writable=False, value=tlc.ImageUrlStringValue())

    return tlc.MetricsTableWriter(
        run_url=RUN_URL,
        dataset_url=TABLE_URL.to_str(),
        override_column_schemas={
            "image_embedding": image_embedding_schema,
            "prompt_embedding": prompt_embedding_schema,
            "loss": loss_schema,
            "predicted_mask": prediction_schema,
        },
    )


def capture_metrics(
    image_embedding: torch.Tensor,
    prompt_embedding: torch.Tensor,
    loss: float,
    prediction: torch.Tensor,
    example_ids: List[int],
    epoch: int,
    metrics_writer: Union[tlc.MetricsTableWriter, None],
) -> None:
    if metrics_writer is None:
        return

    reduced_image_embedding = pre_reduce_embedding(image_embedding)
    reduced_prompt_embedding = prompt_embedding.mean(dim=[1]).cpu().numpy()

    prediction_url = metrics_writer.root_metrics_url.parent / "predictions" / str(epoch) / f"{example_ids[0]}.png"

    prediction_url.make_parents(exist_ok=True)

    img = ToPILImage()(1 - prediction.cpu().detach().squeeze())
    img.save(prediction_url.to_str())

    metrics_writer.add_batch(
        {
            "image_embedding": reduced_image_embedding,
            "prompt_embedding": reduced_prompt_embedding,
            "loss": [loss],
            "predicted_mask": [prediction_url.to_str()],
            "example_id": example_ids,
            "epoch": [epoch],
        }
    )


def flush_metrics_writer(metrics_writer: Union[tlc.MetricsTableWriter, None]) -> None:
    if metrics_writer is None:
        return

    metrics_writer.finalize()
    metrics_infos = metrics_writer.get_written_metrics_infos()
    run.update_metrics(metrics_infos)


def reduce_all_embeddings() -> None:
    run.reduce_embeddings_by_foreign_table_url(TABLE_URL, method=REDUCTION_METHOD, n_components=EMBEDDING_DIM)

In [None]:
table = tlc.Table.from_url(TABLE_URL).map(transform_to_sam_format)

In [None]:
from tqdm import tqdm
from torch.nn.functional import threshold, normalize

cached_samples = [sample for sample in table]

for epoch in range(NUM_EPOCHS):
    # Create a 3LC metrics writer on the last epoch
    metrics_writer = None
    if (epoch + 1) == NUM_EPOCHS:
        metrics_writer = create_metrics_writer()

    for i, sample in enumerate(tqdm(cached_samples, desc=f"Epoch {epoch}")):
        with torch.no_grad():
            image_embedding = sam_model.image_encoder(sample["image"])
            prompt_box = sample["prompt_box"]
            box = resize_transform.apply_boxes(prompt_box, sample["original_image_size"])
            box_torch = torch.as_tensor(box, dtype=torch.float, device=DEVICE)
            box_torch = box_torch[None, :]
            sparse_prompt_embedding, dense_prompt_embedding = sam_model.prompt_encoder(
                points=None,
                boxes=box_torch,
                masks=None,
            )

        low_res_masks, iou_predictions = sam_model.mask_decoder(
            image_embeddings=image_embedding,
            image_pe=sam_model.prompt_encoder.get_dense_pe(),
            sparse_prompt_embeddings=sparse_prompt_embedding,
            dense_prompt_embeddings=dense_prompt_embedding,
            multimask_output=False,
        )

        upscaled_masks = sam_model.postprocess_masks(
            low_res_masks, sample["input_size"], sample["original_image_size"]
        ).to(DEVICE)

        binary_mask = normalize(threshold(upscaled_masks, 0.0, 0))
        gt_mask_resized = torch.from_numpy(
            np.resize(
                sample["ground_truth_masks"],
                (1, 1, sample["ground_truth_masks"].shape[0], sample["ground_truth_masks"].shape[1]),
            )
        ).to(DEVICE)

        gt_binary_mask = torch.as_tensor(gt_mask_resized > 0, dtype=torch.float32)

        loss = loss_fn(binary_mask, gt_binary_mask)

        # Capture metrics with 3LC
        capture_metrics(
            image_embedding=image_embedding,
            prompt_embedding=sparse_prompt_embedding,
            example_ids=[i],
            epoch=epoch,
            loss=loss.item(),
            prediction=binary_mask,
            metrics_writer=metrics_writer,
        )

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    # Flush one epoch of metrics
    flush_metrics_writer(metrics_writer)

# Reduce all captured embeddings
reduce_all_embeddings()