[![Fixel Algorithms](https://fixelalgorithms.co/images/CCExt.png)](https://fixelalgorithms.gitlab.io)
 
# Test Case - Train a Model
This notebooks follows the test case of object detection.  
This notebooks uses a transfer learning approach to train a model.

> Notebook by:
> - Royi Avital RoyiAvital@fixelalgorithms.com

Remarks:
- This notebook is self contained. It downloads the files it requires.

To Do & Ideas:
 - B

## Revision History

| Version | Date       | User        |Content / Changes                                                   |
|---------|------------|-------------|--------------------------------------------------------------------|
| 0.1.000 | 17/03/2023 | Royi Avital | First version                                                      |
|         |            |             |                                                                    |

In [None]:
# General Tools
import numpy as np
import scipy as sp
import pandas as pd

# OpenCV
import cv2 as cv
import PIL 

# PyTorch
import torch
from torch.utils.data import Subset
from torchvision import transforms

# Miscellaneous
import datetime
import math
import os
from platform import python_version
import random
import shutil

import urllib
import warnings
import yaml

# Visualization
import matplotlib.pyplot as plt
import seaborn as sns

# Jupyter
from IPython import get_ipython
from IPython.display import Image, display
from ipywidgets import Dropdown, FloatSlider, IntSlider, Layout
from ipywidgets import interact

In [None]:
# Configuration
%matplotlib inline

warnings.filterwarnings("ignore")

seedNum = 512
np.random.seed(seedNum)
random.seed(seedNum)

# sns.set_theme() #>! Apply SeaBorn theme

runInGoogleColab = 'google.colab' in str(get_ipython())

In [None]:
if runInGoogleColab:
  !nvidia-smi
  # Google Colab
  from google.colab import files

In [None]:
# Constants

PROJECT_NAME        = 'StickerDetection'
DATA_FOLDER_NAME    = 'Data'
MODEL_FOLDER_NAME   = 'Models'

TIME_STAMP_FORMAT = '%Y_%m_%d_%H_%M_%S' #<! For the strftime() formatter

DATA_FILE_URL   = r'https://drive.google.com/uc?export=download&confirm=9iBg&id=1lVDjxqOzR89ev0UmLGWZFbbcJ4Dpe5F6'
DATA_FILE_NAME  = r'Data.zip'

In [None]:
if not os.path.exists(DATA_FILE_NAME):
    urllib.request.urlretrieve(DATA_FILE_URL, DATA_FILE_NAME)

    shutil.unpack_archive(DATA_FILE_NAME, '.')

In [None]:
# Detecto
from core import Dataset, Model
from utils import normalize_transform, read_image, reverse_normalize, xml_to_csv
from visualize import show_labeled_image

In [None]:
# Parameters

# Data
lDataClass  = ['Ship']
trainSize   = 0.95

# Detecto
preTrainedMode  = True
# backEnd         = 'fasterrcnn_resnet50_fpn' #<! 42 [Sec] / 2 Epochs
backEnd         = 'fasterrcnn_mobilenet_v3_large_fpn' #<! 37 [Sec] / 2 Epochs
# backEnd         = 'fasterrcnn_mobilenet_v3_large_320_fpn' 
numEpoch        = 2
learningRate    = 0.00075
momentumFctr    = 0.9
l2RegFctr       = 0.0005
gammaFctr       = 0.85
lrStepSize      = 3
verboseFlag     = True

scoreThr = 0.85

csvFileName   = 'DataLabels.csv'
modelFileName = 'ShipDetector'
modelFileExt  = 'pth'


In [None]:
# Auxiliary Functions

def GenTrainTesIdx(numSamples: int, trainSize: float = 0.8, seedNum: int = 123):
    
    vAllIdx         = np.arange(numSamples)
    numTainsSamples = int(trainSize * numSamples)
    
    rng = np.random.default_rng(seedNum) #<! Stable Random Number Generator
    
    vTrainIdx   = rng.choice(numSamples, numTainsSamples, replace = False) 
    vTestIdx    = np.setdiff1d(vAllIdx, vTrainIdx)

    vTrainIdx   = np.sort(vTrainIdx)
    vTestIdx    = np.sort(vTestIdx)

    return vTrainIdx, vTestIdx

def ExtractBestBoxLabel( boxLabel: list, boxCoord: torch.Tensor, boxScore: torch.Tensor ) -> tuple[list, torch.Tensor, list]:

    # Assuming boxScore is sorted in descending order
    labelIdx = {}
    for ii, label in enumerate(boxLabel):
        if label not in labelIdx:
            labelIdx[label] = ii
    
    keepIdx = [val for val in labelIdx.values()]

    return [boxLabel[ii] for ii in keepIdx], boxCoord[keepIdx], boxScore[keepIdx].tolist()

def ExtractBoxLabelThr( boxLabel: list, boxCoord: torch.Tensor, boxScore: torch.Tensor, scoreThr: float = 0.9 ) -> tuple[list, torch.Tensor, list]:

    tF = boxScore > scoreThr

    return [boxLabel[ii] for ii in range(boxCoord.shape[0]) if tF[ii]], boxCoord[tF], boxScore[tF].tolist()

def CheckNanModel( modelNet: torch.nn.Module ) -> bool:
    hasNan = torch.stack([torch.isnan(p).any() for p in modelNet.parameters()]).any()

    return bool(hasNan)
    

In [None]:
# Set Transformer

imgTransform = transforms.Compose([
    transforms.ToPILImage(),
    transforms.ColorJitter(brightness = 0.1, contrast = 0.1, saturation = 0.05),
    transforms.RandomHorizontalFlip(0.5),
    transforms.ToTensor(),
    normalize_transform(),
])

imgNoTransform = transforms.Compose([
    transforms.ToTensor(),
])

imgTransformValid = transforms.Compose([
    transforms.ToTensor(),
    normalize_transform(),
])

In [None]:
# Data Folder
dataFolderPath = os.path.join(DATA_FOLDER_NAME, PROJECT_NAME) #<! Relative to Drive root
dataFolderPath = 'Data'

In [None]:
dfData = xml_to_csv(dataFolderPath, csvFileName)
dfData.loc[:, 'class'] = 'Ship'
dfData.to_csv(csvFileName, index = False)

dfData.head(10)

In [None]:
# Define the Data Set
dsShips     = Dataset(csvFileName, image_folder = DATA_FOLDER_NAME, transform = imgTransform)
dsShipsView = Dataset(csvFileName, image_folder = DATA_FOLDER_NAME, transform = imgNoTransform) #<! For view
dsShipsVal  = Dataset(csvFileName, image_folder = DATA_FOLDER_NAME, transform = imgTransformValid) #<! For validation
numSamples = len(dsShips)

print(f'The Number of Samples: {numSamples}')

In [None]:
# Train / Test
vTrainIdx, vValIdx = GenTrainTesIdx(numSamples, trainSize = trainSize, seedNum = seedNum)

dsTrain   = Subset(dsShips, vTrainIdx)
dsVal     = Subset(dsShipsVal, vValIdx)
dsValView = Subset(dsShipsView, vValIdx)

numSamples = len(dsTrain)
print(f'The Number of Samples for Training: {numSamples}')
print(f'The Number of Samples for Validation: {len(dsVal)}')

In [None]:
# Sample Image + Target
sampleIdx = np.random.randint(0, numSamples)
tSampleImage, dSampleTarget = dsShips[sampleIdx]
tSampleImageV, dSampleTargetV = dsShipsView[sampleIdx]
tSampleImageV = np.transpose(tSampleImageV.numpy(), (1, 2, 0))
show_labeled_image(tSampleImageV, dSampleTargetV['boxes'], dSampleTargetV['labels'])

In [None]:
# Defining the Model
modelDetector = Model(classes = lDataClass, pretrained = preTrainedMode, model_name = backEnd)

In [None]:
# Fit the Model
fitHistory = modelDetector.fit(dsTrain, val_dataset = dsVal, epochs = numEpoch, learning_rate = learningRate, momentum = momentumFctr, weight_decay = l2RegFctr, gamma = gammaFctr, lr_step_size = lrStepSize, verbose = verboseFlag)

In [None]:
# Plot Loss over Epochs

hF, hA = plt.subplots(figsize = (8, 6))
hA.plot(fitHistory)
hA.set_title('Validation Loss')
hA.set_xlabel('Epoch Index')
hA.set_ylabel('Loss')

plt.show()

In [None]:
# Save the Model
hasNan = CheckNanModel(modelDetector._model) #<! Check model weights
if not hasNan:
    filePostfix   = datetime.datetime.now().strftime(TIME_STAMP_FORMAT)
    modelFileName += ('_' + filePostfix + '.' + modelFileExt)
    # modelDetector.save(os.path.join(MODEL_FOLDER_NAME, modelFileName))
    modelDetector.save(modelFileName)
    if runInGoogleColab:
        files.download(modelFileName)
else:
    raise ValueError(f'The model weights contain NaN value(s)')

In [None]:
# Sample Image Output
labels, boxes, scores = modelDetector.predict(tSampleImage)

print("Box Labels", labels)
print("Box Coordinates", boxes)
print("Box Scores", scores)



In [None]:
# Extract the Unique Labels
boxLabel, boxCoord, boxScore = ExtractBestBoxLabel(labels, boxes, scores)

print("Box Labels", boxLabel)
print("Box Coordinates", boxCoord)
print("Box Scores", boxScore)

# Display Result
show_labeled_image(tSampleImageV, boxCoord, boxLabel)

In [None]:
# Test on Validation Data Set
for ii, (sampleImg, targetData) in enumerate(dsVal):
  sampleImgV, _ = dsValView[ii]
  sampleImgV = np.transpose(sampleImgV.numpy(), (1, 2, 0))
  labels, boxes, scores = modelDetector.predict(sampleImg)
  # boxLabel, boxCoord, boxScore = ExtractBestBoxLabel(labels, boxes, scores)
  boxLabel, boxCoord, boxScore = ExtractBoxLabelThr(labels, boxes, scores, scoreThr)
  show_labeled_image(sampleImgV, boxCoord, boxLabel, boxScore)