# Arthropod Classification with a pretrained network : Resnet50 by Wu et al.

See article __IP102: A Large-Scale Benchmark Dataset for Insect Pest Recognition__
by Xiaoping Wu, Chi Zhan, Yu-Kun Lai, Ming-Ming Cheng, Jufeng Yang from College of Computer Science, Nankai University, Tianjin, China and School of Computer Science and Informatics, Cardiff University, Cardiff, UK

__Abstract of the article:__

*Insect pests are one of the main factors affecting agri- cultural product yield. Accurate recognition of insect pests facilitates timely preventive measures to avoid economic losses. However, the existing datasets for the visual clas- sification task mainly focus on common objects, e.g., flow- ers and dogs. This limits the application of powerful deep learning technology on specific domains like the agricul- tural field. In this paper, we collect a large-scale dataset named IP102 for insect pest recognition. Specifically, it contains more than 75, 000 images belonging to 102 cat- egories, which exhibit a natural long-tailed distribution. In addition, we annotate about 19, 000 images with bounding boxes for object detection. The IP102 has a hierarchical taxonomy and the insect pests which mainly affect one spe- cific agricultural product are grouped into the same upper- level category. Furthermore, we perform several baseline experiments on the IP102 dataset, including handcrafted and deep feature based classification methods. Experimen- tal results show that this dataset has the challenges of inter- and intra- class variance and data imbalance. We believe our IP102 will facilitate future research on practical insect pest control, fine-grained visual classification, and imbal- anced learning fields. We make the dataset and pre-trained models publicly available at https://github.com/xpwu95/IP102.*

__Contents__

1. Preliminaries
    - Data loading with boto3 (AWS)
    - Dataset class, dataloader, transforms
    - Image visualisation
    
    
2. First runs 

    - A. Transfer learning 
        - utils : learning curves, checkpointing
        - train v4 : history : checkpointing + warm start
    - B. Fine-tuning
    
    
3. Hyperparameter tuning
    - code wrapping for ray.tune 
    - hp tuning for the last layer only (transfer learning)
    - hp tuning for the whole network (fine-tuning)
    
   

## Preliminaries

__AWS specific__

In [None]:
import boto3
s3c = boto3.client('s3')

In [None]:
bucket = 'eva-arthropod'
root_dir =  'arthropod_highres' 

#ARN is arn:aws:s3:::eva-arthropod
source_dirs = ['arthropod_highres_us_2/Araneae',
               'arthropod_highres_us_2/Coleoptera',
               'arthropod_highres_us_2/Diptera',
               'arthropod_highres_us_2/Hemiptera',
               'arthropod_highres_us_2/Hymenoptera',
               'arthropod_highres_us_2/Lepidoptera', 
               'arthropod_highres_us_2/Odonata']  

subfolder = 'arthropod_highres/Araneae/' #test
contents = s3c.list_objects(Bucket=bucket, Prefix=subfolder)['Contents']
for f in contents:
    print(f['Key'])
    break

__Regular imports__

In [None]:
!pip install ray[tune]

In [None]:
import os
import shutil
import random
import time
from math import ceil

import torch
import torchvision
from torchvision import transforms, models
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable

from functools import partial
from ray import tune
from ray.tune import CLIReporter
from ray.tune.schedulers import ASHAScheduler

import numpy as np
import pandas as pd
from sklearn.metrics import classification_report

from PIL import Image
import matplotlib.pyplot as plt

In [None]:
# we use GPU if available, otherwise CPU
c = torch.cuda.device_count()
print("Number of GPUs : ", c)
if c > 0 :
    print(torch.cuda.get_device_name(device=None))
    
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

__Reproductibility__

In [None]:
# Ensure reproductibility in all places
def seed_all(seed):
    print(">>> Using Seed : ", seed, " <<<")
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.cuda.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    return None

seed_all(2905)
    
# check pytorch version : 1.7.1
print("Using PyTorch version ", torch.__version__)

# for classification saving results in  a dataframe 
saving = False

__Train/Val/Test split__

In [None]:
class_names = [folder_name.split('/')[-1] for folder_name in source_dirs] # to adapt
class_names.sort() 
print(class_names)
label_dic = {class_names[val]:val for val in range(len(class_names))}
print(label_dic)

In [None]:
paginator = s3c.get_paginator('list_objects_v2')

N_images = []

file_dic = {}

for folder in source_dirs :
    # class name
    class_name = folder.split('/')[-1]
    print('\n', class_name)
    # iterator over pages
    pages = paginator.paginate(Bucket=bucket, Prefix=folder)
    # count images
    files = []
    for page in pages:
        for obj in page['Contents']:
            if obj['Key'].endswith('.jpg'):
                files.append(obj['Key'])
    print(len(files))
    N_images.append(len(files))
    file_dic[class_name]=files

N_min = min(N_images)
print("\nN_min total : ", N_min)


In [None]:
# Train, val and test percentages

p1 = 0.9 # train_val percentage from initial dataset
p2 = 0.2 # val_percentage from train_val

N_test = ceil(N_min*(1-p1))
print("N_test per class : ", N_test)

N_val = ceil(N_min*p1*p2)
print("N_val per class : ", N_val)

for name in class_names:
    print("N_train "+name+' : ', len(file_dic[name])-N_test-N_val)

In [None]:
train_files = []
val_files = []
test_files = []

for name in class_names : 
    files = file_dic[name]
    test_files+=files[:N_test]
    val_files+=files[N_test:N_test+N_val]
    train_files+=files[N_test+N_val:]

print(len(train_files), len(val_files), len(test_files))
print(train_files[:5])

In [None]:
# shuffle lists : seed has been set above
print(train_files[:5])
random.shuffle(train_files)
print(train_files[:5])
random.shuffle(val_files)
random.shuffle(test_files)

""" To verify reproductibility OK
['arthropod_highres_us_2/Araneae/3c628b041db2.jpg', 'arthropod_highres_us_2/Araneae/3c6491416c3f.jpg', 'arthropod_highres_us_2/Araneae/3c9737b52807.jpg', 'arthropod_highres_us_2/Araneae/3cacd14fe17d.jpg', 'arthropod_highres_us_2/Araneae/3cda1ee5743d.jpg']
['arthropod_highres_us_2/Hymenoptera/df4734164730.jpg', 'arthropod_highres_us_2/Araneae/e65766cc8702.jpg', 'arthropod_highres_us_2/Lepidoptera/6396f66b8ed9.jpg', 'arthropod_highres_us_2/Hemiptera/c2ea24c86c40.jpg', 'arthropod_highres_us_2/Hymenoptera/6b5413623f03.jpg']"""

__Data aug__

In [None]:
# get file list from /augmented_images

paginator = s3c.get_paginator('list_objects_v2')

N_img_aug = []
aug_files = []
aug_file_dic = {}

for folder in ['augmented_images'] :
    # iterator over pages
    pages = paginator.paginate(Bucket=bucket, Prefix=folder)
    # count images
    files = []
    for page in pages:
        for obj in page['Contents']:
            if obj['Key'].endswith('.jpg'):
                files.append(obj['Key'])
    print(len(files))
    print(files[:2])
    N_img_aug.append(len(files))
    aug_files+=files

for c in class_names:
    aug_file_dic[c]=[f for f in aug_files if c in f]
    
for name in class_names:
    print("N_aug "+name+' : ', len(aug_file_dic[name]))

In [None]:
# target number of images for each class so that the dataset is balanced

for name in class_names:
    print("N_train "+name+' : ', len(file_dic[name])-N_test-N_val)

N_min_aug = min([len(aug_file_dic[name]) for name in class_names])
print("Min number of images available per class : ", N_min_aug)

N_target = N_min_aug+N_min
print(N_target)

for name in class_names : 
    N_to_add = N_target - len([f for f in train_files if name in f]) - len([f for f in val_files if name in f]) - N_test #+N_min # total number of examples to add
    #print("Total number of augmented images to add for class %s : %s" %(name, N_to_add))
    train_files = train_files + aug_file_dic[name][:ceil(N_to_add*0.8)]
    val_files = val_files + aug_file_dic[name][ceil(N_to_add*0.8):N_to_add]
    #print(len([f for f in train_files_aug if name in f]))
    #print(len([f for f in val_files_aug if name in f]))
    
    
print("Total number of training examples : ", len(train_files))
print("Total number of validation examples : ", len(val_files))

__Data loading and preprocessing__

In [None]:
### Create a custom dataset for the arthropod images ###
# file_list = dataset argument : train, val or test, list of paths as strings


class ArthropodDataset(torch.utils.data.Dataset):

    def __init__(self, file_list, bucket_name = 'eva-arthropod', transform=None): 
        self.bucket_name = bucket_name
        self.files = file_list
        self.s3_resource = boto3.resource('s3')
        self.transform = transform
        if transform is None:
            self.transform = torchvision.transforms.Compose([
                torchvision.transforms.Resize(size=(224, 224)),
                torchvision.transforms.ToTensor(),
                torchvision.transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
            ])

    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):
        img_name = self.files[idx]
        # label is infered from the filename
        if "imaug" in img_name:
            label_str = img_name[17:].split('_')[0]
        else:
            label_str = img_name.split('/')[1] # ex : arthropod_highres/Araneae/69e98607339f.jpg --> Araneae
        label = int(label_dic[label_str]) # label as integer

        # we need to download the file from S3 to a temporary file locally
        # we need to create the local file name
        obj = self.s3_resource.Object(self.bucket_name, img_name)
        tmp_name = '/tmp/'+img_name.split('/')[-1]
        # now we can actually download from S3 to a local place
        with open(tmp_name, 'wb') as f:
            obj.download_fileobj(f)
            f.flush()
            f.close()
            image = Image.open(tmp_name)
        if self.transform:
            image = self.transform(image)

        return image, label
    

NB : issue with this model is that we don't know the distribution of the IP102 dataset so we can't normalize our data with the same mean and variance than the IP102 dataset. So we have chosen to normalize the dataset with the values we have for ImageNet - at least we have an identical starting point.

In [None]:
### Transforms  ###

# train set : 224*224 images as input because min input size of mobilenet, resnet18 etc architectures 
# mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] because they are the values of the ImageNet data     
train_transform = torchvision.transforms.Compose([
    torchvision.transforms.Resize(size=(224, 224)),
    #torchvision.transforms.RandomHorizontalFlip(),
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# test set : same steps except the data augmentation step     
# idem for the mean and std dev values    
test_transform = torchvision.transforms.Compose([
    torchvision.transforms.Resize(size=(224, 224)),
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

In [None]:
### Data loaders ###

train_dataset = ArthropodDataset( train_files, transform = train_transform)
val_dataset = ArthropodDataset(val_files, transform = train_transform)
test_dataset=ArthropodDataset(test_files, transform = test_transform)

# dataloaders objects corresponding to the train and test sets
B=16 # batch size for the data loaders
print("Batch size : ", B)

NUM_WORKERS = 7

dl_train=torch.utils.data.DataLoader(train_dataset, batch_size=B, shuffle=True, num_workers=NUM_WORKERS) 
dl_val=torch.utils.data.DataLoader(val_dataset, batch_size=B, shuffle=True, num_workers=NUM_WORKERS)  
dl_test=torch.utils.data.DataLoader(test_dataset, batch_size=B, shuffle=False, num_workers=NUM_WORKERS)
print("Number of training batches: ",  len(dl_train))
print("Number of val batches: ",  len(dl_val))
print("Number of test batches: ",  len(dl_test))

In [None]:
### Data visualization ###
# B : batchsize = 16
def show_images(images, labels, preds):
    plt.figure(figsize=(10, 12))
    for i, image in enumerate(images):
        plt.subplot(4, 4, i + 1, xticks=[], yticks=[])
        image = image.numpy().transpose((1, 2, 0))
        mean = np.array([0.485, 0.456, 0.406])
        std = np.array([0.229, 0.224, 0.225])
        image = image * std + mean
        image = np.clip(image, 0., 1.)
        plt.imshow(image)
        col = 'green'
        if preds[i] != labels[i]:
            col = 'red'            
        plt.xlabel(f'{class_names[int(labels[i].numpy())]}')
        plt.ylabel(f'{class_names[int(preds[i].numpy())]}', color=col)
    plt.tight_layout()
    plt.show()
    
images, labels = next(iter(dl_train))
show_images(images, labels, labels)

images, labels = next(iter(dl_test))
show_images(images, labels, labels)

# First runs with Resnet50

### A. Transfer learning

__Load the pre-trained Resnet50 network__

Network was pretrained on a problem consisting of 102 insect classes. 

In [None]:
# load pre-trained model

boto3.resource('s3').meta.client.download_file(bucket, 
                                               os.path.join('models', 'resnet50_0.497.pkl'),
                                               '/tmp/resnet50_wu')
resnet_dict = torch.load('/tmp/resnet50_wu', map_location=torch.device('cpu'))
print("Successfully downloaded model")

print(type(resnet_dict))

In [None]:
resnet = models.resnet50()
resnet.fc = nn.Linear(in_features=resnet.fc.in_features,
                                 out_features=102, # Wu et al. trained their models on a dataset with 102 classes
                                 bias=True)
resnet.load_state_dict(resnet_dict)

In [None]:
for param in resnet.parameters():
    param.requires_grad = False

NUM_CLASSES = len(class_names)

resnet.fc = nn.Linear(in_features=resnet.fc.in_features,
                                 out_features=NUM_CLASSES, # Wu et al. trained their models on a dataset with 102 classes
                                 bias=True)
resnet.to(device) # puts model on GPU / CPU
resnet.train(True) 

In [None]:
# number of trainable parameters
model_parameters = filter(lambda p: p.requires_grad, resnet.parameters())
params = sum([np.prod(p.size()) for p in model_parameters])
print("Number of trainable parameters : ", params)

__Training : Transfer learning__

In [None]:
# on définit une loss et un optimizer
# on limite l'optimisation aux paramètres de la nouvelle couche
loss_fn = nn.CrossEntropyLoss()

# test 1
optimizer = torch.optim.SGD(resnet.fc.parameters(), lr=0.005, momentum=0.9)
lr_lambda = lambda epoch : 0.9991 
scheduler = torch.optim.lr_scheduler.MultiplicativeLR(optimizer, lr_lambda, last_epoch=-1, verbose=False)

In [None]:
# checkpointing utils

def load_checkpoint(filename='cp_ResNet', checkpoint_dir='models', bucket=bucket):
    """
    Input
    ------
    filename : str, choose something like cp_modelType_version_epoch
    checkpoint_dir : str, "/models"
    bucket : str, "eva_arthropod"
    
    Output
    ------
    model_state : model state dict
    optimizer_state : optimizer state dict
    """
    boto3.resource('s3').meta.client.download_file(bucket, 
                                                   os.path.join(checkpoint_dir, filename), 
                                                   '/tmp/loaded_checkpoint')
    model_state, optimizer_state = torch.load('/tmp/loaded_checkpoint')
    print("Successfully downloaded checkpoint from s3 bucket : " ,filename)
    return model_state, optimizer_state 

def save_checkpoint(net, optimizer, filename='cp_ResNet', checkpoint_dir='models', bucket=bucket):
    torch.save((net.state_dict(), optimizer.state_dict()), '/tmp/saved_checkpoint')
    s3_path = os.path.join(checkpoint_dir, filename)
    print(s3_path)
    boto3.resource('s3').meta.client.upload_file(Filename = '/tmp/saved_checkpoint', 
                                                 Bucket = bucket, 
                                                 Key = s3_path)
    print("Successfully uploaded checkpoint %s to s3 bucket under name %s." %(filename, s3_path))

# test
# save_checkpoint(model, optimizer)
# load_checkpoint(filename='cp_Resnet_ft', checkpoint_dir='models', bucket=bucket)

In [None]:
## Training with checkpoints ##

# learning rate decay + save checkpoint only if val accuracy was improved


def train_v4(model, N_epochs, train_loader, test_loader, optimizer, scheduler, make_checkpoints = True, filename = 'cp_resnet', warm_start = False, warm_filename = None):
    """
    model : neural network like myCNN()
    N_epochs : int
    train_loader : dataloader instance to iterate over training batches
    test_loader : idem, for val or test set
    make_checkpoints : bool
    filename : str, filename to call whether to load checkpointed state dicts or to save a new version
    warm_start : bool, whether to start training from scratch or from the last checkpoint
    """
    start = time.time()
    print('Starting training..')
    
    if warm_start :
        model_state, optimizer_state  = load_checkpoint(filename=warm_filename) 
        model.load_state_dict(model_state)
        optimizer = torch.optim.SGD(resnet.fc.parameters(), lr=0.01, momentum=0.9) # redefinition is necessary 
        optimizer.load_state_dict(optimizer_state)
            
    train_loss_history = []
    val_loss_history = []
    val_acc_history = []
    batch_idx_list = []
    epoch_list = [n for n in range(N_epochs)]
    
    for epoch in range(N_epochs):
        print('\n'+'='*20+'\nStarting epoch '+str(epoch+1) + '/'+str(N_epochs)+'\n'+'='*20)
        # training
        model.train() # mode "train" agit sur "dropout" ou "batchnorm"
        for batch_idx, (x, target) in enumerate(train_loader):
            optimizer.zero_grad()
            x, target = Variable(x).to(device), Variable(target).to(device)
            out = model(x)
            loss = loss_fn(out, target)
            loss.backward()
            optimizer.step()
            scheduler.step()
            
            if batch_idx %100 ==0:
                print('epoch {} - batch {} [{}/{}] - training loss: {}'.format(epoch+1,batch_idx,batch_idx*len(x),
                        len(train_loader.dataset),loss.item()))
                train_loss_history.append(loss.item())
                batch_idx_list.append(batch_idx+epoch*ceil(len(train_loader.dataset)/len(x))) # number of batches already processed
                
        # testing -- val set
        model.eval()
        correct = 0
        with torch.no_grad():
            for batch_idx, (x, target) in enumerate(test_loader):
                x, target = x.to(device), target.to(device)
                out = model(x)
                loss = loss_fn(out, target)
                # _, prediction = torch.max(out.data, 1)
                prediction = out.argmax(dim=1, keepdim=True) # index of the max log-probability
                correct += prediction.eq(target.view_as(prediction)).sum().item()
        taux_classif = 100. * correct / len(test_loader.dataset)
        print('Val Accuracy: {}/{} (i.e. {:.2f}%, error: {:.2f}%)\n'.format(correct,
         len(test_loader.dataset), taux_classif, 100.-taux_classif))
        val_loss_history.append(loss.item())
        val_acc_history.append(taux_classif)

        if make_checkpoints :
            if epoch == 0 :
                save_checkpoint(model, optimizer, filename)                
            elif val_acc_history[-1]>val_acc_history[-2]:
                print("Improvement in validation accuracy by %s percents" %(val_acc_history[-1]-val_acc_history[-2]))
                save_checkpoint(model, optimizer, filename)
            else : 
                pass
        
        
    print("Elapsed : %s min" %ceil((time.time()-start)/60))
    hist_dic= {'train_loss': train_loss_history, 'epochs' : epoch_list,
             'val_loss': val_loss_history, 'val_acc': val_acc_history, 'batches' : batch_idx_list}
    return model, hist_dic


In [None]:
# Utils : learning curves

def plot_learning_curves(hist_dic):
    train_loss_history, epoch_list = hist_dic['train_loss'], hist_dic['epochs'] 
    val_loss_history, val_acc_history, batch_idx_list = hist_dic['val_loss'], hist_dic['val_acc'], hist_dic['batches'] 
    # plot learning curves - training set
    plt.plot(batch_idx_list, train_loss_history,  
             color='blue', linestyle='dashed', 
             marker = '*', markerfacecolor='black')
    plt.xlabel("Number of processed training batches")
    plt.ylabel("Training loss")
    plt.title('Training loss across training')
    plt.show()
    # plot learning curves - val set
    plt.plot(epoch_list, val_loss_history, 
             color = 'red', linestyle = 'dashed', 
             marker = '*', markerfacecolor = 'black')
    plt.xlabel("Number of epochs")
    plt.ylabel("Val loss")
    plt.title('Val loss across training')
    plt.show()
    plt.plot(epoch_list, val_acc_history,  
             color = 'green', linestyle = 'dashed', 
             marker = '*', markerfacecolor = 'black')
    plt.xlabel("Number of epochs")
    plt.ylabel("Val accuracy")
    plt.title('Val accuracy across training')
    plt.show()

__Evaluation__

In [None]:
# BETTER

def get_pred(model, data_loader):
    model.eval()
    test_pred = torch.LongTensor()
    test_target= torch.LongTensor()
    
    with torch.no_grad():
        for data, target in data_loader:
            if torch.cuda.is_available():
                data = data.cuda()
                
            outputs = model(data)
            pred = np.argmax(outputs,axis=1)
            test_pred = torch.cat((test_pred, pred), dim=0)
            test_target = torch.cat((test_target, target), dim=0)
        
    return test_pred, test_target




### Training 2

In [None]:
# test 2

# Train and plot learning curves
  
N_epochs = 10
optimizer = torch.optim.SGD(resnet.fc.parameters(), lr=0.01, momentum=0.9)
scheduler = torch.optim.lr_scheduler.CyclicLR(optimizer, 5e-3, 5e-5, step_size_up=2000)

# model, N_epochs, train_loader, test_loader, make_checkpoints = True, filename = 'cp_MobileNet', warm_start = False    
trained_model, hist_dic = train_v4(resnet, N_epochs, dl_train, dl_val, optimizer, scheduler, 
                                   filename = 'cp_Resnet_transfered_aug', warm_start = False)

# plot learning curves
plot_learning_curves(hist_dic)

In [None]:
train_pred, train_target = get_pred(model, dl_train)
val_pred, val_target = get_pred(model, dl_val)

print("\nPerformance on training set")
print(classification_report(train_target, train_pred, target_names=class_names))
print("\nPerformance on val set")
print(classification_report(val_target, val_pred, target_names=class_names))

test_pred, test_target = get_pred(trained_model, dl_test)
print("\nPerformance on test set")
print(classification_report(test_target, test_pred, target_names=class_names))


# B. Fine-tuning

__Reload pretrained model__

This time we want to learn the parameters of the entire network.

In [None]:
# re-initialize resnet
resnet_ft = models.resnet50()
resnet_ft.fc = nn.Linear(in_features=resnet_ft.fc.in_features,
                         out_features=102, # Wu et al. trained their models on a dataset with 102 classes
                         bias=True)
resnet_ft.load_state_dict(resnet_dict)
resnet_ft.to(device)

NUM_CLASSES = len(class_names)

resnet_ft.fc = nn.Linear(in_features=resnet_ft.fc.in_features,
                                 out_features=NUM_CLASSES, # Wu et al. trained their models on a dataset with 102 classes
                                 bias=True)
resnet_ft.to(device) # puts model on GPU / CPU
resnet_ft.train(True) 

params_to_update = resnet_ft.parameters()

# number of trainable parameters
model_parameters = filter(lambda p: p.requires_grad, resnet_ft.parameters())
params = sum([np.prod(p.size()) for p in model_parameters])
print("Number of trainable parameters : ", params)

In [None]:
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(params_to_update, lr=0.001, momentum=0.9)

In [None]:
# on ré-entraîne
print("Fine-tuning of ResNet50 by Wu")
resnet_ft.train(True)

trained_model_ft, hist_dic_ft = train_v4(resnet_ft, 5, dl_train, dl_val,  filename = 'cp_Resnet_ft', warm_start = False)
plot_learning_curves(hist_dic_ft)

In [None]:
# Saving
filename = 'resnetWu_finetuned_1.pth'
torch.save(trained_model_ft.state_dict(), filename)
print("saved model to {}".format(filename))

In [None]:
# Evaluate on test set
test_pred, test_target = get_pred(trained_model, dl_test)
print("\nPerformance on test set")
print(classification_report(test_target, test_pred, target_names=class_names))

# Hyperparameter tuning 

In [None]:
### Data loaders ###

def load_data(train_files, val_files, test_files):

    # train set : 224*224 images as input because min input size of mobilenet, resnet18 etc architectures 
    # mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] because they are the values of the ImageNet data     
    train_transform = torchvision.transforms.Compose([
        torchvision.transforms.Resize(size=(224, 224)),
        #torchvision.transforms.RandomHorizontalFlip(),
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])

    # test set : same steps except the data augmentation step     
    # idem for the mean and std dev values    
    test_transform = torchvision.transforms.Compose([
        torchvision.transforms.Resize(size=(224, 224)),
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])
    
    train_dataset = ArthropodDataset(train_files, transform = train_transform)
    val_dataset = ArthropodDataset(val_files, transform = train_transform)
    test_dataset=ArthropodDataset(test_files, transform = test_transform)

    
    return train_dataset, val_dataset, test_dataset

# test OK

In [None]:
### Search space ###

N_epochs = 8

NUM_WORKERS = 2

config = {            
            "lr": tune.loguniform(1e-4, 1e-1),
            "batch_size": tune.choice([4, 8, 16])
}

### A. Transfered model

In [None]:
# reinitialize resnet

resnet = models.resnet50()
resnet.fc = nn.Linear(in_features=resnet.fc.in_features,
                      out_features=102, # Wu et al. trained their models on a dataset with 102 classes
                      bias=True)
resnet.load_state_dict(resnet_dict)

for param in resnet.parameters():
    param.requires_grad = False

NUM_CLASSES = len(class_names)

resnet.fc = nn.Linear(in_features=resnet.fc.in_features,
                      out_features=NUM_CLASSES, # Wu et al. trained their models on a dataset with 102 classes
                      bias=True)
resnet.to(device) # puts model on GPU / CPU
resnet.train(True) 

In [None]:
### Training ###

# with multiplicative learning rate decay  
# and checkpointing 

def train_n_tune(config, checkpoint_dir=None): # config, N_epochs, train_files, val_files, test_files, 
    
    net = resnet

    device = "cpu"
    if torch.cuda.is_available():
        device = "cuda:0"
        if torch.cuda.device_count() > 1:
            net = nn.DataParallel(net)
    net.to(device)

    criterion = nn.CrossEntropyLoss()
    optimizer =  torch.optim.SGD(net.parameters(), lr=config["lr"]) ## To adapt ##
    lr_lambda = lambda x : 0.9991 
    scheduler = torch.optim.lr_scheduler.MultiplicativeLR(optimizer, lr_lambda, last_epoch=-1, verbose=False)

    if checkpoint_dir:    
        model_state, optimizer_state = load_checkpoint('cp_Resnet_transfered')    # to test                    
        net.load_state_dict(model_state)
        optimizer.load_state_dict(optimizer_state)

    train_subset, val_subset, testset = load_data(train_files, val_files, test_files)

    trainloader = torch.utils.data.DataLoader(
        train_subset,
        batch_size=int(config["batch_size"]),
        shuffle=True,
        num_workers=2)
    valloader = torch.utils.data.DataLoader(
        val_subset,
        batch_size=int(config["batch_size"]),
        shuffle=True,
        num_workers=2)

    def test_accuracy(net, device="cpu"): # train_files, val_files, test_files are global 
        testloader = torch.utils.data.DataLoader(testset, batch_size=8, shuffle=False, num_workers=NUM_WORKERS)
        correct = 0
        total = 0
        with torch.no_grad():
            for data in testloader:
                images, labels = data
                images, labels = images.to(device), labels.to(device)
                outputs = net(images)
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()

        return correct / total
    
    for epoch in range(N_epochs):  # loop over the dataset multiple times
        print("--- Epoch %s ---" %(epoch+1))
        running_loss = 0.0
        epoch_steps = 0
        for i, data in enumerate(trainloader, 0):
            try :
                # get the inputs; data is a list of [inputs, labels]
                inputs, labels = data
                inputs, labels = inputs.to(device), labels.to(device)

                # zero the parameter gradients
                optimizer.zero_grad()

                # forward + backward + optimize
                outputs = net(inputs)
                loss = criterion(outputs, labels)
                loss.backward()
                optimizer.step()
                scheduler.step()
                
                # print statistics
                running_loss += loss.item()
                epoch_steps += 1
                if i % 200 == 199:  # print every 200 mini-batches
                    print("[%d, %5d] loss: %.3f" % (epoch + 1, i + 1,
                                                    running_loss / epoch_steps))
                    running_loss = 0.0
                    
            except : 
                print("Batch number %s left out / Training (OSError : missing bytes)" %i)
                # OSError: image file is truncated (X bytes not processed)
        print("Average loss for epoch %s : %s" %(epoch+1, running_loss/epoch_steps))
        
        # Validation loss
        val_loss = 0.0
        val_steps = 0
        total = 0
        correct = 0
        for i, data in enumerate(valloader, 0):
            try :
                with torch.no_grad():
                    inputs, labels = data
                    inputs, labels = inputs.to(device), labels.to(device)

                    outputs = net(inputs)
                    _, predicted = torch.max(outputs.data, 1)
                    total += labels.size(0)
                    correct += (predicted == labels).sum().item()

                    loss = criterion(outputs, labels)
                    val_loss += loss.cpu().numpy()
                    val_steps += 1
            except :
                print("Batch number %s left out / Validation (OSError : missing bytes)" %i)

        with tune.checkpoint_dir(epoch) as checkpoint_dir:
            path = os.path.join(checkpoint_dir, "checkpoint")
            torch.save((net.state_dict(), optimizer.state_dict()), path)

        tune.report(loss=(val_loss / val_steps), accuracy=correct / total)

    print("Finished Training")
    test_acc = test_accuracy(net, device)
    print("Current test set accuracy: {}".format(test_acc))
    


In [None]:
### Main : perform hp tuning ###

# train_files, val_files, test_files,  already defined above 

def main(num_samples=10, max_num_epochs=10, gpus_per_trial=2):
    
    train_dataset, val_dataset, test_dataset = load_data(train_files, val_files, test_files)
    
    scheduler = ASHAScheduler(
        metric="loss",
        mode="min",
        max_t=max_num_epochs,
        grace_period=1,
        reduction_factor=2)
    
    reporter = CLIReporter(
        # parameter_columns=["lr", "batch_size"],
        metric_columns=["loss", "accuracy", "training_iteration"])
    
    result = tune.run(
        partial(train_n_tune), # config=config, N_epochs=max_num_epochs, train_files, val_files, test_files
        resources_per_trial={"cpu": 8, "gpu": gpus_per_trial},
        config=config,
        num_samples=num_samples,
        scheduler=scheduler,
        progress_reporter=reporter)

    best_trial = result.get_best_trial("loss", "min", "last")
    print("Best trial config: {}".format(best_trial.config))
    print("Best trial final validation loss: {}".format(
        best_trial.last_result["loss"]))
    print("Best trial final validation accuracy: {}".format(
        best_trial.last_result["accuracy"]))

    best_trained_model = models.resnet50(pretrained=True)
    best_trained_model.fc = nn.Linear(in_features=best_trained_model.fc.in_features, 
                                                 out_features=NUM_CLASSES, bias=True)

    device = "cpu"
    if torch.cuda.is_available():
        device = "cuda:0"
        if gpus_per_trial > 1:
            best_trained_model = nn.DataParallel(best_trained_model)
    best_trained_model.to(device)

    best_checkpoint_dir = best_trial.checkpoint.value
    model_state, optimizer_state = torch.load(os.path.join(
        best_checkpoint_dir, "checkpoint"))
    best_trained_model.load_state_dict(model_state)
    
    def test_accuracy(net, device="cpu"): # train_files, val_files, test_files are global 
        testloader = torch.utils.data.DataLoader(testset, batch_size=8, shuffle=False, num_workers=NUM_WORKERS)
        correct = 0
        total = 0
        with torch.no_grad():
            for data in testloader:
                images, labels = data
                images, labels = images.to(device), labels.to(device)
                outputs = net(images)
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()

        return correct / total
    
    test_acc = test_accuracy(best_trained_model, device)
    print("Best trial test set accuracy: {}".format(test_acc))

if __name__ == "__main__":
    main(num_samples=6, max_num_epochs=10, gpus_per_trial=0)

### B. Fine-tuned model 

In [None]:
# reinitialize resnet

resnet = models.resnet50()
resnet.fc = nn.Linear(in_features=resnet.fc.in_features,
                      out_features=102, # Wu et al. trained their models on a dataset with 102 classes
                      bias=True)
resnet.load_state_dict(resnet_dict)

for param in resnet.parameters():
    param.requires_grad = True # this time we want to learn all parameters

NUM_CLASSES = len(class_names)

resnet.fc = nn.Linear(in_features=resnet.fc.in_features,
                      out_features=NUM_CLASSES, # Wu et al. trained their models on a dataset with 102 classes
                      bias=True)
resnet.to(device) # puts model on GPU / CPU
resnet.train(True) 

In [None]:
# run again main, this time the newly created resnet version is called
main(num_samples=6, max_num_epochs=10, gpus_per_trial=0)