# Fine-tuning a Classifier Using Bounding Box Data from a 3LC Table

In this tutorial, we will fine-tune a classifier using bounding box data from a 3LC `Table`.

We will load the COCO128 table from an earlier notebook and use it to create a
`torch.utils.Dataset` of bounding box crops. These cropped images will be used to
fine-tune a classifier. In a later tutorial, we will use this trained model to
generate embeddings and predicted labels.

## Imports

In [None]:
import json
from pathlib import Path

import matplotlib.pyplot as plt
import timm
import tlc
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
import tqdm.notebook as tqdm
from torch.utils.data import DataLoader, WeightedRandomSampler

from tlc_tools.common import infer_torch_device
from tlc_tools.datasets import BBCropDataset
from tlc_tools.split import split_table

## Project Setup

In [None]:
EPOCHS = 10
IMAGES_PER_EPOCH = 1000
MODEL_CHECKPOINT = "../../../transient_data/bb_classifier.pth"
MODEL_NAME = "efficientnet_b0"
BATCH_SIZE = 32
INCLUDE_BACKGROUND = False
X_MAX_OFFSET = 0.1
Y_MAX_OFFSET = 0.1
X_SCALE_RANGE = (0.9, 1.1)
Y_SCALE_RANGE = (0.9, 1.1)

## Set device

In [None]:
DEVICE = infer_torch_device()
print(f"Using device: {DEVICE}")

## Load Input Table

We will reuse the table created in the notebook [create-table-from-coco.ipynb](../../1-create-tables/create-table-from-coco.ipynb).

In [None]:
input_table = tlc.Table.from_names(
    "initial",
    "COCO128",
    "3LC Tutorials",
)

In [None]:
# Get the schema of the bounding box column of the input table
bb_schema = input_table.schema.values["rows"].values["bbs"].values["bb_list"]
label_map = input_table.get_simple_value_map("bbs.bb_list.label")
print(f"Input table uses {len(label_map)} unique labels: {json.dumps(label_map, indent=2)}")

In [None]:
assert Path(MODEL_CHECKPOINT).parent.exists()

In [None]:
NUM_CLASSES = len(label_map) + 1 if INCLUDE_BACKGROUND else len(label_map)

## Split the Table

In [None]:
# Create splits for training and validation
splits = split_table(input_table, {"train": 0.8, "val": 0.2})

train_table = splits["train"]
val_table = splits["val"]

print(f"Using table {train_table} for training")
print(f"Using table {val_table} for validation")

## Create Dataset

In [None]:
# Define the transformations to be applied to the images
common_transforms = transforms.Compose(
    [
        transforms.Lambda(lambda img: img.convert("RGB")),
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ]
)

train_transforms = transforms.Compose(
    [
        transforms.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.2, hue=0.3),
        transforms.RandomRotation(degrees=10),
        transforms.RandomHorizontalFlip(),
        *common_transforms.transforms,
    ]
)

# Create the datasets and dataloader
train_dataset = BBCropDataset(
    train_table,
    transform=train_transforms,
    add_background=INCLUDE_BACKGROUND,
    is_train=True,
    x_max_offset=X_MAX_OFFSET,
    y_max_offset=Y_MAX_OFFSET,
    x_scale_range=X_SCALE_RANGE,
    y_scale_range=Y_SCALE_RANGE,
)

val_dataset = BBCropDataset(
    val_table,
    transform=common_transforms,
    add_background=False,
    is_train=False,
)

In [None]:
def write_image_grid(images, labels, rows: int, cols: int, title: str):
    assert len(images) == len(labels), "Number of images and labels must be the same."
    assert len(images) <= rows * cols, "Not enough space in the grid for all images."

    unnormalize = transforms.Normalize(
        mean=[-0.485 / 0.229, -0.456 / 0.224, -0.406 / 0.225],  # Reverse normalization
        std=[1 / 0.229, 1 / 0.224, 1 / 0.225],
    )
    to_pil = transforms.ToPILImage()

    _, axes = plt.subplots(rows, cols, figsize=(cols * 4, rows * 4))
    axes = axes.flatten()

    for idx, (image, label) in enumerate(zip(images, labels)):
        # Unnormalize the tensor image
        image = unnormalize(image)
        pil_image = to_pil(image)

        # Plot the image
        axes[idx].imshow(pil_image)
        axes[idx].set_title(label)
        axes[idx].axis("off")

    # Hide unused subplots
    for ax in axes[len(images) :]:
        ax.axis("off")

    # Save the figure
    plt.tight_layout()
    plt.suptitle(title)
    plt.show()

In [None]:
train_images = []
train_labels = []
for i in range(4 * 3):
    image, label = train_dataset[i]
    train_images.append(image)
    train_labels.append(label_map.get(label.item(), "Background"))

write_image_grid(train_images, train_labels, 4, 3, "Training Images")

In [None]:
val_images = []
val_labels = []
for i in range(4 * 3):
    image, label = val_dataset[i]
    val_images.append(image)
    val_labels.append(label_map[label.item()])

write_image_grid(val_images, val_labels, 4, 3, "Validation Images")

## Train Model

In [None]:
# Create a 3LC Run
run = tlc.init(project_name=input_table.project_name, run_name="Train Bounding Box Classifier")

# Create sampler
num_bbs_per_image = [len(row["bbs"]["bb_list"]) for row in train_table.table_rows]
sampler = WeightedRandomSampler(weights=num_bbs_per_image, num_samples=IMAGES_PER_EPOCH)

# Create the dataloaders
train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, sampler=sampler)
val_dataloader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)

# Load an EfficientNet model using timm
model = timm.create_model(MODEL_NAME, pretrained=True, num_classes=NUM_CLASSES).to(DEVICE)

# Training setup
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4)
scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.9516)

# Training loop
for epoch in range(EPOCHS):
    # Training Phase
    model.train()
    train_loss, train_correct, train_total = 0.0, 0, 0
    for inputs, labels in tqdm.tqdm(train_dataloader, desc=f"Epoch {epoch+1} [Train]"):
        inputs, labels = inputs.to(DEVICE), labels.to(DEVICE)

        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        train_loss += loss.item() * inputs.size(0)
        _, preds = outputs.max(1)
        train_correct += (preds == labels).sum().item()
        train_total += labels.size(0)

    train_loss /= train_total
    train_acc = train_correct / train_total

    # Validation Phase
    model.eval()
    val_loss, val_correct, val_total = 0.0, 0, 0
    with torch.no_grad():
        for inputs, labels in tqdm.tqdm(val_dataloader, desc=f"Epoch {epoch+1} [Val]"):
            inputs, labels = inputs.to(DEVICE), labels.to(DEVICE)

            outputs = model(inputs)
            loss = criterion(outputs, labels)

            val_loss += loss.item() * inputs.size(0)
            _, preds = outputs.max(1)
            val_correct += (preds == labels).sum().item()
            val_total += labels.size(0)

    val_loss /= val_total
    val_acc = val_correct / val_total

    # Update the learning rate
    scheduler.step()

    print(f"Epoch {epoch+1}/{EPOCHS}")
    print(f"  Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}")
    print(f"  Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}")
    tlc.log(
        {
            "epoch": epoch + 1,
            "train_loss": train_loss,
            "train_acc": train_acc,
            "val_loss": val_loss,
            "val_acc": val_acc,
            "lr": optimizer.param_groups[0]["lr"],
        }
    )

run.set_status_completed()

In [None]:
# Save the model to a pth file:
torch.save(model.state_dict(), MODEL_CHECKPOINT)