## Training

In [7]:
import torch
import torch.nn as nn
from tqdm import tqdm
from sklearn.metrics import confusion_matrix, precision_recall_fscore_support
from datetime import datetime


def other(one_input, fix_length, device):
    X = one_input['input_ids_0']
    Y = one_input['label']
    if not fix_length:
        decision_position = torch.sum(one_input['mask_0'], dim=1).long() - 1
    else:
        decision_position = torch.zeros(1)
    X, Y, mask = X.to(device), Y.to(device), decision_position.to(device)
    return X, mask, Y


def net_eval(fix_length, val_test, n, eval_loader, device, net, loss, loginf, n_classes, wandb):
    
    eval_loss = 0
    eval_num = 0
    eval_correct = 0
    eval_start = datetime.now()
    
     # Calculate confusion matrix and metrics
    y_true = []
    y_pred = []
    
    for one_input in tqdm(eval_loader, total=len(eval_loader)):
        
        X, mask, Y = other(one_input, fix_length, device)
        pred = net(X, mask)
        eval_loss += loss(pred, Y).item()
        eval_num += len(Y)
        
        _, predicted = pred.max(1)
        eval_correct += predicted.eq(Y).sum().item()
        
        # Extend calculated results
        y_true.extend(Y.cpu().numpy())
        y_pred.extend(predicted.cpu().numpy())
        
    eval_loss_mean = eval_loss / eval_num
    eval_acc = eval_correct / eval_num * 100
    eval_end = datetime.now()
    eval_time = (eval_end - eval_start).total_seconds()
    
    loginf('{} num: {} — {} loss: {} — {} accuracy: {} — Time: {}'.format(val_test, eval_num, val_test, eval_loss_mean, val_test, eval_acc, eval_time))
    loginf('_' * n)
    
    # Generate confusion matrix
    cm = confusion_matrix(y_true, y_pred)
    precision, recall, f1, _ = precision_recall_fscore_support(y_true, y_pred, average='weighted')
            
    
    # Log confusion matrix and metrics to wandb
    wandb.log({
            f"{val_test}_confusion_matrix": wandb.plot.confusion_matrix(
                  probs=None,
                  y_true=y_true,
                  preds=y_pred,
                  class_names=tf.unique(one_input['label']).y.numpy()
            ),
            f"{val_test}_precision": precision,
            f"{val_test}_recall": recall,
            f"{val_test}_f1": f1
    })

    return eval_loss_mean, eval_acc


def TrainModel(
        fix_length,
        net,
        device,
        trainloader,
        valloader,
        testloader,
        n_epochs,
        n_classes,
        optimizer,
        loss,
        loginf,
        wandb,
        file_name
):
    saving_best = 0

    for epoch in range(n_epochs):
        
        # train
        net.train()

        train_loss = 0
        train_num = 0
        t_start = datetime.now()
        
        for one_input in tqdm(trainloader, total=len(trainloader)):
            optimizer.zero_grad()
            X, mask, Y = other(one_input, fix_length, device)
            pred = net(X, mask)
            batch_loss = loss(pred, Y)
            batch_loss.backward()
            optimizer.step()
            train_loss += batch_loss.item()
            train_num += len(Y)

        
        train_loss_mean = train_loss / train_num
        t_end = datetime.now()
        epoch_time = (t_end - t_start).total_seconds()
        loginf('Epoch: {}'.format(epoch))
        loginf('Train num: {} — Train loss: {} — Time: {}'.format(train_num, train_loss_mean, epoch_time))

        # validation and test
        with torch.no_grad():
            net.eval()
            val_loss_mean, val_acc = net_eval(fix_length, 'Val', 80, valloader, device, net, loss, loginf, n_classes, wandb)
            if val_acc >= saving_best:
                saving_best = val_acc
                torch.save(net.state_dict(), file_name)
                _, test_acc = net_eval(fix_length, 'Test', 120, testloader, device, net, loss, loginf, n_classes, wandb)
            
        
        wandb.log({"train loss": train_loss_mean,
                       "val loss": val_loss_mean,
                       "val acc": val_acc,
                       "epoch": epoch
                       })
        

    loginf('best test acc: {}'.format(test_acc))
    loginf('_' * 200)
