In [None]:
import numpy as np  # Load required libs
import pandas as pd
import matplotlib.pyplot as plt
import torch
import torchvision
from tqdm.notebook import tqdm
import os
import zipfile
import copy
import time
import sys
from typing import Tuple, List, Type, Dict, Any
import adabelief_pytorch
import gc  # sometimes it is required to clean GPU memory in that way

In [None]:
def clear_cuda():
    gc.collect()
    torch.cuda.empty_cache()

In [None]:
def train_data_processing_casia(data, model, loss_function):
    anchor = data[0]
    positive = data[1]
    target = data[2]

    anchor_pred, anchor_embedding = model(anchor)
    positive_pred, positive_embedding = model(positive)
    return loss_function(anchor_pred, anchor_embedding, positive_pred, positive_embedding, target)

def train_data_processing_cifar(data, model, loss_function):
    img = data[0]
    target = data[1]

    pred = model(img)
    return loss_function(pred, target)

def val_data_processing_casia(data, model, loss_function):
    anchor = data[0]
    target = data[1]
    anchor_pred, anchor_embedding = model(anchor)
    
    correct = (anchor_pred.argmax(1) == target).type(torch.float).sum().item()
    val_loss = loss_function(anchor_pred, target).item()
    return correct, val_loss

def val_data_processing_cifar(data, model, loss_function):
    img = data[0]
    target = data[1]
    pred = model(img)

    correct = (pred.argmax(1) == target).type(torch.float).sum().item()
    val_loss = loss_function(pred, target).item()
    return correct, val_loss

In [None]:
def train_single_epoch(model: torch.nn.Module,
                       optimizer: torch.optim.Optimizer, 
                       loss_function: torch.nn.Module, 
                       datagiver,
                       data_processing):
    steps_per_epoch = datagiver.get_train_steps_per_epoch()
    batch_size = datagiver.get_train_batch_size()

    model.train()
    size = steps_per_epoch * batch_size # STEPS_PER_EPOCH_TRAIN * BATCH_SIZE
    loss_val = 0
    with tqdm(total=steps_per_epoch) as pbar:
        step = 0
        for _ in range(steps_per_epoch):
            data = datagiver.get(block=True)
            loss = data_processing(data, model, loss_function)
            loss_val += loss.item()

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            pbar.update()
            pbar.set_postfix({'Batch num - ': step, 'loss_value - ': loss_val, 'loss on batch - ': loss.item()})
            step += 1
    clear_cuda()
    loss_val = loss_val / size
    return loss_val

In [None]:
def validate_single_epoch(model: torch.nn.Module,
                          loss_function: torch.nn.Module, 
                          datagiver,
                          data_processing):
    model.eval()
    steps_per_epoch = datagiver.get_val_steps_per_epoch()
    batch_size = datagiver.get_val_batch_size()

    size = steps_per_epoch * batch_size
    val_loss, correct = 0, 0
    
    with torch.no_grad():
        for _ in range(steps_per_epoch):
            data = datagiver.get(block=True)
            d_correct, d_val_loss = data_processing(data, model, loss_function)
            
            correct += d_correct
            val_loss += d_val_loss

    clear_cuda()
            
    val_loss /= steps_per_epoch
    correct /= size
    print('correct - {} , loss - {}'.format(correct, val_loss))
    return {'loss': val_loss, 'accuracy' : correct}

In [None]:
def train_model(model: torch.nn.Module, 
                datagiver,
                path_to_save,
                loss_function: torch.nn.Module = torch.nn.CrossEntropyLoss(),
                optimizer_params = 'default',
                lr_scheduler_params: Dict = {},
                max_epochs = 50,
                early_stopping_patience = 15,
                save_each_epoch=False,
                loss_progression = None,
                dataset_name = 'cifar'):
    train_processing = {'cifar':train_data_processing_cifar, 'casia':train_data_processing_casia}
    val_processing = {'cifar':val_data_processing_cifar, 'casia':val_data_processing_casia}
    
    if optimizer_params == 'default':
        optimizer = adabelief_pytorch.AdaBelief(model.parameters(),
                                                lr=0.01,
                                                betas=(0.9, 0.999),
                                                eps=1e-8,
                                                weight_decouple=True,
                                                rectify=False,
                                                weight_decay=1e-2,
                                                fixed_decay=False,
                                                amsgrad=False)
    else:
        optimizer = adabelief_pytorch.AdaBelief(model.parameters(), **optimizer_params)
    
    lr_scheduler1 = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.5, patience=3)  # Расписание
    lr_scheduler2 = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.965)

    if dataset_name == 'casia':
        val_loss_fn = torch.nn.CrossEntropyLoss()
    else:
        val_loss_fn = loss_function

    best_val_loss = None
    best_epoch = None
    
    loss_history = []
    
    with tqdm(total=max_epochs) as pbar:
        for epoch in tqdm(range(max_epochs)):
            datagiver.change_task('train')
            train_epoch_loss_history = train_single_epoch(model, optimizer, loss_function, datagiver, train_processing[dataset_name])
            loss_history.append(train_epoch_loss_history)
            datagiver.change_task('validate')
            val_metrics = validate_single_epoch(model, val_loss_fn, datagiver, val_processing[dataset_name])
            lr_scheduler1.step(val_metrics['loss'])
            lr_scheduler2.step()
            pbar.update
            pbar.set_postfix({'Epoch - ': epoch})

            if loss_progression:
                loss_progression(loss_function)
            
            if best_val_loss is None or best_val_loss > val_metrics['loss']:
                print(f'Best model yet, saving')
                best_val_loss = val_metrics['loss']
                best_epoch = epoch
                torch.save(model, path_to_save + model.__class__.__name__ + '_best' + '.tp')
                
            if epoch - best_epoch > early_stopping_patience:
                print('Early stopping has been triggered')
                return
    return loss_history