# 02A - Image Classifier Training for CleanLab

In this notebook, we train a ResNet50 model for image classification on the Oxford Pet IIIT dataset using the Timm library and Pytorch.

We train the model with K-fold cross-validation and use it to produce out-of-sample predicted class probabilities for each image in our dataset, as well as a feature embedding of each image.

In [1]:
!pip install -r ../requirements.txt -q


[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m23.3.1[0m[39;49m -> [0m[32;49m24.0[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpython3.9 -m pip install --upgrade pip[0m


In [5]:
import os

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
import numpy as np
from torchvision.datasets import ImageFolder
from tqdm import tqdm
import timm
from sklearn.model_selection import StratifiedKFold
from pathlib import Path
import sys

sys.path.append("../")

from utils import get_oxford_pets3t

## Dataset Preparation

We'll download the dataset to disk and load it with a Torchvision data loader (applying necessary transformations, which we define below).

In [12]:
transform = transforms.Compose(
    [
        transforms.Resize((224, 224)),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(  # ImageNet stats for normalization (mean and std) of RGB channels
            mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
        ),
    ]
)

OUT_DIR = "../pre_computed_assets/CleanLab"  # Save notebook artifacts in this directory
dataset, df = get_oxford_pets3t(
    root_path="../data", return_dataframe=True, transform=transform
)

n_classes = len(dataset.classes)
labels = np.array(dataset.targets)

dataset

Oxford PetIIIT already downloaded to `../data`.


## Model Training

With the dataset ready, we now define the hyperparameters for training and perform k-fold cross-validation to train a ResNet-50.


In [8]:
batch_size = 64
learning_rate = 0.001
num_epochs = 10
num_folds = 5  # Use 3 for faster training
patience = 2

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_name = "resnet50"
model_prefix = "oxford_pet3t" + model_name

In [9]:
print(f"Running on {device}")

Running on cuda


### K-Fold Cross-Validation and Training

In this notebook, we use k-fold cross-validation to train the model and extract out-of-sample predicted probabilities for all data points.

During the training process, we'll just use the validation accuracy on the held-out fold to allow early stopping for each fold. This approach helps us prevent overfitting and obtain a better estimate of the model's performance.

While we're not specifically interested in the model artifacts themselves, we aim to get a general idea of whether the chosen model architecture is accurate enough for our purpose.

**Warning**: This cell may take a long time to execute and should be run with a GPU.

In [15]:
kf = StratifiedKFold(n_splits=num_folds, shuffle=True, random_state=42)

In [None]:
# Set seed
seed = 42
torch.manual_seed(seed)
np.random.seed(seed)
for fold, (train_idx, test_idx) in enumerate(kf.split(dataset, labels)):
    print(f"Fold {fold + 1}/{num_folds}")
    print("-" * 10)

    # Define data loaders for current fold
    train_subset = torch.utils.data.Subset(dataset, train_idx)
    val_subset = torch.utils.data.Subset(dataset, test_idx)
    # Print train and validation set sizes
    print(f"Train set size: {len(train_idx)}")
    print(f"Test set size: {len(test_idx)}")

    train_loader = torch.utils.data.DataLoader(
        train_subset, batch_size=batch_size, shuffle=True
    )
    val_loader = torch.utils.data.DataLoader(
        val_subset, batch_size=batch_size, shuffle=False
    )

    # Initialize model for current fold
    model = timm.create_model(model_name, pretrained=True, num_classes=n_classes)
    model = model.to(device)
    num_features = model.num_features

    # Define loss function and optimizer
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)

    best_val_accuracy = 0
    best_epoch = 0
    # Train model for current fold
    for epoch in range(num_epochs):
        print(f"Epoch {epoch + 1}/{num_epochs}")
        model.train()
        for inputs, targets in tqdm(train_loader):
            inputs = inputs.to(device)
            targets = targets.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        print(f"Train loss: {loss:.4f}")

        # Evaluate model on training set for current fold
        model.eval()
        with torch.no_grad():
            correct = 0
            total = 0
            eval_loader = train_loader
            for inputs, targets in tqdm(val_loader):
                inputs = inputs.to(device)
                targets = targets.to(device)
                outputs = model(inputs)
                _, predicted = torch.max(outputs.data, 1)
                total += targets.size(0)
                correct += (predicted == targets).sum().item()
            val_accuracy = 100 * correct / total
            print(f"Validation accuracy: {val_accuracy:.2f}%")

        # Save model checkpoint if it is the best so far
        if val_accuracy > best_val_accuracy:
            print("Saving model...")
            path = f"{model_prefix}_fold_{fold + 1}.pt"
            torch.save(model.state_dict(), path)
            best_val_accuracy = val_accuracy
            best_epoch = epoch

        # Early stopping
        if epoch - best_epoch > patience:
            print(f"Early stopping at epoch {epoch + 1}")
            break

Fold 1/5
----------
Train set size: 5912
Test set size: 1478


Downloading: "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/resnet50_a1_0-14fe96d1.pth" to /home/jovyan/.cache/torch/hub/checkpoints/resnet50_a1_0-14fe96d1.pth


Epoch 1/10


100%|██████████| 93/93 [00:48<00:00,  1.90it/s]


Train loss: 0.4253


100%|██████████| 24/24 [00:08<00:00,  2.88it/s]


Validation accuracy: 87.42%
Saving model...
Epoch 2/10


100%|██████████| 93/93 [00:48<00:00,  1.91it/s]


Train loss: 0.2023


100%|██████████| 24/24 [00:08<00:00,  2.84it/s]


Validation accuracy: 91.00%
Saving model...
Epoch 3/10


100%|██████████| 93/93 [00:49<00:00,  1.88it/s]


Train loss: 0.2547


100%|██████████| 24/24 [00:08<00:00,  2.84it/s]


Validation accuracy: 91.00%
Epoch 4/10


100%|██████████| 93/93 [00:49<00:00,  1.88it/s]


Train loss: 0.3091


100%|██████████| 24/24 [00:08<00:00,  2.82it/s]


Validation accuracy: 87.28%
Saving model...
Epoch 2/10


100%|██████████| 93/93 [00:49<00:00,  1.89it/s]


Train loss: 0.2132


100%|██████████| 24/24 [00:08<00:00,  2.88it/s]


Validation accuracy: 90.53%
Saving model...
Epoch 3/10


100%|██████████| 93/93 [00:49<00:00,  1.87it/s]


Train loss: 0.2090


100%|██████████| 24/24 [00:08<00:00,  2.86it/s]


Validation accuracy: 88.84%
Epoch 4/10


100%|██████████| 93/93 [00:49<00:00,  1.88it/s]


Train loss: 0.0691


100%|██████████| 24/24 [00:08<00:00,  2.87it/s]


Validation accuracy: 86.47%
Epoch 5/10


100%|██████████| 93/93 [00:49<00:00,  1.87it/s]


Train loss: 0.1174


100%|██████████| 24/24 [00:08<00:00,  2.90it/s]


Validation accuracy: 90.26%
Early stopping at epoch 5
Fold 3/5
----------
Train set size: 5912
Test set size: 1478
Epoch 1/10


100%|██████████| 93/93 [00:49<00:00,  1.87it/s]


Train loss: 0.3184


100%|██████████| 24/24 [00:08<00:00,  2.79it/s]


Validation accuracy: 87.55%
Saving model...
Epoch 2/10


100%|██████████| 93/93 [00:50<00:00,  1.85it/s]


Train loss: 0.1660


100%|██████████| 24/24 [00:08<00:00,  2.78it/s]


Validation accuracy: 87.96%
Saving model...
Epoch 3/10


100%|██████████| 93/93 [00:50<00:00,  1.84it/s]


Train loss: 0.2799


100%|██████████| 24/24 [00:08<00:00,  2.91it/s]


Validation accuracy: 88.36%
Saving model...
Epoch 4/10


100%|██████████| 93/93 [00:50<00:00,  1.85it/s]


Train loss: 0.0662


100%|██████████| 24/24 [00:08<00:00,  2.85it/s]


Validation accuracy: 88.02%
Epoch 5/10


100%|██████████| 93/93 [00:49<00:00,  1.86it/s]


Train loss: 0.1020


100%|██████████| 24/24 [00:08<00:00,  2.88it/s]


Validation accuracy: 89.45%
Saving model...
Epoch 6/10


100%|██████████| 93/93 [00:50<00:00,  1.85it/s]


Train loss: 0.0188


100%|██████████| 24/24 [00:08<00:00,  2.93it/s]


Validation accuracy: 89.65%
Saving model...
Epoch 7/10


100%|██████████| 93/93 [00:49<00:00,  1.86it/s]


Train loss: 0.2134


100%|██████████| 24/24 [00:08<00:00,  2.87it/s]


Validation accuracy: 86.67%
Epoch 8/10


100%|██████████| 93/93 [00:50<00:00,  1.86it/s]


Train loss: 0.0525


100%|██████████| 24/24 [00:08<00:00,  2.91it/s]


Validation accuracy: 88.16%
Epoch 9/10


100%|██████████| 93/93 [00:49<00:00,  1.87it/s]


Train loss: 0.1209


100%|██████████| 24/24 [00:08<00:00,  2.90it/s]


Validation accuracy: 87.62%
Early stopping at epoch 9
Fold 4/5
----------
Train set size: 5912
Test set size: 1478
Epoch 1/10


100%|██████████| 93/93 [00:49<00:00,  1.87it/s]


Train loss: 0.1700


100%|██████████| 24/24 [00:08<00:00,  2.83it/s]


Validation accuracy: 81.87%
Saving model...
Epoch 2/10


100%|██████████| 93/93 [00:49<00:00,  1.86it/s]


Train loss: 0.3182


100%|██████████| 24/24 [00:08<00:00,  2.88it/s]


Validation accuracy: 86.33%
Saving model...
Epoch 3/10


100%|██████████| 93/93 [00:50<00:00,  1.86it/s]


Train loss: 0.1773


100%|██████████| 24/24 [00:08<00:00,  2.79it/s]


Validation accuracy: 87.75%
Saving model...
Epoch 4/10


100%|██████████| 93/93 [00:49<00:00,  1.86it/s]


Train loss: 0.0923


100%|██████████| 24/24 [00:08<00:00,  2.82it/s]


Validation accuracy: 86.67%
Epoch 5/10


100%|██████████| 93/93 [00:49<00:00,  1.88it/s]


Train loss: 0.1047


100%|██████████| 24/24 [00:08<00:00,  2.77it/s]


Validation accuracy: 87.69%
Epoch 6/10


100%|██████████| 93/93 [00:50<00:00,  1.86it/s]


Train loss: 0.2025


100%|██████████| 24/24 [00:08<00:00,  2.76it/s]


Validation accuracy: 89.72%
Saving model...
Epoch 7/10


100%|██████████| 93/93 [00:50<00:00,  1.85it/s]


Train loss: 0.1443


100%|██████████| 24/24 [00:08<00:00,  2.83it/s]


Validation accuracy: 89.85%
Saving model...
Epoch 8/10


100%|██████████| 93/93 [00:49<00:00,  1.86it/s]


Train loss: 0.1304


100%|██████████| 24/24 [00:08<00:00,  2.86it/s]


Validation accuracy: 91.00%
Saving model...
Epoch 9/10


100%|██████████| 93/93 [00:49<00:00,  1.87it/s]


Train loss: 0.0033


100%|██████████| 24/24 [00:08<00:00,  2.85it/s]


Validation accuracy: 90.19%
Epoch 10/10


100%|██████████| 93/93 [00:49<00:00,  1.86it/s]


Train loss: 0.1873


100%|██████████| 24/24 [00:08<00:00,  2.87it/s]


Validation accuracy: 85.72%
Fold 5/5
----------
Train set size: 5912
Test set size: 1478
Epoch 1/10


100%|██████████| 93/93 [00:49<00:00,  1.87it/s]


Train loss: 0.3110


100%|██████████| 24/24 [00:08<00:00,  2.84it/s]


Validation accuracy: 85.52%
Saving model...
Epoch 2/10


100%|██████████| 93/93 [00:49<00:00,  1.87it/s]


Train loss: 0.2776


100%|██████████| 24/24 [00:08<00:00,  2.85it/s]


Validation accuracy: 88.36%
Saving model...
Epoch 3/10


 20%|██        | 19/93 [00:10<00:39,  1.87it/s]

## Getting Predicted Class Probabilities and Extracting Feature Embeddings 

After training, we will compute predicted class probabilities for the entire dataset using the trained models from each fold.

In addition, to keep things simple, we'll use the model trained on the first fold as a feature extractor to obtain embeddings for every image in the dataset.

These artifacts will be used by `Datalab` to inspect the dataset for potential issues, so we save them to files used in the next notebook.

In [None]:
model = timm.create_model(model_name, pretrained=True, num_classes=n_classes)
path = f"{model_prefix}_fold_1.pt"
model.load_state_dict(torch.load(path))
model.eval()
model.to(device)
num_features = model.num_features

features = np.zeros((len(dataset), num_features))
pred_probs = np.zeros((len(dataset), n_classes))

for fold, (_, test_idx) in enumerate(kf.split(dataset, labels)):
    # Save out-of-sample predictions and features for current fold
    # This is the validation set
    # Define data loaders for current fold
    test_subset = torch.utils.data.Subset(dataset, test_idx)
    test_loader = torch.utils.data.DataLoader(
        test_subset, batch_size=batch_size, shuffle=False
    )

    model = timm.create_model(model_name, pretrained=True, num_classes=n_classes)
    path = f"{model_prefix}_fold_{fold + 1}.pt"
    model.load_state_dict(torch.load(path))

    model.eval()
    model.to(device)

    with torch.no_grad():
        pred_probs_fold = []
        for inputs, _ in tqdm(test_loader):
            inputs = inputs.to(device)
            outputs = model(inputs)
            # Predicted probabilities
            outputs = nn.functional.softmax(outputs, dim=1)
            pred_probs_fold.append(outputs.cpu().numpy())
        pred_probs[test_idx] = np.concatenate(pred_probs_fold, axis=0)

    model = timm.create_model(model_name, pretrained=True, num_classes=n_classes)
    path = f"{model_prefix}_fold_1.pt"
    model.load_state_dict(torch.load(path))
    model.eval()
    model.to(device)
    with torch.no_grad():
        features_fold = []
        model.reset_classifier(0)
        for inputs, _ in tqdm(test_loader):
            inputs = inputs.to(device)
            features_fold.append(model(inputs).cpu().numpy())
        features[test_idx] = np.concatenate(features_fold, axis=0)

features_path = os.path.join(OUT_DIR, "features.npy")
pred_probs_path = os.path.join(OUT_DIR, "pred_probs.npy")

np.save(features_path, features)
np.save(pred_probs_path, pred_probs)