In [2]:
from dataclasses import dataclass, asdict
from copy import deepcopy
import numpy as np
from sklearn.model_selection import StratifiedKFold
import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as D
import torch.optim as optim
import timm
from pathlib import Path

from kuma_utils.torch import TorchTrainer, TorchLogger
from kuma_utils.torch.callbacks import EarlyStopping, SaveSnapshot
from kuma_utils.torch.hooks import SimpleHook
from kuma_utils.metrics import Accuracy
from kuma_utils.torch.optimizer import SAM


@dataclass
class Config:
    num_workers: int = 32
    batch_size: int = 64
    num_epochs: int = 100
    early_stopping_rounds: int = 5


In [3]:
def get_dataset():
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
    ])
    train = torchvision.datasets.CIFAR10(
        root='input', train=True, download=True, transform=transform)
    test = torchvision.datasets.CIFAR10(
        root='input', train=False, download=True, transform=transform)
    return train, test


def split_dataset(dataset, index):
    new_dataset = deepcopy(dataset)
    new_dataset.data = new_dataset.data[index]
    new_dataset.targets = np.array(new_dataset.targets)[index]
    return new_dataset


def get_model(num_classes):
    model = timm.create_model('tf_efficientnet_b0.ns_jft_in1k', pretrained=True, num_classes=num_classes)
    return model

In [None]:
cfg = Config(
    num_workers=32, 
    batch_size=2048,
    num_epochs=10,
    early_stopping_rounds=5,
)
export_dir = Path('results/demo')
export_dir.mkdir(parents=True, exist_ok=True)

train, test = get_dataset()
print('classes', train.classes)

predictions = []
splitter = StratifiedKFold(n_splits=5, shuffle=True, random_state=0)
for fold, (train_idx, valid_idx) in enumerate(
    splitter.split(train.targets, train.targets)):

    print(f'fold{fold} starting')

    valid_fold = split_dataset(train, valid_idx)
    train_fold = split_dataset(train, train_idx)

    print(f'train: {len(train_fold)} / valid: {len(valid_fold)}')

    loader_train = D.DataLoader(
        train_fold, batch_size=cfg.batch_size, num_workers=cfg.num_workers, 
        shuffle=True, pin_memory=True)
    loader_valid = D.DataLoader(
        valid_fold, batch_size=cfg.batch_size, num_workers=cfg.num_workers, 
        shuffle=False, pin_memory=True)
    loader_test = D.DataLoader(
        test, batch_size=cfg.batch_size, num_workers=cfg.num_workers, 
        shuffle=False, pin_memory=True)

    model = get_model(num_classes=len(train.classes))
    optimizer = optim.Adam(model.parameters(), lr=2e-3)
    # optimizer = SAM(model.parameters(), optim.Adam, lr=2e-3)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode='max', factor=0.5, patience=2)
    logger = TorchLogger(
        path=export_dir/f'fold{fold}.log', 
        log_items='epoch train_loss train_metric valid_loss valid_metric learning_rate early_stop', 
        file=True,
        use_wandb=True, wandb_params={
            'project': 'kuma_utils_demo', 
            'group': 'demo_cross_validation',
            'name': f'fold{fold}',
            'config': asdict(cfg),
        },
        use_tensorboard=True,  # In addition to epoch summaries, tensorboard can record batch summaries.
        tensorboard_dir=export_dir/'tensorboard',
    )
    
    trn = TorchTrainer(model, serial=f'fold{fold}')
    trn.train(
        loader=loader_train,
        loader_valid=loader_valid,
        criterion=nn.CrossEntropyLoss(),
        eval_metric=Accuracy().torch, 
        monitor_metrics=[
            Accuracy().torch
        ],
        optimizer=optimizer,
        scheduler=scheduler,
        scheduler_target='valid_loss', # ReduceLROnPlateau reads metric each epoch
        num_epochs=cfg.num_epochs,
        hook=SimpleHook(
            evaluate_in_batch=False, clip_grad=None, sam_optimizer=False),
        callbacks=[
            EarlyStopping(
                patience=cfg.early_stopping_rounds, 
                target='valid_metric', 
                maximize=True),
            SaveSnapshot() # Default snapshot path: {export_dir}/{serial}.pt
        ],
        logger=logger, 
        export_dir=export_dir,
        parallel=None, # Supported parallel methods: None, 'dp', 'ddp'
        fp16=True, # Pytorch mixed precision
        deterministic=True, 
        random_state=0, 
        progress_bar=False, # Progress bar shows batches done
    )

    oof = trn.predict(loader_valid)
    predictions.append(trn.predict(loader_test))

    score = Accuracy()(valid_fold.targets, oof)
    print(f'Folf{fold} score: {score:.6f}')
    break

![cifar_wandb](images/cifar_wandb.png)
![cifar_tensorboard](images/cifar_tensorboard.png)

## `kuma_utils.torch.callbacks.EarlyStopping`
```python
EarlyStopping(
    patience: int = 5, 
    target: str = 'valid_metric', 
    maximize: bool = False, 
    skip_epoch: int = 0 
)
```
| argument   | description                                                                                         |
|------------|-----------------------------------------------------------------------------------------------------|
| patience   | Epochs to wait before early stop                                                                    |
| target     | Variable name to watch (choose from  `['train_loss', 'train_metric', 'valid_loss', 'valid_metric']`) |
| maximize   | Whether to maximize the target                                                                      |
| skip_epoch | Epochs to skip before early stop counter starts                                                     |


## `kuma_utils.torch.TorchLogger`
```python
TorchLogger(
    path: (str, pathlib.Path),
    log_items: (list, str) = [
        'epoch', 'train_loss', 'valid_loss', 'train_metric', 'valid_metric',
        'train_monitor', 'valid_monitor', 'learning_rate', 'early_stop'
        ],
    verbose_eval: int = 1,
    stdout: bool = True, 
    file: bool = False,
    use_wandb: bool = False,
    wandb_params: dict = {} 
)
```
| argument     | description                                                                                                                                                                                                        |
|--------------|--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
| path         | Path to  log.                                                                                                                                                                                                |
| log_items    | Items to be shown in log. Must be a combination of the following items:  `['epoch',  'train_loss', 'valid_loss', 'train_metric' , 'valid_metric', 'train_monitor',  'valid_monitor', 'learning_rate', 'early_stop', 'gpu_memory']`. List or string separated by space (e.g. `'epoch valid_loss learning_rate'`).| 
| verbose_eval | Frequency of log (unit: epoch).                                                                                                                                                                              |
| stdout       | Whether to print log.                                                                                                                                                                            |
| file         | Whether to export log file to the path. (False by default)                                                                                                                                                                          |
| use_wandb         | Whether to use wandb.                                                                                                                                                    |

## Hook
Hook is used to specify detailed training and evaluation process.
Usually it is not necessary to modify training hook, but in some cases such like: 

- training a Graph Neural Network which takes multiple arguments in `.forward`
- training with a special metric which requires extra variables (other than predictions and targets)
- calculate metrics on whole dataset (not in each mini-batch)

A Hook class contains the following functions:
```python
class TrainHook(HookTemplate):

    def __init__(self, evaluate_in_batch=False):
        super().__init__()
        self.evaluate_in_batch = evaluate_in_batch

    def _evaluate(self, trainer, approx, target):
        if trainer.eval_metric is None:
            metric_score = None
        else:
            metric_score = trainer.eval_metric(approx, target)
            if isinstance(metric_score, torch.Tensor):
                metric_score = metric_score.item()
        monitor_score = []
        for monitor_metric in trainer.monitor_metrics:
            score = monitor_metric(approx, target)
            if isinstance(score, torch.Tensor):
                score = score.item()
            monitor_score.append(score)
        return metric_score, monitor_score

    def forward_train(self, trainer, inputs):
        target = inputs[-1]
        approx = trainer.model(*inputs[:-1])
        loss = trainer.criterion(approx, target)
        return loss, approx.detach()

    forward_valid = forward_train

    def forward_test(self, trainer, inputs):
        approx = trainer.model(*inputs[:-1])
        return approx

    def backprop(self, trainer, loss, inputs=None):
        trainer.scaler.scale(loss).backward()
        dispatch_clip_grad(trainer.model.parameters(), self.max_grad_norm, mode=self.clip_grad)
        trainer.scaler.step(trainer.optimizer)
        trainer.scaler.update()
        trainer.optimizer.zero_grad()

    def evaluate_batch(self, trainer, inputs, approx):
        target = inputs[-1]
        storage = trainer.epoch_storage
        if self.evaluate_in_batch:
            # Add scores to storage
            metric_score, monitor_score = self._evaluate(trainer, approx, target)
            storage['batch_metric'].append(metric_score)
            storage['batch_monitor'].append(monitor_score)
        else:
            # Add prediction and target to storage
            storage['approx'].append(approx)
            storage['target'].append(target)

    def evaluate_epoch(self, trainer):
        storage = trainer.epoch_storage
        if self.evaluate_in_batch:
            # Calculate mean metrics from all batches
            metric_total = storage['batch_metric'].mean(0)
            monitor_total = storage['batch_monitor'].mean(0).tolist()

        else: 
            # Calculate scores
            metric_total, monitor_total = self._evaluate(
                trainer, storage['approx'], storage['target'])
        return metric_total, monitor_total
```

`.forward_train()` is called in each mini-batch in training and validation loop. 
This method returns loss and prediction tensors.

`.forward_test()` is called in each mini-batch in inference loop. 
This method returns prediction values tensor.

`.evaluate_batch()` is called in each mini-batch after back-propagation and optimizer.step(). 
This method returns nothing.

`.evaluate_epoch()` is called at the end of each training and validation loop. 
This method returns eval_metric (scaler) and monitor metrics (list).

Note that `trainer.epoch_storage` is a dicationary object you can use. 
In `SampleHook`,  predictions and targets are added to storage in each mini-batch, 
and at the end of loop, metrics are calculated on the whole dataset 
(tensors are concatenated batch-wise automatically).