# Model training using trainers
This use-case is model training - the same, but now the usage of Trainer will be shown.

In [4]:
# !pip3 install torchvision
# !pip3 install scikit-learn

Collecting scikit-learn
  Downloading scikit_learn-1.5.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (13.3 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m13.3/13.3 MB[0m [31m2.2 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0mm
[?25hCollecting scipy>=1.6.0
  Downloading scipy-1.13.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (38.6 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m38.6/38.6 MB[0m [31m3.6 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0m
Collecting threadpoolctl>=3.1.0
  Downloading threadpoolctl-3.5.0-py3-none-any.whl (18 kB)
Installing collected packages: threadpoolctl, scipy, scikit-learn
Successfully installed scikit-learn-1.5.0 scipy-1.13.1 threadpoolctl-3.5.0


In [5]:
import cascade.data as cdd
import cascade.models as cdm
from cascade.utils.torch import TorchModel
from cascade.utils.sklearn import SkMetric

from tqdm import tqdm
import torch
import torchvision
from torchvision.transforms import functional as F
from torch import nn

In [6]:
import cascade
cascade.__version__

'0.14.0-alpha'

## Defining data pipeline
This part will be without comments. For more detailed explanations, please see [pipeline building example](https://oxid15.github.io/cascade/examples/pipeline_building.html)

In [7]:
MNIST_ROOT = 'data'
INPUT_SIZE = 784
BATCH_SIZE = 10

In [8]:
class NoiseModifier(cdd.Modifier):
    def __getitem__(self, index):
        img, label = self._dataset[index]
        img += torch.rand_like(img) * 0.1
        img = torch.clip(img, 0, 255)
        return img, label


train_ds = torchvision.datasets.MNIST(root=MNIST_ROOT,
                                     train=True, 
                                     transform=F.to_tensor,
                                     download=True)
test_ds = torchvision.datasets.MNIST(root=MNIST_ROOT, 
                                    train=False, 
                                    transform=F.to_tensor)

train_ds = cdd.Wrapper(train_ds)
train_ds.describe("This is MNIST dataset of handwritten images, TRAIN PART")
test_ds = cdd.Wrapper(test_ds)

train_ds = NoiseModifier(train_ds)
test_ds = NoiseModifier(test_ds)

# We will constraint the number of samples to speed up learning in example
train_ds = cdd.CyclicSampler(train_ds, 10000)
test_ds = cdd.CyclicSampler(test_ds, 5000)

train_dl = torch.utils.data.DataLoader(dataset=train_ds, 
                                       batch_size=BATCH_SIZE,
                                       shuffle=True)
test_dl = torch.utils.data.DataLoader(dataset=test_ds,
                                      batch_size=BATCH_SIZE,
                                      shuffle=False)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to data/MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9912422/9912422 [00:11<00:00, 840403.58it/s] 


Extracting data/MNIST/raw/train-images-idx3-ubyte.gz to data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to data/MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28881/28881 [00:00<00:00, 585963.72it/s]


Extracting data/MNIST/raw/train-labels-idx1-ubyte.gz to data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to data/MNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 1648877/1648877 [00:00<00:00, 2002218.63it/s]


Extracting data/MNIST/raw/t10k-images-idx3-ubyte.gz to data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to data/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4542/4542 [00:00<00:00, 760469.79it/s]


Extracting data/MNIST/raw/t10k-labels-idx1-ubyte.gz to data/MNIST/raw



In [9]:
train_ds.get_meta()

[{'name': 'cascade.data.cyclic_sampler.CyclicSampler',
  'description': None,
  'tags': [],
  'comments': [],
  'links': [],
  'type': 'dataset',
  'len': 10000},
 {'name': '__main__.NoiseModifier',
  'description': None,
  'tags': [],
  'comments': [],
  'links': [],
  'type': 'dataset',
  'len': 60000},
 {'name': 'cascade.data.dataset.Wrapper',
  'description': 'This is MNIST dataset of handwritten images, TRAIN PART',
  'tags': [],
  'comments': [],
  'links': [],
  'type': 'dataset',
  'len': 60000,
  'obj_type': "<class 'torchvision.datasets.mnist.MNIST'>"}]

## Model definition
Before training we need to define our model. We need regular nn.Module and Cascade's wrapper around it.  
  
Module defined without any specific changes in the original pytorch code, except now it accepts `*args` and `**kwargs` in `__init__`.

In [10]:
class SimpleNN(nn.Module):
    def __init__(self, input_size, hidden_size, num_classes, *args, **kwargs):
        super().__init__()

        self.input_size = input_size
        self.hidden_size = hidden_size
        self.l1 = nn.Linear(input_size, hidden_size)
        self.l2 = nn.Linear(hidden_size, num_classes)
        self.relu = nn.ReLU()

    def forward(self, y):
         out = self.l1(y)
         out = self.relu(out)
         out = self.l2(out)

         return out

Next Cascade's wrapper is defined. The most of the interactions with pytorch modules is already implemented in `cascade.utils.TorchModel` so we need to only define how to train and evaluate this model.  
  
The difference between previous example and this one is in the `fit` function - now it only fits one epoch per call and doesn't need additional logging - Trainer will cover this functionality.

In [11]:
class Classifier(TorchModel):
    # In train we copy-paste regular pytorch trainloop, 
    # but use self._model, where our SimpleNN is placed
    def fit(self, train_dl, lr, *args, **kwargs):
        criterion = nn.CrossEntropyLoss()
        optim = torch.optim.Adam(self._model.parameters(), lr=lr)

        ds_size = len(train_dl)
        for x, (imgs, labels) in enumerate(train_dl): 
            imgs = imgs.reshape(-1, self._model.input_size)

            out = self._model(imgs)
            loss = criterion(out, labels)

            optim.zero_grad()
            loss.backward()
            optim.step() 


    # Evaluate function takes the metrics from arguments
    # and populates self.metrics without returning anything
    def evaluate(self, test_dl, metrics, *args, **kwargs):
        pred = []
        gt = []
        for imgs, labels in tqdm(test_dl): 
            imgs = imgs.reshape(-1, self._model.input_size)
            out = torch.argmax(self._model(imgs, *args, **kwargs), -1)

            pred.append(out)
            gt.append(labels)

        pred = torch.concat(pred).detach().numpy()
        gt = torch.concat(gt).detach().numpy()

        for metric in metrics:
            metric.compute(gt, pred)
            self.add_metric(metric)

### Model initialization

In [12]:
NUM_EPOCHS = 5
LR = 1e-3

# Classifier will initialize SimpleNN with all the parameters passed
# but some of them are not for the SimpleNN, but to be recorded in metadata
model = Classifier(SimpleNN,
    # These arguments are needed by SimpleNN, 
    # but passed as keywords to be recorded in meta
    input_size=INPUT_SIZE,
    hidden_size=100,
    num_classes=10,
    # These arguments will be skipped by SimpleNN,
    # but will be added to meta
    num_epochs=NUM_EPOCHS,
    lr=LR,
    bs=BATCH_SIZE)

## Set up trainer
Let's set up logging first to catch trainer's logs

In [13]:
import sys
import logging
logging.basicConfig(
    handlers=[logging.StreamHandler(sys.stdout)],
    level='INFO'
)

In [14]:
# Trainer accepts ModelRepo object or just a path 
trainer = cdm.BasicTrainer('trainer_repo')

In [15]:
# The main method of course is train
# It will do all the stuff needed for us
# including training, evaluating, saving and logging
trainer.train(
    model,
    train_data=train_dl,
    test_data=test_dl,
    train_kwargs={'lr': LR, 'bs': BATCH_SIZE}, # will be passed into model.fit()
    test_kwargs={"metrics": [SkMetric("accuracy_score")]}, # will be passed into model.evaluate()
    epochs=NUM_EPOCHS,
    start_from=None, # can start from checkpoint if line is specified,
    save_strategy=2,
    eval_strategy=1
)

INFO:cascade.models.trainer:Training started with parameters:
{'lr': 0.001, 'bs': 10}
INFO:cascade.models.trainer:repo is ModelRepo in trainer_repo of 1 lines
INFO:cascade.models.trainer:line is 00000
INFO:cascade.models.trainer:training will last 5 epochs


100%|██████████| 500/500 [00:07<00:00, 66.67it/s]

INFO:cascade.models.trainer:Epoch: 0
INFO:cascade.models.trainer:SkMetric(name=accuracy_score, value=0.884, created_at=2024-06-12 14:52:57.222170+00:00)



100%|██████████| 500/500 [00:06<00:00, 75.63it/s] 

INFO:cascade.models.trainer:Epoch: 1
INFO:cascade.models.trainer:SkMetric(name=accuracy_score, value=0.8878, created_at=2024-06-12 14:52:57.222170+00:00)



100%|██████████| 500/500 [00:05<00:00, 85.85it/s] 

INFO:cascade.models.trainer:Epoch: 2
INFO:cascade.models.trainer:SkMetric(name=accuracy_score, value=0.9066, created_at=2024-06-12 14:52:57.222170+00:00)



100%|██████████| 500/500 [00:06<00:00, 72.60it/s]

INFO:cascade.models.trainer:Epoch: 3
INFO:cascade.models.trainer:SkMetric(name=accuracy_score, value=0.9148, created_at=2024-06-12 14:52:57.222170+00:00)



100%|██████████| 500/500 [00:06<00:00, 82.80it/s] 

INFO:cascade.models.trainer:Epoch: 4
INFO:cascade.models.trainer:SkMetric(name=accuracy_score, value=0.92, created_at=2024-06-12 14:52:57.222170+00:00)
INFO:cascade.models.trainer:Training finished in 2 minutes
INFO:cascade.models.trainer:repo was ModelRepo in trainer_repo of 1 lines
INFO:cascade.models.trainer:line was 00000
INFO:cascade.models.trainer:training ended on 4 epoch
INFO:cascade.models.trainer:Parameters:
{'lr': 0.001, 'bs': 10}
INFO:cascade.models.trainer:Metrics:
INFO:cascade.models.trainer:SkMetric(name=accuracy_score, value=0.92, created_at=2024-06-12 14:52:57.222170+00:00)





## Results
We can obtain the results of training from trainer's meta data.

In [16]:
trainer.get_meta()

[{'name': 'cascade.models.trainer.BasicTrainer',
  'description': None,
  'tags': [],
  'comments': [],
  'links': [],
  'metrics': [[SkMetric(name=accuracy_score, value=0.92, created_at=2024-06-12 14:52:57.222170+00:00)],
   [SkMetric(name=accuracy_score, value=0.92, created_at=2024-06-12 14:52:57.222170+00:00)],
   [SkMetric(name=accuracy_score, value=0.92, created_at=2024-06-12 14:52:57.222170+00:00)],
   [SkMetric(name=accuracy_score, value=0.92, created_at=2024-06-12 14:52:57.222170+00:00)],
   [SkMetric(name=accuracy_score, value=0.92, created_at=2024-06-12 14:52:57.222170+00:00)]],
  'repo': [{'name': 'ModelRepo in trainer_repo of 1 lines',
    'description': None,
    'tags': [],
    'comments': [],
    'links': [],
    'updated_at': '2024-06-12 14:56:31.085887+00:00',
    'root': 'trainer_repo',
    'len': 1,
    'type': 'repo',
    'cascade_version': '0.14.0-alpha',
    'created_at': '2024-06-12 14:52:55.415058+00:00'}],
  'training_started_at': DateTime(2024, 6, 12, 17, 52, 

## Start from checkpoint
Let's try continue learning where we finished using the same line as before.

In [17]:
trainer.train(
    model,
    train_data=train_dl,
    test_data=test_dl,
    train_kwargs={'lr': LR, 'bs': BATCH_SIZE},
    test_kwargs={'metrics': [SkMetric("accuracy_score")]},
    epochs=5,
    start_from='00000',
    save_strategy=4,
    eval_strategy=1
)

INFO:cascade.models.trainer:Training started with parameters:
{'lr': 0.001, 'bs': 10}
INFO:cascade.models.trainer:repo is ModelRepo in trainer_repo of 1 lines
INFO:cascade.models.trainer:line is 00000
INFO:cascade.models.trainer:started from model 4
INFO:cascade.models.trainer:training will last 5 epochs


100%|██████████| 500/500 [00:09<00:00, 50.05it/s]

INFO:cascade.models.trainer:Epoch: 0
INFO:cascade.models.trainer:SkMetric(name=accuracy_score, value=0.9246, created_at=2024-06-12 14:56:48.094497+00:00)



100%|██████████| 500/500 [00:06<00:00, 75.38it/s] 

INFO:cascade.models.trainer:Epoch: 1
INFO:cascade.models.trainer:SkMetric(name=accuracy_score, value=0.9282, created_at=2024-06-12 14:56:48.094497+00:00)



100%|██████████| 500/500 [00:08<00:00, 60.97it/s]

INFO:cascade.models.trainer:Epoch: 2
INFO:cascade.models.trainer:SkMetric(name=accuracy_score, value=0.928, created_at=2024-06-12 14:56:48.094497+00:00)



100%|██████████| 500/500 [00:10<00:00, 45.93it/s]

INFO:cascade.models.trainer:Epoch: 3





INFO:cascade.models.trainer:SkMetric(name=accuracy_score, value=0.9244, created_at=2024-06-12 14:56:48.094497+00:00)


100%|██████████| 500/500 [00:06<00:00, 77.76it/s]


INFO:cascade.models.trainer:Epoch: 4
INFO:cascade.models.trainer:SkMetric(name=accuracy_score, value=0.9298, created_at=2024-06-12 14:56:48.094497+00:00)
INFO:cascade.models.trainer:Training finished in 2 minutes
INFO:cascade.models.trainer:repo was ModelRepo in trainer_repo of 1 lines
INFO:cascade.models.trainer:line was 00000
INFO:cascade.models.trainer:started from model 4
INFO:cascade.models.trainer:training ended on 4 epoch
INFO:cascade.models.trainer:Parameters:
{'lr': 0.001, 'bs': 10}
INFO:cascade.models.trainer:Metrics:
INFO:cascade.models.trainer:SkMetric(name=accuracy_score, value=0.9298, created_at=2024-06-12 14:56:48.094497+00:00)


In [18]:
trainer.metrics

[[SkMetric(name=accuracy_score, value=0.92, created_at=2024-06-12 14:52:57.222170+00:00)],
 [SkMetric(name=accuracy_score, value=0.92, created_at=2024-06-12 14:52:57.222170+00:00)],
 [SkMetric(name=accuracy_score, value=0.92, created_at=2024-06-12 14:52:57.222170+00:00)],
 [SkMetric(name=accuracy_score, value=0.92, created_at=2024-06-12 14:52:57.222170+00:00)],
 [SkMetric(name=accuracy_score, value=0.92, created_at=2024-06-12 14:52:57.222170+00:00)],
 [SkMetric(name=accuracy_score, value=0.9298, created_at=2024-06-12 14:56:48.094497+00:00)],
 [SkMetric(name=accuracy_score, value=0.9298, created_at=2024-06-12 14:56:48.094497+00:00)],
 [SkMetric(name=accuracy_score, value=0.9298, created_at=2024-06-12 14:56:48.094497+00:00)],
 [SkMetric(name=accuracy_score, value=0.9298, created_at=2024-06-12 14:56:48.094497+00:00)],
 [SkMetric(name=accuracy_score, value=0.9298, created_at=2024-06-12 14:56:48.094497+00:00)]]

## See also:
- [Pipeline building](pipeline_building.html)