# Fine-tuning a Classifier Using Bounding Box Data from a 3LC Table

In this tutorial, we will fine-tune a classifier using bounding box data from a 3LC `Table`.

The Table will initially be created from a COCO-style dataset (Balloons), and we will
use 3LC to generate a cropped image for every bounding box in the dataset. These
cropped images will be used to fine-tune a classifier.

In [13]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


## Project Setup

In [14]:
PROJECT_NAME = "3LC Tutorials"
EPOCHS = 10
IMAGES_PER_EPOCH = 100
TEST_DATA_PATH = "../../../data"
TRANSIENT_DATA_PATH = "../../../transient_data"
BATCH_SIZE = 32
DATASET_NAME = "COCO128"
INSTALL_DEPENDENCIES = False

In [15]:
%%capture
if INSTALL_DEPENDENCIES:
    %pip --quiet install torch --index-url https://download.pytorch.org/whl/cu118
    %pip --quiet install torchvision --index-url https://download.pytorch.org/whl/cu118
    %pip --quiet install timm
    %pip --quiet install 3lc

## Imports

In [16]:
import json
import os
import random
from collections import defaultdict
from io import BytesIO

import timm
import tlc
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
import tqdm.notebook as tqdm
from PIL import Image
from torch.utils.data import DataLoader, Dataset, WeightedRandomSampler

## Set device

In [17]:
from tlc_tools.common import infer_torch_device

device = infer_torch_device()
print(f"Using device: {device}")

Using device: cuda


## Create Input Table

In [18]:
# Re-use the TableFromCoco we created in 1-create-tables/
input_table = tlc.Table.from_names("initial", "COCO128", "3LC Tutorials")

In [19]:
# Get the schema of the bounding box column of the input table
bb_schema = input_table.schema.values["rows"].values["bbs"].values["bb_list"]
label_map = input_table.get_simple_value_map("bbs.bb_list.label")
print(f"Input table uses {len(label_map)} unique labels: {json.dumps(label_map, indent=2)}")

Input table uses 80 unique labels: {
  "0": "person",
  "1": "bicycle",
  "2": "car",
  "3": "motorcycle",
  "4": "airplane",
  "5": "bus",
  "6": "train",
  "7": "truck",
  "8": "boat",
  "9": "traffic light",
  "10": "fire hydrant",
  "11": "stop sign",
  "12": "parking meter",
  "13": "bench",
  "14": "bird",
  "15": "cat",
  "16": "dog",
  "17": "horse",
  "18": "sheep",
  "19": "cow",
  "20": "elephant",
  "21": "bear",
  "22": "zebra",
  "23": "giraffe",
  "24": "backpack",
  "25": "umbrella",
  "26": "handbag",
  "27": "tie",
  "28": "suitcase",
  "29": "frisbee",
  "30": "skis",
  "31": "snowboard",
  "32": "sports ball",
  "33": "kite",
  "34": "baseball bat",
  "35": "baseball glove",
  "36": "skateboard",
  "37": "surfboard",
  "38": "tennis racket",
  "39": "bottle",
  "40": "wine glass",
  "41": "cup",
  "42": "fork",
  "43": "knife",
  "44": "spoon",
  "45": "bowl",
  "46": "banana",
  "47": "apple",
  "48": "sandwich",
  "49": "orange",
  "50": "broccoli",
  "51": "carro

In [20]:
# Use the number of bounding boxes per image to control sampling
num_bbs_per_image = [len(row["bbs"]["bb_list"]) for row in input_table.table_rows]
sampler = WeightedRandomSampler(weights=num_bbs_per_image, num_samples=IMAGES_PER_EPOCH)

## Create Dataset

In [None]:
import random
from collections import defaultdict
from io import BytesIO
from PIL import Image
import torch
from torch.utils.data import Dataset


class BBCropDataset(Dataset):
    def __init__(
        self,
        table: tlc.Table,
        transform=None,
        label_map=None,
        add_background=False,
        background_freq=0.5,
        is_train=True,
    ):
        """
        Custom dataset for cropping bounding boxes and generating background patches.

        Args:
            table: tlc.Table, the input table containing image and bounding box data.
            transform: callable, transformations to apply to cropped images.
            label_map: dict, mapping from original labels to contiguous integer labels.
            add_background: bool, whether to include background patches.
            background_freq: float, probability of sampling a background patch.
            is_train: bool, whether the dataset is used for training (affects background generation).
            crop_func: callable, function to crop bounding boxes (defaults to tlc.BBCropInterface.crop).
        """
        self.table = table
        self.transform = transform
        self.label_map = label_map or table.get_value_map("bbs.bb_list.label")
        self.bb_schema = table.schema.values["rows"].values["bbs"].values["bb_list"]
        self.add_background = add_background
        self.background_freq = background_freq
        self.is_train = is_train
        self.background_label = len(self.label_map) if add_background else None
        self.random_gen = random.Random(42)  # Fixed seed for reproducibility

    def __len__(self):
        return len(self.table)  # Dataset length tied to the number of table rows

    def __getitem__(self, idx):
        """
        Fetch a sample from the dataset.

        Args:
            idx: int, index provided by the sampler.

        Returns:
            tuple: (cropped image, label) where label is a tensor.
        """
        # Determine if a background patch should be generated
        is_background = (
            self.add_background and self.random_gen.random() < self.background_freq and self.is_train
        )

        # Select a random row for background or use the given index
        if is_background:
            row_idx = self.random_gen.randint(0, len(self.table) - 1)
        else:
            row_idx = idx

        row = self.table.table_rows[row_idx]
        image = self.load_image_data(row)
        
        bbs = row["bbs"]["bb_list"]
        while len(bbs) == 0 and not is_background:
            row_idx = self.random_gen.randint(0, len(self.table) - 1)
            row = self.table.table_rows[row_idx]
            image = self.load_image_data(row)
            bbs = row["bbs"]["bb_list"]

        if is_background:
            crop, label = self.generate_background(image, bbs)
        else:
            crop, label = self.generate_bb_crop(image, bbs)

        if self.transform:
            crop = self.transform(crop)

        return crop, label

    def load_image_data(self, row):
        image_bytes = tlc.Url(row["image"]).read()
        image = Image.open(BytesIO(image_bytes))
        return image

    def generate_bb_crop(self, image, bbs):
        """
        Crop a bounding box from the image.

        Args:
            image: PIL.Image, the input image.
            bbs: list, bounding boxes associated with the image.

        Returns:
            tuple: (cropped image, label) where label is a tensor.
        """
        if not bbs:
            raise ValueError("No bounding boxes found. Check your sampler.")

        random_bb = random.choice(bbs)
        crop = tlc.BBCropInterface.crop(image, random_bb, self.bb_schema)
        
        label = random_bb["label"]
        return crop, torch.tensor(label, dtype=torch.long)

    def generate_background(self, image, bbs):
        """
        Generate a background patch from the image.

        Args:
            image: PIL.Image, the input image.
            bbs: list, bounding boxes associated with the image.

        Returns:
            tuple: (background patch, background label) where label is a tensor.
        """
        image_width, image_height = image.size
        bb_factory = tlc.BoundingBox.from_schema(self.bb_schema)
        gt_boxes_xywh = [
            bb_factory([bb["x0"], bb["y0"], bb["x1"], bb["y1"]])
            .to_top_left_xywh()
            .denormalize(image_width, image_height)
            for bb in bbs
        ]

        while True:
            # Generate a random box
            x = max(
                min(int(self.random_gen.normalvariate(mu=image_width // 2, sigma=image_width // 6)), image_width - 1), 0
            )
            y = max(
                min(
                    int(self.random_gen.normalvariate(mu=image_height // 2, sigma=image_height // 6)),
                    image_height - 1,
                ),
                0,
            )
            w = max(
                min(
                    int(self.random_gen.normalvariate(mu=image_width // 8, sigma=image_width // 16)), image_width - x
                ),
                1,
            )
            h = max(
                min(
                    int(self.random_gen.normalvariate(mu=image_height // 8, sigma=image_height // 16)), image_height - y
                ),
                1,
            )
            proposal_box = [x, y, w, h]

            # Ensure the proposed box does not intersect any ground truth boxes
            if not any(self._intersects(proposal_box, gt_box) for gt_box in gt_boxes_xywh):
                break

        # Crop the background patch from the image
        background_patch = image.crop((x, y, x + w, y + h))
        return background_patch, torch.tensor(self.background_label, dtype=torch.long)

    @staticmethod
    def _intersects(box1, box2):
        """
        Check if two bounding boxes intersect.

        Args:
            box1: list[int], first bounding box [x, y, w, h].
            box2: list[int], second bounding box [x, y, w, h].

        Returns:
            bool: True if boxes intersect, otherwise False.
        """
        x1, y1, w1, h1 = box1
        x2, y2, w2, h2 = box2
        return not (x1 + w1 <= x2 or x2 + w2 <= x1 or y1 + h1 <= y2 or y2 + h2 <= y1)


In [22]:
# Define the transformations to be applied to the images

common_transforms = transforms.Compose([
    transforms.Lambda(lambda img: img.convert("RGB")),
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

train_transforms = transforms.Compose([
    transforms.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.2, hue=0.3),
    transforms.RandomRotation(degrees=10),
    transforms.RandomHorizontalFlip(),
    *common_transforms.transforms,
])

# Create the dataset and dataloader
train_dataset = BBCropDataset(
    input_table,
    transform=train_transforms,
    is_train=True
)

val_dataset = BBCropDataset(
    input_table,
    transform=common_transforms,
    is_train=False
)

In [23]:
import matplotlib.pyplot as plt
import torchvision.transforms as transforms
from typing import List
import torch

def save_image_grid(images: List[torch.Tensor], labels: List[str], rows: int, cols: int, save_path: str):
    """
    Saves a grid of images with corresponding labels as subplot titles.

    Args:
        images (List[torch.Tensor]): List of image tensors (C x H x W) to display.
        labels (List[str]): List of labels corresponding to each image.
        rows (int): Number of rows in the grid.
        cols (int): Number of columns in the grid.
        save_path (str): Path to save the output image grid.
    """
    assert len(images) == len(labels), "Number of images and labels must be the same."
    assert len(images) <= rows * cols, "Not enough space in the grid for all images."

    # Unnormalize and convert tensors to PIL images
    unnormalize = transforms.Normalize(
        mean=[-0.485 / 0.229, -0.456 / 0.224, -0.406 / 0.225],  # Reverse normalization
        std=[1 / 0.229, 1 / 0.224, 1 / 0.225],
    )
    to_pil = transforms.ToPILImage()

    fig, axes = plt.subplots(rows, cols, figsize=(cols * 4, rows * 4))
    axes = axes.flatten()

    for idx, (image, label) in enumerate(zip(images, labels)):
        # Unnormalize the tensor image
        image = unnormalize(image)
        pil_image = to_pil(image)

        # Plot the image
        axes[idx].imshow(pil_image)
        axes[idx].set_title(label)
        axes[idx].axis("off")

    # Hide unused subplots
    for ax in axes[len(images):]:
        ax.axis("off")

    # Save the figure
    plt.tight_layout()
    plt.show()


In [24]:
train_images = []
train_labels = []
for i in range(4*3):
    image, label = train_dataset[i]
    train_images.append(image)
    train_labels.append(label_map[label.item()])

save_image_grid(train_images, train_labels, 4, 3, "Training Images")

UnboundLocalError: cannot access local variable 'image_bytes' where it is not associated with a value

In [None]:
val_images = []
val_labels = []
for i in range(4*3):
    image, label = val_dataset[i]
    val_images.append(image)
    val_labels.append(label_map[label.item()])

save_image_grid(val_images, val_labels, 4, 3, "Validation Images")

## Train Model

In [None]:
# Create the dataloaders
train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, sampler=sampler)
val_dataloader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)

# Load an EfficientNet model using timm
model = timm.create_model("efficientnet_b0", pretrained=True, num_classes=len(label_map)+1).to(device)

# Training setup
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4)
scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.9516)

# Training loop
for epoch in range(EPOCHS):
    # Training Phase
    model.train()
    train_loss, train_correct, train_total = 0.0, 0, 0
    for inputs, labels in tqdm.tqdm(train_dataloader, desc=f"Epoch {epoch+1} [Train]"):
        inputs, labels = inputs.to(device), labels.to(device)
        
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        train_loss += loss.item() * inputs.size(0)
        _, preds = outputs.max(1)
        train_correct += (preds == labels).sum().item()
        train_total += labels.size(0)

    train_loss /= train_total
    train_acc = train_correct / train_total

    # Validation Phase
    model.eval()
    val_loss, val_correct, val_total = 0.0, 0, 0
    with torch.no_grad():
        for inputs, labels in tqdm.tqdm(val_dataloader, desc=f"Epoch {epoch+1} [Val]"):
            inputs, labels = inputs.to(device), labels.to(device)
            
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            
            val_loss += loss.item() * inputs.size(0)
            _, preds = outputs.max(1)
            val_correct += (preds == labels).sum().item()
            val_total += labels.size(0)

    val_loss /= val_total
    val_acc = val_correct / val_total

    print(f"Epoch {epoch+1}/{EPOCHS}")
    print(f"  Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}")
    print(f"  Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}")

In [None]:
# Save the model to a pth file:
torch.save(model.state_dict(), TRANSIENT_DATA_PATH + "/bb_classifier.pth")