# ResNet-50 Transfer Learning Testing
In this notebook, we'll be loading the last trained checkpoint and comparing it to ResNet-50 trained on ImageNet, using Flower Classification CC BY 4.0 dataset, [available here](https://data.mendeley.com/datasets/738sdjm6h9/1).

We try to use the same hyperparameters as [this notebook](https://github.com/ovh/ai-training-examples/blob/main/notebooks/computer-vision/image-classification/tensorflow/resnet50/notebook-resnet-transfer-learning-image-classification.ipynb), so that we're able to compare it to a ResNet-50 trained on ImageNet.


### Constants and imports

In [27]:
import os
import torch
import requests
import tarfile
from tqdm.notebook import tqdm
from lightning.pytorch import Trainer, seed_everything, LightningModule
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision import utils
from torchvision.io import read_image
from torchvision.transforms import v2

CHECKPOINTS_DIRECTORY = 'checkpoints'
FLOWER_DATASET_DIRECTORY = 'flowers'
FLOWER_DATASET_ARCHIVE_FILENAME = 'flower_photos.tgz'

# Notebook and Keras Adam default values
EPOCHS = 10
BATCH_SIZE = 32
LEARNING_RATE = 0.001
BETA_1 = 0.9
BETA_2 = 0.999
EPSILON = 1e-7

TRAIN_DATASET_RATIO = 0.7
VAL_DATASET_RATIO = 0.2
TEST_DATASET_RATIO = 0.1

FLOWER_CLASSIFICATION_URL = 'https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz'

### Download Flower Classification dataset if it isn't downloaded

In [7]:
if not os.path.exists(FLOWER_DATASET_DIRECTORY):
    os.makedirs(FLOWER_DATASET_DIRECTORY)

In [12]:
is_downloaded = False
is_unpacked = False

def download_file(url, path):
    image_data = requests.get(url, stream=True)
    if image_data.status_code == 200:
        with open(path, 'wb') as f:
            for chunk in image_data.iter_content(2048):
                f.write(chunk)

def extract_flower_archive(archive_path, target_path):
    tar = tarfile.open(archive_path, 'r')
    for file in tar.getmembers():
        if file.isdir() or '.jpg' in file.name:
            tar.extract(file, target_path)

flower_directory_listing = os.listdir(FLOWER_DATASET_DIRECTORY)
if len(flower_directory_listing) != 0:
    if 'flower_photos.tgz' in flower_directory_listing:
        is_downloaded = True
    if 'flower_photos' in flower_directory_listing:
        flower_photos_listing = os.listdir(os.path.join(FLOWER_DATASET_DIRECTORY, flower_photos))
        if 'daisy' in flower_photos_listing and 'roses' in flower_photos_listing and 'dandelion' in flower_photos_listing and 'sunflowers' in flower_photos_listing and 'tulips' in flower_photos_listing:
            in_unpacked = True

if not is_unpacked:
    if not is_downloaded:
        download_file(FLOWER_CLASSIFICATION_URL, os.path.join(FLOWER_DATASET_DIRECTORY, FLOWER_DATASET_ARCHIVE_FILENAME))
    extract_flower_archive(os.path.join(FLOWER_DATASET_DIRECTORY, FLOWER_DATASET_ARCHIVE_FILENAME), FLOWER_DATASET_DIRECTORY)
    

### Create the dataset and dataloader

In [23]:
classes = []
for file in os.listdir(os.path.join(FLOWER_DATASET_DIRECTORY, 'flower_photos')):
    if os.path.isdir(os.path.join(FLOWER_DATASET_DIRECTORY, 'flower_photos', file)):
        new_class = []
        for image in os.listdir(os.path.join(FLOWER_DATASET_DIRECTORY, 'flower_photos', file)):
            new_class.append(image)
        classes.append(new_class)

dataset = []
for i, cl in enumerate(classes):
    for sample in enumerate(cl):
        dataset.append({ 'class': i, 'path': sample })

train_paths, validate_paths, test_paths = random_split(dataset, [TRAIN_DATASET_RATIO, VAL_DATASET_RATIO, TEST_DATASET_RATIO])

In [29]:
class FlowerClassificationDataset(Dataset):
    def __init__(self, samples, transform=None):
        self.samples = samples
        self.transform = transform
        
    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        image = read_image(self.samples[idx]['path'])

        image = image.float() / 255

        if self.transform:
            image = self.transform(image)

        return image, self.samples[idx]['classes']


In [31]:
transforms = v2.Resize(size=(224, 224))

train_dataset = FlowerClassificationDataset(train_paths, transforms)
val_dataset = FlowerClassificationDataset(validate_paths, transforms)
test_dataset = FlowerClassificationDataset(test_paths, transforms)

train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=5)
val_dataloader = DataLoader(val_dataset, batch_size=BATCH_SIZE, num_workers=5)
test_dataloader = DataLoader(test_dataset, batch_size=BATCH_SIZE, num_workers=5)

### Load the model

### Change layers of the model

### Train model

### Evaluate the model
The ImageNet-trained has managed to get loss of 0.3605 and accuracy of 88.71%. 