# Imports

In [1]:
%matplotlib inline
import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation

from torch import optim
import torch.nn as nn


# ---
import sys; sys.path.append('../')

from commons.imgs_mean_std import *

from commons.dataset import *
from commons.imageutils import *
from commons.HistCollection import *

from modules.img_transforms import *
from modules.train_functions import *
from modules.Comparator import *

from modules.Dataset import *
from modules.EarlyStopper import *

from modules.Model import *

# Constants & hyperparams

In [2]:
RANDOM_STATE=None

IMAGE_SIZE=220

BATCH_SIZE=32
NUM_EPOCHS=100

LEARNING_RATE=.0001
MOMENTUM=.9

EARLY_PATIENCE=15
SCHEDULER_PATIENCE=9

# CUDA

In [None]:
print(torch.__version__)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Device: {device}")

if device == 'cuda':
    print(torch.cuda.is_available())
    print(torch.cuda.get_device_name(0))
    print('Devices:', torch.cuda.device_count())

# Loading datasets

In [4]:
train_dataset = pd.read_csv(os.path.join(DATASETS_PATHS.norm_faces.info, 'train_dataset.csv'))
val_dataset = pd.read_csv(os.path.join(DATASETS_PATHS.norm_faces.info, 'val_dataset.csv'))
test_dataset = pd.read_csv(os.path.join(DATASETS_PATHS.norm_faces.info, 'test_dataset.csv'))

# Creating custom datasets

In [5]:
normalization= normalize(IMAGE_SIZE, MEAN, STD)

train_dataset = CD_TrippletsCreator(train_dataset, transform=normalization, data_augmentation_tranforms=[data_augmentation()])
val_dataset = CD_TrippletsCreator(val_dataset, transform=normalization)
test_dataset = CD_TrippletsCreator(test_dataset, transform=normalization)

# Dataloaders

In [6]:
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, pin_memory=True, drop_last=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, pin_memory=True, drop_last=True)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, pin_memory=True, drop_last=True)

# Early stopping

In [7]:
def early_callback(**kwarks):
    print(" > Early Stop <")

early_stopper = EarlyStopper(EARLY_PATIENCE, .001, callback=early_callback, verbose=True)

# Loading model

In [None]:
model = Model()
if device == 'cuda' and torch.cuda.device_count() > 1: model = nn.DataParallel(model)

model.to(device)

# Defining tools

In [None]:
criterion = nn.TripletMarginWithDistanceLoss(distance_function=dst).to(device)

optimizer = optim.SGD(model.parameters(), lr=LEARNING_RATE, momentum=MOMENTUM)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=.2, patience=SCHEDULER_PATIENCE, verbose=True)

# Training

In [None]:
def callback(**kwargs):
	print(f"Epoch: {kwargs['epoch']} =================================")
	print(f"\ttrain_loss: {kwargs['train_loss']:2f}\tval_loss: {kwargs['val_loss']:2f}")
	print(f"\tdst_mean_train (pos, neg): ({np.mean(kwargs['train_h'].posit_dst[-1]):2f}, {np.mean(kwargs['train_h'].negat_dst[-1]):2f}) dst_mean_val (pos, neg): ({np.mean(kwargs['val_h'].posit_dst[-1]):2f}, {np.mean(kwargs['val_h'].negat_dst[-1]):2f})")


train_h, val_h = train_loop(model, train_loader, val_loader=val_loader,
                            optimizer=optimizer, criterion=criterion, dst=dst, num_epoch=100, device=device,
                            early_stopper=early_stopper, scheduler=scheduler, callback=callback)

model.load_state_dict(early_stopper.best_model)

In [11]:
save_model(model, 'model.pth', complete=False)

# Model evaluation

In [12]:
def train_val_analisys(**kwargs):

	train_posit_mean_h = [np.mean(elm) for elm in  kwargs['train_h'].posit_dst]
	train_negat_mean_h = [np.mean(elm) for elm in  kwargs['train_h'].negat_dst]

	train_posit_median_h = [np.median(elm) for elm in kwargs['train_h'].posit_dst]
	train_negat_median_h = [np.median(elm) for elm in kwargs['train_h'].negat_dst]

	val_posit_mean_h = [np.mean(elm) for elm in  kwargs['val_h'].posit_dst]
	val_negat_mean_h = [np.mean(elm) for elm in  kwargs['val_h'].negat_dst]

	val_posit_median_h = [np.median(elm) for elm in kwargs['val_h'].posit_dst]
	val_negat_median_h = [np.median(elm) for elm in kwargs['val_h'].negat_dst]


	plt.figure(figsize=(24, 6))

	plt.subplot(1, 3, 1); plt.title('Train / Val loss')
	plt.xlabel('Epochs')
	plt.ylabel('Loss')
	plt.xticks([1, len(kwargs['val_h'])])
	plt.plot(kwargs['train_h'].loss, label='Train')
	plt.plot(kwargs['val_h'].loss, label='Val')
	plt.plot(kwargs['best_loss'][0], kwargs['best_loss'][1], 'o', )
	plt.annotate('best', (kwargs['best_loss'][0], kwargs['best_loss'][1]))
	plt.axvline(kwargs['best_loss'][0], linestyle='dashed', color='red', linewidth=1)
	plt.legend()

	plt.subplot(1, 3, 2); plt.title('Distances means')
	plt.xlabel('Epochs')
	plt.ylabel('Mean')
	plt.xticks([1, len(kwargs['val_h'])])
	plt.plot(train_posit_mean_h, label='Train positive')
	plt.plot(train_negat_mean_h, label='Train negative')
	plt.plot(val_posit_mean_h, label='Val positive')
	plt.plot(val_negat_mean_h, label='Val negative')
	plt.axvline(kwargs['best_loss'][0], linestyle='dashed', color='red', linewidth=1)
	plt.legend()

	plt.subplot(1,3,3); plt.title('Distances medians')
	plt.xlabel('Epoch')
	plt.ylabel('Median')
	plt.xticks([1, len(kwargs['val_h'])])
	plt.plot(train_posit_median_h, label='Train positive')
	plt.plot(train_negat_median_h, label='Train negative')
	plt.plot(val_posit_median_h, label='Val positive')
	plt.plot(val_negat_median_h, label='Val negative')
	plt.axvline(kwargs['best_loss'][0], linestyle='dashed', color='red', linewidth=1)
	plt.legend()


In [None]:
train_val_analisys(train_h=train_h, val_h=val_h, best_loss=(early_stopper.best_epoch, early_stopper.best_loss))

# Test

In [None]:
test_h = eval(model, test_loader, device=device, criterion=criterion, dst=dst)
print('Loss', test_h.loss)
print(f'Posit dst:\t| mean {np.mean(test_h.posit_dst)}\t| median {np.median(test_h.posit_dst)}\t| std {np.std(test_h.posit_dst)}\t| min, max {np.min(test_h.posit_dst)}, {np.max(test_h.posit_dst)}\t| diff(min,max) {np.max(test_h.posit_dst)-np.min(test_h.posit_dst)}')
print(f'Negat dst:\t| mean {np.mean(test_h.negat_dst)}\t| median {np.median(test_h.negat_dst)}\t| std {np.std(test_h.negat_dst)}\t| min, max {np.min(test_h.negat_dst)}, {np.max(test_h.negat_dst)}\t| diff(min,max) {np.max(test_h.negat_dst)-np.min(test_h.negat_dst)}')
print(f'Diff(max(posit_dst), min(negat_dst)) {min(test_h.negat_dst) - max(test_h.posit_dst)}')

# Segundo entrenamiento con el mejor modelo actual

In [None]:
train_h_, val_h_ =  train_h.copy(), val_h.copy()

In [None]:
early_stopper.reset()

train_h_, val_h_ = train_loop(model, train_loader, val_loader=val_loader, initial_epoch=early_stopper.best_epoch,
                            optimizer=optimizer, criterion=criterion, dst=dst, num_epoch=100, device=device,
                            early_stopper=early_stopper, scheduler=scheduler, callback=callback, train_h=train_h_, val_h=val_h_)

model.load_state_dict(early_stopper.best_model)
save_model(model, 'model.pth', complete=False)

In [16]:
model.load_state_dict(early_stopper.best_model)
save_model(model, 'model.pth', complete=False)

# Analisys again

In [None]:
train_val_analisys(train_h=train_h_, val_h=val_h_, best_loss=(early_stopper.best_epoch, early_stopper.best_loss))

In [None]:
test_h = eval(model, test_loader, device=device, criterion=criterion, dst=dst)
print('Loss', test_h.loss)
print(f'Posit dst:\t| mean {np.mean(test_h.posit_dst)}\t| median {np.median(test_h.posit_dst)}\t| std {np.std(test_h.posit_dst)}\t| min, max {np.min(test_h.posit_dst)}, {np.max(test_h.posit_dst)}\t| diff(min,max) {np.max(test_h.posit_dst)-np.min(test_h.posit_dst)}')
print(f'Negat dst:\t| mean {np.mean(test_h.negat_dst)}\t| median {np.median(test_h.negat_dst)}\t| std {np.std(test_h.negat_dst)}\t| min, max {np.min(test_h.negat_dst)}, {np.max(test_h.negat_dst)}\t| diff(min,max) {np.max(test_h.negat_dst)-np.min(test_h.negat_dst)}')
print(f'Diff(max(posit_dst), min(negat_dst)) {min(test_h.negat_dst) - max(test_h.posit_dst)}')

TIENES UN PROBLEMA EN COMO HACES EL SEGUNDO ENTRENAMIENTO, PORQUE SI EMPIEZAS DESDE LA EPOCA 14, Y TIENES EL REGISTRO ANTERIOR (VAL_H), NO CUADRAN LAS COSAS