## Initialize Libraries

In [None]:
import os
import glob

import pandas as pd
import numpy as np

import torch
from torch import nn
from torch import optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torchvision.models import resnet50, ResNet50_Weights
from torchvision.io import read_image

import urllib.request
from PIL import Image
import cv2
import matplotlib.pyplot as plt

from google.colab import output
output.enable_custom_widget_manager()
from tqdm import tqdm, trange

from sklearn.metrics import confusion_matrix
from datetime import datetime

## Mount GDrive

In [None]:
from google.colab import drive
drive.mount('/content/gdrive')

## Initialize Path for each Dataset

In [None]:
path_train = '/content/gdrive/MyDrive/Colab Notebooks/Final_data/train'
path_val = '/content/gdrive/MyDrive/Colab Notebooks/Final_data/val'
path_test = '/content/gdrive/MyDrive/Colab Notebooks/Final_data/test'

## Dataset Pre-Processing

Must be built with init, len, and getitem

Reference: https://pytorch.org/tutorials/beginner/basics/data_tutorial.html

__len__: https://pytorch.org/docs/master/data.html#torch.utils.data.Dataset

In [None]:
class GreenFingerDataset(Dataset):

    def __init__(self, data_path, label_dict, oversample=False, transforms=None):
        # retrieve jpg file path from directory: training/validation/testing
        self.files = sorted(glob.glob(os.path.join(data_path, "*/*.jpg")))
        # convert jpg file path to dataframe and attach true label
        self.df = pd.DataFrame(
            dict(
                cat=[f.split('/')[-2] for f in self.files],
                image_path=self.files
            )
        )
        # copy dataframe
        self.df_oversampled = self.df.copy()
        # process oversample function
        if oversample:
            self.oversample()
        # initialize transform procedure
        self.transforms = transforms
        # initialize true label numbering
        self.label_dict = label_dict

    def oversample(self):
        """
        This functions can be called or 
        initialized automatically when the oversample=True
        """
        # Random sampling til the # of sample matches to the # of largest category
        cats = self.df.cat.drop_duplicates().to_numpy()
        cat_sizes = []

        for c in cats:
            n = len(self.df.query(f"cat == '{c}'"))
            cat_sizes.append(n)

        cat_sizes = np.array(cat_sizes)
        dfs = []
        n_majority = np.max(cat_sizes)

        for i, c in enumerate(cats):
            df_cat = self.df.query(f"cat == '{c}'").sample(frac=n_majority/cat_sizes[i], replace=True)
            dfs.append(df_cat)

        self.df_oversampled = pd.concat(dfs, axis=0).reset_index(drop=True)

    def __len__(self):
        return len(self.df_oversampled)

    def __getitem__(self, idx):
        img_src, cat = self.df_oversampled.iloc[idx][["image_path", "cat"]]
        # What's the difference between torchvision's readimage and PIL's Image.open?
        x = Image.open(img_src)
        # x = read_image(img_src)
        if self.transforms:
            x = self.transforms(x)
        x = x.float()
        if len(x.size()) == 2:
            # when it's multi channels, x.size() shows (Channel, Heigh, Weight)
            # when channel is 1, x.size() shows (Heigh, Weight) w/o channel
            x = x.unsqueeze(0) 
            x = torch.cat((x, x, x), axis=0)
        y = self.label_dict[cat]
        y = torch.Tensor([y]).long()
        
        return x, y

## Preview Data Augmentation

In [None]:
my_transforms = transforms.Compose(
    [
        transforms.ToTensor(),
        transforms.RandomResizedCrop(size=256, scale=(0.2, 1.0), ratio=(1.0, 1.0)),
        transforms.RandomHorizontalFlip(),
        transforms.RandomRotation(degrees=(-10, 10), fill=(0,)),
        transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.2, hue=0.05),
    ]
)

ds_preview = GreenFingerDataset(
    path_train,
    label_dict,
    transforms=my_transforms
)
print("ds", len(ds_preview))

In [None]:
fig = plt.figure(figsize=(20, 15))
for i in range(25):
    img, _ = ds_preview[1]
    plt.subplot(5, 5, i + 1)
    # Permute is to re-arrange C x H x W to H x W x C, then we can conver to
    # numpy array and visualize via plt
    plt.imshow(torch.permute(img, [1, 2, 0]).numpy())
    plt.axis("off")

## Initialize Neural Network

In [None]:
model = resnet50(weights=ResNet50_Weights.IMAGENET1K_V2)
model.fc = nn.Linear(model.fc.in_features, n_class)
model

## Define Transform Procedure

In [None]:
# Transform for Training Dataset
transforms_train = transforms.Compose(
    [
        transforms.ToTensor(),
        transforms.RandomRotation(degrees=(-10, 10), fill=(0,)),
        transforms.RandomResizedCrop(size=256, scale=(0.2, 1.0), ratio=(1.0, 1.0)),
        transforms.RandomHorizontalFlip(),
        transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.2, hue=0.05),
        transforms.Normalize(
            mean=[0.485, 0.456, 0.406],
            std=[0.229, 0.224, 0.225]
        )
    ]
)

In [None]:
# Transform for Validation Dataset
transforms_val = transforms.Compose(
    [
        transforms.ToTensor(),
        transforms.Resize(256),
        transforms.CenterCrop(256), # unify all image as square shape
        transforms.Normalize( # tied with pre-trained weights
            mean=[0.485, 0.456, 0.406],
            std=[0.229, 0.224, 0.225]
        )
    ]
)

In [None]:
# Transform for Testing Dataset
transforms_test = transforms.Compose(
    [
        transforms.ToTensor(),
        transforms.Resize(256),
        transforms.CenterCrop(256),
        transforms.Normalize(
            mean=[0.485, 0.456, 0.406],
            std=[0.229, 0.224, 0.225]
        )
    ]
)

## Initialize Dataset

In [None]:
# Training Dataset
ds_train = GreenFingerDataset(
    path_train,
    label_dict,
    oversample=True,
    transforms=transforms_train
)
print("ds_train", len(ds_train))

In [None]:
# Validation Dataset
ds_val = GreenFingerDataset(
    path_val,
    label_dict,
    transforms=transforms_val
)
print("ds_val", len(ds_val))

In [None]:
# Testing Dataset
ds_test = GreenFingerDataset(
    path_test,
    label_dict,
    transforms=transforms_test
)
print("ds_test", len(ds_test))

## Training & Validation

In [None]:
# Initialize parameters
n_epoch = 10
lr = 2e-4
batch_size = 128

# seen counter
seen = 0

# Path for saving the training log
user = 'Tom'  # Tom, CY, Zephyr --> Change your name accordingly
log_path = os.path.join('/content/gdrive/MyDrive/Colab Notebooks/Training_Log', user)

In [None]:
# Initialize Data Loader
dl_train = DataLoader(
    ds_train,
    shuffle=True,
    batch_size=batch_size,
    drop_last=True
)
dl_val = DataLoader(
    ds_val,
    shuffle=False,
    batch_size=batch_size,
    drop_last=False,
)

Save & Load Model

Reference: https://pytorch.org/tutorials/beginner/saving_loading_models.html

In [None]:
# Initialize Loss Function
loss_fn = nn.CrossEntropyLoss()

# Initialize Optimizer
optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=1e-7)

device = torch.device("cuda:0")
model = model.to(device)

# Initialize Log
train_log = []

for i in range(n_epoch):
    model.train()
    # no need to switch the model back to model.train()?
    for (x, y) in tqdm(dl_train, desc='Training', position=0, leave=True):
        x = x.to(device)
        y = y.to(device)
        # what's the difference if we put optimizer.zero_grad() to later step?
        optimizer.zero_grad()
        output = model(x)
        loss = loss_fn(output, y[:, 0])
        loss.backward()
        optimizer.step()
        with torch.no_grad():
            seen = seen + x.size(0)
            acc_train = torch.sum(torch.argmax(output, axis=1) == y[:, 0]) / np.sum(batch_size)
            print(f' | Seen {seen} | Loss {loss.item()} | Train Acc: {acc_train}')

    # Validate every 1 epochs
    if (i+1) % 1 == 0:
        model.eval()

        predictions = []

        with torch.no_grad():
            for (x, y) in tqdm(dl_val, desc='Validation', position=0, leave=True):
                x = x.to(device)
                y = y.to(device)
                output = model(x)
                # To-Do: Optimize Prediction & Accuracy?
                predictions.append(output.cpu().numpy())
            # Compute Validation Accuracy
            preds = np.argmax(np.concatenate(predictions, axis=0), axis=1)
            y = np.concatenate([item[1].numpy() for item in ds_val])
            acc_val = np.sum(preds == y) / len(preds)
            print(f' | Seen: {seen} | Loss: {loss.item()} | Val Acc: {acc_val}')
            # Append Training Log
            train_log.append(
                      dict(
                          ePoch = i+1,
                          seen = seen,
                          loss_train = loss.item(),
                          acc_train = acc_train.item(),
                          acc_val = acc_val.item()
                      )
                  )
        torch.save({
          'epoch': i+1,
          'model_state_dict': model.state_dict(),
          'optimizer_state_dict': optimizer.state_dict(),
          'loss': loss.item()
        },
          os.path.join(
              # Select your path from above
              log_path, f'epoch_{i+1}_{datetime.now()}.pt'
          )
        )

In [None]:
df_train = pd.DataFrame(train_log)
df_train

In [None]:
plt.figure()
plt.plot(df_train["seen"], df_train["acc_train"], label="training")
plt.plot(df_train["seen"], df_train["acc_val"], label="validation")
plt.legend()
plt.xlabel("iterations")
plt.ylabel("accuracy")
plt.show()

## Visualize Unmatched Result for Validation Dataset

p.s. Use the last validation set

In [None]:
# Initialize transform proceduce for visualizing images
transforms_val_for_viz = transforms.Compose(
    [
        transforms.ToTensor(),
        transforms.Resize(128),
        transforms.CenterCrop(128),
    ]
)

# Initialize dataset for visualization
ds_val_for_viz = GreenFingerDataset(
    path_val,
    label_dict,
    transforms=transforms_val_for_viz
)
print("ds_val_for_viz", len(ds_val_for_viz))

In [None]:
# Filter image's indices where preds != y (unmatched items' indices)
debug_idxs = np.where(preds != y)
debug_idxs

In [None]:
# Filter ds_val_for_viz dataframe with unmatched items
debug_df = ds_val_for_viz.df.filter(items=debug_idxs[0], axis=0)
debug_df

In [None]:
# Filter preds array with unmatched items
debug_preds = preds[debug_idxs]
debug_preds

In [None]:
# Visualize unmatched items

fig = plt.figure()
fig.set_size_inches(30, 150)
for i in range(len(debug_df)):
  img_src, cat = debug_df.iloc[i][["image_path", "cat"]]
  img = Image.open(img_src)
  pred = debug_preds[i]
  pred = list(label_dict.keys())[list(label_dict.values()).index(pred)]
  plt.subplot(16, 5, i+1)
  plt.imshow(img)
  plt.title(f'P: {pred} \n vs\n T: {cat}')
  plt.axis("off")