# Trainer

> Main logic for trainer (Fit, predict)

In [None]:
#| default_exp trainer

In [None]:
#| hide
from nbdev.showdoc import *

In [None]:
#| export
from fastcore.all import *
from fastcore.utils import *

In [None]:
#| export
import pandas as pd
import torch
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, cohen_kappa_score
import wandb


In [None]:
#| export
class Trainer:
    def __init__(self, cfg, model, loaders, criterion, optimizer, device, writer):
        self.cfg = cfg
        self.model = model
        self.train_loader = loaders['train']
        self.val_loader = loaders['val']
        self.test_loader = loaders['test']
        self.criterion = criterion
        self.optimizer = optimizer
        self.device = device
        self.writer = writer

In [None]:
#| export
@patch
def train(self:Trainer):
    self.model.train()
    running_loss = 0.0
    for inputs, labels in self.train_loader:
        inputs, labels = inputs.to(self.device), labels.to(self.device)

        self.optimizer.zero_grad()
        outputs = self.model(inputs)
        loss = self.criterion(outputs, labels)
        loss.backward()
        self.optimizer.step()

        running_loss += loss.item() * inputs.size(0)
    epoch_loss = running_loss / len(self.train_loader.dataset)
    return epoch_loss


In [None]:
#| export
@patch
def eval_(self: Trainer):
    self.model.eval()
    running_loss = 0.0
    with torch.no_grad():
        for inputs, labels in self.val_loader:
            inputs, labels = inputs.to(self.device), labels.to(self.device)

            outputs = self.model(inputs)
            loss = self.criterion(outputs, labels)

            running_loss += loss.item() * inputs.size(0)

    epoch_loss = running_loss / len(self.val_loader.dataset)
    return epoch_loss


In [None]:
#| export
@patch
def predict(self: Trainer):
    self.model.to(self.device)
    self.model.eval()
    y_true, y_pred = [], []

    with torch.no_grad():
        for imgs, labels in self.test_loader:
            imgs = imgs.to(self.device)
            outputs = self.model(imgs)
            probs = torch.sigmoid(outputs).cpu().numpy()
            preds = (probs > 0.5).astype(int)
            y_true.extend(labels.numpy())
            y_pred.extend(preds)

    y_true = torch.tensor(y_true).numpy()
    y_pred = torch.tensor(y_pred).numpy()

    res = {}
    for i, disease in enumerate(self.cfg.data.label_names):  #compute metrics for every disease
        y_t = y_true[:, i]
        y_p = y_pred[:, i]

        acc = accuracy_score(y_t, y_p)
        precision = precision_score(y_t, y_p, average="macro",zero_division=0)
        recall = recall_score(y_t, y_p, average="macro",zero_division=0)
        f1 = f1_score(y_t, y_p, average="macro",zero_division=0)
        kappa = cohen_kappa_score(y_t, y_p)

        print(f"{disease} Results [{self.cfg.model.backbone}]")
        print(f"Accuracy : {acc:.4f}")
        print(f"Precision: {precision:.4f}")
        print(f"Recall   : {recall:.4f}")
        print(f"F1-score : {f1:.4f}")
        print(f"Kappa    : {kappa:.4f}")

        res[disease] = {
            "accuracy": f"{acc:.4f}",
            "precision": f"{precision:.4f}",
            "recall": f"{recall:.4f}",
            "f1_score": f"{f1:.4f}",
            "cohen_kappa": f"{kappa:.4f}"
        }    

    avg_acc = accuracy_score(y_true, y_pred)
    avg_precision = precision_score(y_true, y_pred, average="macro",zero_division=0)
    avg_recall = recall_score(y_true, y_pred, average="macro",zero_division=0)
    avg_f1 = f1_score(y_true, y_pred, average="macro",zero_division=0)
    avg_kappa = cohen_kappa_score(y_true, y_pred)
    res['avg'] = {"f1_score": avg_f1, "accuracy": avg_acc, "precision": avg_precision, "recall": avg_recall, "cohen_kappa": avg_kappa}

    metrics_df = pd.DataFrame(res).T
    df_result = metrics_df.reset_index()
    self.writer.write({'Test Metrics': wandb.Table(dataframe= df_result)})
    return df_result

In [None]:
#| export
@patch
def fit(self: Trainer):
    train_losses = []
    val_losses = []
    for epoch in range(self.cfg.num_epochs):
        self.model.train()
        train_loss = self.train()
        val_loss = self.eval_()
        
        to_log = {
            'train_loss': train_loss,
            'val_loss': val_loss,
        }

        train_losses.append(train_loss)
        val_losses.append(val_loss)
        
        self.writer.write(to_log)

    return self.model, self.optimizer, train_losses, val_losses

In [None]:
#| hide
import nbdev; nbdev.nbdev_export()