# Model training using trainers
This use-case is model training - the same, but now the usage of Trainer will be shown.

In [1]:
#!pip3 install torchvision

In [2]:
import cascade.data as cdd
import cascade.models as cdm
from cascade.utils.torch import TorchModel
from cascade.utils.sklearn import SkMetric

from tqdm import tqdm
import torch
import torchvision
from torchvision.transforms import functional as F
from torch import nn

  warn(f"Failed to load image Python extension: {e}")


In [3]:
import cascade
cascade.__version__

'0.13.0'

## Defining data pipeline
This part will be without comments. For more detailed explanations, please see [pipeline building example](https://oxid15.github.io/cascade/examples/pipeline_building.html)

In [4]:
MNIST_ROOT = 'data'
INPUT_SIZE = 784
BATCH_SIZE = 10

In [5]:
class NoiseModifier(cdd.Modifier):
    def __getitem__(self, index):
        img, label = self._dataset[index]
        img += torch.rand_like(img) * 0.1
        img = torch.clip(img, 0, 255)
        return img, label


train_ds = torchvision.datasets.MNIST(root=MNIST_ROOT,
                                     train=True, 
                                     transform=F.to_tensor,
                                     download=True)
test_ds = torchvision.datasets.MNIST(root=MNIST_ROOT, 
                                    train=False, 
                                    transform=F.to_tensor)

train_ds = cdd.Wrapper(train_ds)
train_ds.describe("This is MNIST dataset of handwritten images, TRAIN PART")
test_ds = cdd.Wrapper(test_ds)

train_ds = NoiseModifier(train_ds)
test_ds = NoiseModifier(test_ds)

# We will constraint the number of samples to speed up learning in example
train_ds = cdd.CyclicSampler(train_ds, 10000)
test_ds = cdd.CyclicSampler(test_ds, 5000)

train_dl = torch.utils.data.DataLoader(dataset=train_ds, 
                                       batch_size=BATCH_SIZE,
                                       shuffle=True)
test_dl = torch.utils.data.DataLoader(dataset=test_ds,
                                      batch_size=BATCH_SIZE,
                                      shuffle=False)

In [6]:
train_ds.get_meta()

[{'name': 'cascade.data.cyclic_sampler.CyclicSampler',
  'description': None,
  'tags': [],
  'comments': [],
  'links': [],
  'type': 'dataset',
  'len': 10000},
 {'name': '__main__.NoiseModifier',
  'description': None,
  'tags': [],
  'comments': [],
  'links': [],
  'type': 'dataset',
  'len': 60000},
 {'name': 'cascade.data.dataset.Wrapper',
  'description': 'This is MNIST dataset of handwritten images, TRAIN PART',
  'tags': [],
  'comments': [],
  'links': [],
  'type': 'dataset',
  'len': 60000,
  'obj_type': "<class 'torchvision.datasets.mnist.MNIST'>"}]

## Model definition
Before training we need to define our model. We need regular nn.Module and Cascade's wrapper around it.  
  
Module defined without any specific changes in the original pytorch code, except now it accepts `*args` and `**kwargs` in `__init__`

In [7]:
class SimpleNN(nn.Module):
    def __init__(self, input_size, hidden_size, num_classes, *args, **kwargs):
        super().__init__()

        self.input_size = input_size
        self.hidden_size = hidden_size
        self.l1 = nn.Linear(input_size, hidden_size)
        self.l2 = nn.Linear(hidden_size, num_classes)
        self.relu = nn.ReLU()

    def forward(self, y):
         out = self.l1(y)
         out = self.relu(out)
         out = self.l2(out)

         return out

Next Cascade's wrapper is defined. The most of the interaction with pytorch modules are already implemented in `cascade.utils.TorchModel` so we need to only define how to train and evaluate the model.  
  
The difference between previous example and this in the `fit` function - now it only fits one epoch per call and doesn't need additional logging - Trainer will cover this functionality.

In [8]:
class Classifier(TorchModel):
    # In train we copy-paste regular pytorch trainloop, 
    # but use self._model, where our SimpleNN is placed
    def fit(self, train_dl, lr, *args, **kwargs):
        criterion = nn.CrossEntropyLoss()
        optim = torch.optim.Adam(self._model.parameters(), lr=lr)

        ds_size = len(train_dl)
        for x, (imgs, labels) in enumerate(train_dl): 
            imgs = imgs.reshape(-1, self._model.input_size)

            out = self._model(imgs)
            loss = criterion(out, labels)

            optim.zero_grad()
            loss.backward()
            optim.step() 


    # Evaluate function takes the metrics from arguments
    # and populates self.metrics without returning anything
    def evaluate(self, test_dl, metrics, *args, **kwargs):
        pred = []
        gt = []
        for imgs, labels in tqdm(test_dl): 
            imgs = imgs.reshape(-1, self._model.input_size)
            out = torch.argmax(self._model(imgs, *args, **kwargs), -1)

            pred.append(out)
            gt.append(labels)

        pred = torch.concat(pred).detach().numpy()
        gt = torch.concat(gt).detach().numpy()

        for metric in metrics:
            metric.compute(gt, pred)
            self.add_metric(metric)

### Model initialization

In [9]:
NUM_EPOCHS = 5
LR = 1e-3

# Classifier will initialize SimpleNN with all the parameters passed
# but some of them are not for the SimpleNN, but to be recorded in metadata
model = Classifier(SimpleNN,
    # These arguments are needed by SimpleNN, 
    # but passed as keywords to be recorded in meta
    input_size=INPUT_SIZE,
    hidden_size=100,
    num_classes=10,
    # These arguments will be skipped by SimpleNN,
    # but will be added to meta
    num_epochs=NUM_EPOCHS,
    lr=LR,
    bs=BATCH_SIZE)

## Set up trainer
Let's set up logging first to catch trainer's logs

In [10]:
import sys
import logging
logging.basicConfig(
    handlers=[logging.StreamHandler(sys.stdout)],
    level='INFO'
)

In [18]:
# Trainer accepts ModelRepo object or just a path 
trainer = cdm.BasicTrainer('trainer_repo')

In [19]:
# The main method of course is train
# It will do all the stuff needed for us
# including training, evaluating, saving and logging
trainer.train(
    model,
    train_data=train_dl,
    test_data=test_dl,
    train_kwargs={'lr': LR, 'bs': BATCH_SIZE}, # will be passed into model.fit()
    test_kwargs={"metrics": [SkMetric("accuracy_score")]}, # will be passed into model.evaluate()
    epochs=NUM_EPOCHS,
    start_from=None, # can start from checkpoint if line is specified,
    save_strategy=2,
    eval_strategy=1
)

INFO:cascade.models.trainer:Training started with parameters:
{'lr': 0.001, 'bs': 10}
INFO:cascade.models.trainer:repo is ModelRepo in trainer_repo of 1 lines
INFO:cascade.models.trainer:line is 00000
INFO:cascade.models.trainer:training will last 10 epochs


100%|██████████| 500/500 [00:05<00:00, 89.62it/s]

INFO:cascade.models.trainer:Epoch: 0





INFO:cascade.models.trainer:SkMetric(name=accuracy_score, value=0.9348, created_at=2023-11-06T14:21:51.679983+00:00)


100%|██████████| 500/500 [00:05<00:00, 91.46it/s] 

INFO:cascade.models.trainer:Epoch: 1
INFO:cascade.models.trainer:SkMetric(name=accuracy_score, value=0.9376, created_at=2023-11-06T14:21:51.679983+00:00)



100%|██████████| 500/500 [00:10<00:00, 46.28it/s]


INFO:cascade.models.trainer:Epoch: 2
INFO:cascade.models.trainer:SkMetric(name=accuracy_score, value=0.9378, created_at=2023-11-06T14:21:51.679983+00:00)


100%|██████████| 500/500 [00:08<00:00, 57.03it/s]


INFO:cascade.models.trainer:Epoch: 3
INFO:cascade.models.trainer:SkMetric(name=accuracy_score, value=0.9338, created_at=2023-11-06T14:21:51.679983+00:00)


100%|██████████| 500/500 [00:06<00:00, 72.69it/s]


INFO:cascade.models.trainer:Epoch: 4
INFO:cascade.models.trainer:SkMetric(name=accuracy_score, value=0.9336, created_at=2023-11-06T14:21:51.679983+00:00)


100%|██████████| 500/500 [00:07<00:00, 68.04it/s]

INFO:cascade.models.trainer:Epoch: 5





INFO:cascade.models.trainer:SkMetric(name=accuracy_score, value=0.937, created_at=2023-11-06T14:21:51.679983+00:00)


100%|██████████| 500/500 [00:08<00:00, 59.83it/s]


INFO:cascade.models.trainer:Epoch: 6
INFO:cascade.models.trainer:SkMetric(name=accuracy_score, value=0.9416, created_at=2023-11-06T14:21:51.679983+00:00)


100%|██████████| 500/500 [00:06<00:00, 73.83it/s]

INFO:cascade.models.trainer:Epoch: 7
INFO:cascade.models.trainer:SkMetric(name=accuracy_score, value=0.9256, created_at=2023-11-06T14:21:51.679983+00:00)



100%|██████████| 500/500 [00:11<00:00, 43.62it/s]

INFO:cascade.models.trainer:Epoch: 8





INFO:cascade.models.trainer:SkMetric(name=accuracy_score, value=0.9396, created_at=2023-11-06T14:21:51.679983+00:00)


100%|██████████| 500/500 [00:07<00:00, 69.91it/s]

INFO:cascade.models.trainer:Epoch: 9





INFO:cascade.models.trainer:SkMetric(name=accuracy_score, value=0.9394, created_at=2023-11-06T14:21:51.679983+00:00)
INFO:cascade.models.trainer:Training finished in 4 minutes


## Results
We can obtain the results of training from trainer's meta data.

In [20]:
trainer.get_meta()

[{'name': 'cascade.models.trainer.BasicTrainer',
  'description': None,
  'tags': [],
  'comments': [],
  'links': [],
  'metrics': [[SkMetric(name=accuracy_score, value=0.9394, created_at=2023-11-06T14:21:51.679983+00:00)],
   [SkMetric(name=accuracy_score, value=0.9394, created_at=2023-11-06T14:21:51.679983+00:00)],
   [SkMetric(name=accuracy_score, value=0.9394, created_at=2023-11-06T14:21:51.679983+00:00)],
   [SkMetric(name=accuracy_score, value=0.9394, created_at=2023-11-06T14:21:51.679983+00:00)],
   [SkMetric(name=accuracy_score, value=0.9394, created_at=2023-11-06T14:21:51.679983+00:00)],
   [SkMetric(name=accuracy_score, value=0.9394, created_at=2023-11-06T14:21:51.679983+00:00)],
   [SkMetric(name=accuracy_score, value=0.9394, created_at=2023-11-06T14:21:51.679983+00:00)],
   [SkMetric(name=accuracy_score, value=0.9394, created_at=2023-11-06T14:21:51.679983+00:00)],
   [SkMetric(name=accuracy_score, value=0.9394, created_at=2023-11-06T14:21:51.679983+00:00)],
   [SkMetric(na

## Start from checkpoint
Let's try continue learning where we finished using the same line as before.

In [22]:
trainer.train(
    model,
    train_data=train_dl,
    test_data=test_dl,
    train_kwargs={'lr': LR, 'bs': BATCH_SIZE},
    test_kwargs={'metrics': [SkMetric("accuracy_score")]},
    epochs=5,
    start_from='00000',
    save_strategy=4,
    eval_strategy=1
)

Error when loading an artifact - c:\cascade\cascade\docs\source\examples\trainer_repo\00000\00010\artifacts is not a folder
Error when loading an artifact - c:\cascade\cascade\docs\source\examples\trainer_repo\00000\00009\artifacts is not a folder
INFO:cascade.models.trainer:Training started with parameters:
{'lr': 0.001, 'bs': 10}
INFO:cascade.models.trainer:repo is ModelRepo in trainer_repo of 1 lines
INFO:cascade.models.trainer:line is 00000
INFO:cascade.models.trainer:started from model 8
INFO:cascade.models.trainer:training will last 5 epochs


100%|██████████| 500/500 [00:13<00:00, 36.09it/s]


INFO:cascade.models.trainer:Epoch: 0
INFO:cascade.models.trainer:SkMetric(name=accuracy_score, value=0.941, created_at=2023-11-06T14:27:45.955057+00:00)


100%|██████████| 500/500 [00:07<00:00, 65.59it/s]

INFO:cascade.models.trainer:Epoch: 1
INFO:cascade.models.trainer:SkMetric(name=accuracy_score, value=0.936, created_at=2023-11-06T14:27:45.955057+00:00)



100%|██████████| 500/500 [00:21<00:00, 23.12it/s]


INFO:cascade.models.trainer:Epoch: 2
INFO:cascade.models.trainer:SkMetric(name=accuracy_score, value=0.936, created_at=2023-11-06T14:27:45.955057+00:00)


100%|██████████| 500/500 [00:12<00:00, 40.11it/s]


INFO:cascade.models.trainer:Epoch: 3
INFO:cascade.models.trainer:SkMetric(name=accuracy_score, value=0.941, created_at=2023-11-06T14:27:45.955057+00:00)


100%|██████████| 500/500 [00:11<00:00, 45.25it/s]


INFO:cascade.models.trainer:Epoch: 4
INFO:cascade.models.trainer:SkMetric(name=accuracy_score, value=0.9364, created_at=2023-11-06T14:27:45.955057+00:00)
INFO:cascade.models.trainer:Training finished in 3 minutes


In [23]:
trainer.metrics

[[SkMetric(name=accuracy_score, value=0.9394, created_at=2023-11-06T14:21:51.679983+00:00)],
 [SkMetric(name=accuracy_score, value=0.9394, created_at=2023-11-06T14:21:51.679983+00:00)],
 [SkMetric(name=accuracy_score, value=0.9394, created_at=2023-11-06T14:21:51.679983+00:00)],
 [SkMetric(name=accuracy_score, value=0.9394, created_at=2023-11-06T14:21:51.679983+00:00)],
 [SkMetric(name=accuracy_score, value=0.9394, created_at=2023-11-06T14:21:51.679983+00:00)],
 [SkMetric(name=accuracy_score, value=0.9394, created_at=2023-11-06T14:21:51.679983+00:00)],
 [SkMetric(name=accuracy_score, value=0.9394, created_at=2023-11-06T14:21:51.679983+00:00)],
 [SkMetric(name=accuracy_score, value=0.9394, created_at=2023-11-06T14:21:51.679983+00:00)],
 [SkMetric(name=accuracy_score, value=0.9394, created_at=2023-11-06T14:21:51.679983+00:00)],
 [SkMetric(name=accuracy_score, value=0.9394, created_at=2023-11-06T14:21:51.679983+00:00)],
 [SkMetric(name=accuracy_score, value=0.9364, created_at=2023-11-06T14

## See also:
- [Pipeline building](pipeline_building.html)