In [5]:
# Understand the COCO dataset
from PIL import Image
import json
import os
from pathlib import Path

project_root = Path.cwd().parent

with open(f"{project_root}/data/coco/subset/annotations/instances_subset.json") as f:
    coco = json.load(f)     # dict contains filenames and ground truth classes

categories = coco["categories"]

for c in categories:
    print(f"Category ID: {c['id']} Name: {c['name']}")
# 1 - 90 categories
print(f"Number of Categories: {len(categories)}")

# Data structure
BASE_DIR = "data/coco/images/train2017" if os.environ.get("USE_FULL_DATASET", 0) == 1 else "data/coco/subset/images"
print(f"Segmentation: The coordinates for a polygon line of the detected object: {coco['annotations'][2]['segmentation']} ")
print(f"Area: The area of the detected object: {coco['annotations'][2]['area']}")
print(f"Image ID: {coco['annotations'][2]['id']}")
print(f"Number of annotations: {len(coco['annotations'])}")     # 7000+, we need to filter 1K which really exist in our images folder


Category ID: 1 Name: person
Category ID: 2 Name: bicycle
Category ID: 3 Name: car
Category ID: 4 Name: motorcycle
Category ID: 5 Name: airplane
Category ID: 6 Name: bus
Category ID: 7 Name: train
Category ID: 8 Name: truck
Category ID: 9 Name: boat
Category ID: 10 Name: traffic light
Category ID: 11 Name: fire hydrant
Category ID: 13 Name: stop sign
Category ID: 14 Name: parking meter
Category ID: 15 Name: bench
Category ID: 16 Name: bird
Category ID: 17 Name: cat
Category ID: 18 Name: dog
Category ID: 19 Name: horse
Category ID: 20 Name: sheep
Category ID: 21 Name: cow
Category ID: 22 Name: elephant
Category ID: 23 Name: bear
Category ID: 24 Name: zebra
Category ID: 25 Name: giraffe
Category ID: 27 Name: backpack
Category ID: 28 Name: umbrella
Category ID: 31 Name: handbag
Category ID: 32 Name: tie
Category ID: 33 Name: suitcase
Category ID: 34 Name: frisbee
Category ID: 35 Name: skis
Category ID: 36 Name: snowboard
Category ID: 37 Name: sports ball
Category ID: 38 Name: kite
Category

In [2]:
# Dataset

import os, json
from PIL import Image
import torch
from torch.utils.data import Dataset

class ImageDataset(Dataset):
    """
    COCO classification dataset.
    (image_tensor, one_hot)    # multi-label, softmax, one image contains multiple objects
    """

    def __init__(self,
                 img_dir: str,
                 ann_file: str,
                 transform=None):
        self.img_dir      = img_dir
        self.transform    = transform

        # parse COCO json
        with open(ann_file, "r") as f:
            coco = json.load(f)

        self.cat_ids      = sorted({c["id"] for c in coco["categories"]})
        self.idx2cat_id = {i: cid for i, cid in enumerate(self.cat_ids)}
        self.cat_id2idx = {cid: i for i, cid in enumerate(self.cat_ids)}
        self.num_classes  = len(self.cat_ids)

        # map image id to related list of categories
        img_to_cats = {}
        for ann in coco["annotations"]:
            # {image_id: [category_11, category_23, ...]}
            img_to_cats.setdefault(ann["image_id"], []).append(ann["category_id"])

        # keep only images that actually have annotations
        # filenames: categories
        self.samples = [
            (img["file_name"], img_to_cats[img["id"]])
            for img in coco["images"]
            if img["id"] in img_to_cats
        ]
        if not self.samples:
            raise RuntimeError("No annotated images found!")

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        fname, cats = self.samples[idx]
        img_path    = os.path.join(self.img_dir, fname)
        img         = Image.open(img_path).convert("RGB")   # make sure only 3 channels

        # preprocessing the image if there's predefined transform(size, etc.)
        if self.transform:
            img = self.transform(img)

        # multi-label one-hot vector, all zeros for initiation
        target = torch.zeros(self.num_classes, dtype=torch.float32)
        # ground truth category of this image in one-hot encoding, mark those valid categories into 1
        for cid in cats:
            target[self.cat_id2idx[cid]] = 1.0
        return img, target

In [3]:
# Check ImageDataset
image_ds = ImageDataset(img_dir="data/coco/subset/images", ann_file="data/coco/subset/annotations/instances_subset.json")
print(f"Data length: {len(image_ds)}")
print(f"Sample Data: {image_ds[33][0]}")
print(f"Ground Truth Category:\n {image_ds[33][1]}")

Data length: 990
Sample Data: <PIL.Image.Image image mode=RGB size=640x501 at 0x119A6ADC0>
Ground Truth Category:
 tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0.])


In [4]:
# Define backbone of ResNet18 model

import torch.nn as nn  # Neural Network lib
import torchvision


class MyResNet18(nn.Module):
    """
    ResNet18 backbone
    Inherit PyTorch nn Module, define and train my own ResNet from scratch
    num_classes: class num for the model, 91 classes(categories) for COCO dataset
    """

    def __init__(self, num_classes=1000, weights=None):
        """
        Args:
            num_classes (int):
                Number of output channels from the final linear layer.
                For ImageNet classification use 1000; for a custom task like COCO, 91 classes(categories)
            weights (bool):
                If True, loads ImageNet‑pretrained weights.
                If False, train from scratch
        """
        super().__init__()
        # load ResNet18 architecture
        self.model = torchvision.models.resnet18(
            weights=weights)  # Main feature extractor, no pretrained weights
        # replace the final fully connected layer with new classes
        in_features = self.model.fc.in_features
        self.model.fc = nn.Linear(in_features, num_classes)

    def forward(self, x):
        """
        Forward pass through the network.

        Args:
            x (torch.Tensor): Input tensor of shape [B, 3, H, W],
                              where B = batch size.

        Returns:
            torch.Tensor: Output logits of shape [B, num_classes].
        """
        return self.model(x)


In [5]:
# Transform, image preprocessing
from torchvision import transforms

transform = transforms.Compose([
    transforms.Resize((480, 480)),      # smaller -> trains faster
    transforms.RandomHorizontalFlip(0.5),   # probability of flipping current image
    transforms.ToTensor(),              # PIL -> PyTorch FloatTensor
    transforms.Normalize(               # data from ImageNet, pixel normalization
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    )
])


In [25]:
#!/usr/bin/env python3
"""
Training related functions definitions
"""

import os
import logging
from datetime import datetime

import mlflow
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from sklearn.metrics import roc_curve, auc, average_precision_score
from torch.utils.data import DataLoader


# Configure logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)


def get_dataset_paths():
    """
    Get appropriate data paths based on environment variable.

    Returns:
        tuple: (image_directory_path, annotation_file_path)
    """
    use_full = os.environ.get("USE_FULL_DATASET", "0").lower() in ("1", "true", "yes")
    base_dir = os.path.abspath(".")

    if use_full:
        img_dir = os.path.join(base_dir, "data", "coco", "images", "train2017")
        ann_file = os.path.join(base_dir, "data", "coco", "annotations", "instances_train2017.json")
    else:
        img_dir = os.path.join(base_dir, "data", "coco", "subset", "images")
        ann_file = os.path.join(base_dir, "data", "coco", "subset", "annotations", "instances_subset.json")

    return img_dir, ann_file


def select_device():
    """
    Select the appropriate device for training.

    Returns:
        torch.device: The device to use for training
    """
    if torch.backends.mps.is_available():
        device = torch.device("mps")
    elif torch.cuda.is_available():
        device = torch.device("cuda")
    else:
        device = torch.device("cpu")

    logger.info(f"Using device: {device}")
    return device


def setup_mlflow(batch_size, learning_rate, num_epochs, device):
    """
    Set up and configure the MLFlow for experiment tracking.

    Args:
        batch_size (int): Training batch size
        learning_rate (float): Learning rate
        num_epochs (int): Number of training epochs
        device (torch.device): Training device

    Returns:
        MLFlow instance
    """
    mlflow.set_tracking_uri(os.getenv("MLFLOW_TRACKING_URI", "file:./mlruns"))
    mlflow.set_experiment("ResNet18_COCO")
    mlflow_instance = mlflow.start_run(run_name=datetime.now().strftime("%Y%m%d_%H%M%S"))

    mlflow.log_params({
        "batch_size": batch_size,
        "learning_rate": learning_rate,
        "epochs": num_epochs,
        "model": "ResNet18",
        "device": device.type
    })
    return mlflow_instance



def create_checkpoint_dir():
    """
    Create directory for saving model checkpoints.

    Returns:
        str: Path to checkpoint directory
    """
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    save_dir = os.path.join("checkpoints", timestamp)
    os.makedirs(save_dir, exist_ok=True)
    logger.info(f"Checkpoints will be saved to: {save_dir}")
    return save_dir


def save_checkpoint(model, optimizer, epoch, accuracy, save_dir):
    """
    Save model checkpoint.

    Args:
        model (nn.Module): The model to save
        optimizer (torch.optim.Optimizer): The optimizer to save
        epoch (int): Current epoch number
        accuracy (float): Current accuracy
        save_dir (str): Directory to save checkpoint

    Returns:
        str: Path to saved checkpoint
    """
    checkpoint_path = os.path.join(save_dir, f"resnet18_epoch_{epoch + 1}.pth")
    torch.save({
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'accuracy': accuracy
    }, checkpoint_path)
    logger.info(f"Checkpoint saved to {checkpoint_path}")
    return checkpoint_path


def track_metrics(metrics_dict, epoch, step=None, context=None):
    """
    Track multiple metrics in mlflow.

    Args:
        metrics_dict (dict): Dictionary of metrics to track
        epoch (int): Current epoch
        step (int, optional): Current step within the epoch
        context (dict, optional): Additional context for the metrics
    """
    context = context or {"subset": "train"}

    for name, value in metrics_dict.items():
        if step is not None:
            mlflow.log_metric(name, value, step=step)
        else:
            mlflow.log_metric(name, value, step=epoch)


def calculate_metrics(all_targets, all_predictions):
    """
    all_targets      - list of (C,) one-hot np.float32
    all_predictions  - list of (C,) sigmoid scores np.float32
    """
    y_true  = np.vstack(all_targets).astype(np.int8)     # (N, C)
    y_score = np.vstack(all_predictions)                 # (N, C) ∈[0,1]

    metrics = {}

    # Average precision (macro over classes that have ≥1 positive)
    aps = []
    for c in range(y_true.shape[1]):
        if y_true[:, c].sum() == 0:      # skip empty class
            continue
        aps.append(average_precision_score(y_true[:, c], y_score[:, c]))
    metrics["avg_precision"] = float(np.mean(aps)) if aps else 0.0

    # ROC AUC (macro over valid classes)
    aucs = []
    for c in range(y_true.shape[1]):
        pos = y_true[:, c].sum()
        neg = (1 - y_true[:, c]).sum()
        if pos == 0 or neg == 0:
            continue                     # undefined
        fpr, tpr, _ = roc_curve(y_true[:, c], y_score[:, c])
        aucs.append(auc(fpr, tpr))
    metrics["roc_auc"] = float(np.mean(aucs)) if aucs else 0.0

    return metrics


def train_epoch(model, train_loader, optimizer, criterion, device, epoch):
    """
    Main logic for one training epoch
    """
    model.train()

    running_loss   = 0.0
    running_correct = 0
    running_total   = 0
    log_interval    = 10

    all_targets      = []
    all_predictions  = []

    for i, (images, targets) in enumerate(train_loader):
        images   = images.to(device)
        targets  = targets.to(device).float()          # one-hot
        
        optimizer.zero_grad()
        logits     = model(images)
        loss       = criterion(logits, targets)
        loss.backward()
        optimizer.step()

        # accuracy: label-wise (much more informative early)
        probs = logits.sigmoid()
        preds = (probs > 0.5)                                       # filter out all probabilities > 0.5 in the vector

        correct_labels = (preds == targets.bool()).sum().item()     # see how many categories are correctly predicted
        total_labels   = targets.numel()

        running_correct += correct_labels
        running_total   += total_labels

        running_loss += loss.item()

        # Mlflow logging every batch
        track_metrics(
            {"loss": loss.item()},
            epoch,
            step=i
        )

        # Periodic console log
        if i % log_interval == log_interval - 1:
            avg_loss  = running_loss / log_interval
            acc_sofar = running_correct / running_total
            logger.info(f"Epoch {epoch + 1} | Batch {i + 1} | Loss {avg_loss:.4f} | Accuracy {acc_sofar:.4f}")
            running_loss = 0.0

        # save for end-of-epoch metrics
        all_targets.append(targets.cpu().numpy())
        all_predictions.append(logits.sigmoid().detach().cpu().numpy())

    epoch_accuracy = running_correct / running_total
    all_targets    = np.vstack(all_targets)
    all_predictions = np.vstack(all_predictions)

    return epoch_accuracy, all_targets, all_predictions


In [None]:
# main train logic
from mlflow.models import infer_signature, validate_serving_input, convert_input_example_to_serving_input


def train():
    # Set hyperparameters
    batch_size = 64
    num_epochs = 10
    learning_rate = 1e-4

    # Setup device, mlflow tracking, and checkpoint directory
    device = select_device()
    mlflow_instance = setup_mlflow(batch_size, learning_rate, num_epochs, device)
    save_dir = create_checkpoint_dir()

    # Get dataset paths and setup data
    img_dir, ann_file = get_dataset_paths()
    print(f"Training with data from: {img_dir}")
    print(f"Using annotations from: {ann_file}")

    # Create datasets and dataloaders - disable pin_memory on MPS
    train_dataset = ImageDataset(img_dir=img_dir, ann_file=ann_file, transform=transform)

    # Disable pin_memory explicitly on MPS to avoid warnings
    pin_memory = device.type != "mps"
    train_loader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=0 if device.type == "mps" else 4,       # 0 threads to avoid issue using mps on macos
        pin_memory=pin_memory                               # Disable for MPS
    )

    # Initialize the model with weights=None instead of pretrained=False
    model = MyResNet18(num_classes=train_dataset.num_classes, weights=None)
    model = model.to(device)

    # Define loss function and optimizer, Multi-label for sigmoid + BCE
    criterion = nn.BCEWithLogitsLoss()
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)

    best_acc = float("-inf")                     # use for telling best model so far

    # Training loop
    for epoch in range(num_epochs):
        # Train for one epoch
        epoch_accuracy, all_targets, all_predictions = train_epoch(
            model, train_loader, optimizer, criterion, device, epoch
        )

        if epoch_accuracy > best_acc:                # new best model found: persist storage + logging checkpoints
            best_acc = epoch_accuracy
            # reuse your helper
            best_path = save_checkpoint(model, optimizer, epoch, best_acc, save_dir)
            # put the raw .pt file in MLflow artifacts
            mlflow.log_artifact(best_path, artifact_path="checkpoints")
            # package the model for serving / registry
            torch_sample = torch.randn(1, 3, 224, 224, dtype=torch.float32)
            np_sample    = torch_sample.cpu().numpy().astype(np.float32)
            
            # infer signature (input → output)
            signature = infer_signature(
                np_sample,
                model(torch_sample.to(device)).detach().cpu().numpy()
            )

            pip_reqs = [
                f"torch=={torch.__version__}",
                f"torchvision=={torchvision.__version__}",
            ]

            mlflow.pytorch.log_model(
                model,
                artifact_path="best_model",
                signature=signature,            # passes schema, no input_example needed
                pip_requirements=pip_reqs
            )

        # Track epoch-level metrics
        mlflow.log_metric("epoch_accuracy", epoch_accuracy, step=epoch)
        print(f'Epoch: {epoch + 1} completed, Accuracy: {epoch_accuracy:.4f}')

        # Track learning rate
        current_lr = optimizer.param_groups[0]['lr']
        mlflow.log_metric("learning_rate", current_lr, step=epoch)

        # Calculate and track model metrics
        metrics = calculate_metrics(
            all_targets, all_predictions
        )

        # Track overall average precision
        mlflow.log_metric("avg_precision", metrics["avg_precision"], step=epoch)

        # Track overall ROC AUC (macro-average across classes)
        mlflow.log_metric("roc_auc", metrics["roc_auc"], step=epoch)

        # Save checkpoint after each epoch
        save_checkpoint(model, optimizer, epoch, epoch_accuracy, save_dir)

    mlflow.end_run()
        

try:
    train()                             #  train() already starts & ends the run
except Exception as e:
    mlflow.end_run(status="FAILED")     # avoid RUNNING but failed task jam the queue

2025-05-21 23:10:54,028 - INFO - Using device: mps
2025-05-21 23:10:54,047 - INFO - Checkpoints will be saved to: checkpoints/20250521_231054


Training with data from: /Users/yj/Workspaces/ResNet-CIFAR10/data/coco/subset/images
Using annotations from: /Users/yj/Workspaces/ResNet-CIFAR10/data/coco/subset/annotations/instances_subset.json


2025-05-21 23:11:18,616 - INFO - Epoch 1 | Batch 10 | Loss 0.6221 | Accuracy 0.6889


In [None]:
# use to end the failed test, avoid RUNNING but failed task jam the workflow
mlflow.end_run("6ef3867a369b45f3ba0bc1e2e74a4d90")