# Classification with Delira and PyTorch - A very short introduction
*Author: Justus Schock* 

*Date: 31.07.2019*

This Example shows how to set up a basic classification model and experiment using PyTorch.

Let's first setup the essential hyperparameters. We will use `delira`'s `Parameters`-class for this:

In [None]:
logger = None
import torch
from delira.training import Parameters
params = Parameters(fixed_params={
    "model": {
        "in_channels": 3, 
         "num_classes": 10
    },
    "training": {
        "batch_size": 64, # batchsize to use
        "num_epochs": 10, # number of epochs to train
        "optimizer_cls": torch.optim.Adam, # optimization algorithm to use
        "optimizer_params": {'lr': 1e-3}, # initialization parameters for this algorithm
        "losses": {"CE": torch.nn.CrossEntropyLoss()}, # the loss function
        "lr_sched_cls": None,  # the learning rate scheduling algorithm to use
        "lr_sched_params": {}, # the corresponding initialization parameters
        "metrics": {} # and some evaluation metrics
    }
}) 

Since we did not specify any metric, only the `CrossEntropyLoss` will be calculated for each batch. Since we have a classification task, this should be sufficient. We will train our network with a batchsize of 64 by using `Adam` as optimizer of choice.

## Logging and Visualization
To get a visualization of our results, we should monitor them somehow. For logging we will use `Tensorboard`. Per default the logging directory will be the same as our experiment directory.


## Data Preparation
### Loading
Next we will create some fake data. For this we use the `ClassificationFakeData`-Dataset, which is already implemented in `deliravision`. To avoid getting the exact same data from both datasets, we use a random offset.

In [None]:
from deliravision.data.fakedata import ClassificationFakeData
dataset_train = ClassificationFakeData(num_samples=1000, 
                                       img_size=(3, 32, 32), 
                                       num_classes=10)
dataset_val = ClassificationFakeData(num_samples=100, 
                                     img_size=(3, 32, 32), 
                                     num_classes=10,
                                     rng_offset=10001
                                     )

### Augmentation
For Data-Augmentation we will apply a few transformations:

In [None]:
from batchgenerators.transforms import RandomCropTransform, \
                                        ContrastAugmentationTransform, Compose
from batchgenerators.transforms.spatial_transforms import ResizeTransform
from batchgenerators.transforms.sample_normalization_transforms import ZeroMeanUnitVarianceTransform

transforms = Compose([
    RandomCropTransform(24), # Perform Random Crops of Size 24 x 24 pixels
    ResizeTransform(32), # Resample these crops back to 32 x 32 pixels
    ContrastAugmentationTransform(), # randomly adjust contrast
    ZeroMeanUnitVarianceTransform()
]) 

With these transformations we can now wrap our datasets into datamanagers:

In [None]:
from delira.data_loading import BaseDataManager, SequentialSampler, RandomSampler

manager_train = BaseDataManager(dataset_train, params.nested_get("batch_size"),
                                transforms=transforms,
                                sampler_cls=RandomSampler,
                                n_process_augmentation=4)

manager_val = BaseDataManager(dataset_val, params.nested_get("batch_size"),
                              transforms=transforms,
                              sampler_cls=SequentialSampler,
                              n_process_augmentation=4)


## Model

After we have done that, we can specify our model: We will use a smaller version of a [VGG-Network](https://arxiv.org/pdf/1409.1556.pdf) in this case. We will use more convolutions to reduce the feature dimensionality and reduce the number of units in the linear layers to save up memory (and we only have to deal with 10 classes, not the 1000 imagenet classes).

In [None]:
from delira.models import AbstractPyTorchNetwork
from delira.models.backends.torch.utils import scale_loss
import torch

class Flatten(torch.nn.Module):
        
    def forward(self, x):
        return x.view(x.size(0), -1)

class SmallVGGPyTorch(AbstractPyTorchNetwork):
    def __init__(self, in_channels, num_classes):
        super().__init__()
        
        self.model = torch.nn.Sequential(
            torch.nn.Conv2d(in_channels, 64, 3, padding=1), # 32 x 32
            torch.nn.ReLU(),
            torch.nn.MaxPool2d(2), # 16 x 16
            torch.nn.Conv2d(64, 128, 3, padding=1),
            torch.nn.ReLU(),
            torch.nn.MaxPool2d(2), # 8 x 8
            torch.nn.Conv2d(128, 256, 3, padding=1), # 4 x 4
            torch.nn.ReLU(),
            torch.nn.MaxPool2d(2), # 4 x 4
            torch.nn.Conv2d(256, 512, 3, padding=1),
            torch.nn.ReLU(),
            torch.nn.MaxPool2d(2), # 2 x 2
            torch.nn.Conv2d(512, 512, 3, padding=1),
            torch.nn.ReLU(),
            torch.nn.MaxPool2d(2), # 1 x 1
            Flatten(),
            torch.nn.Linear(1*1*512, num_classes),
        )
        
    def forward(self, x: torch.Tensor):
        return {"pred": self.model(x)}
    
    @staticmethod
    def prepare_batch(data_dict, input_device, output_device):
        return_dict = {
            "data": torch.from_numpy(data_dict["data"]).to(
                input_device).to(torch.float),
            "label": torch.from_numpy(data_dict["label"]).to(
                input_device).to(torch.long).squeeze(dim=1),
        }

        for key, vals in data_dict.items():
            if key in ["data", "label"]:
                continue
            return_dict[key] = torch.from_numpy(vals).to(output_device).to(
                torch.float)

        return return_dict
    
    @staticmethod
    def closure(model, data_dict: dict, optimizers: dict, losses: dict,
                fold=0, **kwargs):

        loss_vals = {}
        total_loss = 0


        # predict
        inputs = data_dict.pop("data")
        preds = model(inputs)

        # calculate losses
        for key, crit_fn in losses.items():
            _loss_val = crit_fn(preds["pred"], data_dict["label"])
            loss_vals[key] = _loss_val.item()
            total_loss += _loss_val

        optimizers['default'].zero_grad()
        # perform loss scaling via apex if half precision is enabled
        with scale_loss(total_loss, optimizers["default"]) as scaled_loss:
            scaled_loss.backward()
        optimizers['default'].step()
        
        # return metrics, losses, predcitions
        return {}, loss_vals, {k: v.detach() for k, v in preds.items()}
    
    

So let's evisit, what we have just done.

In `delira` all networks must be derived from `delira.models.AbstractNetwork`. For each backend there is a class derived from this class, handling some backend-specific function calls and registrations. For the `PyTorch` Backend this class is `AbstractPyTorchNetwork` and all PyTorch Networks should be derived from it.

First we defined the network itself (this is the part simply concatenating the layers into a sequential model). Next, we defined the logic to apply, when we want to predict from the model (this is the `forward` method).

So far this was plain `PyTorch`. The `prepare_batch` function is not plain PyTorch anymore, but allows us to ensure the data is in the correct shape, has the correct data-type and lies on the correct device. The function above is the standard `prepare_batch` function, which is also implemented in the `AbstractPyTorchNetwork` and just re-implemented here for the sake of completeness.

Same goes for the `closure` function. This function defines the update rule for our parameters (and how to calculate the losses). These funcitons are good to go for many simple networks but can be overwritten for customization when training more complex networks.


## Training
Now that we have defined our network, we can finally specify our experiment and run it.

In [None]:
import warnings
warnings.simplefilter("ignore", UserWarning) # ignore UserWarnings raised by dependency code
warnings.simplefilter("ignore", FutureWarning) # ignore FutureWarnings raised by dependency code


from delira.training import PyTorchExperiment
from delira.training.backends import create_pytorch_optims_default

if logger is not None:
    logger.info("Init Experiment")
experiment = PyTorchExperiment(params, SmallVGGPyTorch,
                               name="ClassificationExample",
                               save_path="./tmp/delira_Experiments",
                               optim_builder=create_pytorch_optims_default,
                               key_mapping={"x": "data"},
                               gpu_ids=[0])
experiment.save()

model = experiment.run(manager_train, manager_val)

Congratulations, you have now trained your first Classification Model using `delira`, we will now predict a few samples from the testset to show, that the networks predictions are valid (for now, this is done manually, but we also have a `Predictor` class to automate stuff like this):

The accuracy is pretty low because our dataset only consists of random data.

In [None]:
import numpy as np
from tqdm.auto import tqdm # utility for progress bars

device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # set device (use GPU if available)
model = model.to(device) # push model to device
preds, labels = [], []

with torch.no_grad():
    for i in tqdm(range(len(dataset_val))):
        img = dataset_val[i]["data"] # get image from current batch
        img_tensor = torch.from_numpy(img).unsqueeze(0).to(device).to(torch.float) # create a tensor from image, push it to device and add batch dimension
        pred_tensor = model(img_tensor) # feed it through the network
        pred = pred_tensor["pred"].argmax(1).item() # get index with maximum class confidence
        label = np.asscalar(dataset_val[i]["label"]) # get label from batch
        if i % 1000 == 0:
            print("Prediction: %d \t label: %d" % (pred, label)) # print result
        preds.append(pred)
        labels.append(label)
        
# calculate accuracy
accuracy = (np.asarray(preds) == np.asarray(labels)).sum() / len(preds)
print("Accuracy: %.3f" % accuracy)