# Projet INF8225
## Apprentissage par transfert


## Importations

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import cv2 #OpenCV documentation : https://docs.opencv.org/4.x/
import matplotlib.patches as patches
import torch
import torch.nn as nn
from PIL import Image
from torchvision import transforms
from torch.utils.tensorboard import SummaryWriter
from datetime import datetime
import time
import glob
%pip install wandb > /dev/null
import wandb
from collections import defaultdict

import seaborn as sns
sns.set_theme()

## Preprocessing

### Flowers dataset

In [None]:
import scipy.io
mat = scipy.io.loadmat('data/Flowers/imagelabels.mat')
flower_labels = mat.get('labels')
size = len(flower_labels[0])
print(size)

In [None]:
import scipy.io
mat = scipy.io.loadmat('data/Flowers/setid.mat')
train = mat.get('trnid')
valid = mat.get('valid')
test = mat.get('tstid')
print("train: ", len(train[0]))
print("valid: ", len(valid[0]))
print("test: ", len(test[0]))

In [None]:
ex = cv2.imread('data/Flowers/jpg/image_00435.jpg')
print(flower_labels[0][435])
ex = cv2.cvtColor(ex, cv2.COLOR_BGR2RGB)
fig,ax = plt.subplots(1)
ax.imshow(ex)
plt.axis('off')
plt.show()

In [None]:
class FlowerDataset(torch.utils.data.dataset.Dataset):
    def __init__(
            self,
            datasetid,
            preprocess,
            labels
        ):
        super().__init__()

        self.datasetid = datasetid
        self.preprocess = preprocess
        self.labels = labels

    def __len__(self):
        """Return the number of examples in the dataset.
        """
        return len(self.datasetid)

    def __getitem__(self, index: int) -> tuple:
        """Return a sample.

        Args
        ----
            index: Index of the sample.

        Output
        ------
            input_tensor : tensor
            label : int
        """
        idx_image = '0000'+str(self.datasetid[index])
        input_image = Image.open('data/Flowers/jpg/image_'+idx_image[-5:]+'.jpg')
        label = self.labels[self.datasetid[index]]-1
        input_tensor = self.preprocess(input_image)
        return input_tensor, label

### Naturalist dataset

In [None]:
root_path = 'data/iNaturalist/train_val_images/Fungi' 

In [None]:
ex = cv2.imread(root_path+'/Agrocybe parasitica/028ec1c88701ec38291810cba41df7aa.jpg')
ex = cv2.cvtColor(ex, cv2.COLOR_BGR2RGB)
fig,ax = plt.subplots(1)
ax.imshow(ex)
plt.axis('off')
plt.show()

In [None]:
import os
categories_dic = {}
categories_list = []
images = []
fungi_labels = []
i = 0
for directory in os.listdir(root_path):
  categories_dic[directory] = i
  categories_list.append(directory)

  for filename in glob.glob(root_path + '/' + directory + "/*.jpg"):
    #print(filename)
    images.append(filename.split('/')[-1].split('.')[0])
    fungi_labels.append(i)

  i+=1

In [None]:
print(len(categories_list))

In [None]:
np.bincount(fungi_labels).min()

In [None]:
from sklearn.model_selection import train_test_split

fungi_train, fungi_valid, y_fungi_train, y_fungi_valid = train_test_split(images, fungi_labels, train_size=0.8)

print(len(fungi_train))

In [None]:
print(np.unique(y_fungi_train))
print(len(np.unique(fungi_labels)))

In [None]:
print(fungi_train)
print(len(fungi_train))
print(y_fungi_train)
print(len(y_fungi_train))

label = y_fungi_train[0]

print(categories_list[label])

ex = cv2.imread(root_path+'/' +categories_list[label] +'/' + fungi_train[0] +'.jpg')
ex = cv2.cvtColor(ex, cv2.COLOR_BGR2RGB)
fig,ax = plt.subplots(1)
ax.imshow(ex)
plt.axis('off')
plt.show()

In [None]:
class iNaturalistDataset(torch.utils.data.dataset.Dataset):
    def __init__(
            self,
            filenames,
            preprocess,
            labels,
            categories_list = None
        ):
        super().__init__()

        self.filenames = filenames
        self.preprocess = preprocess
        self.labels = labels

        self.categories_list = categories_list

    def __len__(self):
        """Return the number of examples in the dataset.
        """
        return len(self.filenames)

    def __getitem__(self, index: int) -> tuple:
        """Return a sample.

        Args
        ----
            index: Index of the sample.

        Output
        ------
            input_tensor : tensor
            label : int
        """
        
        idx_image = self.filenames[index]

        label = self.labels[index]

        input_image = Image.open(root_path + '/' + self.categories_list[label] + '/' +idx_image +'.jpg')
        input_tensor = self.preprocess(input_image)
        return input_tensor, label

## Architecture

### Training loop

In [None]:
def print_logs(dataset_type: str, logs: dict):
    """Print the logs.

    Args
    ----
        dataset_type: Either "Train", "Eval", "Test" type.
        logs: Containing the metric's name and value.
    """
    desc = [
        f'{name}: {value:.2f}'
        for name, value in logs.items()
    ]
    desc = '\t'.join(desc)
    desc = f'{dataset_type} -\t' + desc
    desc = desc.expandtabs(5)
    print(desc)


def loss_batch(
        model,
        image,
        label,
        config,
    )-> dict:
    """Compute the metrics associated with this batch.
    The metrics are:
        - loss
        - top-1 accuracy
        - top-5 accuracy

    Output
    ------
        metrics: Dictionnary containing evaluated metrics on this batch.
    """
    device = config['device']
    loss_fn = config['loss'].to(device)
    metrics = dict()

    image, label = image.to(device), label.to(device)

    # Loss
    pred = model(image) 
    try :
      pred = pred.to(device)
    except AttributeError:
      pred = pred.logits.to(device)
    
    metrics['loss'] = loss_fn(pred, label)

    # Accuracy 
    for k in [1, 5, 10]:
        total = image.shape[0]

        _, pred_k = pred.topk(k=k, dim=-1)

        real_labels = einops.repeat(label, 'b -> b k', k=k)  

        good = (pred_k == real_labels)
        acc = good.sum() / total
        
        metrics[f'top-{k}'] = acc

    return metrics


def eval_model(model, dataloader, config) -> dict:
    """Evaluate the model on the given dataloader.
    """
    device = config['device']
    logs = defaultdict(list)

    model.to(device)

    with torch.no_grad():
        for image, label in dataloader:
            metrics = loss_batch(model, image, label, config)
            for name, value in metrics.items():
                logs[name].append(value.cpu().item())

    for name, values in logs.items():
        logs[name] = np.mean(values)
    return logs


def train_model(model: nn.Module, config: dict):
    train_loader, val_loader = config['train_loader'], config['val_loader']
    optimizer = config['optimizer']
    device = config['device']

    # Early stopping
    last_loss = 100
    patience = config['patience']
    triggertimes = 0
    
    model = model.to(device)

    # Freezing of layers
    for (name, module) in model.named_children():
        for param in module.parameters():
            if name in config['freeze'] :
                param.requires_grad = False
            else :
                param.requires_grad = True

    print(f'Starting training for {config["epochs"]} epochs, using {device}.')
    for e in range(config['epochs']):
        print(f'\nEpoch {e+1}')
        debut = time.time()
        model.train(True)      
        logs = defaultdict(list)

        running_loss = 0
        for batch_id, (source, label) in enumerate(train_loader):
            optimizer.zero_grad()

            metrics = loss_batch(model, source, label, config)
            loss = metrics['loss']
            loss.backward()
            running_loss += loss

            optimizer.step()

            for name, value in metrics.items():
                logs[name].append(value.cpu().item())  # Don't forget the '.item' to free the cuda memory
            if batch_id % config['log_every'] == 0:
                for name, value in logs.items():
                    logs[name] = np.mean(value)

                train_logs = {
                    f'Train - {m}': v
                    for m, v in logs.items()
                }
                wandb.log(train_logs)
                logs = defaultdict(list)

        # We don't need gradients on to do reporting
        model.train(False)

        # Logs
        if len(logs) != 0:
            for name, value in logs.items():
                logs[name] = np.mean(value)
            train_logs = {
                f'Train - {m}': v
                for m, v in logs.items()
            }
        else:
            logs = {
                m.split(' - ')[1]: v
                for m, v in train_logs.items()
            if ' - ' in m}

        print_logs('Train', logs)

        logs = eval_model(model, val_loader, config)
        print_logs('Eval', logs)
        val_logs = {
            f'Validation - {m}': v
            for m, v in logs.items()
        }

        logs = {**train_logs, **val_logs}  # Merge dictionnaries
        wandb.log(logs)  # Upload to the WandB cloud

        duree = time.time()-debut
        print(f'Durée : {duree} secondes')

        # Early stopping
        current_loss = logs['Validation - loss']

        if current_loss > last_loss:
            trigger_times += 1
            print('Trigger Times:', trigger_times)

            if trigger_times >= patience:
                print('Early stopping!')

                return model

        else:
            print('Trigger Times: 0')
            trigger_times = 0

            last_loss = current_loss


### ResNet

In [None]:
#Data augmentation
resnet_transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomRotation(10),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

resnet_transform_valid = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

In [None]:
resnet = torch.hub.load('pytorch/vision:v0.10.0', 'resnet18', pretrained=True)
resnet._modules['fc'] = nn.Linear(in_features=512, out_features=102, bias=True)
resnet.eval()

In [None]:
for name, param in resnet.named_parameters():
    print(name)

In [None]:
for (name, module) in resnet.named_children():
    print(name)

### Inception

In [None]:
inception_transform = transforms.Compose([
    transforms.Resize(299),
    transforms.CenterCrop(299),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomRotation(10),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

inception_transform_valid = transforms.Compose([
    transforms.Resize(299),
    transforms.CenterCrop(299),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

In [None]:
inception = torch.hub.load('pytorch/vision:v0.10.0', 'inception_v3', pretrained=False, aux_logits=False)
inception._modules['fc'] = nn.Linear(in_features=2048, out_features=102, bias=True)
inception.eval()

In [None]:
for name, module in inception.named_children():
  print(name)

In [None]:
for name, param in inception.named_parameters():
  print(name)

### Training

In [None]:
!wandb offline
!wandb login

!nvidia-smi

In [None]:
config = {
    # General parameters
    'epochs': 20,
    'batch_size': 20,
    'lr': 1e-3,
    'betas': (0.9, 0.99),
    'device': 'cuda' if torch.cuda.is_available() else 'cpu',

    'log_every': 5,  # Number of batches between each wandb logs
    'patience' : 5,  # Patience for early stopping mechanism
}


config['pretrained'] = True

config['data_augmentation'] = True

model_type = inception

config['freeze'] = []

if model_type == resnet:
  config['valid_preprocess'] = resnet_transform_valid
  if config['data_augmentation']:
    config['preprocess'] = resnet_transform
  else : 
    config['preprocess'] = resnet_transform_valid
elif model_type == inception:
  config['valid_preprocess'] = inception_transform_valid
  if config['data_augmentation']:
    config['preprocess'] = inception_transform
  else : 
    config['preprocess'] = inception_transform_valid

"""

config['train_loader'] = torch.utils.data.DataLoader(FlowerDataset(train[0],config['preprocess'],flower_labels[0]), 
                                                     batch_size=config['batch_size'], shuffle=True)

config['val_loader'] = torch.utils.data.DataLoader(FlowerDataset(valid[0],config['preprocess'],flower_labels[0]), 
                                                     batch_size=config['batch_size'], shuffle=True)

config['dataset'] = 'flower'

config['nb_classes'] = 102

"""

config['train_loader'] = torch.utils.data.DataLoader(iNaturalistDataset(fungi_train,config['preprocess'],y_fungi_train,categories_list), 
                                                     batch_size=config['batch_size'], shuffle=True)

config['val_loader'] = torch.utils.data.DataLoader(iNaturalistDataset(fungi_valid,config['valid_preprocess'],y_fungi_valid,categories_list), 
                                                     batch_size=config['batch_size'], shuffle=True)

config['dataset'] = 'fungi'

config['nb_classes'] = 121



if model_type == resnet:
  model = torch.hub.load('pytorch/vision:v0.10.0', 'resnet18', pretrained=config['pretrained'])
  model._modules['fc'] = nn.Linear(in_features=512, out_features=config['nb_classes'], bias=True) 
elif model_type == inception:
  model = torch.hub.load('pytorch/vision:v0.10.0', 'inception_v3', pretrained=config['pretrained'],aux_logits=False,transform_input=False)
  model._modules['fc'] = nn.Linear(in_features=2048, out_features=config['nb_classes'], bias=True)

config['optimizer'] = torch.optim.Adam(
    filter(lambda p: p.requires_grad, model.parameters()),
    lr=1e-3,
    betas=(0.9, 0.99),
)

config['loss'] = torch.nn.CrossEntropyLoss()

In [None]:
#!wandb offline
with wandb.init(
        config=config,
        project='INF8225 - projet',  # Title of your project
        group='pretrained - inception',  # In what group of runs do you want this run to be in?
        save_code=True,
        entity="katia_juliette"
    ):
    train_model(model, config)