In [None]:
%%capture
! pip install pytorch-lightning

In [None]:
import torch
from torch import nn
from torch.optim import Adam

import pytorch_lightning as pl
from torchvision.datasets import MNIST

In [None]:
train_loader = MNIST('', download=True, )
encoder = nn.Linear(28*28, 10)

optimiser = Adam(encoder.parameters())

In [None]:
train_loader

We want to be able to do something like that with the Trainer class.
```
cb = NormCallback()
trainer = Trainer(max_epochs=10, callbacks=[cb])
trainer.fit(encoder, train_loader)
```

How should the `Trainer` class be defined?

In [None]:
from torch.nn.functional import cross_entropy

class Trainer():
    def __init__(self, max_epochs, callbacks=None):
        self.max_epochs = max_epochs
        if callbacks is None:
            self.callbacks = []

    def fit(self, model, train_loader):
        self._train(model, train_loader)

    def on_batch_start(self):
        for cb in self.callbacks:
            cb.on_batch_start()

    def _train(self, model, train_loader):
        for epoch in range(self.max_epochs):
            for batch in train_loader:

                # hook "on_batch_start" that can be overriden in the model class that is passed
                #   to our Trainer instance as a model in .fit() method
                batch = model.on_batch_start(self, modelbatch)
            
                # hook "training_step" that can be overriden in the model class that is passed
                #   to our Trainer instance as a model in .fit() method                
                loss = model.training_step()

                optimiser.step()
                optimiser.zero_grad()
    

In [None]:
class Model(pl.LightningModule):
    def __init__(self):
        self.encoder = nn.Linear(28*28, 10)

    def forward(self, x):
        # for inference on predictions
        return self.encoder(x)

    def training_step(self, batch, batch_idx):
        x, y = batch
        x = x.view(x.size(0), -1)
        logits = self.encoder(x)
        loss = cross_entropy(logits, y)
        loss.backward()
        return loss

    def on_batch_start(self, batch):
        print('hello')
        return batch

In [None]:
for batch in train_loader:
    x, y = batch
    display(x)
    print(y)
    break

`Callbacks` come from web development. When you send a request, you might need to wait for some time to get the reply. This will mean that you are waiting for the thing you sent the request to to "call back".

In [None]:
class NormCallback(pl.Callback):

    def on_batch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
        return super().on_batch_start(trainer, pl_module)