In [None]:
#| default_exp nn

# Custom `Module` and `Sequential`

> Simple tools to provide a Keras-like interface to PyTorch.

In [None]:
#| hide
from nbdev.showdoc import *
import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn.functional as F
from torchmetrics import Accuracy, MetricCollection

In [None]:
#| export
from typing import Dict, Callable, Optional
from tqdm.auto import tqdm

import torch
import torch.nn as nn

from fastcore.basics import patch

In [None]:
#| exporti
def get_pbar_description_from_batch_metrics(batch_metrics, prefix=""):
    description = ""
    for name, value in batch_metrics.items():
        description += f'{prefix}{name}: {value:.2f} '
    return description[:-1]

In [None]:
#| export
class History(dict):
    """
    This object inherits from base `dict` to provide a `History` object similar to Keras', 
    allowing the automatic logging of the loss and the different metrics during training. 
    It's automatically used in `.fit()` (as it is in Keras).
    """
    def log_dict(self, 
                 data: Dict, # Dictionary to log.
                 prefix="", # Prefix for the logged metrics.
                 ):
        """Logs a dictionary into the `History` object."""
        for name, value in data.items():
            name = prefix+name
            if name in self.keys():
                self[name].append(value)
            else:
                self[name] = [value]

    def aggregate(self, 
                  agg_fn: Callable = lambda x: sum(x)/len(x), # Function used to aggregate the data.
                  ) -> Dict: # Returns a dictionary with the aggregated data.
        """Aggregates the stored values using the designed aggregation function."""
        return {name:agg_fn(values) for name, values in self.items()}

In [None]:
#| export
class Module(nn.Module):
    """
    Modification of PyTorch base `nn.Module` to provide a basic
    predefined training loop with logging and a Keras-like interface
    to be able to customize the training.
    This Module implements as well a `.compile()` method and an `.evaluate()` one. 
    All is done to obtain a behaviour as similar to Keras as possible.
    """
    def __init__(self, **kwargs):
        super(Module, self).__init__(**kwargs)

In [None]:
#| exporti
@patch
def train_step(self: Module, 
               batch, # Batch of data to train.
               ) -> Dict: # Metrics from the training step.
    """Perform a training step."""

    inputs, labels = batch
    self.optimizer.zero_grad()
    outputs = self(inputs)
    loss = self.loss_fn(outputs, labels)
    loss.backward()
    self.optimizer.step()

    ## Obtain metrics if needed
    if self.metrics is not None:
        metrics = self.metrics(outputs, labels)
        metrics = {name:value.item() for name, value in metrics.items()}
        metrics['Loss'] = loss.item()
    else:
        metrics = {'Loss':loss.item()}
    return metrics


In [None]:
#| exporti
@patch
def validation_step(self: Module, 
                    batch, # Batch of data to validate.
                    ) -> Dict: # Metrics from the validation step.
    """Perform a validation step"""

    inputs, labels = batch
    outputs = self(inputs)
    loss = self.loss_fn(outputs, labels)

    ## Obtain metrics if needed
    if self.metrics is not None:
        metrics = self.metrics(outputs, labels)
        metrics = {name:value.item() for name, value in metrics.items()}
        metrics['Loss'] = loss.item()
    else:
        metrics = {'Loss':loss.item()}
    return metrics

In [None]:
#| export
@patch
def fit(self: Module, 
        trainloader: torch.utils.data.DataLoader, # Training DataLoader. 
        epochs, # Number of epochs to train.
        validationloader=None, # Validation DataLoader (optional).
        ) -> History: # History object with the training dynamics as in Keras.
    """Fit a model to the desired `trainloader` for `epochs` epochs. Returns the corresponding `History`."""

    history_epoch = History()
    for epoch in tqdm(range(epochs), desc='Epochs', position=0):
        self.train()
        pbar = tqdm(enumerate(trainloader), total=len(trainloader), position=1, leave=False)
        history_batch = History()
        for batch_idx, batch in pbar:
            batch_metrics = self.train_step(batch)
            history_batch.log_dict(batch_metrics)
            pbar.set_description(get_pbar_description_from_batch_metrics(batch_metrics))
        if validationloader is not None:
            self.eval()
            pbar = tqdm(enumerate(validationloader), total=len(validationloader), position=2, leave=False)
            for batch_idx, batch in pbar:
                with torch.no_grad():
                    batch_metrics = self.validation_step(batch)
                history_batch.log_dict(batch_metrics, prefix='Val_')
                pbar.set_description(get_pbar_description_from_batch_metrics(batch_metrics, 'Val_'))
        self.metrics.reset()
        history_epoch.log_dict(history_batch.aggregate())
    return history_epoch

In [None]:
#| export
@patch
def compile(self: Module, 
            loss: Optional=None, # Loss function to be used.
            optimizer: Optional[torch.optim.Optimizer] = None, # Optimizer to be used.
            metrics=None, # `MetricCollection` containing the desired metrics.
            ):
    """Sets the loss, optimizer and desired metrics to train the model."""

    self.loss_fn = loss
    self.optimizer = optimizer
    self.metrics = metrics

In [None]:
#| export
@patch
def evaluate(self: Module, 
             dataloader, # `DataLoader` to evaluate the model with.
             ) -> Dict: # Results from the evaluation
    """Evaluates the model on a set of data."""

    self.eval()
    for i, batch in tqdm(enumerate(dataloader), total=len(dataloader)):
        if i == 0:
            results = self.validation_step(batch)
            results = {name:[value] for name, value in results.items()}
        else:
            result = self.validation_step(batch)
            for name, value in result.items():
                results[name].append(value)
    results = {name:sum(values)/len(values) for name,values in results.items()}
    return results

### Example of usage

We can perform a very simple example using the Fashion MNIST dataset (as is done in the official [PyTorch docs](https://pytorch.org/tutorials/beginner/introyt/trainingyt.html).

In [None]:
transform = transforms.Compose(
    [transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))])

# Create datasets for training & validation, download if necessary
training_set = torchvision.datasets.FashionMNIST('./data', train=True, transform=transform, download=True)
validation_set = torchvision.datasets.FashionMNIST('./data', train=False, transform=transform, download=True)

# Create data loaders for our datasets; shuffle for training, not for validation
training_loader = torch.utils.data.DataLoader(training_set, batch_size=128, shuffle=True, num_workers=0)
validation_loader = torch.utils.data.DataLoader(validation_set, batch_size=256, shuffle=False, num_workers=0)

See that the only different with respect to basic PyTorch is that we're inhereting from our custom `Module`, not from PyTorch's `nn.Module`:

In [None]:
class SimpleModel(Module):
    def __init__(self):
        super(SimpleModel, self).__init__()
        self.conv1 = nn.Conv2d(1, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 4 * 4, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 4 * 4)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

Following the usual Keras way, we instantiate the model and compile it, providing the *loss* and the *optimizer*:

In [None]:
model = SimpleModel()
model.compile(loss=torch.nn.CrossEntropyLoss(),
              optimizer=torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9),
              metrics = MetricCollection([Accuracy()]))

In [None]:
model

SimpleModel(
  (conv1): Conv2d(1, 6, kernel_size=(5, 5), stride=(1, 1))
  (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
  (fc1): Linear(in_features=256, out_features=120, bias=True)
  (fc2): Linear(in_features=120, out_features=84, bias=True)
  (fc3): Linear(in_features=84, out_features=10, bias=True)
  (loss_fn): CrossEntropyLoss()
  (metrics): MetricCollection(
    (Accuracy): Accuracy()
  )
)

In [None]:
model.evaluate(training_loader), model.evaluate(validation_loader)

  0%|          | 0/469 [00:00<?, ?it/s]

  0%|          | 0/40 [00:00<?, ?it/s]

({'Accuracy': 0.11770944386275846, 'Loss': 2.3061464941069514},
 {'Accuracy': 0.12353515625, 'Loss': 2.3057154595851896})

In [None]:
history = model.fit(trainloader=training_loader, epochs=1, validationloader=validation_loader)

Epochs:   0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/469 [00:00<?, ?it/s]

  0%|          | 0/40 [00:00<?, ?it/s]

In [None]:
history

{'Accuracy': [0.18737784294939752],
 'Loss': [2.2949233914235],
 'Val_Accuracy': [0.2640625],
 'Val_Loss': [2.2760460674762726]}

In [None]:
model.evaluate(training_loader), model.evaluate(validation_loader)

  0%|          | 0/469 [00:00<?, ?it/s]

  0%|          | 0/40 [00:00<?, ?it/s]

({'Accuracy': 0.25996690654932564, 'Loss': 2.276174947905388},
 {'Accuracy': 0.2640625, 'Loss': 2.2760460674762726})