# 1. Preemptive Research

The Oxford Flowers102 dataset is a popular dataset on Kaggle and based off this leaderboard, https://paperswithcode.com/sota/image-classification-on-flowers-102, the best classification accuracy achieved is 99.76% by a CvT-W24 neural network trained with extra data. Without extra data, the best network was a TransBoost-ResNet50 model with 97.85% accuracy. This notebook attempts to recreate the TransBoost-Resnet50 model by using acquiring a pretrained ResNet50 model from Pytorch trained on ImageNet and performing transfer learning on it by training it on the Flowers102 dataset.    

# 2. Training and saving the model

## Import Pytorch and other dependencies

In [None]:
from torchvision.models import resnet50, ResNet50_Weights
from torchvision.datasets import Flowers102
from torchvision.transforms import transforms
from torch.optim import lr_scheduler
from torchvision import datasets, transforms, models


import torchvision
import torch
import matplotlib.pyplot as plt
import numpy as np
import json
import copy
import time
import seaborn as sns

import torch.nn as nn
import torch.optim as optim

from PIL import Image

In [None]:
# set to GPU if its available

device = torch.device('cpu')
if torch.cuda.is_available():
    print(f'Found {torch.cuda.device_count()} GPU(s).')
    device = torch.device('cuda:0')
print(f'The device is set to {device}')


## Load model and image dataset

In [None]:
model = resnet50(weights=ResNet50_Weights.DEFAULT)

In [None]:
normalize = transforms.Normalize([0.485, 0.456, 0.406],
                                 [0.229, 0.224, 0.225])

data_transforms = {
    'train': transforms.Compose([
    transforms.RandomRotation(45),
    transforms.RandomResizedCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    normalize]),

    'validation': transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        normalize
    ]),

    'test': transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            normalize
    ])
}

training_set = datasets.ImageFolder('data2/train', transform=data_transforms['train'])
validation_set = datasets.ImageFolder('data2/valid', transform=data_transforms['validation'])
test_set = datasets.ImageFolder('data2/test', transform=data_transforms['test'])

print(np.max(training_set.targets))
idx_to_class = { v:k for k,v in training_set.class_to_idx.items()}

# make sure training labels are the same as testing labels to ensure that testing works
assert training_set.class_to_idx == test_set.class_to_idx

# Pytorch's Image folder creates an internal representation of folder names to actual numbered labels 
# So we have to create another map to convert Pytorch label's to folder names and then convert those to actual flower names
print(idx_to_class)



# Create data loaders for our datasets; shuffle for training, not for validation
training_loader = torch.utils.data.DataLoader(training_set, batch_size=64, shuffle=True, num_workers=0)
validation_loader = torch.utils.data.DataLoader(validation_set, batch_size=64, shuffle=True, num_workers=0)
test_loader = torch.utils.data.DataLoader(test_set, batch_size=64, shuffle=False, num_workers=0)

dataloaders = {'train': training_loader, 'val': validation_loader}
dataset_sizes = {'train': len(training_set), 'val': len(validation_set)}

# Report split sizes
print('Training set has {} instances'.format(len(training_set)))
print('Validation set has {} instances'.format(len(validation_set)))
print('Test set has {} instances'.format(len(test_set)))

### Sanity Check (Display images)

In [None]:
def imshow(image, ax=None, title=None):
    if ax is None:
        fig, ax = plt.subplots()
    if title:
        plt.title(title)
    # PyTorch tensors assume the color channel is first
    # but matplotlib assumes is the third dimension
    image = image.permute(1, 2, 0)
    
    # Undo preprocessing
    mean = np.array([0.485, 0.456, 0.406])
    std = np.array([0.229, 0.224, 0.225])
    image = std * image.numpy() + mean
    
    # Image needs to be clipped between 0 and 1
    image = np.clip(image, 0, 1)
    
    ax.imshow(image)
    
    return ax

# from https://github.com/bdevnani3/oxfordflowers102-label-name-mapping/blob/main/mapping.json
with open('udacity_label_to_name.json', 'r') as f:
    flower_to_name = json.load(f)

dataiter = iter(training_loader)
images, labels = next(dataiter)
print(labels[0:8])

# Create a grid from the images and show them
out = torchvision.utils.make_grid(images[0:8])
imshow(out, title=[flower_to_name[idx_to_class[labels[j].item()]] for j in range(8)])

In [None]:
def train_model(model, criterion, optimizer, scheduler, num_epochs=25):
    since = time.time()

    best_model_wts = copy.deepcopy(model.state_dict())
    best_acc = 0.0

    for epoch in range(num_epochs):
        print(f'Epoch {epoch}/{num_epochs - 1}')
        print('-' * 10)

        # Each epoch has a training and validation phase
        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()  # Set model to training mode
            else:
                model.eval()   # Set model to evaluate mode

            running_loss = 0.0
            running_corrects = 0

            # Iterate over data.
            for inputs, labels in dataloaders[phase]:
                inputs = inputs.to(device)
                labels = labels.to(device)

                # zero the parameter gradients
                optimizer.zero_grad()

                # forward
                # track history if only in train
                with torch.set_grad_enabled(phase == 'train'):
                    outputs = model(inputs)
                    _, preds = torch.max(outputs, 1)
                    loss = criterion(outputs, labels)

                    # backward + optimize only if in training phase
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()

                # statistics
                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)
            if phase == 'train':
                scheduler.step()

            epoch_loss = running_loss / dataset_sizes[phase]
            epoch_acc = running_corrects.double() / dataset_sizes[phase]

            print(f'{phase} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}')

            # deep copy the model
            if phase == 'val' and epoch_acc > best_acc:
                best_acc = epoch_acc
                best_model_wts = copy.deepcopy(model.state_dict())

        print()

    time_elapsed = time.time() - since
    print(f'Training complete in {time_elapsed // 60:.0f}m {time_elapsed % 60:.0f}s')
    print(f'Best val Acc: {best_acc:4f}')

    # load best model weights
    model.load_state_dict(best_model_wts)
    return model

In [None]:
# modifying architecture for transfer learning
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, len(flower_to_name))
model_ft = model.to(device)
epochs = 9
learning_rate = 0.001
criterion = nn.NLLLoss()
optimizer_ft = optim.Adam(model.classifier.parameters(), lr=learning_rate)

# Decay LR by a factor of 0.1 every 3 epochs
exp_lr_scheduler = lr_scheduler.StepLR(optimizer_ft, step_size=2, gamma=0.1)

In [None]:
# train model using transfer learning
model_tl = train_model(model, criterion, optimizer_ft, exp_lr_scheduler, num_epochs=12)

### Saving and Reloading the model

In [None]:
model_tl.cpu()
torch.save({'arch': 'resNet50',
            'state_dict': model_tl.state_dict(), 
            'class_to_idx': training_set.class_to_idx}, 
            'resNet50classifierHigherAcc.pth')

In [None]:
def load_model(checkpoint_path):
    chpt = torch.load(checkpoint_path)
    
    if chpt['arch'] == 'resNet50':
        model = models.resnet50(weights=ResNet50_Weights.DEFAULT)
        for param in model.parameters():
            param.requires_grad = False
    else:
        print("Sorry base architecture note recognized")
        exit()
    
    model.class_to_idx = chpt['class_to_idx']
    num_ftrs = model.fc.in_features

    model.fc = nn.Linear(num_ftrs, len(flower_to_name))

    model.load_state_dict(chpt['state_dict'])
    model.to(device) 
    return model

In [None]:
model_tl = load_model('resNet50classifier.pth')

# 3. Testing Model

In [None]:
running_accuracy = 0 
total = 0 
with torch.no_grad(): 
	for data in test_loader: 
		inputs, outputs = data
		inputs = inputs.to(device)
		outputs = outputs.to(device)
		outputs = outputs.to(torch.float32) 

		predicted_outputs = model_tl(inputs) 
		_, predicted = torch.max(predicted_outputs, 1) 
		total += outputs.size(0) 
		running_accuracy += (predicted == outputs).sum().item() 

print('Accuracy of the model based on the test set of 819 inputs is: %d %%' % (100 * running_accuracy / total))    

### Perform transforms that were used in training/validation

In [None]:
def process_image(image_path):
    img = Image.open(image_path)
    adjust = transforms.Compose([transforms.Resize(256),
                                      transforms.CenterCrop(224),
                                      transforms.ToTensor(),
                                      transforms.Normalize([0.485, 0.456, 0.406], 
                                                           [0.229, 0.224, 0.225])])
    img_tensor = adjust(img)
    return img_tensor

## Sanity Checking

Determining if the model finds the right label with high probability based on a hand-picked test set (smoke-testing)

In [None]:
def predict(image_path, model, topk=5):
    ''' Predict the class (or classes) of an image using a trained deep learning model.
    '''
    processed_image = process_image(image_path)
    processed_image.unsqueeze_(0)
    processed_image = processed_image.to(device)
    probs = torch.exp(model.forward(processed_image))
    top_probs, top_labs = probs.topk(topk)

    top_probs = top_probs.cpu()
    top_labs = top_labs.cpu()

    idx_to_class = {}
    for key, value in model.class_to_idx.items():
        idx_to_class[value] = key

    np_top_labs = top_labs[0].numpy()

    top_labels = []
    for label in np_top_labs:
        top_labels.append(int(idx_to_class[label]))

    top_flowers = [flower_to_name[idx_to_class[lab]] for lab in top_labels]
    
    return top_probs, top_labels, top_flowers

In [None]:
predict("./data2/test/1/image_06743.jpg", model_tl)

In [None]:
def plot_solution(image_path, model):
    # Sets up our plot
    plt.figure(figsize = (6,10))
    ax = plt.subplot(2,1,1)
    # Set up title
    flower_num = image_path.split('/')[2]
    title_ = flower_to_name[flower_num] # Calls dictionary for name
    # Plot flower
    img = process_image(image_path)
    plt.title("Actual Label:", title_)
    imshow(img, ax)
    # Make prediction
    top_probs, top_labels, top_flowers = predict(image_path, model) 
    top_probs = top_probs[0].detach().numpy() #converts from tensor to nparray
    # Plot bar chart
    plt.subplot(2,1,2)
    sns.barplot(x=top_probs, y=top_flowers, color=sns.color_palette()[0])
    plt.show()

    print(top_probs, top_labels, top_flowers)

In [None]:
plot_solution("data2/test/1/image_06743.jpg", model_tl)

In [None]:
def visualize_model(model, num_images=6):
    model.to(device)
    was_training = model.training
    model.eval()
    images_so_far = 0
    fig = plt.figure()

    with torch.no_grad():
        for i, (inputs, labels) in enumerate(dataloaders['val']):
            inputs = inputs.to(device)
            labels = labels.to(device)

            outputs = model(inputs)
            _, preds = torch.max(outputs, 1)

            for j in range(inputs.size()[0]):
                images_so_far += 1
                ax = plt.subplot(num_images//2, 2, images_so_far)
                ax.axis('off')
                ax.set_title(f'predicted: {flower_to_name[idx_to_class[preds[j].item()]]}')
                imshow(inputs.cpu().data[j])

                if images_so_far == num_images:
                    model.train(mode=was_training)
                    return
        model.train(mode=was_training)

In [None]:
visualize_model(model_tl)