# Catalyst - customizing what happens in `train()`
based on `Keras customizing what happens in fit`

## Introduction

When you're doing supervised learning, you can use `train()` and everything works smoothly.

A core principle of Catalyst is **progressive disclosure of complexity**. You should always be able to get into lower-level workflows in a gradual way. You shouldn't fall off a cliff if the high-level functionality doesn't exactly match your use case. You should be able to gain more control over the small details while retaing a commensurate amount of high-level convenience. 

When you need to customize what `train()` does, you should **override the `handle_batch` function of the `Runner` class**. This is the function that is called by `train()` for every batch of data. You will then be able to call `train()` as usual -- and it will be running your own learning algorithm.

Note that this pattern does not prevent you from building models with the Functional API. You can do this with **any** PyTorch model.

Let's see how that works.

## Setup

In [1]:
!pip install catalyst[ml]==21.4.2
# don't forget to restart runtime for correct `PIL` work with Colab

Collecting catalyst[ml]==21.4.2
[?25l  Downloading https://files.pythonhosted.org/packages/23/d8/8a7cad2a6736bcd25cd45eb5a6891c15ccd4c6af46f2675c084d0ec34ce7/catalyst-21.4.2-py2.py3-none-any.whl (511kB)
[K     |▋                               | 10kB 18.1MB/s eta 0:00:01[K     |█▎                              | 20kB 25.0MB/s eta 0:00:01[K     |██                              | 30kB 21.3MB/s eta 0:00:01[K     |██▋                             | 40kB 17.3MB/s eta 0:00:01[K     |███▏                            | 51kB 8.3MB/s eta 0:00:01[K     |███▉                            | 61kB 8.7MB/s eta 0:00:01[K     |████▌                           | 71kB 9.3MB/s eta 0:00:01[K     |█████▏                          | 81kB 8.8MB/s eta 0:00:01[K     |█████▊                          | 92kB 9.0MB/s eta 0:00:01[K     |██████▍                         | 102kB 8.0MB/s eta 0:00:01[K     |███████                         | 112kB 8.0MB/s eta 0:00:01[K     |███████▊                        

In [2]:
import catalyst
from catalyst import dl, metrics, utils
catalyst.__version__

'21.04.2'

## A first simple example

Let's start from a simple example:

- We create a new runner that subclasses `dl.Runner`.
- We just override the `handle_batch(self, batch)` method for custom train step logic 
- And update `on_loader_start`/`on_loader_start` handlers for correct custom metrics aggregation.

The input argument `batch` is what gets passed to fit as training data. If you pass a `torch.utils.data.DataLoader`, by calling `train(loaders={"train": loader, "valid": loader}, ...)`, then `batch` will be what gets yielded by `loader` at each batch.

In the body of the `handle_batch` method, we implement a regular training update, similar to what you are already familiar with. Importantly, **we log batch-based metrics via `self.batch_metrics`**, which passes them to the loggers.

Addiionally, we have to use [`AdditiveValueMetric`](https://catalyst-team.github.io/catalyst/api/metrics.html#additivevaluemetric) during `on_loader_start` and `on_loader_start` for correct metrics aggregation for the whole loader. Importantly, **we log loader-based metrics via `self.loader_metrics`**, which passes them to the loggers.

In [3]:
import torch
from torch.nn import functional as F

class CustomRunner(dl.Runner):
    
    def on_loader_start(self, runner):
        super().on_loader_start(runner)
        self.meters = {
            key: metrics.AdditiveValueMetric(compute_on_call=False)
            for key in ["loss", "mae"]
        }

    def handle_batch(self, batch):
        # Unpack the data. Its structure depends on your model and
        # on what you pass to `train()`.
        x, y = batch

        y_pred = self.model(x) # Forward pass

        # Compute the loss value
        loss = F.mse_loss(y_pred, y)

        # Update metrics (includes the metric that tracks the loss)
        self.batch_metrics.update({"loss": loss, "mae": F.l1_loss(y_pred, y)})
        for key in ["loss", "mae"]:
            self.meters[key].update(self.batch_metrics[key].item(), self.batch_size)

        if self.is_train_loader:
            # Compute gradients
            loss.backward()
            # Update weights
            # (the optimizer is stored in `self.state`)
            self.optimizer.step()
            self.optimizer.zero_grad()
    
    def on_loader_end(self, runner):
        for key in ["loss", "mae"]:
            self.loader_metrics[key] = self.meters[key].compute()[0]
        super().on_loader_end(runner)

Let's try this out:

In [4]:
import numpy as np
import torch
from torch.utils.data import DataLoader, TensorDataset

# Construct custom data
num_samples, num_features = int(1e4), int(1e1)
X, y = torch.rand(num_samples, num_features), torch.rand(num_samples, 1)
dataset = TensorDataset(X, y)
loader = DataLoader(dataset, batch_size=32, num_workers=1)
loaders = {"train": loader, "valid": loader}

# and model
model = torch.nn.Linear(num_features, 1)
criterion = torch.nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters())

# and use `train`
runner = CustomRunner()
runner.train(
  model=model, 
  optimizer=optimizer, 
  loaders=loaders, 
  num_epochs=3,
  verbose=True, # you can pass True for more precise training process logging
  timeit=False, # you can pass True to measure execution time of different parts of train process
)

HBox(children=(FloatProgress(value=0.0, description='1/3 * Epoch (train)', max=313.0, style=ProgressStyle(desc…

  for k, v in self.batch_metrics.items()



train (1/3) loss: 0.10750187281370165 | mae: 0.2735895418167114


HBox(children=(FloatProgress(value=0.0, description='1/3 * Epoch (valid)', max=313.0, style=ProgressStyle(desc…


valid (1/3) loss: 0.09517744991779326 | mae: 0.26153938560485857
* Epoch (1/3) 


HBox(children=(FloatProgress(value=0.0, description='2/3 * Epoch (train)', max=313.0, style=ProgressStyle(desc…


train (2/3) loss: 0.09045877856016156 | mae: 0.25704234609603893


HBox(children=(FloatProgress(value=0.0, description='2/3 * Epoch (valid)', max=313.0, style=ProgressStyle(desc…


valid (2/3) loss: 0.08689186211824419 | mae: 0.25343582811355586
* Epoch (2/3) 


HBox(children=(FloatProgress(value=0.0, description='3/3 * Epoch (train)', max=313.0, style=ProgressStyle(desc…


train (3/3) loss: 0.0855684122681618 | mae: 0.25226863002777095


HBox(children=(FloatProgress(value=0.0, description='3/3 * Epoch (valid)', max=313.0, style=ProgressStyle(desc…


valid (3/3) loss: 0.08450214244127267 | mae: 0.2511766980171202
* Epoch (3/3) 


## Going high-level

Naturally, you could skip a loss function backward in `handle_batch()`, and instead do everything with `Callbacks` in `train` params. Likewise for metrics. Here's a high-level example, that only uses `handle_batch()` for model forward pass and metrics computation:

In [5]:
import numpy as np
import torch
from torch.nn import functional as F
from torch.utils.data import DataLoader, TensorDataset


class CustomRunner(dl.Runner):
    
    def on_loader_start(self, runner):
        super().on_loader_start(runner)
        self.meters = {
            key: metrics.AdditiveValueMetric(compute_on_call=False)
            for key in ["loss", "mae"]
        }

    def handle_batch(self, batch):
        # Unpack the data. Its structure depends on your model and
        # on what you pass to `train()`.
        x, y = batch

        y_pred = self.model(x) # Forward pass

        # Compute the loss value
        # (the criterion is stored in `self.state` also)
        loss = self.criterion(y_pred, y)

        # Update metrics (includes the metric that tracks the loss)
        self.batch_metrics.update({"loss": loss, "mae": F.l1_loss(y_pred, y)})
        for key in ["loss", "mae"]:
            self.meters[key].update(self.batch_metrics[key].item(), self.batch_size)

    def on_loader_end(self, runner):
        for key in ["loss", "mae"]:
            self.loader_metrics[key] = self.meters[key].compute()[0]
        super().on_loader_end(runner)


# Construct custom data
num_samples, num_features = int(1e4), int(1e1)
X, y = torch.rand(num_samples, num_features), torch.rand(num_samples, 1)
dataset = TensorDataset(X, y)
loader = DataLoader(dataset, batch_size=32, num_workers=1)
loaders = {"train": loader, "valid": loader}

# and model
model = torch.nn.Linear(num_features, 1)
criterion = torch.nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters())

# Just use `train` as usual
runner = CustomRunner()
runner.train(
  model=model, 
  optimizer=optimizer,
  criterion=criterion,       # you could also pass any PyTorch criterion for loss computation
  scheduler=None,            # or scheduler, but let's simplify the train loop for now :)
  loaders=loaders, 
  num_epochs=3,
  verbose=True,
  timeit=False,
  callbacks={
    "optimizer": dl.OptimizerCallback(
      metric_key="loss",     # you can also pass 'mae' to optimize it instead
                             # generaly, you can optimize any differentiable metric from `runner.batch_metrics`
      accumulation_steps=1,  # also you can pass any number of steps for gradient accumulation
      grad_clip_fn=None,     # or you can use `grad_clip_fn=nn.utils.clip_grad_norm_`
      grad_clip_params=None, #   with `grad_clip_params={max_norm=1, norm_type=2}`
                             # or `grad_clip_fn=nn.utils.clip_grad_value_`
                             #   with `grad_clip_params={clip_value=1}`
                             # for gradient clipping during training!
                             # for more information about gradient clipping please follow pytorch docs
                             # https://pytorch.org/docs/stable/nn.html#clip-grad-norm
    )
  }
)

HBox(children=(FloatProgress(value=0.0, description='1/3 * Epoch (train)', max=313.0, style=ProgressStyle(desc…

  for k, v in self.batch_metrics.items()



train (1/3) loss: 0.3298786108493806 | lr: 0.001 | mae: 0.46371007761955274 | momentum: 0.9


HBox(children=(FloatProgress(value=0.0, description='1/3 * Epoch (valid)', max=313.0, style=ProgressStyle(desc…


valid (1/3) loss: 0.13439255228042615 | lr: 0.001 | mae: 0.30216759462356574 | momentum: 0.9
* Epoch (1/3) 


HBox(children=(FloatProgress(value=0.0, description='2/3 * Epoch (train)', max=313.0, style=ProgressStyle(desc…


train (2/3) loss: 0.12848652594089513 | lr: 0.001 | mae: 0.29636721019744844 | momentum: 0.9


HBox(children=(FloatProgress(value=0.0, description='2/3 * Epoch (valid)', max=313.0, style=ProgressStyle(desc…


valid (2/3) loss: 0.1227676458120346 | lr: 0.001 | mae: 0.29078147163391116 | momentum: 0.9
* Epoch (2/3) 


HBox(children=(FloatProgress(value=0.0, description='3/3 * Epoch (train)', max=313.0, style=ProgressStyle(desc…


train (3/3) loss: 0.11714917050600063 | lr: 0.001 | mae: 0.285176122999191 | momentum: 0.9


HBox(children=(FloatProgress(value=0.0, description='3/3 * Epoch (valid)', max=313.0, style=ProgressStyle(desc…


valid (3/3) loss: 0.11182150180339812 | lr: 0.001 | mae: 0.279757907629013 | momentum: 0.9
* Epoch (3/3) 


## Metrics support through Callbacks

Let's go even deeper! Could we transfer different metrics/criterions computation to `Callbacks` too? Of course! If you want to support different losses, you'd simply do the following:

- Do your model forward pass as usual.
- Save all batch-based artefacts to `self.batch`, so Callbacks can find it.
- Add extra callbacks, that will use data from `runner.batch` during training.

That's it. That's the list. Let's see the example:

In [6]:
import numpy as np
import torch
from torch.nn import functional as F
from torch.utils.data import DataLoader, TensorDataset


class CustomRunner(dl.Runner):

    def handle_batch(self, batch):
        # Unpack the data. Its structure depends on your model and
        # on what you pass to `train()`.
        x, y = batch

        y_pred = self.model(x) # Forward pass

        # pass all batch-based artefacts to `self.batch`
        # we recommend to use key-value storage to make it Callbacks-friendly
        self.batch = {"features": x, "targets": y, "logits": y_pred}


# Construct custom data
num_samples, num_features = int(1e4), int(1e1)
X, y = torch.rand(num_samples, num_features), torch.rand(num_samples, 1)
dataset = TensorDataset(X, y)
loader = DataLoader(dataset, batch_size=32, num_workers=1)
loaders = {"train": loader, "valid": loader}

# and model
model = torch.nn.Linear(num_features, 1)
criterion = torch.nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters())

# Just use `train` as usual
runner = CustomRunner()
runner.train(
  model=model, 
  optimizer=optimizer,
  criterion=criterion,
  scheduler=None,
  loaders=loaders, 
  num_epochs=3,
  verbose=True,
  timeit=False,
  callbacks={
    # alias for 
    # `runner.batch_metrics[metric_key] = \
    #     runner.criterion[criterion_key](runner.batch[input_key], runner.batch[target_key])`
    "criterion": dl.CriterionCallback(  # special Callback for criterion computation
      input_key="logits",               # `input_key` specifies model predictions (`y_pred`) from `runner.batch`
      target_key="targets",             # `target_key` specifies correct labels (or `y_true`) from `runner.batch`       
      metric_key="loss",                # `metric_key` - key to use with `runner.batch_metrics`
      criterion_key=None,               # `criterion_key` specifies criterion in case of key-value runner.criterion
                                        #   if `criterion_key=None`, runner.criterion used for computation
    ), 
    # alias for 
    # `runner.batch_metrics[metric_key] = \
    #     metric_fn(runner.batch[input_key], runner.batch[target_key])`
    "metric": dl.FunctionalMetricCallback( # special Callback for metrics computation
      input_key="logits",                  # the same logic as with `CriterionCallback`
      target_key="targets",                # the same logic as with `CriterionCallback`
      metric_key="loss_mae",               # the same logic as with `CriterionCallback`
      metric_fn=F.l1_loss,                 # metric function to use
    ),  
    "optimizer": dl.OptimizerCallback(
      metric_key="loss", 
      accumulation_steps=1,
      grad_clip_fn=None,
      grad_clip_params=None,
    )
  }
)

HBox(children=(FloatProgress(value=0.0, description='1/3 * Epoch (train)', max=313.0, style=ProgressStyle(desc…

  for k, v in self.batch_metrics.items()



train (1/3) loss: 0.3298785984516144 | loss/mean: 0.3298785984516144 | loss/std: 0.25835494177194396 | loss_mae: 0.46371006965637207 | loss_mae/mean: 0.46371006965637207 | loss_mae/std: 0.19783500851214642 | lr: 0.001 | momentum: 0.9


HBox(children=(FloatProgress(value=0.0, description='1/3 * Epoch (valid)', max=313.0, style=ProgressStyle(desc…


valid (1/3) loss: 0.13439255952835083 | loss/mean: 0.13439255952835083 | loss/std: 0.02956744992561198 | loss_mae: 0.3021675944328308 | loss_mae/mean: 0.3021675944328308 | loss_mae/std: 0.037032035958603216 | lr: 0.001 | momentum: 0.9
* Epoch (1/3) 


HBox(children=(FloatProgress(value=0.0, description='2/3 * Epoch (train)', max=313.0, style=ProgressStyle(desc…


train (2/3) loss: 0.1284865289926529 | loss/mean: 0.1284865289926529 | loss/std: 0.027590900125039258 | loss_mae: 0.29636719822883606 | loss_mae/mean: 0.29636719822883606 | loss_mae/std: 0.035602213959118445 | lr: 0.001 | momentum: 0.9


HBox(children=(FloatProgress(value=0.0, description='2/3 * Epoch (valid)', max=313.0, style=ProgressStyle(desc…


valid (2/3) loss: 0.12276764214038849 | loss/mean: 0.12276764214038849 | loss/std: 0.025995866888168892 | loss_mae: 0.2907814681529999 | loss_mae/mean: 0.2907814681529999 | loss_mae/std: 0.034684277526328175 | lr: 0.001 | momentum: 0.9
* Epoch (2/3) 


HBox(children=(FloatProgress(value=0.0, description='3/3 * Epoch (train)', max=313.0, style=ProgressStyle(desc…


train (3/3) loss: 0.11714917421340942 | loss/mean: 0.11714917421340942 | loss/std: 0.024331286056113455 | loss_mae: 0.2851761281490326 | loss_mae/mean: 0.2851761281490326 | loss_mae/std: 0.03356506820218203 | lr: 0.001 | momentum: 0.9


HBox(children=(FloatProgress(value=0.0, description='3/3 * Epoch (valid)', max=313.0, style=ProgressStyle(desc…


valid (3/3) loss: 0.1118215024471283 | loss/mean: 0.1118215024471283 | loss/std: 0.022747544874339398 | loss_mae: 0.27975791692733765 | loss_mae/mean: 0.27975791692733765 | loss_mae/std: 0.032593717547230144 | lr: 0.001 | momentum: 0.9
* Epoch (3/3) 


## Simplify it a bit - SupervisedRunner

But can we simplify last example a bit? <br/>
What if we know, that we are going to train `supervised` model, that will take some `features` in and output some `logits` back? <br/>
Looks like commom case... could we automate it? Let's check it out!

In [7]:
import numpy as np
import torch
from torch.nn import functional as F
from torch.utils.data import DataLoader, TensorDataset

# Construct custom data
num_samples, num_features = int(1e4), int(1e1)
X, y = torch.rand(num_samples, num_features), torch.rand(num_samples, 1)
dataset = TensorDataset(X, y)
loader = DataLoader(dataset, batch_size=32, num_workers=1)
loaders = {"train": loader, "valid": loader}

# and model
model = torch.nn.Linear(num_features, 1)
criterion = torch.nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters())

# Just use `train` as usual
runner = dl.SupervisedRunner(  # `SupervisedRunner` works with any model like `some_output = model(some_input)`
    input_key="features",      # if your dataloader yields (x, y) tuple, it will be transformed to 
    output_key="logits",       # {input_key: x, target_key: y} and stored to runner.batch
    target_key="targets",      # then the model will be used like
    loss_key="loss",           # runner.batch[runner.output_key] = model(runner.batch[input_key])
)                              # loss computation suppose to looks like
                               # loss = criterion(runner.batch[runner.output_key], runner.batch[runner.target_key])
                               # and stored to `runner.batch_metrics[runner.loss_key]`

# thanks to prespecified `input_key`, `output_key`, `target_key` and `loss_key`
#   `SupervisedRunner` automatically adds required `CriterionCallback` and `OptimizerCallback`
# moreover, with specified `logdir`, `valid_loader` and `valid_metric`
#   `SupervisedRunner` automatically adds `CheckpointCallback` and tracks best performing based on selected metric
runner.train(
  model=model, 
  optimizer=optimizer,
  criterion=criterion,
  scheduler=None,
  loaders=loaders, 
  num_epochs=3,
  verbose=True,
  timeit=False,
  valid_loader="valid",        # `loader_key` from loaders to use for model selection
  valid_metric="loss",         # `metric_key` to use for model selection
  logdir="./logs_supervised",  # logdir to store models checkpoints
  callbacks={
#     "criterion_mse": dl.CriterionCallback(
#       input_key="logits",
#       target_key="targets",
#       metric_key="loss",
#     ),
    "criterion_mae": dl.FunctionalMetricCallback(
      input_key="logits",
      target_key="targets",
      metric_key="mae",
      metric_fn=F.l1_loss,
    ),
#     "optimizer": dl.OptimizerCallback(
#       metric_key="loss", 
#       accumulation_steps=1,
#       grad_clip_fn=None,
#       grad_clip_params=None,
#     )
  }
)

HBox(children=(FloatProgress(value=0.0, description='1/3 * Epoch (train)', max=313.0, style=ProgressStyle(desc…

  for k, v in self.batch_metrics.items()



train (1/3) loss: 0.3298785984516144 | loss/mean: 0.3298785984516144 | loss/std: 0.25835494177194396 | lr: 0.001 | mae: 0.46371006965637207 | mae/mean: 0.46371006965637207 | mae/std: 0.19783500851214642 | momentum: 0.9


HBox(children=(FloatProgress(value=0.0, description='1/3 * Epoch (valid)', max=313.0, style=ProgressStyle(desc…


valid (1/3) loss: 0.13439255952835083 | loss/mean: 0.13439255952835083 | loss/std: 0.02956744992561198 | lr: 0.001 | mae: 0.3021675944328308 | mae/mean: 0.3021675944328308 | mae/std: 0.037032035958603216 | momentum: 0.9
* Epoch (1/3) 


HBox(children=(FloatProgress(value=0.0, description='2/3 * Epoch (train)', max=313.0, style=ProgressStyle(desc…


train (2/3) loss: 0.1284865289926529 | loss/mean: 0.1284865289926529 | loss/std: 0.027590900125039258 | lr: 0.001 | mae: 0.29636719822883606 | mae/mean: 0.29636719822883606 | mae/std: 0.035602213959118445 | momentum: 0.9


HBox(children=(FloatProgress(value=0.0, description='2/3 * Epoch (valid)', max=313.0, style=ProgressStyle(desc…


valid (2/3) loss: 0.12276764214038849 | loss/mean: 0.12276764214038849 | loss/std: 0.025995866888168892 | lr: 0.001 | mae: 0.2907814681529999 | mae/mean: 0.2907814681529999 | mae/std: 0.034684277526328175 | momentum: 0.9
* Epoch (2/3) 


HBox(children=(FloatProgress(value=0.0, description='3/3 * Epoch (train)', max=313.0, style=ProgressStyle(desc…


train (3/3) loss: 0.11714917421340942 | loss/mean: 0.11714917421340942 | loss/std: 0.024331286056113455 | lr: 0.001 | mae: 0.2851761281490326 | mae/mean: 0.2851761281490326 | mae/std: 0.03356506820218203 | momentum: 0.9


HBox(children=(FloatProgress(value=0.0, description='3/3 * Epoch (valid)', max=313.0, style=ProgressStyle(desc…


valid (3/3) loss: 0.1118215024471283 | loss/mean: 0.1118215024471283 | loss/std: 0.022747544874339398 | lr: 0.001 | mae: 0.27975791692733765 | mae/mean: 0.27975791692733765 | mae/std: 0.032593717547230144 | momentum: 0.9
* Epoch (3/3) 
Top best models:
logs_supervised/checkpoints/train.3.pth	0.1118


## Providing your own inference step

But let's return to the basics.

What if you want to do the same customization for calls to `runner.predict_*()`? Then you would override `predict_batch` in exactly the same way. Here's what it looks like:

In [8]:
import torch
from torch.nn import functional as F

class CustomRunner(dl.Runner):
    
    def predict_batch(self, batch):                  # here is the trick
        return self.model(batch[0].to(self.device))  # you can write any prediciton logic here

    # our first time example
    def on_loader_start(self, runner):
        super().on_loader_start(runner)
        self.meters = {
            key: metrics.AdditiveValueMetric(compute_on_call=False)
            for key in ["loss", "mae"]
        }

    def handle_batch(self, batch):
        # Unpack the data. Its structure depends on your model and
        # on what you pass to `train()`.
        x, y = batch

        y_pred = self.model(x) # Forward pass

        # Compute the loss value
        loss = F.mse_loss(y_pred, y)

        # Update metrics (includes the metric that tracks the loss)
        self.batch_metrics.update({"loss": loss, "mae": F.l1_loss(y_pred, y)})
        for key in ["loss", "mae"]:
            self.meters[key].update(self.batch_metrics[key].item(), self.batch_size)

        if self.is_train_loader:
            # Compute gradients
            loss.backward()
            # Update weights
            # (the optimizer is stored in `self.state`)
            self.optimizer.step()
            self.optimizer.zero_grad()
    
    def on_loader_end(self, runner):
        for key in ["loss", "mae"]:
            self.loader_metrics[key] = self.meters[key].compute()[0]
        super().on_loader_end(runner)

In [9]:
import numpy as np
import torch
from torch.utils.data import DataLoader, TensorDataset

# Construct custom data
num_samples, num_features = int(1e4), int(1e1)
X, y = torch.rand(num_samples, num_features), torch.rand(num_samples, 1)
dataset = TensorDataset(X, y)
loader = DataLoader(dataset, batch_size=32, num_workers=1)
loaders = {"train": loader, "valid": loader}

# and model
model = torch.nn.Linear(num_features, 1)
criterion = torch.nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters())

# Just use `train` as usual
runner = CustomRunner()
runner.train(
  model=model, 
  optimizer=optimizer, 
  loaders=loaders, 
  num_epochs=3,
  verbose=True,
  timeit=False,
  valid_loader="valid",   # `loader_key` from loaders to use for model selection
  valid_metric="loss",    # `metric_key` to use for model selection
  load_best_on_end=True,  # flag to load best model at the end of the training process
  logdir="./logs",        # logdir to store models checkpoints (required for `load_best_on_end`)
)
# and use `batch` prediciton
prediction = runner.predict_batch(next(iter(loader))) # let's sample first batch from loader
# or `loader` prediction
for prediction in runner.predict_loader(loader=loader):
    assert prediction.detach().cpu().numpy().shape[-1] == 1 # as we have 1-class regression

HBox(children=(FloatProgress(value=0.0, description='1/3 * Epoch (train)', max=313.0, style=ProgressStyle(desc…

  for k, v in self.batch_metrics.items()



train (1/3) loss: 0.3298786108493806 | mae: 0.46371007761955274


HBox(children=(FloatProgress(value=0.0, description='1/3 * Epoch (valid)', max=313.0, style=ProgressStyle(desc…


valid (1/3) loss: 0.13439255228042615 | mae: 0.30216759462356574
* Epoch (1/3) 


HBox(children=(FloatProgress(value=0.0, description='2/3 * Epoch (train)', max=313.0, style=ProgressStyle(desc…


train (2/3) loss: 0.12848652594089513 | mae: 0.29636721019744844


HBox(children=(FloatProgress(value=0.0, description='2/3 * Epoch (valid)', max=313.0, style=ProgressStyle(desc…


valid (2/3) loss: 0.1227676458120346 | mae: 0.29078147163391116
* Epoch (2/3) 


HBox(children=(FloatProgress(value=0.0, description='3/3 * Epoch (train)', max=313.0, style=ProgressStyle(desc…


train (3/3) loss: 0.11714917050600063 | mae: 0.285176122999191


HBox(children=(FloatProgress(value=0.0, description='3/3 * Epoch (valid)', max=313.0, style=ProgressStyle(desc…


valid (3/3) loss: 0.11182150180339812 | mae: 0.279757907629013
* Epoch (3/3) 
Top best models:
logs/checkpoints/train.3.pth	0.1118


Finally, after model training and evaluation, it's time to prepare it for deployment. PyTorch upport model tracing for production-friendly Deep Leanring models deployment.

Could we make it quick with Catalyst? Sure!

In [10]:
features_batch = next(iter(loaders["valid"]))[0].to(runner.device)
# model stochastic weight averaging
model.load_state_dict(utils.get_averaged_weights_by_path_mask(logdir="./logs", path_mask="*.pth"))
# model tracing
utils.trace_model(model=runner.model, batch=features_batch)
# model quantization
utils.quantize_model(model=runner.model)
# model pruning
utils.prune_model(model=runner.model, pruning_fn="l1_unstructured", amount=0.8)
# onnx export, catalyst[onnx] or catalyst[onnx-gpu] required
# utils.onnx_export(model=runner.model, batch=features_batch, file="./logs/mnist.onnx", verbose=True)

## Wrapping up: an end-to-end GAN example

Let's walk through an end-to-end example that leverages everything you just learned.

Let's consider:

- A generator network meant to generate 28x28x1 images.
- A discriminator network meant to classify 28x28x1 images into two classes ("fake" - 1 and "real" - 0).



In [11]:
import torch
from torch import nn
from torch.nn import functional as F
from catalyst.contrib.nn.modules import Flatten, GlobalMaxPool2d, Lambda

# Create the discriminator
discriminator = nn.Sequential(
    nn.Conv2d(1, 64, (3, 3), stride=(2, 2), padding=1),
    nn.LeakyReLU(0.2, inplace=True),
    nn.Conv2d(64, 128, (3, 3), stride=(2, 2), padding=1),
    nn.LeakyReLU(0.2, inplace=True),
    GlobalMaxPool2d(),
    Flatten(),
    nn.Linear(128, 1),
)

# Create the generator
latent_dim = 128
generator = nn.Sequential(
    # We want to generate 128 coefficients to reshape into a 7x7x128 map
    nn.Linear(128, 128 * 7 * 7),
    nn.LeakyReLU(0.2, inplace=True),
    Lambda(lambda x: x.view(x.size(0), 128, 7, 7)),
    nn.ConvTranspose2d(128, 128, (4, 4), stride=(2, 2), padding=1),
    nn.LeakyReLU(0.2, inplace=True),
    nn.ConvTranspose2d(128, 128, (4, 4), stride=(2, 2), padding=1),
    nn.LeakyReLU(0.2, inplace=True),
    nn.Conv2d(128, 1, (7, 7), padding=3),
    nn.Sigmoid(),
)

# Final model
model = {"generator": generator, "discriminator": discriminator}
criterion = {"generator": nn.BCEWithLogitsLoss(), "discriminator": nn.BCEWithLogitsLoss()}
optimizer = {
    "generator": torch.optim.Adam(generator.parameters(), lr=0.0003, betas=(0.5, 0.999)),
    "discriminator": torch.optim.Adam(discriminator.parameters(), lr=0.0003, betas=(0.5, 0.999)),
}

Here's a feature-complete `GANRunner`, overriding `predict_batch()` to use its own signature, and implementing the entire GAN algorithm in 16 lines in `handle_batch`:

In [12]:
class GANRunner(dl.Runner):
  
    def __init__(self, latent_dim: int):
        super().__init__()
        self.latent_dim = latent_dim

    def predict_batch(self, batch):
        batch_size = 1
        # Sample random points in the latent space
        random_latent_vectors = torch.randn(batch_size, self.latent_dim).to(self.device)
        # Decode them to fake images
        generated_images = self.model["generator"](random_latent_vectors).detach()
        return generated_images

    def handle_batch(self, batch):
        real_images, _ = batch
        batch_size = real_images.shape[0]

        # Sample random points in the latent space
        random_latent_vectors = torch.randn(batch_size, self.latent_dim).to(self.device)

        # Decode them to fake images
        generated_images = self.model["generator"](random_latent_vectors).detach()
        # Combine them with real images
        combined_images = torch.cat([generated_images, real_images])

        # Assemble labels discriminating real from fake images
        labels = \
            torch.cat([torch.ones((batch_size, 1)), torch.zeros((batch_size, 1))]).to(self.device)
        # Add random noise to the labels - important trick!
        labels += 0.05 * torch.rand(labels.shape).to(self.device)

        # Discriminator forward
        combined_predictions = self.model["discriminator"](combined_images)

        # Sample random points in the latent space
        random_latent_vectors = torch.randn(batch_size, self.latent_dim).to(self.device)
        # Assemble labels that say "all real images"
        misleading_labels = torch.zeros((batch_size, 1)).to(self.device)

        # Generator forward
        generated_images = self.model["generator"](random_latent_vectors)
        generated_predictions = self.model["discriminator"](generated_images)

        self.batch = {
            "combined_predictions": combined_predictions,
            "labels": labels,
            "generated_predictions": generated_predictions,
            "misleading_labels": misleading_labels,
        }