In [1]:
import torch 
import torch.nn as nn
import torch.optim as optim
import torchvision
#suppress warnings about beta APIs which clutter up output
torch.set_warn_always(False)
torchvision.disable_beta_transforms_warning()
from torchvision.models import resnet50, ResNet50_Weights
import torchvision.transforms as transforms

import torchvision.transforms.v2 as v2
import numpy as np
import scipy
from copy import deepcopy

In [2]:
#Create resnet50 model and load pretrained weights. Fine tuning all the weights as opposed to just the final layer
#produces a higher validation accuracy, so we do not freeze the base model.
weights = ResNet50_Weights.DEFAULT
model = resnet50(weights=weights)

#Replace final prediction head with a zero-initialized linear layer for fine tuning
model._modules["fc"] = nn.Linear(2048,102)
torch.nn.init.constant_(model._modules["fc"].weight,0)

Parameter containing:
tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]], requires_grad=True)

Because the Oxford Flowers dataset has large variations in light, scale, and pose, as well as large variations in classes, we need data augmentation to improve performance on the dataset. Out of the preformulated methods AutoAugment, RandAugment, TrivialAugment, and AugMix, the TrivialAugment method was found to have the lowest validation error rate. 
In addition, composing the TrivialAugment method with other augmentations like photometric distortion did not improve the validation performance, so we use only TrivialAugment.
After hyperparameter tuning, we train on both the training and validation sets and report performance on the testing set.

In [3]:
#Combine default preprocessing with the TrivialAugment data augmentation method for training, and remove 
#preprocessing for testing
preprocess = weights.transforms()
transform_train = transforms.Compose([v2.TrivialAugmentWide(), preprocess])
transform_test = preprocess

#Load/download datasets
train_ds = torchvision.datasets.Flowers102("flowers102/train", split="train", download=True, transform=transform_train)
validation_ds = torchvision.datasets.Flowers102("flowers102/validation", split="val", download=True, transform=transform_train)
test_ds = torchvision.datasets.Flowers102("flowers102/test", split="test", download=True, transform=transform_test)

#Combine train and validation sets for training after hyperparameter tuning
train_ds = torch.utils.data.ConcatDataset([train_ds, validation_ds])

#Create dataloaders
trainloader = torch.utils.data.DataLoader(
        train_ds, batch_size=64, shuffle=True, num_workers=2)
testloader = torch.utils.data.DataLoader(
        test_ds, batch_size=64, shuffle=True, num_workers=2)

While Adam and its variant AMSGrad are robust and state-of-the art algorithms for gradient descent, many image classification papers still use SGD with momentum. On Oxford Flowers, AMSGrad performs the best on validation error, followed by SGD with momentum, followed by Adam. 

In [4]:
#Fine tune for 30 epochs
num_epochs = 30

#Use cross entropy classification loss
criterion = nn.CrossEntropyLoss()

#Use AMSGrad, which had the lowest validation error out of Adam (default), AMSGrad, 
#and SGD w/ momentum during hyperparameter tuning.
optimizer = optim.Adam(model.parameters(), lr=5E-4, weight_decay=0, amsgrad=True) 
    
#Use a Cosine learning rate decay schedule
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs)
    
#Use GPU when available
device = 'cuda' if torch.cuda.is_available() else 'cpu'

#Training loop. Iterate through dataloader, backpropagating losses through the model, and track the training loss 
#during the epoch

def train(cnn):
    cnn.train()
    train_loss = []
    correct = 0
    total = 0
    for batch_idx, (inputs, targets) in enumerate(trainloader):
        inputs, targets = inputs.to(device), targets.to(device)
        optimizer.zero_grad()
        outputs = cnn(inputs)
        loss = criterion(outputs, targets)
        
        loss.backward()
        optimizer.step()
        train_loss.append(loss.detach().to("cpu").numpy())
    return np.mean(train_loss)


#Model testing. Reports both the average loss over the test set as well as the accuracy. 

def test(cnn):
    cnn.eval()
    test_loss = 0
    correct = 0
    total = 0
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(testloader):
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = cnn(inputs)
            loss = criterion(outputs, targets)

            test_loss += loss.mean().item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()

    return test_loss/(batch_idx+1), 100.*correct/total

In [None]:
#Train model
import warnings
warnings.filterwarnings('ignore')

model.to(device)
import time 
t = time.time()
for epoch in range(num_epochs):
    train_loss = train(model)
    print(f"train loss {train_loss}")
    print(f"TIME {time.time()-t}")
    t = time.time()
    scheduler.step()



train loss 3.4496569633483887
TIME 447.0408010482788




train loss 0.7933040857315063
TIME 429.0125479698181




train loss 0.30601584911346436
TIME 407.79935908317566




train loss 0.21013431251049042
TIME 402.0576858520508




train loss 0.15429770946502686
TIME 415.4211268424988




train loss 0.11646496504545212
TIME 411.57110118865967




train loss 0.13206937909126282
TIME 406.99008798599243




In [None]:
#Evaluate model 
test(model)

Our model achieves a test accuracy of 96.0%. While this model achieves good results (in comparison, the ImageNet1k-pretrained resnet50 model in the BiT paper only has 74.9% accuracy), we can improve performance by choosing a resnet50 model with better pretraining. In the next section, we fine tune the ImageNet21k-pretrained resnet50 model from the BiT paper (https://arxiv.org/pdf/2106.00116v4.pdf) to get significantly lower error rates. 

# BiT Fine Tuning

This section largely follows the BiT paper, and many choices like data augmentation were initialized to their settings before hyperparameter search.

In [None]:
import numpy as np
import os
import tensorflow as tf
from tensorflow.keras import layers
import tensorflow_hub as hub
import tensorflow_datasets as tfds

In [None]:
#Load train, validation, and test datasets
train_ds = tfds.load('oxford_flowers102', split='train', shuffle_files=True, batch_size=32)
validation_ds = tfds.load('oxford_flowers102', split='validation', shuffle_files=True, batch_size=32)
test_ds = tfds.load('oxford_flowers102', split='test', shuffle_files=True, batch_size=32)

train_ds = train_ds.prefetch(buffer_size=tf.data.AUTOTUNE)
validation_ds = validation_ds.prefetch(buffer_size=tf.data.AUTOTUNE)
test_ds = test_ds.prefetch(buffer_size=tf.data.AUTOTUNE)

train_ds = train_ds.map(lambda x:(x["image"], x["label"]))
validation_ds = validation_ds.map(lambda x:(x["image"], x["label"]))
test_ds = test_ds.map(lambda x:(x["image"], x["label"]))

#Combine train and validation sets for training.
train_ds.concatenate(validation_ds)
train_ds.shuffle(buffer_size=train_ds.cardinality())

Due to the large variations in light in the Oxford Flowers dataset, we use random brightness manipulation in addition to the transformations from the BiT paper (https://arxiv.org/pdf/2106.00116v4.pdf). This results in a slight increase in validation accuracy.
As before, and as in the paper, we use the base model as a feature extractor and fine tune all of the weights.

In [None]:
#Preprocessing includes the augmentation techniques used in the BiT paper, 
#as well as a random brightness scaling to deal with the light variations in images
transform = tf.keras.Sequential([
  tf.keras.layers.RandomFlip('horizontal'),
  tf.keras.layers.RandomBrightness(0.2),
  tf.keras.layers.Resizing(512,512),
  tf.keras.layers.RandomCrop(480,480),
  tf.keras.layers.Rescaling(1./255)
])

# Load resnet50 BiT model as a feature extractor
model_url = "https://tfhub.dev/google/bit/m-r50x1/1"
module = hub.KerasLayer(model_url)

#Combine preprocessing, feature extraction, and prediction head into model
model = tf.keras.Sequential([transform, module, tf.keras.layers.Dense(102, kernel_initializer='zeros')])

#Compile model. Since we train for a small number of 30 epochs compared to the fine-tuning done in the paper, 
#removing learning rate annealing improves the validation accuracy as it allows the model to actually fit the dataset.
optimizer = tf.keras.optimizers.SGD(learning_rate=0.003, momentum=0.9)
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
model.compile(optimizer=optimizer,
              loss=loss_fn,
              metrics=['accuracy'])

In [None]:
#Train model for 30 epochs
history = model.fit(
    train_ds,
    batch_size=32,
    epochs=30,
    validation_data=test_ds 
)

In [None]:
#Evaluate model
model.evaluate(test_ds)

By using a resnet50 model pretrained on ImageNet21k, we are able to raise the test accuracy by two percentage points to 98.2%, equalling the results reported in the BiT paper