# Requirements

In [1]:
# JAX, Flax, and Optax for the model and training
import jax
import jax.numpy as jnp
from flax import linen as nn
from flax.training import train_state
import optax

# Data loading and image processing
import numpy as np
from datasets import load_dataset
from PIL import Image
import io

# Dataset loading

In [2]:
TARGET_SIZE = (300, 300)
CAR_CATEGORY_ID = 2
BATCH_SIZE = 16

In [3]:
def contains_car(example):
    """Filter function to check if an image contains a car."""
    return any(category == CAR_CATEGORY_ID for category in example['objects']['category'])

def preprocess_with_padding(example, target_size=(300, 300), bg_color=(128, 128, 128)):
    """
    Resizes and pads an image to a target square size while maintaining aspect ratio,
    and correctly adjusts the bounding box coordinates.
    """
    # --- THIS IS THE CORRECTED LINE ---
    # We directly use the decoded Pillow image object from the dataset.
    image = example['image'].convert("RGB")

    original_w, original_h = image.size

    # Calculate scaling ratio and new dimensions
    ratio = min(target_size[0] / original_w, target_size[1] / original_h)
    new_w, new_h = int(original_w * ratio), int(original_h * ratio)
    resized_image = image.resize((new_w, new_h), Image.Resampling.LANCZOS)

    # Create a padded canvas and paste the resized image
    padded_image = Image.new("RGB", target_size, bg_color)
    paste_x = (target_size[0] - new_w) // 2
    paste_y = (target_size[1] - new_h) // 2
    padded_image.paste(resized_image, (paste_x, paste_y))

    # Convert to a normalized NumPy array
    image_np = np.array(padded_image, dtype=np.float32) / 255.0

    # Adjust bounding boxes by scaling and adding the padding offset
    new_boxes, new_labels = [], []
    for bbox, category in zip(example['objects']['bbox'], example['objects']['category']):
        x_min, y_min, w, h = bbox
        new_boxes.append([(x_min*ratio)+paste_x, (y_min*ratio)+paste_y, w*ratio, h*ratio])
        new_labels.append(category)

    return {
        'image': image_np,
        'bboxes': np.array(new_boxes, dtype=np.float32),
        'labels': np.array(new_labels, dtype=np.int32)
    }

def collate_fn(batch):
    """Pads bboxes and labels to the max length in a batch to create uniform tensors."""
    max_objects = max(len(item['labels']) for item in batch)
    if max_objects == 0: max_objects = 1  # Avoid shape errors for images with no objects

    padded_batch = {'image': [], 'bboxes': [], 'labels': []}
    for item in batch:
        num_objects = len(item['labels'])
        padding_needed = max_objects - num_objects

        # Use -1 as the padding value for bboxes and labels
        padded_bboxes = np.pad(item['bboxes'], ((0, padding_needed), (0, 0)), mode='constant', constant_values=-1)
        padded_labels = np.pad(item['labels'], (0, padding_needed), mode='constant', constant_values=-1)

        padded_batch['image'].append(item['image'])
        padded_batch['bboxes'].append(padded_bboxes)
        padded_batch['labels'].append(padded_labels)

    # Stack individual examples and convert to JAX arrays
    return {k: jnp.array(np.stack(v, axis=0)) for k, v in padded_batch.items()}

def create_data_loader(dataset, batch_size, shuffle=True):
    """A generator function that yields batches of data."""
    if shuffle:
        # buffer_size controls how many elements are loaded into memory for shuffling
        dataset = dataset.shuffle(buffer_size=1000)

    batch = []
    for example in dataset:
        batch.append(example)
        if len(batch) == batch_size:
            yield collate_fn(batch)
            batch = []

    if batch: # Yield the last, potentially smaller batch
        yield collate_fn(batch)

In [4]:
# Load the raw dataset (streaming to save disk space)
raw_ds = load_dataset("detection-datasets/coco", split="train", streaming=True)

# Apply filtering and preprocessing
car_ds = raw_ds.filter(contains_car)
processed_ds = car_ds.map(preprocess_with_padding)

# Create the final data loader
# NOTE: .take(1000) is used for a quick test. Remove it for full training.
train_loader = create_data_loader(processed_ds.take(1000), BATCH_SIZE)

# --- 4. Test the Loader ---
print("✅ Data loader created successfully!")
print("Testing the first batch...")

for first_batch in train_loader:
    print(f"Image batch shape: {first_batch['image'].shape}")
    print(f"BBoxes batch shape: {first_batch['bboxes'].shape}")
    print(f"Labels batch shape: {first_batch['labels'].shape}")
    break

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


README.md:   0%|          | 0.00/58.0 [00:00<?, ?B/s]

Resolving data files:   0%|          | 0/40 [00:00<?, ?it/s]

dataset_infos.json: 0.00B [00:00, ?B/s]

✅ Data loader created successfully!
Testing the first batch...
Image batch shape: (16, 300, 300, 3)
BBoxes batch shape: (16, 25, 4)
Labels batch shape: (16, 25)


# Pre-Training

In [5]:
class L2Norm(nn.Module):
    """L2 Normalization layer with a learnable scale parameter."""
    n_channels: int
    initial_scale: float = 20.0

    @nn.compact
    def __call__(self, x):
        # Normalize along the channel dimension (axis=-1)
        norm = jnp.sqrt(jnp.sum(x**2, axis=-1, keepdims=True))
        x_norm = x / (norm + 1e-10)

        # Add a learnable scaling factor for each channel
        scale = self.param('scale', nn.initializers.constant(self.initial_scale), (self.n_channels,))
        # Reshape scale for broadcasting: (C,) -> (1, 1, 1, C)
        scale = scale.reshape((1,) * (x.ndim - 1) + (-1,))

        return x_norm * scale

class SSD(nn.Module):
    """
    A standard SSD300 model with a VGG16-based backbone.
    """
    num_classes: int
    num_anchors_per_location: tuple

    @nn.compact
    def __call__(self, x):
        total_classes = self.num_classes + 1
        feature_maps = []

        # --- 1. VGG16 Backbone ---
        # Block 1
        x = nn.Conv(features=64, kernel_size=(3, 3), padding='SAME', name='conv1_1')(x); x = nn.relu(x)
        x = nn.Conv(features=64, kernel_size=(3, 3), padding='SAME', name='conv1_2')(x); x = nn.relu(x)
        x = nn.max_pool(x, window_shape=(2, 2), strides=(2, 2))
        # Block 2
        x = nn.Conv(features=128, kernel_size=(3, 3), padding='SAME', name='conv2_1')(x); x = nn.relu(x)
        x = nn.Conv(features=128, kernel_size=(3, 3), padding='SAME', name='conv2_2')(x); x = nn.relu(x)
        x = nn.max_pool(x, window_shape=(2, 2), strides=(2, 2))
        # Block 3
        x = nn.Conv(features=256, kernel_size=(3, 3), padding='SAME', name='conv3_1')(x); x = nn.relu(x)
        x = nn.Conv(features=256, kernel_size=(3, 3), padding='SAME', name='conv3_2')(x); x = nn.relu(x)
        x = nn.Conv(features=256, kernel_size=(3, 3), padding='SAME', name='conv3_3')(x); x = nn.relu(x)
        x = nn.max_pool(x, window_shape=(2, 2), strides=(2, 2))
        # Block 4 -> Source for first feature map (Conv4_3)
        x = nn.Conv(features=512, kernel_size=(3, 3), padding='SAME', name='conv4_1')(x); x = nn.relu(x)
        x = nn.Conv(features=512, kernel_size=(3, 3), padding='SAME', name='conv4_2')(x); x = nn.relu(x)
        x = nn.Conv(features=512, kernel_size=(3, 3), padding='SAME', name='conv4_3')(x); x = nn.relu(x)

        # >> Add first feature map (38x38) with L2 normalization <<
        fm1 = L2Norm(n_channels=512, name='conv4_3_norm')(x)
        feature_maps.append(fm1)

        x = nn.max_pool(x, window_shape=(2, 2), strides=(2, 2))
        # Block 5
        x = nn.Conv(features=512, kernel_size=(3, 3), padding='SAME', name='conv5_1')(x); x = nn.relu(x)
        x = nn.Conv(features=512, kernel_size=(3, 3), padding='SAME', name='conv5_2')(x); x = nn.relu(x)
        x = nn.Conv(features=512, kernel_size=(3, 3), padding='SAME', name='conv5_3')(x); x = nn.relu(x)
        x = nn.max_pool(x, window_shape=(3, 3), strides=(1, 1), padding='SAME') # Modified pooling

        # Converted FC layers of VGG -> Source for second feature map
        x = nn.Conv(features=1024, kernel_size=(3, 3), padding='SAME', name='conv6')(x); x = nn.relu(x)
        x = nn.Conv(features=1024, kernel_size=(1, 1), padding='SAME', name='conv7')(x); x = nn.relu(x)

        # >> Add second feature map (19x19) <<
        feature_maps.append(x)

        # --- 2. Extra Feature Layers ---
        # Extra Layer 1 -> produces 10x10 feature map
        x = nn.Conv(features=256, kernel_size=(1, 1), padding='SAME', name='extra1_1')(x); x = nn.relu(x)
        x = nn.Conv(features=512, kernel_size=(3, 3), strides=(2, 2), padding='SAME', name='extra1_2')(x); x = nn.relu(x)
        feature_maps.append(x)

        # Extra Layer 2 -> produces 5x5 feature map
        x = nn.Conv(features=128, kernel_size=(1, 1), padding='SAME', name='extra2_1')(x); x = nn.relu(x)
        x = nn.Conv(features=256, kernel_size=(3, 3), strides=(2, 2), padding='SAME', name='extra2_2')(x); x = nn.relu(x)
        feature_maps.append(x)

        # Extra Layer 3 -> produces 3x3 feature map
        x = nn.Conv(features=128, kernel_size=(1, 1), padding='SAME', name='extra3_1')(x); x = nn.relu(x)
        x = nn.Conv(features=256, kernel_size=(3, 3), padding='VALID', name='extra3_2')(x); x = nn.relu(x)
        feature_maps.append(x)

        # Extra Layer 4 -> produces 1x1 feature map
        x = nn.Conv(features=128, kernel_size=(1, 1), padding='SAME', name='extra4_1')(x); x = nn.relu(x)
        x = nn.Conv(features=256, kernel_size=(3, 3), padding='VALID', name='extra4_2')(x); x = nn.relu(x)
        feature_maps.append(x)

        # --- 3. Prediction Heads ---
        loc_preds, conf_preds = [], []
        for i, fm in enumerate(feature_maps):
            num_anchors = self.num_anchors_per_location[i]
            # Location head
            loc_pred = nn.Conv(features=num_anchors * 4, kernel_size=(3, 3), padding='SAME', name=f'loc_head_{i}')(fm)
            loc_preds.append(loc_pred.reshape((loc_pred.shape[0], -1, 4)))
            # Confidence head
            conf_pred = nn.Conv(features=num_anchors * total_classes, kernel_size=(3, 3), padding='SAME', name=f'conf_head_{i}')(fm)
            conf_preds.append(conf_pred.reshape((conf_pred.shape[0], -1, total_classes)))

        return jnp.concatenate(loc_preds, axis=1), jnp.concatenate(conf_preds, axis=1)

In [6]:
# We are still only detecting cars
NUM_CLASSES = 1

# THIS IS THE STANDARD CONFIGURATION for an SSD300 with 6 feature maps
ANCHORS_PER_LOCATION = (4, 6, 6, 6, 4, 4)

# 1. Create a PRNG key for initialization
key = jax.random.PRNGKey(0)

# 2. Instantiate the new SSD300 model
model = SSD(num_classes=NUM_CLASSES, num_anchors_per_location=ANCHORS_PER_LOCATION)

# 3. Create a dummy input batch
dummy_input_batch = jnp.ones((BATCH_SIZE, TARGET_SIZE[0], TARGET_SIZE[1], 3))

# 4. Initialize the model's parameters
# Note: This will take a moment longer as the model is much larger
params = model.init(key, dummy_input_batch)['params']
print("✅ SSD300 model initialized successfully!")

# 5. Apply the model to get output shapes
loc_predictions, conf_predictions = model.apply({'params': params}, dummy_input_batch)

print("\n--- Output Shapes ---")
print(f"Location predictions shape: {loc_predictions.shape}")
print(f"Confidence predictions shape: {conf_predictions.shape}")

✅ SSD300 model initialized successfully!

--- Output Shapes ---
Location predictions shape: (16, 8096, 4)
Confidence predictions shape: (16, 8096, 2)


In [7]:
import math

def generate_default_boxes():
    """Generates the 8732 default anchor boxes for the SSD300 architecture."""
    # Configuration from the original SSD paper
    feature_map_sizes = [38, 19, 10, 5, 3, 1]
    min_sizes = [30, 60, 111, 162, 213, 264]
    max_sizes = [60, 111, 162, 213, 264, 315]
    aspect_ratios = [[2], [2, 3], [2, 3], [2, 3], [2], [2]]

    all_default_boxes = []

    for i in range(len(feature_map_sizes)):
        fm_size = feature_map_sizes[i]

        # Create a grid of center points
        # The centers are normalized to be between 0 and 1
        x_centers = (jnp.arange(fm_size) + 0.5) / fm_size
        y_centers = (jnp.arange(fm_size) + 0.5) / fm_size

        for y in y_centers:
            for x in x_centers:
                # Box 1: Small square box
                s_k = min_sizes[i] / TARGET_SIZE[0]
                all_default_boxes.append([x, y, s_k, s_k])

                # Box 2: Large square box
                s_k_prime = math.sqrt(s_k * (max_sizes[i] / TARGET_SIZE[0]))
                all_default_boxes.append([x, y, s_k_prime, s_k_prime])

                # Additional boxes based on aspect ratios
                for ar in aspect_ratios[i]:
                    all_default_boxes.append([x, y, s_k * math.sqrt(ar), s_k / math.sqrt(ar)])
                    all_default_boxes.append([x, y, s_k / math.sqrt(ar), s_k * math.sqrt(ar)])

    # The boxes are currently in [center_x, center_y, width, height] format
    default_boxes_tensor = jnp.array(all_default_boxes, dtype=jnp.float32)

    # Clip box coordinates to be within [0, 1]
    default_boxes_tensor = jnp.clip(default_boxes_tensor, 0.0, 1.0)

    return default_boxes_tensor

# Generate the boxes and inspect the shape
default_boxes = generate_default_boxes()
print(f"✅ Generated default boxes successfully!")
print(f"Shape of default boxes tensor: {default_boxes.shape}")

✅ Generated default boxes successfully!
Shape of default boxes tensor: (8732, 4)


In [8]:
def box_hw_to_corners(boxes):
    """Converts boxes from [x_min, y_min, w, h] to [x_min, y_min, x_max, y_max]."""
    return jnp.concatenate([boxes[..., :2], boxes[..., :2] + boxes[..., 2:]], axis=-1)

def box_corners_to_hw(boxes):
    """Converts boxes from [x_min, y_min, x_max, y_max] to [x_min, y_min, w, h]."""
    return jnp.concatenate([boxes[..., :2], boxes[..., 2:] - boxes[..., :2]], axis=-1)

def box_center_to_corners(boxes):
    """Converts boxes from [cx, cy, w, h] to [x_min, y_min, x_max, y_max]."""
    return jnp.concatenate([boxes[..., :2] - boxes[..., 2:] / 2,
                           boxes[..., :2] + boxes[..., 2:] / 2], axis=-1)


def jaccard_overlap(boxes1, boxes2):
    """Calculates intersection over union for batch of boxes.
    Args:
        boxes1: (N, 4) in corner format [xmin, ymin, xmax, ymax]
        boxes2: (M, 4) in corner format [xmin, ymin, xmax, ymax]
    Returns:
        iou: (N, M) matrix of IoU values.
    """
    # Find intersection corners
    xy_max = jnp.minimum(boxes1[:, None, 2:], boxes2[None, :, 2:])
    xy_min = jnp.maximum(boxes1[:, None, :2], boxes2[None, :, :2])

    # Clip to ensure intersection is never negative
    inter_dims = jnp.clip(xy_max - xy_min, a_min=0)
    inter_area = inter_dims[..., 0] * inter_dims[..., 1]

    # Calculate individual box areas
    area1 = (boxes1[:, 2] - boxes1[:, 0]) * (boxes1[:, 3] - boxes1[:, 1])
    area2 = (boxes2[:, 2] - boxes2[:, 0]) * (boxes2[:, 3] - boxes2[:, 1])

    # Calculate union
    union_area = area1[:, None] + area2[None, :] - inter_area

    return inter_area / (union_area + 1e-10)

In [21]:
def multibox_loss(loc_preds, conf_preds, gt_boxes, gt_labels, default_boxes, neg_pos_ratio=3):
    """
    Calculates the final, numerically stable MultiBox loss.
    """
    batch_size = gt_labels.shape[0]
    num_default_boxes = default_boxes.shape[0]
    epsilon = 1e-10

    default_boxes_corners = box_center_to_corners(default_boxes)
    true_locs = jnp.zeros_like(loc_preds)
    true_confs = jnp.zeros((batch_size, num_default_boxes), dtype=jnp.int32)

    for i in range(batch_size):
        mask = gt_labels[i] > -1
        image_gt_boxes, image_gt_labels = gt_boxes[i][mask], gt_labels[i][mask]
        if image_gt_boxes.shape[0] == 0: continue

        image_gt_boxes_corners = box_hw_to_corners(image_gt_boxes / TARGET_SIZE[0])
        overlaps = jaccard_overlap(default_boxes_corners, image_gt_boxes_corners)

        best_gt_overlap, best_gt_idx = overlaps.max(axis=1), overlaps.argmax(axis=1)
        pos_mask = best_gt_overlap > 0.5

        # --- THIS IS THE FIX ---
        # For any positive match, we set the target class to 1 (our 'car' class).
        # For any negative match (background), we set the target class to 0.
        conf = jnp.where(pos_mask, 1, 0)
        # ---

        matched_gt_boxes = image_gt_boxes[best_gt_idx]
        cx, cy = matched_gt_boxes[:, 0] + matched_gt_boxes[:, 2] / 2, matched_gt_boxes[:, 1] + matched_gt_boxes[:, 3] / 2
        matched_gt_boxes_center = jnp.stack([cx, cy, matched_gt_boxes[:, 2], matched_gt_boxes[:, 3]], axis=-1)
        matched_gt_boxes_center /= TARGET_SIZE[0]

        offset_cx = (matched_gt_boxes_center[:, 0] - default_boxes[:, 0]) / (default_boxes[:, 2] + epsilon)
        offset_cy = (matched_gt_boxes_center[:, 1] - default_boxes[:, 1]) / (default_boxes[:, 3] + epsilon)
        offset_w = jnp.log((matched_gt_boxes_center[:, 2] + epsilon) / (default_boxes[:, 2] + epsilon))
        offset_h = jnp.log((matched_gt_boxes_center[:, 3] + epsilon) / (default_boxes[:, 3] + epsilon))
        loc = jnp.stack([offset_cx, offset_cy, offset_w, offset_h], axis=-1)

        true_locs = true_locs.at[i].set(jnp.where(pos_mask[:, None], loc, 0))
        true_confs = true_confs.at[i].set(conf)

    pos_mask_batch = true_confs > 0
    loc_loss = optax.losses.huber_loss(loc_preds[pos_mask_batch], true_locs[pos_mask_batch], delta=1.0).sum()
    conf_loss_all = optax.losses.softmax_cross_entropy_with_integer_labels(conf_preds, true_confs)
    pos_conf_loss = conf_loss_all[pos_mask_batch].sum()
    neg_conf_loss = conf_loss_all.copy().at[pos_mask_batch].set(0)
    num_positives = pos_mask_batch.sum()
    num_negatives = num_positives * neg_pos_ratio
    neg_conf_loss_sorted = jnp.sort(neg_conf_loss.flatten())[::-1]
    hard_neg_loss = neg_conf_loss_sorted[:num_negatives.astype(int)].sum()

    total_loss = (loc_loss + pos_conf_loss + hard_neg_loss) / jnp.maximum(1, num_positives)

    return total_loss

# --- Test the loss function with dummy data ---
key, subkey = jax.random.split(key)
dummy_loc_preds = jax.random.normal(key, shape=(BATCH_SIZE, 8732, 4))
dummy_conf_preds = jax.random.normal(subkey, shape=(BATCH_SIZE, 8732, NUM_CLASSES + 1))
dummy_gt_batch = next(iter(train_loader))

loss_value = multibox_loss(
    dummy_loc_preds,
    dummy_conf_preds,
    dummy_gt_batch['bboxes'],
    dummy_gt_batch['labels'],
    default_boxes
)

print(f"✅ Loss function defined and tested successfully!")
print(f"Calculated a dummy loss value of: {loss_value:.4f}")

✅ Loss function defined and tested successfully!
Calculated a dummy loss value of: 12.8586


In [22]:
# We already have the 'params' variable from initializing our model in Cell 6
# We also have the 'model' object defined in Cell 5

# 1. Define a learning rate and optimizer
LEARNING_RATE = 1e-4
optimizer = optax.adam(learning_rate=LEARNING_RATE)

# 2. Create the Training State
# This simple dataclass holds all the state needed for training
training_state = train_state.TrainState.create(
    apply_fn=model.apply,
    params=params,
    tx=optimizer,
)

print("✅ Training state created successfully!")

✅ Training state created successfully!


In [23]:
@jax.jit
def train_step(state, batch, default_boxes):
    """Performs a single training step on a batch of data."""

    # Define a loss function that takes the model's parameters
    def loss_fn(params):
        # Get model predictions
        loc_preds, conf_preds = state.apply_fn({'params': params}, batch['image'])

        # Calculate the loss
        loss = multibox_loss(
            loc_preds,
            conf_preds,
            batch['bboxes'],
            batch['labels'],
            default_boxes
        )
        return loss

    # Calculate the loss and the gradients of the loss with respect to the parameters
    loss_value, grads = jax.value_and_grad(loss_fn)(state.params)

    # Update the model's state by applying the gradients
    new_state = state.apply_gradients(grads=grads)

    # Return the new state and the loss for logging
    return new_state, loss_value

print("✅ JIT-compiled train_step function created successfully!")

✅ JIT-compiled train_step function created successfully!


In [None]:
from tqdm.notebook import tqdm

NUM_EPOCHS = 10 # Set to a small number for this demo

print("Starting training...")
for epoch in range(NUM_EPOCHS):

    # Create a new data loader for each epoch to get different shuffles
    # For a full run, remove the .take(1000)
    train_loader = create_data_loader(processed_ds.take(1000), BATCH_SIZE)

    # Use tqdm for a progress bar
    batch_losses = []
    pbar = tqdm(train_loader, desc=f"Epoch {epoch + 1}/{NUM_EPOCHS}")

    for batch in pbar:
        # Perform one training step
        training_state, loss = train_step(training_state, batch, default_boxes)

        # Store the loss
        batch_losses.append(loss)

        # Update progress bar description with the latest loss
        pbar.set_description(f"Epoch {epoch + 1}/{NUM_EPOCHS} | Loss: {loss:.4f}")

    # Calculate and print the average loss for the epoch
    avg_loss = jnp.mean(jnp.array(batch_losses))
    print(f"Epoch {epoch + 1} finished. Average Loss: {avg_loss:.4f}\n")

print("🎉 Training finished!")

Starting training...


Epoch 1/10: 0it [00:00, ?it/s]