In [1]:
import numpy as np  # Load required libs
import pandas as pd
import matplotlib.pyplot as plt
import torch
import torchvision
from tqdm.notebook import tqdm
import os
import zipfile
import copy
import time
import sys
from typing import Tuple, List, Type, Dict, Any
import adabelief_pytorch
import gc  # sometimes it is required to clean GPU memory in that way

ModuleNotFoundError: ignored

In [None]:
def clear_cuda():
    gc.collect()
    torch.cuda.empty_cache()

In [None]:
def train_data_processing_celeba(data, model, loss_fn):
    imgs, labels = data
    embeddings = model.embedder(imgs)
    logits = model.classifier(embeddings, labels)
    
    loss = loss_fn(logits, labels)
    return loss

def val_data_processing_celeba(data, model, loss_fn):
    imgs, labels = data
    embeddings = model.embedder(imgs)
    pred = model.classifier.calculate_cosines(embeddings)
    
    correct = (pred.argmax(1) == labels).type(torch.float).sum().item()
    loss = loss_fn(pred, labels)
    return correct, loss

In [None]:
def train_single_epoch(model: torch.nn.Module,
                       optimizer: torch.optim.Optimizer, 
                       loss_function: torch.nn.Module, 
                       datagiver,
                       train_data_processing,
                       scheduler,
                       epoch_number,
                       scheduler_step_every_epoch = False):
    steps_per_epoch = datagiver.get_train_steps_per_epoch()

    model.train()
    loss_val = 0
    with tqdm(total=steps_per_epoch) as pbar:
        for step in range(steps_per_epoch):
            data = datagiver.get(block=True)
            loss = train_data_processing(data, model, loss_function)
            loss_val += loss.item()

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            if scheduler_step_every_epoch:
                scheduler.step(epoch_number + step / steps_per_epoch)
            
            pbar.update()
            pbar.set_postfix({'loss_value - ': loss_val, 'loss on batch - ': loss.item()})
            
    if not scheduler_step_every_epoch:
        scheduler.step()
    clear_cuda()
    loss_val /= steps_per_epoch
    return loss_val

In [None]:
def validate_single_epoch(model: torch.nn.Module,
                          loss_function: torch.nn.Module, 
                          datagiver,
                          val_data_processing):
    model.eval()
    steps_per_epoch = datagiver.get_val_steps_per_epoch()
    batch_size = datagiver.get_val_batch_size()

    size = steps_per_epoch * batch_size
    val_loss, correct = 0, 0
    
    with torch.no_grad():
        for _ in range(steps_per_epoch):
            data = datagiver.get(block=True)
            step_correct, step_val_loss = val_data_processing(data, model, loss_function)
            
            correct += step_correct
            val_loss += step_val_loss

    clear_cuda()
            
    val_loss /= steps_per_epoch
    correct /= size
    print('accuracy - {} , loss - {}'.format(correct, val_loss))
    return {'loss': val_loss, 'accuracy' : correct}

In [None]:
def train_model(model: torch.nn.Module, 
                datagiver,
                path_to_save,
                loss_function: torch.nn.Module = torch.nn.CrossEntropyLoss(),
                optimizer_params = 'default',
                lr_scheduler_class = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts,
                lr_scheduler_params: Dict = {},
                lr_scheduler_step_every_epoch = False,
                max_epochs = 50,
                early_stopping_patience = 15,
                loss_progression = None,
                train_data_processing = train_data_processing_celeba,
                val_data_processing = val_data_processing_celeba,
                model_name_appendix=""):
    if optimizer_params == 'default':
        optimizer = adabelief_pytorch.AdaBelief(model.parameters(),
                                                lr=0.01,
                                                betas=(0.9, 0.999),
                                                eps=1e-8,
                                                weight_decouple=True,
                                                rectify=False,
                                                weight_decay=1e-2,
                                                fixed_decay=False,
                                                amsgrad=False)
    else:
        optimizer = adabelief_pytorch.AdaBelief(model.parameters(), **optimizer_params)
        
    scheduler = lr_scheduler_class(optimizer, **lr_scheduler_params)

    best_val_loss = None
    best_epoch = None
    
    loss_history = []
    
    with tqdm(total=max_epochs) as pbar:
        for epoch in tqdm(range(max_epochs)):
            datagiver.change_task('train')
            train_epoch_loss_history = train_single_epoch(model,
                                                          optimizer,
                                                          loss_function,
                                                          datagiver,
                                                          train_data_processing,
                                                          scheduler,
                                                          epoch,
                                                          scheduler_step_every_epoch = lr_scheduler_step_every_epoch)
            
            loss_history.append(train_epoch_loss_history)
            datagiver.change_task('validate')
            val_metrics = validate_single_epoch(model, loss_function, datagiver, val_data_processing)
            pbar.update
            pbar.set_postfix({'Epoch - ': epoch})

            if loss_progression:
                loss_progression(loss_function)
            
            if best_val_loss is None or best_val_loss > val_metrics['loss']:
                print(f'Best model yet, saving')
                best_val_loss = val_metrics['loss']
                best_epoch = epoch
                torch.save(model, path_to_save + "/" + model.__class__.__name__ + '_best' + model_name_appendix + '.tp')
                
            if epoch - best_epoch > early_stopping_patience:
                print('Early stopping has been triggered')
                return
    return loss_history