# trainer

> A module used to train, validate and test pytorch models

In [None]:
#| default_exp trainer

In [None]:
#| hide
from nbdev.showdoc import *
from fastcore.test import *
from fastcore.utils import *

In [None]:
#| export
import math
import copy
from tqdm import tqdm

import wandb
import torch

from birdclef.dataset import get_dataloader
from birdclef.network import get_model
from birdclef.training_utils import get_optimizer, get_loss_func, get_callback_func,get_lr_scheduler, compute_metrics, metrics_dict

In [None]:
#| export
def log_weights(model, # A pytorch model
                artifact_name, # The name of the artifact
                config # wandb config
                ):
    "A method to log artifacts into wandb"

    model_artifact = wandb.Artifact(
        artifact_name, type="model",
        metadata=dict(config))

    torch.save(model.state_dict(), f"{artifact_name}.pth")
    
    model_artifact.add_file(f"{artifact_name}.pth")

    wandb.log_artifact(model_artifact)

In [None]:
#| export
def train_one_epoch(model,                  # A pytorch model
                    train_dl,               # A pytorch dataloader
                    loss_func,              # A function to compute the loss
                    optimizer,              # The optimizer
                    device,                 # The device where the training is executed ('cpu'|'cuda')
                    epoch,                  # The epoch the model is training
                    example_ct,             # The number of examples the model has been trained on
                    step_ct,                # The number of backpropagation steps the model has done
                    n_steps_per_epoch,      # The number of steps for each epoch
                    callback_step,          # Steps indicating when the callback function must be called
                    callback_func,          # Callback function
                    scheduler_step,         # steps indicating when to call the learning rate scheduler
                    scheduler_metric,       # metrics tu update the learning rate
                    scheduler               # the learning rate scheduler
                    ):
    "Train a pytorch model for one epoch"

    model.train()
    progress_bar = tqdm(range(len(train_dl)))

    for step, data in enumerate(train_dl):
        inputs, labels = data['input'], data['label']
        inputs, labels = inputs.to(device), labels.to(device)
        
        optimizer.zero_grad()

        outputs = model(inputs)

        train_loss = loss_func(outputs, labels)
        
        train_loss.backward()
        optimizer.step()

        example_ct += len(inputs)

        epoch_number = (step + 1) / n_steps_per_epoch + epoch
        metrics = compute_metrics('train', outputs, labels, train_loss, example_ct, step_ct, epoch_number)
        
        if (step + 1)%scheduler_step == 0:
            if type(scheduler) is torch.optim.lr_scheduler.ReduceLROnPlateau:
                scheduler.step(metrics[f"train/{scheduler_metric}"])
            else:
                scheduler.step()

        if (step + 1) < n_steps_per_epoch:
            # Log train metrics to wandb
            wandb.log(metrics)
        # Run callback func
        if callback_func is not None and step_ct % callback_step == 0:
            callback_func(data, outputs)

        step_ct += 1
        progress_bar.update(1)

    return metrics, example_ct, step_ct

In [None]:
#| export
def validate_model(model, # A pytorch model
                   valid_dl, # A pytorch dataloader
                   loss_func, # The loss function
                   device, # The device where the training is executed ('cpu'|'cuda')
                   epoch, # The epoch the model has been trained
                   example_ct, # The number of examples the model has been trained on
                   step_ct, # The number of backpropagation steps the model has done
                   dataset_type='val' # The name of the dataset used
                  ):
    "Test or validate a pytorch model"
    
    model.eval()
    
    metrics = {}
    labels_acc = []
    outputs_acc = []
    loss = 0.0
    
    progress_bar = tqdm(range(len(valid_dl)))
    with torch.inference_mode():
        for i, data in enumerate(valid_dl):
            inputs, labels = data['input'], data['label']
            inputs, labels = inputs.to(device), labels.to(device)

            # Forward pass
            outputs = model(inputs)
            loss += loss_func(outputs, labels) * labels.size(0)

            # Add labels and outputs to acc
            labels_acc.append(labels)
            outputs_acc.append(outputs)

            progress_bar.update(1)

    # Divide loss by dataset length
    val_loss = loss / len(valid_dl.dataset)

    labels = torch.cat(labels_acc, dim=0)
    outputs = torch.cat(outputs_acc, dim=0)

    metrics = compute_metrics(dataset_type, outputs, labels, val_loss, example_ct, step_ct, epoch)

    return metrics

In [None]:
#| export
def train(conf = None # Wandb configurations containing all hyperparameters
          ):
    "Train, validate and test a model using the given configurations"

    with wandb.init(conf) as run:
        config = wandb.config
        run.name = f"{config.run_name}"

        # Checking that the defined metric exist
        assert config.metric in metrics_dict, f'{config.metric} is not an existing metric, choose one from {metrics_dict.keys()}.'

        # Getting dataloaders
        train_dl = get_dataloader(config.train_key, config.train_kwargs)
        valid_dl = get_dataloader(config.val_key, config.val_kwargs)
        test_dl = get_dataloader(config.test_key, config.val_kwargs)

        # Getting model, optimizer and loss function
        model = get_model(config.model_key, num_classes=train_dl.dataset.num_classes)
        model.to(config.device)
        optimizer = get_optimizer(config.optimizer_key, model, config.optimizer_kwargs)
        loss_func = get_loss_func(config.loss_key)
        callback_func = get_callback_func(config.callback_key)
        config.lr_scheduler_kwargs["total_iters"] = (len(train_dl)*config.epochs)//config.lr_scheduler_kwargs["scheduler_step"]
        config.lr_scheduler_kwargs["T_max"] = (len(train_dl)*config.epochs)//config.lr_scheduler_kwargs["scheduler_step"]
        lr_scheduler = get_lr_scheduler(config.lr_scheduler_key, optimizer, config.lr_scheduler_kwargs)

        n_steps_per_epoch = math.ceil(len(train_dl.dataset) / config.train_kwargs['batch_size'])

        # Counters
        example_ct = 0
        step_ct = 0

        best_val = None
        for epoch in range(config.epochs):
            print(f"Training epoch {epoch}")
            # Train
            metrics, example_ct, step_ct = train_one_epoch(model, train_dl, loss_func, optimizer, config.device, epoch, example_ct, step_ct, n_steps_per_epoch, config.callback_step, callback_func, config.lr_scheduler_kwargs["scheduler_step"], config.lr_scheduler_kwargs["scheduler_metric"], lr_scheduler)

            print("\tFinished training. Starting validation")

            # Validate
            val_metrics = validate_model(model, valid_dl, loss_func, config.device, epoch + 1, example_ct, step_ct)

            print('\tFinshed validation')

            # Log train and validation metrics to wandb
            wandb.log({**metrics, **val_metrics})

            print("\tMetrics logged to wandb")

            # If the best metric is reached, save the artifact
            if best_val is None or metrics_dict[config.metric](val_metrics[f'val/{config.metric}'], best_val):
                print(f'\t{config.metric} in the validation set has improved!')
                best_val = val_metrics[f'val/{config.metric}']
                best_example, best_step, best_epoch = example_ct, step_ct, epoch
                best_model = copy.deepcopy(model)
                log_weights(model, config.run_name, config)

        print("\tTesting with best model")
        # Test best model
        test_metrics = validate_model(best_model, test_dl, loss_func, config.device, best_epoch, best_example, best_step, dataset_type="test")

        # Load test metrics as summary
        for key in test_metrics.keys():
            wandb.summary[key] = test_metrics[key]

The configuration parameters (config) used in the train function are accessed using wandb.config and are defined by the user when calling the wandb.init function. Here's a list of the config parameters needed in the provided train function:

1. run_name: Name for the run.

2. train_key: Key to obtain the training dataset.

3. train_kwargs: Additional keyword arguments for obtaining the training dataloader.

4. val_key: Key to obtain the validation dataset.

5. test_key: Key to obtain the test dataset.

6. val_kwargs: Additional keyword arguments for obtaining the validation or test dataloader.

7. model_key: Key to obtain the model.

8. model_kwargs: Additional keyword arguments for obtaining the model.

9. optimizer_key: Key to obtain the optimizer.

10. optimizer_kwargs: Additional keyword arguments for obtaining the optimizer.

11. loss_key: Key to obtain the loss function to be used.

12. device: Device for training (e.g., "cuda" for GPU, "cpu" for CPU).

13. epochs: Number of training epochs.

14. metric: Metric to use for determining the best model (e.g., accuracy, f1-score).

In [None]:
#| hide
import nbdev; 
nbdev.nbdev_export()