# Image Classifier Training on Caltech-256 Subset


[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/cleanlab/examples/blob/master/datalab_image_classification/train_image_classifier.ipynb)

In this notebook, we train a Swin Transformer model for image classification on a subset of the Caltech-256 dataset using the Timm library and Pytorch.

We train the model with K-fold cross-validation and use it to produce out-of-sample predicted class probabilities for each image in our dataset, as well as a feature embedding of each image.

Please install the dependencies specified in this [requirements-train.txt](https://github.com/cleanlab/examples/blob/master/datalab_image_classification/requirements-train.txt) file before running the notebook.

In [12]:
!pip3 install torch torchvision torchaudio
!pip3 install timm
!pip3 install -U scikit-learn scipy matplotlib



In [9]:
import os

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
import numpy as np
from torchvision.datasets import ImageFolder
from tqdm import tqdm
import timm
from sklearn.model_selection import StratifiedKFold

## Dataset Preparation

We'll download the dataset to disk and load it with a Torchvision data loader (applying necessary transformations, which we define below).

In [13]:
!brew install wget
!wget -nc 'https://cleanlab-public.s3.amazonaws.com/Datalab/caltech256-subset.tar.gz'
!mkdir -p data
!tar -xf caltech256-subset.tar.gz -C data/

To reinstall 1.21.4, run:
  brew reinstall wget
--2023-06-20 23:22:00--  https://cleanlab-public.s3.amazonaws.com/Datalab/caltech256-subset.tar.gz
Resolving cleanlab-public.s3.amazonaws.com (cleanlab-public.s3.amazonaws.com)... 100.64.1.83
Connecting to cleanlab-public.s3.amazonaws.com (cleanlab-public.s3.amazonaws.com)|100.64.1.83|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 43765912 (42M) [application/x-gzip]
Saving to: ‘caltech256-subset.tar.gz’


2023-06-20 23:22:15 (3.12 MB/s) - ‘caltech256-subset.tar.gz’ saved [43765912/43765912]



In [15]:
def convert_to_rgb(img):
    if img.mode == 'L':
        img = img.convert('RGB')
    return img


transform = transforms.Compose([
    transforms.Lambda(convert_to_rgb),
    transforms.Resize((224, 224)),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(  # ImageNet stats for normalization (mean and std) of RGB channels
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    )
])

DATA_DIR = './data' # Save notebook artifacts in this directory

# Load data from disk
dataset = ImageFolder(
    os.path.join(DATA_DIR, "caltech256-subset"),
    transform=transform
)
n_classes = len(dataset.classes)
labels = np.array(dataset.targets)

## Model Training

With the dataset ready, we now define the hyperparameters for training and perform k-fold cross-validation to train a Swin Transformer.


In [16]:
# Define hyperparameters

batch_size = 32 # Resnet50: 64, Swin-Transformer-patch-4-window-7-224: 32
learning_rate = 0.00001 # Resnet50: 0.001, Swin-Transformer-patch-4-window-7-224: 0.0001
num_epochs = 10
num_folds = 5  # Use 3 for faster training
patience = 2

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model_name = "swin_base_patch4_window7_224"
model_prefix = "caltech256_subset_" + model_name

### K-Fold Cross-Validation and Training

In this notebook, we use k-fold cross-validation to train the model and extract out-of-sample predicted probabilities for all data points.

During the training process, we'll just use the validation accuracy on the held-out fold to allow early stopping for each fold. This approach helps us prevent overfitting and obtain a better estimate of the model's performance.

While we're not specifically interested in the model artifacts themselves, we aim to get a general idea of whether the chosen model architecture is accurate enough for our purpose.

**Warning**: This cell may take a long time to execute and should be run with a GPU.

In [17]:
kf = StratifiedKFold(n_splits=num_folds, shuffle=True, random_state=42)

In [18]:
# Train model

# Set seed
seed = 42
torch.manual_seed(seed)
np.random.seed(seed)
for fold, (train_idx, test_idx) in enumerate(kf.split(dataset, labels)):
    print(f'Fold {fold + 1}/{num_folds}')
    print('-' * 10)

    # Define data loaders for current fold
    train_subset = torch.utils.data.Subset(dataset, train_idx)
    val_subset = torch.utils.data.Subset(dataset, test_idx)
    # Print train and validation set sizes
    print(f'Train set size: {len(train_idx)}')
    print(f'Test set size: {len(test_idx)}')

    train_loader = torch.utils.data.DataLoader(train_subset, batch_size=batch_size, shuffle=True)
    val_loader = torch.utils.data.DataLoader(val_subset, batch_size=batch_size, shuffle=False)

    # Initialize model for current fold
    model = timm.create_model(model_name, pretrained=True, num_classes=n_classes)
    model = model.to(device)
    num_features = model.num_features

    # Define loss function and optimizer
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)

    best_val_accuracy = 0
    best_epoch = 0
    # Train model for current fold
    for epoch in range(num_epochs):
        print(f'Epoch {epoch + 1}/{num_epochs}')
        model.train()
        for inputs, targets in tqdm(train_loader):
            inputs = inputs.to(device)
            targets = targets.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        print(f'Train loss: {loss:.4f}')

        # Evaluate model on training set for current fold
        model.eval()
        with torch.no_grad():
            correct = 0
            total = 0
            eval_loader = train_loader
            for inputs, targets in tqdm(val_loader):
                inputs = inputs.to(device)
                targets = targets.to(device)
                outputs = model(inputs)
                _, predicted = torch.max(outputs.data, 1)
                total += targets.size(0)
                correct += (predicted == targets).sum().item()
            val_accuracy = 100 * correct / total
            print(f'Validation accuracy: {val_accuracy:.2f}%')

        # Save model checkpoint if it is the best so far
        if val_accuracy > best_val_accuracy:
            print('Saving model...')
            path = f'{model_prefix}_fold_{fold + 1}.pt'
            torch.save(model.state_dict(), path)
            best_val_accuracy = val_accuracy
            best_epoch = epoch

        # Early stopping
        if epoch - best_epoch > patience:
            print(f'Early stopping at epoch {epoch + 1}')
            break

Fold 1/5
----------
Train set size: 497
Test set size: 125


  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]
Downloading model.safetensors: 100%|██████████| 353M/353M [00:18<00:00, 18.9MB/s] 


Epoch 1/10


100%|██████████| 16/16 [01:44<00:00,  6.51s/it]


Train loss: 1.1315


100%|██████████| 4/4 [00:10<00:00,  2.61s/it]


Validation accuracy: 80.80%
Saving model...
Epoch 2/10


100%|██████████| 16/16 [01:42<00:00,  6.44s/it]


Train loss: 0.8746


100%|██████████| 4/4 [00:10<00:00,  2.56s/it]


Validation accuracy: 93.60%
Saving model...
Epoch 3/10


100%|██████████| 16/16 [01:43<00:00,  6.45s/it]


Train loss: 0.3720


100%|██████████| 4/4 [00:10<00:00,  2.64s/it]


Validation accuracy: 95.20%
Saving model...
Epoch 4/10


100%|██████████| 16/16 [01:42<00:00,  6.42s/it]


Train loss: 0.2547


100%|██████████| 4/4 [00:10<00:00,  2.66s/it]


Validation accuracy: 96.00%
Saving model...
Epoch 5/10


100%|██████████| 16/16 [01:44<00:00,  6.51s/it]


Train loss: 0.0865


100%|██████████| 4/4 [00:10<00:00,  2.56s/it]


Validation accuracy: 97.60%
Saving model...
Epoch 6/10


100%|██████████| 16/16 [01:42<00:00,  6.42s/it]


Train loss: 0.0675


100%|██████████| 4/4 [00:10<00:00,  2.56s/it]


Validation accuracy: 98.40%
Saving model...
Epoch 7/10


100%|██████████| 16/16 [01:42<00:00,  6.44s/it]


Train loss: 0.0274


100%|██████████| 4/4 [00:10<00:00,  2.57s/it]


Validation accuracy: 97.60%
Epoch 8/10


100%|██████████| 16/16 [01:46<00:00,  6.67s/it]


Train loss: 0.0243


100%|██████████| 4/4 [00:10<00:00,  2.65s/it]


Validation accuracy: 96.80%
Epoch 9/10


100%|██████████| 16/16 [01:47<00:00,  6.74s/it]


Train loss: 0.0179


100%|██████████| 4/4 [00:10<00:00,  2.66s/it]


Validation accuracy: 96.80%
Early stopping at epoch 9
Fold 2/5
----------
Train set size: 497
Test set size: 125
Epoch 1/10


100%|██████████| 16/16 [01:50<00:00,  6.92s/it]


Train loss: 1.2916


100%|██████████| 4/4 [00:10<00:00,  2.64s/it]


Validation accuracy: 78.40%
Saving model...
Epoch 2/10


100%|██████████| 16/16 [01:49<00:00,  6.81s/it]


Train loss: 0.8155


100%|██████████| 4/4 [00:10<00:00,  2.67s/it]


Validation accuracy: 87.20%
Saving model...
Epoch 3/10


100%|██████████| 16/16 [01:48<00:00,  6.77s/it]


Train loss: 0.2490


100%|██████████| 4/4 [00:10<00:00,  2.60s/it]


Validation accuracy: 89.60%
Saving model...
Epoch 4/10


100%|██████████| 16/16 [01:47<00:00,  6.71s/it]


Train loss: 0.1646


100%|██████████| 4/4 [00:10<00:00,  2.58s/it]


Validation accuracy: 92.80%
Saving model...
Epoch 5/10


100%|██████████| 16/16 [01:46<00:00,  6.67s/it]


Train loss: 0.0606


100%|██████████| 4/4 [00:10<00:00,  2.62s/it]


Validation accuracy: 92.80%
Epoch 6/10


100%|██████████| 16/16 [01:46<00:00,  6.68s/it]


Train loss: 0.0888


100%|██████████| 4/4 [00:10<00:00,  2.52s/it]


Validation accuracy: 92.80%
Epoch 7/10


100%|██████████| 16/16 [01:44<00:00,  6.52s/it]


Train loss: 0.0695


100%|██████████| 4/4 [00:10<00:00,  2.54s/it]


Validation accuracy: 94.40%
Saving model...
Epoch 8/10


100%|██████████| 16/16 [01:43<00:00,  6.49s/it]


Train loss: 0.0280


100%|██████████| 4/4 [00:10<00:00,  2.53s/it]


Validation accuracy: 94.40%
Epoch 9/10


100%|██████████| 16/16 [01:43<00:00,  6.46s/it]


Train loss: 0.0229


100%|██████████| 4/4 [00:10<00:00,  2.54s/it]


Validation accuracy: 94.40%
Epoch 10/10


100%|██████████| 16/16 [01:43<00:00,  6.48s/it]


Train loss: 0.0124


100%|██████████| 4/4 [00:10<00:00,  2.55s/it]


Validation accuracy: 93.60%
Early stopping at epoch 10
Fold 3/5
----------
Train set size: 498
Test set size: 124
Epoch 1/10


100%|██████████| 16/16 [01:43<00:00,  6.47s/it]


Train loss: 1.2679


100%|██████████| 4/4 [00:10<00:00,  2.56s/it]


Validation accuracy: 78.23%
Saving model...
Epoch 2/10


100%|██████████| 16/16 [01:43<00:00,  6.47s/it]


Train loss: 0.8611


100%|██████████| 4/4 [00:10<00:00,  2.53s/it]


Validation accuracy: 91.94%
Saving model...
Epoch 3/10


100%|██████████| 16/16 [01:44<00:00,  6.50s/it]


Train loss: 0.4309


100%|██████████| 4/4 [00:10<00:00,  2.61s/it]


Validation accuracy: 92.74%
Saving model...
Epoch 4/10


100%|██████████| 16/16 [01:43<00:00,  6.50s/it]


Train loss: 0.1647


100%|██████████| 4/4 [00:10<00:00,  2.61s/it]


Validation accuracy: 92.74%
Epoch 5/10


100%|██████████| 16/16 [01:43<00:00,  6.48s/it]


Train loss: 0.1509


100%|██████████| 4/4 [00:10<00:00,  2.57s/it]


Validation accuracy: 95.16%
Saving model...
Epoch 6/10


100%|██████████| 16/16 [01:44<00:00,  6.50s/it]


Train loss: 0.0213


100%|██████████| 4/4 [00:10<00:00,  2.55s/it]


Validation accuracy: 92.74%
Epoch 7/10


100%|██████████| 16/16 [01:44<00:00,  6.50s/it]


Train loss: 0.0056


100%|██████████| 4/4 [00:10<00:00,  2.56s/it]


Validation accuracy: 95.16%
Epoch 8/10


100%|██████████| 16/16 [01:43<00:00,  6.48s/it]


Train loss: 0.0237


100%|██████████| 4/4 [00:10<00:00,  2.57s/it]


Validation accuracy: 95.16%
Early stopping at epoch 8
Fold 4/5
----------
Train set size: 498
Test set size: 124
Epoch 1/10


100%|██████████| 16/16 [01:45<00:00,  6.58s/it]


Train loss: 1.1438


100%|██████████| 4/4 [00:10<00:00,  2.55s/it]


Validation accuracy: 82.26%
Saving model...
Epoch 2/10


100%|██████████| 16/16 [01:44<00:00,  6.54s/it]


Train loss: 0.6534


100%|██████████| 4/4 [00:10<00:00,  2.54s/it]


Validation accuracy: 91.13%
Saving model...
Epoch 3/10


100%|██████████| 16/16 [01:44<00:00,  6.54s/it]


Train loss: 0.3050


100%|██████████| 4/4 [00:10<00:00,  2.54s/it]


Validation accuracy: 92.74%
Saving model...
Epoch 4/10


100%|██████████| 16/16 [01:45<00:00,  6.58s/it]


Train loss: 0.1159


100%|██████████| 4/4 [00:10<00:00,  2.56s/it]


Validation accuracy: 95.16%
Saving model...
Epoch 5/10


100%|██████████| 16/16 [01:44<00:00,  6.56s/it]


Train loss: 0.0519


100%|██████████| 4/4 [00:10<00:00,  2.55s/it]


Validation accuracy: 94.35%
Epoch 6/10


100%|██████████| 16/16 [01:44<00:00,  6.56s/it]


Train loss: 0.0233


100%|██████████| 4/4 [00:10<00:00,  2.56s/it]


Validation accuracy: 94.35%
Epoch 7/10


100%|██████████| 16/16 [01:45<00:00,  6.57s/it]


Train loss: 0.0156


100%|██████████| 4/4 [00:09<00:00,  2.50s/it]


Validation accuracy: 94.35%
Early stopping at epoch 7
Fold 5/5
----------
Train set size: 498
Test set size: 124
Epoch 1/10


100%|██████████| 16/16 [01:42<00:00,  6.38s/it]


Train loss: 1.2440


100%|██████████| 4/4 [00:09<00:00,  2.50s/it]


Validation accuracy: 77.42%
Saving model...
Epoch 2/10


100%|██████████| 16/16 [01:41<00:00,  6.34s/it]


Train loss: 0.8856


100%|██████████| 4/4 [00:09<00:00,  2.46s/it]


Validation accuracy: 93.55%
Saving model...
Epoch 3/10


100%|██████████| 16/16 [01:42<00:00,  6.38s/it]


Train loss: 0.4025


100%|██████████| 4/4 [00:09<00:00,  2.46s/it]


Validation accuracy: 93.55%
Epoch 4/10


100%|██████████| 16/16 [01:42<00:00,  6.43s/it]


Train loss: 0.3152


100%|██████████| 4/4 [00:10<00:00,  2.65s/it]


Validation accuracy: 95.16%
Saving model...
Epoch 5/10


100%|██████████| 16/16 [01:47<00:00,  6.69s/it]


Train loss: 0.1281


100%|██████████| 4/4 [00:09<00:00,  2.47s/it]


Validation accuracy: 95.16%
Epoch 6/10


100%|██████████| 16/16 [01:42<00:00,  6.41s/it]


Train loss: 0.0492


100%|██████████| 4/4 [00:10<00:00,  2.60s/it]


Validation accuracy: 95.16%
Epoch 7/10


100%|██████████| 16/16 [01:45<00:00,  6.57s/it]


Train loss: 0.0208


100%|██████████| 4/4 [00:10<00:00,  2.72s/it]

Validation accuracy: 94.35%
Early stopping at epoch 7





## Getting Predicted Class Probabilities and Extracting Feature Embeddings

After training, we will compute predicted class probabilities for the entire dataset using the trained models from each fold.

In addition, to keep things simple, we'll use the model trained on the first fold as a feature extractor to obtain embeddings for every image in the dataset.

These artifacts will be used by `Datalab` to inspect the dataset for potential issues, so we save them to files used in the next notebook.

In [19]:
model = timm.create_model(model_name, pretrained=True, num_classes=n_classes)
path = f'{model_prefix}_fold_1.pt'
model.load_state_dict(torch.load(path))
model.eval()
model.to(device)
num_features = model.num_features

features = np.zeros((len(dataset),num_features))
pred_probs = np.zeros((len(dataset), n_classes))

for fold, (_, test_idx) in enumerate(kf.split(dataset, labels)):
    # Save out-of-sample predictions and features for current fold
    # This is the validation set
    # Define data loaders for current fold
    test_subset = torch.utils.data.Subset(dataset, test_idx)
    test_loader = torch.utils.data.DataLoader(test_subset, batch_size=batch_size, shuffle=False)


    model = timm.create_model(model_name, pretrained=True, num_classes=n_classes)
    path = f'{model_prefix}_fold_{fold + 1}.pt'
    model.load_state_dict(torch.load(path))

    model.eval()
    model.to(device)

    with torch.no_grad():
        pred_probs_fold = []
        for inputs, _ in tqdm(test_loader):
            inputs = inputs.to(device)
            outputs = model(inputs)
            # Predicted probabilities
            outputs = nn.functional.softmax(outputs, dim=1)
            pred_probs_fold.append(outputs.cpu().numpy())
        pred_probs[test_idx] = np.concatenate(pred_probs_fold, axis=0)

    model = timm.create_model(model_name, pretrained=True, num_classes=n_classes)
    path = f'{model_prefix}_fold_1.pt'
    model.load_state_dict(torch.load(path))
    model.eval()
    model.to(device)
    with torch.no_grad():
        features_fold = []
        model.reset_classifier(0)
        for inputs, _ in tqdm(test_loader):
            inputs = inputs.to(device)
            features_fold.append(model(inputs).cpu().numpy())
        features[test_idx] = np.concatenate(features_fold, axis=0)

features_path = os.path.join(DATA_DIR, "features.npy")
pred_probs_path = os.path.join(DATA_DIR, "pred_probs.npy")

np.save(features_path, features)
np.save(pred_probs_path, pred_probs)

100%|██████████| 4/4 [00:10<00:00,  2.59s/it]
100%|██████████| 4/4 [00:10<00:00,  2.69s/it]
100%|██████████| 4/4 [00:10<00:00,  2.73s/it]
100%|██████████| 4/4 [00:10<00:00,  2.72s/it]
100%|██████████| 4/4 [00:10<00:00,  2.61s/it]
100%|██████████| 4/4 [00:10<00:00,  2.59s/it]
100%|██████████| 4/4 [00:10<00:00,  2.53s/it]
100%|██████████| 4/4 [00:10<00:00,  2.63s/it]
100%|██████████| 4/4 [00:10<00:00,  2.65s/it]
100%|██████████| 4/4 [00:10<00:00,  2.57s/it]


In [20]:
import pickle
pickle.dump(model, open('model.pkl', 'wb'))