
Finetuning Torchvision Models
=============================
[original tutorial](https://pytorch.org/tutorials/beginner/finetuning_torchvision_models_tutorial.html) ([cs231n](https://cs231n.github.io/transfer-learning)) | [models](https://pytorch.org/vision/stable/models.html) | [datasets](https://pytorch.org/vision/stable/datasets.html)

TODO:
* try resolution > 224 ([discussion](https://discuss.pytorch.org/t/transfer-learning-usage-with-different-input-size/20744/6))

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
from torchvision import datasets, transforms
#from torch.utils.data import Dataset
import matplotlib.pyplot as plt
import os
import numpy as np
import json
import time

# Helper Functions
from libs.model_definitions import initialize_model
from libs.train_model import train_model
from libs.dataset_utils import get_transforms
from libs import splitfolders

%matplotlib inline

print("PyTorch Version: ",torch.__version__)
print("Torchvision Version: ",torchvision.__version__)

## Settings

In [33]:
raw_data_dir = "./dataset/_raw/"
split_data_dir = "./dataset/"  # [train, val, test] dirs
split_ratio = (.8, .1, .1)  # train, val, test split
copy_dataset = True  # copy or move files

# ImageNet normalization
mean = np.array([0.485, 0.456, 0.406])
std = np.array([0.229, 0.224, 0.225])

# number of data loader workers
dloader_workers = 6

model_list =    ["resnet18", "resnet50", "alexnet", "vgg11_bn", "squeezenet", "densenet121", "inception_v3", "mobilenet_v2", 
                "mobilenet_v3_large", "regnet_y_16gf", "efficientnet_v2_s", "efficientnet_v2_m", "convnext_base", "swin_v2_b"]

# select model type
model_type = "mobilenet_v3_large"

model_name = "mobilenet_v3_large"
checkpoints_dir = "checkpoints/"
checkpoint_path = os.path.join(checkpoints_dir, model_name + ".pt")

# metadata
json_path = os.path.join(checkpoints_dir, model_name + "_metadata.json")

# include softmax layer in model
add_softmax = False

# Feature Augmentation
scale_range = (0.75, 1.2)

# Device
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# Batch size for training (change depending on how much memory you have)
batch_size = 50

# Number of epochs to train for 
num_epochs = 20

learning_rate = 1e-4

optimizer_name = "adam"  # [adam, sgd]

# train all layers or just head layer
train_deep = True

# check if model_type exists
if not model_type in model_list:
    print(f"ERROR: model {model_type} unknown!")

### Split dataset

In [3]:
# check if dataset is already splitted
if not splitfolders.check_existence(split_data_dir, dirs=["val", "test", "train"]):

    # copy or move dataset split into train, validation and test
    splitfolders.ratio(raw_data_dir, output=split_data_dir, 
                       seed=1337, ratio=split_ratio,
                       group_prefix=None, 
                       move=not(copy_dataset))

In [None]:
# Number of classes in the dataset

def count_directories(path):
    return len([name for name in os.listdir(path) if os.path.isdir(os.path.join(path, name))])

num_classes = count_directories(os.path.join(split_data_dir, "train"))
print(f"Number of classes: {num_classes}")

----------------
# Training

## initialize model

In [5]:
# Initialize the model for this run
model, input_size = initialize_model(model_type, num_classes, train_deep, add_softmax=add_softmax)

replace Colab imshow with custom function for use in Jupyter

In [6]:
def convert_tensor_to_array(image, idx=0):
    # from libs.dataset_utils import convert_image_to_cv
    # img = convert_image_to_cv(image, RGB2BGR=False, normalized=True)

    img = image.cpu().data[idx].numpy().transpose((1, 2, 0))  # [-1 |  1 ]
    img = std * img + mean  # denormalize
    img = np.clip(img*255, 0, 255).astype(np.uint8)  # convert to uint8
    return img

### Load Data
----------------
Now that we know what the input size must be, we can initialize the data
transforms, image datasets, and the dataloaders. Notice, the models were
pretrained with the hard-coded normalization values, as described
`here <https://pytorch.org/docs/master/torchvision/models.html>`__.




In [None]:
# Data augmentation and normalization for training: http://pytorch.org/vision/main/transforms.html
# TODO: or use https://albumentations.ai instead
data_transforms = get_transforms(input_size, scale_range=scale_range, hflip=0.5, mean=mean, std=std)

# Create training and validation datasets
image_datasets = {x: datasets.ImageFolder(os.path.join(split_data_dir, x), data_transforms[x]) for x in ['train', 'val']}

# Create training and validation dataloaders
dataloaders_dict = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=batch_size, shuffle=True, num_workers=dloader_workers) for x in ['train', 'val']}

class_labels = image_datasets['train'].classes
print("class_labels:", class_labels)

In [56]:
def imshow(image, title=None):
    if not isinstance(image, np.ndarray):
        image = convert_tensor_to_array(image)

    ax = plt.subplot(2, 2, 1)
    ax.axis('off')
    if title:
        ax.set_title(title)
    plt.imshow(image)
    plt.pause(0.001)


def visualize_model(model, device=device, num_images=6):
    was_training = model.training
    model.eval()
    images_so_far = 0
    fig = plt.figure()

    with torch.no_grad():
        for i, (inputs, labels) in enumerate(dataloaders_dict['val']):
            inputs = inputs.to(device)
            labels = labels.to(device)

            outputs = model(inputs)
            probabilities = torch.nn.functional.softmax(outputs, dim=1)
            top_probs, top_labels = torch.max(probabilities, 1)

            for j in range(inputs.size()[0]):
                images_so_far += 1
                ax = plt.subplot(num_images // 2, 2, images_so_far)
                ax.axis('off')

                class_label = class_labels[top_labels[j]]
                probability = top_probs[j].cpu().numpy()
                ax.set_title(f'predicted: {class_label} ({probability:.2f})')

                img = convert_tensor_to_array(inputs, idx=j)
                plt.imshow(img)
            
                if images_so_far == num_images:
                    model.train(mode=was_training)
                    return
                
        model.train(mode=was_training)

## Create the Optimizer
--------------------

Now that the model structure is correct, the final step for finetuning
and feature extracting is to create an optimizer that only updates the
desired parameters. Recall that after loading the pretrained model, but
before reshaping, if ``train_deep=False`` we manually set all of the
parameter’s ``.requires_grad`` attributes to False. Then the
reinitialized layer’s parameters have ``.requires_grad=True`` by
default. So now we know that *all parameters that have
.requires_grad=True should be optimized.* Next, we make a list of such
parameters and input this list to the SGD algorithm constructor.

To verify this, check out the printed parameters to learn. When
finetuning, this list should be long and include all of the model
parameters. However, when feature extracting this list should be short
and only include the weights and biases of the reshaped layers.

--------------------
[using Adam instead of SGD](https://analyticsindiamag.com/ultimate-guide-to-pytorch-optimizers/)


In [9]:
# Send the model to GPU if possible
model = model.to(device)

# Gather the parameters to be optimized/updated in this run. If we are
#  finetuning we will be updating all parameters. However, if we are 
#  doing feature extract method, we will only update the parameters
#  that we have just initialized, i.e. the parameters with requires_grad
#  is True.
params_to_update = model.parameters()
#print("Params to learn:")
if not train_deep:
    params_to_update = []
    for name,param in model.named_parameters():
        if param.requires_grad == True:
            params_to_update.append(param)
            #print("\t",name)
else:
    for name,param in model.named_parameters():
        if param.requires_grad == True:
            #print("\t",name)
            continue

# Observe that all parameters are being optimized
if optimizer_name == "adam":
    optimizer_ft = optim.Adam(params_to_update, lr=learning_rate)
else:
    optimizer_ft = optim.SGD(params_to_update, lr=0.001, momentum=0.9)


## Run Training and Validation Step
--------------------------------

Finally, the last step is to setup the loss for the model, then run the
training and validation function for the set number of epochs. Notice,
depending on the number of epochs this step may take a while on a CPU.
Also, the default learning rate is not optimal for all of the models, so
to achieve maximum accuracy it would be necessary to tune for each model
separately.




In [None]:
# Setup the loss function
criterion = nn.CrossEntropyLoss()

starttime = time.time()

# Train and evaluate
model, hist = train_model(model, dataloaders_dict, criterion, optimizer_ft, num_epochs=num_epochs, device=device)

training_duration = time.time() - starttime

In [None]:
best_val_acc = round(float(max(hist)), 4)
last_val_acc = round(float(hist[-1]), 4)

print("best_val_acc:", best_val_acc)
print("last_val_acc:", last_val_acc)

In [None]:
visualize_model(model, device=device)

## Save Checkpoint
[saving and loading checkpoints tutorial](https://pytorch.org/tutorials/recipes/recipes/saving_and_loading_a_general_checkpoint.html) , [stackoverflow](https://stackoverflow.com/questions/42703500/how-do-i-save-a-trained-model-in-pytorch)



In [27]:
# get current date and time as string
def get_current_date():
    return time.strftime("%Y-%m-%d %H:%M:%S", time.gmtime())

# write model metadata to json file
def save_json(data, json_path, update=True):

    if update:
        with open(json_path) as f:
            data_old = json.load(f)
        data_old.update(data)
        data = data_old
    
    with open(json_path, 'w') as f:
        json.dump(data, f)

In [14]:
# EPOCH = num_epochs  # TODO: current epoch / epoch of best val acc
# LOSS = 0.4          # TODO read from hist
#
# torch.save({
#             'epoch': EPOCH,
#             'model_state_dict': model.state_dict(),
#             'optimizer_state_dict': optimizer_ft.state_dict(),
#             'loss': LOSS,
#             }, checkpoint_path)

torch.save(model, checkpoint_path)


# Write variables to a JSON file
data = {
    'date_created':             get_current_date(),
    'model_type':               model_type, 
    'input_size':               input_size,
    'has_softmax':              add_softmax,
    'class_labels':             class_labels, 
    'initial_learning_rate':    learning_rate,
    'epochs':                   num_epochs,
    'training_time':            training_duration,
    'best_val_acc':             best_val_acc,
    'last_val_acc':             last_val_acc
    }

save_json(data, json_path, update=False)

### convert to ONNX

In [28]:
model.eval()

# # Define the data transformations (same as used during training)
# data_transforms = transforms.Compose([
#     transforms.Resize(256),
#     transforms.CenterCrop(224),
#     transforms.ToTensor(),
#     transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
# ])

# Load a sample image from your validation dataset
images_dir = os.path.join(split_data_dir, "val")
image_datasets = datasets.ImageFolder(images_dir, data_transforms["val"])
dataloader = torch.utils.data.DataLoader(image_datasets, batch_size=1, shuffle=True)

# Get a single batch (one image)
inputs, _ = next(iter(dataloader))

# Move the model and inputs to the same device (CPU or GPU)
model.to(device)
inputs = inputs.to(device)


onnx_path = os.path.join(checkpoints_dir, model_name + ".onnx")

# Export the model
torch.onnx.export(model, inputs, onnx_path, 
                  input_names=['input'], output_names=['output'],
                  dynamic_axes={'input': {0: 'batch_size'}, 'output': {0: 'batch_size'}})


----------------
# Test

### Load Checkpoint and Data Loader

[tutorial](https://towardsdatascience.com/how-to-save-and-load-a-model-in-pytorch-with-a-complete-example-c2920e617dee)

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
batch_size = 64

# # Remember to first initialize the model and optimizer, then load the dictionary locally.
# model, input_size = initialize_model(model_type, num_classes, train_deep)
# model = model.to(device)

# optimizer = optim.Adam(model.parameters(), lr=0.001)

# checkpoint = torch.load(checkpoint_path)
# model.load_state_dict(checkpoint['model_state_dict'])
# optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
# epoch = checkpoint['epoch']
# loss = checkpoint['loss']


model = torch.load(checkpoint_path)
model = model.to(device)
model.eval()  # set dropout and batch normalization layers to evaluation mode before running inference


# create data loader for test-data
test_transform = transforms.Compose([
    transforms.Resize(input_size),
    transforms.CenterCrop(input_size),
    transforms.ToTensor(),
    transforms.Normalize(mean, std)])

In [None]:
test_dir = os.path.join(split_data_dir, "test")

testset = datasets.ImageFolder(test_dir, test_transform)
test_loader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=True)

test_acc = 0.0
for samples, labels in test_loader:
    with torch.no_grad():
        samples, labels = samples.to(device), labels.to(device)
        output = model(samples)

        # calculate accuracy
        pred = torch.argmax(output, dim=1)
        correct = pred.eq(labels)
        test_acc += torch.mean(correct.float())

testimage_count = len(testset)
test_result = test_acc.item()/len(test_loader)
print(f'Accuracy of the network on {testimage_count} test images: {round(test_result * 100.0, 2)}%')


In [20]:
# update metadata with test accuracy
data = {'date_modified': get_current_date(), 'test_acc': round(test_result, 2)}
save_json(data, json_path, update=True)

In [None]:
visualize_model(model, device=device, num_images=6)

--------------------
# Predict

In [76]:
import os
import random
from PIL import Image


def get_random_image(images_dir, ext=".jpg"):
    # List all files in the directory
    all_files = os.listdir(images_dir)
    # Filter out only JPG files
    jpg_files = [f for f in all_files if f.endswith(ext)]
    # Select a random JPG file
    random_image = random.choice(jpg_files)
    # Return the full path to the random image
    return os.path.join(images_dir, random_image)


def load_pt_tensor(image_path, device="cpu"):
    image = Image.open(image_path).convert('RGB')
    image_transformed = test_transform(image)  # = data_transforms['val'](image)

    # Add a batch dimension and move to CPU/GPU
    image_transformed = image_transformed.unsqueeze(0)  
    return image_transformed.to(device)


# Load class labels from a JSON file
def load_class_labels(json_path):
    with open(json_path, 'r') as f:
        class_labels = json.load(f)["class_labels"]
    return class_labels


# get a random image from the validation set
image_path = get_random_image(os.path.join(split_data_dir, "val", "ants"))

# load class labels from metadata.json
class_labels = load_class_labels(json_path)

device = torch.device("cpu")  # ("cuda:0" if torch.cuda.is_available() else "cpu")


### PyTorch inference

In [None]:
def torch_predict(model, image, class_labels):
    with torch.no_grad():  # Disable gradient calculation
        outputs = model(image)
        
        # Get the predicted class and probability
        probabilities = torch.nn.functional.softmax(outputs, dim=1)
        top_prob, top_class = probabilities.topk(1, dim=1)
        
        # Get the predicted class name and probability
        class_label = class_labels[top_class[0][0]]
        probability = top_prob[0][0].item()
    return class_label, probability


image_tensor = load_pt_tensor(image_path, device=device)

model = torch.load(checkpoint_path)
model = model.to(device)
model.eval()

predicted_label, probability = torch_predict(model, image_tensor, class_labels)

print(f'Predicted: {predicted_label} ({probability:.2f})')
imshow(image_tensor, title=f'{predicted_label} ({probability:.2f})')

-----------
### ONNX inference

In [None]:
import onnxruntime as ort
import torch.nn.functional as F


# Load the ONNX model
onnx_path = os.path.join(checkpoints_dir, model_name + ".onnx")
ort_session = ort.InferenceSession(onnx_path)


# Preprocess the input image
def load_onnx_tensor(image_path):
    image = Image.open(image_path)
    image_tensor = test_transform(image)
    image_tensor = image_tensor.unsqueeze(0)
    return image_tensor


# Run inference
def onnx_predict(input_batch):
    ort_inputs = {ort_session.get_inputs()[0].name: input_batch}
    ort_outs = ort_session.run(None, ort_inputs)

    # Convert logits to probabilities
    logits = ort_outs[0]
    probabilities = F.softmax(torch.tensor(logits), dim=1).numpy()
    
    # Get the top predicted class
    top_class_idx = np.argmax(probabilities, axis=1)[0]
    top_class_label = class_labels[top_class_idx]
    top_class_prob = probabilities[0][top_class_idx]
    
    return top_class_label, top_class_prob


image_tensor = load_onnx_tensor(image_path)
class_labels = load_class_labels(json_path)

label, prob = onnx_predict(image_tensor.numpy())

print(f"Predicted: {label} ({prob:.2f})")
imshow(image_tensor, title=f'{predicted_label} ({probability:.2f})')