# Collect and Reduce Classifier Embeddings

In this tutorial, we will use an existing classifier model to generate per-bounding-box embeddings for a COCO-style object detection dataset. We will then reduce these embeddings to 3D using PaCMAP.

To run this notebook, you must also have run:
* [1-fine-tune-on-crops.ipynb](https://github.com/3lc-ai/3lc-examples/blob/main/tutorials/bb-embeddings/1-fine-tune-on-crops.ipynb)

## Imports

In [None]:
from io import BytesIO

import numpy as np
import pacmap
import timm
import tlc
import torch
from PIL import Image
from torchvision import transforms
from tqdm.auto import tqdm
from tlc_tools.common import infer_torch_device

## Project Setup

In [None]:
PROJECT_NAME = "3LC Tutorials"
TRANSIENT_DATA_PATH = "../../../transient_data"
EMBEDDING_SAVE_PATH = TRANSIENT_DATA_PATH + "/bb_classifier_embeddings.npy"
LABELS_SAVE_PATH = TRANSIENT_DATA_PATH + "/bb_classifier_labels.npy"
MODEL_CHECKPOINT = TRANSIENT_DATA_PATH + "/bb_classifier.pth"
MODEL_NAME = "efficientnet_b0"
BATCH_SIZE = 32
NUM_COMPONENTS = 2

## Set device

In [None]:
device = infer_torch_device()
print(f"Using device: {device}")

## Get Input Table

In [None]:
# Open the Table used in the previous notebook
input_table = tlc.Table.from_names(
    table_name="initial",
    dataset_name="COCO128", 
    project_name=PROJECT_NAME, 
)

# Get the schema of the bounding box list, required to crop
bb_schema = input_table.rows_schema.values["bbs"].values["bb_list"]
assert bb_schema
NUM_CLASSES = len(bb_schema.values["label"].value.map)


## Get Model

In [None]:
# Load the model trained in the previous notebook
model = timm.create_model(
    MODEL_NAME, 
    num_classes=NUM_CLASSES, 
    checkpoint_path=MODEL_CHECKPOINT,
).to(device)

model.eval()

# The hidden layer whose activations we will use for embeddings
hidden_layer = model.global_pool.flatten

## Set Up Data Processing

In [None]:
# The transformation to apply to the image before feeding it to the model
image_transform = transforms.Compose(
    [
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ]
)

# An iterator over every transformed BB crop in a single sample (image)
def single_sample_bb_crop_iterator(sample):
    image_filename = sample["image"]
    image_bytes = tlc.Url(image_filename).to_absolute(input_table.url).read()
    image = Image.open(BytesIO(image_bytes))
    w, h = image.size

    for bb in sample["bbs"]["bb_list"]:
        bb_crop = tlc.BBCropInterface.crop(image, bb, bb_schema, h, w)
        yield image_transform(bb_crop)

# An iterator over every transformed BB crop in the dataset
def bb_crop_iterator():
    for sample in input_table:
        for bb_crop in single_sample_bb_crop_iterator(sample):
            yield bb_crop

# A batched iterator over every transformed BB crop in the dataset
def batched_bb_crop_iterator():
    batch = []
    for bb_crop in bb_crop_iterator():
        batch.append(bb_crop)
        if len(batch) == BATCH_SIZE:
            yield torch.stack(batch).to(device)
            batch = []
    if batch:
        yield torch.stack(batch).to(device)

## Collect Embeddings

In [None]:
# Add a model hook which saves activations to output_list
output_list = []
def hook_fn(module, input, output):
    output_list.append(output.cpu())

hook_handle = hidden_layer.register_forward_hook(hook_fn)

In [None]:
# Use our iterator to run the model on every BB crop in the dataset
labels_list = []
for batch in tqdm(batched_bb_crop_iterator(), desc="Running model inference"):
    with torch.no_grad():
        output = model(batch)
        predicted_labels = torch.argmax(output, dim=1)
        labels_list.extend(predicted_labels.cpu().numpy())

hook_handle.remove()

## Dimensionality Reduction

In [None]:
# Stack all the embeddings and labels into a single numpy array
embeddings = np.vstack(output_list)
labels = np.array(labels_list)

# Reduce the 1280-dimensional activations to NUM_COMPONENTS using pacmap
reducer = pacmap.PaCMAP(n_components=NUM_COMPONENTS)
embeddings_nd = reducer.fit_transform(embeddings)

In [None]:
# Save embeddings for use in the next notebook(s)
np.save(EMBEDDING_SAVE_PATH, embeddings_nd)
np.save(LABELS_SAVE_PATH, labels)