In [None]:
# default_exp core

# Korch

> Simple tools to provide a Keras-like interface to PyTorch.

In [None]:
#hide
from nbdev.showdoc import *
import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn.functional as F
from torchmetrics import Accuracy, MetricCollection

In [None]:
#hide
#export
import torch
import torch.nn as nn
from tqdm.auto import tqdm

In [None]:
#hide
#exporti
def get_pbar_description_from_batch_metrics(batch_metrics, prefix=""):
    description = ""
    for name, value in batch_metrics.items():
        description += f'{prefix}{name}: {value:.2f} '
    return description[:-1]

In [None]:
#export
class History(dict):
    def log_dict(self, data, prefix=""):
        """
        Parameters
        ----------
        data: dict
        """
        for name, value in data.items():
            name = prefix+name
            if name in self.keys():
                self[name].append(value)
            else:
                self[name] = [value]

    def aggregate(self, agg_fn=lambda x: sum(x)/len(x)):
        """
        Aggregates the stored values using the designed aggregation function.

        Parameters
        ----------
        agg_fn: function
            Mean by default.

        Returns
        -------
        agg_data: dict
            Aggregated data.
        """
        return {name:agg_fn(values) for name, values in self.items()}

This object inherits from base `dict` to provide a `History` object similar to Keras', allowing the automatic logging of the loss and the different metrics during training. Is automatically used in `.fit()` (as it is in Keras).

In [None]:
#export
class Module(nn.Module):
    def __init__(self, **kwargs):
        super(Module, self).__init__(**kwargs)

    def train_step(self, batch):
        inputs, labels = batch
        self.optimizer.zero_grad()
        outputs = self(inputs)
        loss = self.loss_fn(outputs, labels)
        loss.backward()
        self.optimizer.step()

        ## Obtain metrics if needed
        if self.metrics is not None:
            metrics = self.metrics(outputs, labels)
            metrics = {name:value.item() for name, value in metrics.items()}
            metrics['Loss'] = loss.item()
        else:
            metrics = {'Loss':loss.item()}
        return metrics
    
    def validation_step(self, batch):
        inputs, labels = batch
        outputs = self(inputs)
        loss = self.loss_fn(outputs, labels)

        ## Obtain metrics if needed
        if self.metrics is not None:
            metrics = self.metrics(outputs, labels)
            metrics = {name:value.item() for name, value in metrics.items()}
            metrics['Loss'] = loss.item()
        else:
            metrics = {'Loss':loss.item()}
        return metrics

    def fit(self, trainloader, epochs, validationloader=None):
        history_epoch = History()
        for epoch in tqdm(range(epochs), desc='Epochs', position=0):
            self.train()
            pbar = tqdm(enumerate(trainloader), total=len(trainloader), position=1, leave=False)
            history_batch = History()
            for batch_idx, batch in pbar:
                batch_metrics = self.train_step(batch)
                history_batch.log_dict(batch_metrics)
                pbar.set_description(get_pbar_description_from_batch_metrics(batch_metrics))
            if validationloader is not None:
                self.eval()
                pbar = tqdm(enumerate(validationloader), total=len(validationloader), position=2, leave=False)
                for batch_idx, batch in pbar:
                    with torch.no_grad():
                        batch_metrics = self.validation_step(batch)
                    history_batch.log_dict(batch_metrics, prefix='Val_')
                    pbar.set_description(get_pbar_description_from_batch_metrics(batch_metrics, 'Val_'))
            self.metrics.reset()
            history_epoch.log_dict(history_batch.aggregate())
        return history_epoch

    def compile(self, loss=None, optimizer=None, metrics=None):
        """
        metrics: torchmetrics.MetricCollection
        """
        self.loss_fn = loss
        self.optimizer = optimizer
        self.metrics = metrics

    def evaluate(self, dataloader):
        self.eval()
        for i, batch in tqdm(enumerate(dataloader), total=len(dataloader)):
            if i == 0:
                results = self.validation_step(batch)
                results = {name:[value] for name, value in results.items()}
            else:
                result = self.validation_step(batch)
                for name, value in result.items():
                    results[name].append(value)
        results = {name:sum(values)/len(values) for name,values in results.items()}
        return results

This Module implements a basic training loop similar to Keras', as well as a `compile` method and an `evaluate` one. All is done to obtain a behaviour as similar to Keras as possible.

### Example of usage

We can perform a very simple example using the Fashion MNIST dataset (as is done in the official [PyTorch docs](https://pytorch.org/tutorials/beginner/introyt/trainingyt.html).

In [None]:
transform = transforms.Compose(
    [transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))])

# Create datasets for training & validation, download if necessary
training_set = torchvision.datasets.FashionMNIST('./data', train=True, transform=transform, download=True)
validation_set = torchvision.datasets.FashionMNIST('./data', train=False, transform=transform, download=True)

# Create data loaders for our datasets; shuffle for training, not for validation
training_loader = torch.utils.data.DataLoader(training_set, batch_size=128, shuffle=True, num_workers=2)
validation_loader = torch.utils.data.DataLoader(validation_set, batch_size=256, shuffle=False, num_workers=2)

See that the only different with respect to basic PyTorch is that we're inhereting from our custom `Module`, not from PyTorch's `nn.Module`:

In [None]:
class SimpleModel(Module):
    def __init__(self):
        super(SimpleModel, self).__init__()
        self.conv1 = nn.Conv2d(1, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 4 * 4, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 4 * 4)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

Following the usual Keras way, we instantiate the model and compile it, providing the *loss* and the *optimizer*:

In [None]:
model = SimpleModel()
model.compile(loss=torch.nn.CrossEntropyLoss(),
              optimizer=torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9),
              metrics = MetricCollection([Accuracy()]))

In [None]:
model

SimpleModel(
  (conv1): Conv2d(1, 6, kernel_size=(5, 5), stride=(1, 1))
  (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
  (fc1): Linear(in_features=256, out_features=120, bias=True)
  (fc2): Linear(in_features=120, out_features=84, bias=True)
  (fc3): Linear(in_features=84, out_features=10, bias=True)
  (loss_fn): CrossEntropyLoss()
  (metrics): MetricCollection(
    (Accuracy): Accuracy()
  )
)

In [None]:
model.evaluate(training_loader), model.evaluate(validation_loader)

  0%|          | 0/469 [00:00<?, ?it/s]

  0%|          | 0/40 [00:00<?, ?it/s]

({'Accuracy': 0.10011327292110875, 'Loss': 2.3076691297071577},
 {'Accuracy': 0.09892578125, 'Loss': 2.3078657686710358})

In [None]:
history = model.fit(trainloader=training_loader, epochs=1, validationloader=validation_loader)

Epochs:   0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/469 [00:00<?, ?it/s]

  0%|          | 0/40 [00:00<?, ?it/s]

In [None]:
history

{'Accuracy': [0.17958200073191352],
 'Loss': [2.293788175338875],
 'Val_Accuracy': [0.36142578125],
 'Val_Loss': [2.26760156750679]}

In [None]:
model.evaluate(training_loader), model.evaluate(validation_loader)

  0%|          | 0/469 [00:00<?, ?it/s]

  0%|          | 0/40 [00:00<?, ?it/s]

({'Accuracy': 0.35581578711456835, 'Loss': 2.2677379803362685},
 {'Accuracy': 0.36142578125, 'Loss': 2.26760156750679})