# Collect and Reduce Classifier Embeddings

<div style="display: inline-flex; align-items: center; gap: 10px;">
        <a href="https://colab.research.google.com/github/3lc-ai/3lc-examples/blob/main/tutorials/bb-embeddings/2-collect-embeddings.ipynb"
        target="_blank"
            style="background-color: transparent; text-decoration: none; display: inline-flex; align-items: center;
            padding: 5px 10px; font-family: Arial, sans-serif;"> <img
            src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab" style="height: 30px;
            vertical-align: middle;box-shadow: none;"/>
        </a> <a href="https://github.com/3lc-ai/3lc-examples/blob/main/tutorials/bb-embeddings/2-collect-embeddings.ipynb"
            style="text-decoration: none; display: inline-flex; align-items: center; background-color: #ffffff; border:
            1px solid #d1d5da; border-radius: 8px; padding: 2px 10px; color: #333; font-family: Arial, sans-serif;">
            <svg aria-hidden="true" focusable="false" role="img" class="octicon octicon-mark-github" viewBox="0 0 16 16"
            width="20" height="20" fill="#333"
            style="display:inline-block;user-select:none;vertical-align:text-bottom;overflow:visible; margin-right:
            8px;">
                <path d="M8 0c4.42 0 8 3.58 8 8a8.013 8.013 0 0 1-5.45 7.59c-.4.08-.55-.17-.55-.38 0-.27.01-1.13.01-2.2
                0-.75-.25-1.23-.54-1.48 1.78-.2 3.65-.88 3.65-3.95 0-.88-.31-1.59-.82-2.15.08-.2.36-1.02-.08-2.12 0
                0-.67-.22-2.2.82-.64-.18-1.32-.27-2-.27-.68 0-1.36.09-2 .27-1.53-1.03-2.2-.82-2.2-.82-.44 1.1-.16
                1.92-.08 2.12-.51.56-.82 1.28-.82 2.15 0 3.06 1.86 3.75 3.64 3.95-.23.2-.44.55-.51
                1.07-.46.21-1.61.55-2.33-.66-.15-.24-.6-.83-1.23-.82-.67.01-.27.38.01.53.34.19.73.9.82 1.13.16.45.68
                1.31 2.69.94 0 .67.01 1.3.01 1.49 0 .21-.15.45-.55.38A7.995 7.995 0 0 1 0 8c0-4.42 3.58-8 8-8Z"></path>
            </svg> <span style="vertical-align: middle; color: #333;">Open in GitHub</span>
        </a>
</div>

In this tutorial, we will use an existing classifier model to generate per-bounding-box embeddings for a COCO-style object detection dataset. We will then reduce these embeddings to 3D using PaCMAP.

To run this notebook, you must also have run:
* [1-fine-tune-on-crops.ipynb](https://github.com/3lc-ai/3lc-examples/blob/main/tutorials/bb-embeddings/1-fine-tune-on-crops.ipynb)

## Project Setup

In [None]:
PROJECT_NAME = "Bounding Box Classifier"
TRANSIENT_DATA_PATH = "../../transient_data"
BATCH_SIZE = 32
DATASET_NAME = "Balloons"
DEVICE = None
INSTALL_DEPENDENCIES = False

In [None]:
%%capture
if INSTALL_DEPENDENCIES:
    %pip --quiet install pacmap

## Imports

In [None]:
from io import BytesIO

import numpy as np
import pacmap
import timm
import tlc
import torch
from PIL import Image
from torchvision import transforms
from tqdm.auto import tqdm

## Set device

In [None]:
if DEVICE is None:
    if torch.cuda.is_available():
        device = "cuda:0"
    elif torch.backends.mps.is_available():
        device = "mps"
    else:
        device = "cpu"
else:
    device = DEVICE

device = torch.device(device)
print(f"Using device: {device}")

## Get Input Table

In [None]:
# Open the Table created in the previous notebook
input_table = tlc.Table.from_names(
    project_name=PROJECT_NAME, 
    dataset_name=DATASET_NAME, 
    table_name="original"
)

# Get the schema of the bounding box list, required to crop
bb_schema = input_table.rows_schema.values["bbs"].values["bb_list"]

## Get Model

In [None]:
# Load the model trained in the previous notebook
model = timm.create_model(
    "efficientnet_b0", 
    num_classes=2, 
    checkpoint_path=TRANSIENT_DATA_PATH + "/bb_classifier.pth"
).to(device)

model.eval()

# The hidden layer whose activations we will use for embeddings
hidden_layer = model.global_pool.flatten

## Set Up Data Processing

In [None]:
# The transformation to apply to the image before feeding it to the model
image_transform = transforms.Compose(
    [
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ]
)

# An iterator over every transformed BB crop in a single sample (image)
def single_sample_bb_crop_iterator(sample):
    image_filename = sample["image"]
    image_bytes = tlc.Url(image_filename).read()
    image = Image.open(BytesIO(image_bytes))
    w, h = image.size

    for bb in sample["bbs"]["bb_list"]:
        bb_crop = tlc.BBCropInterface.crop(image, bb, bb_schema, h, w)
        yield image_transform(bb_crop)

# An iterator over every transformed BB crop in the dataset
def bb_crop_iterator():
    for sample in input_table:
        for bb_crop in single_sample_bb_crop_iterator(sample):
            yield bb_crop

# A batched iterator over every transformed BB crop in the dataset
def batched_bb_crop_iterator():
    batch = []
    for bb_crop in bb_crop_iterator():
        batch.append(bb_crop)
        if len(batch) == BATCH_SIZE:
            yield torch.stack(batch).to(device)
            batch = []
    if batch:
        yield torch.stack(batch).to(device)

## Collect Embeddings

In [None]:
# Add a model hook which saves activations to output_list
output_list = []
def hook_fn(module, input, output):
    output_list.append(output.cpu())

hook_handle = hidden_layer.register_forward_hook(hook_fn)

In [None]:
# Use our iterator to run the model on every BB crop in the dataset
for batch in tqdm(batched_bb_crop_iterator(), desc="Running model inference"):
    with torch.no_grad():
        model(batch)

hook_handle.remove()

## Dimensionality Reduction

In [None]:
# Stack all the embeddings into a single numpy array
embeddings = np.vstack(output_list)

# Reduce the 1280-dimensional activations to 3D using pacmap
reducer = pacmap.PaCMAP(n_components=3)
embeddings_3d = reducer.fit_transform(embeddings)

In [None]:
# Save embeddings for use in the next notebook(s)
np.save(TRANSIENT_DATA_PATH + "/bb_classifier_embeddings.npy", embeddings_3d)