# Зачем помещать код обучения в класс?

Предположим что мы снова решили классифицировать CIFAR10

In [None]:
from torchvision.transforms.transforms import ToTensor
from torchvision.datasets import FakeData
from torch.utils.data import DataLoader


trainset = FakeData(128, (3, 32, 32), 10, transform=ToTensor(),random_offset=43)
valset = FakeData(16, (3, 32, 32), 10, transform=ToTensor(),random_offset=42)
trainloader = DataLoader(trainset, batch_size=8, shuffle=True)
valloader = DataLoader(valset, batch_size=8, shuffle=True)


Вероятно сначала наш код будет выглядеть так же как в примерах с сайта Pytorch:

https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html

In [None]:
import torch 
from torchvision.models import resnet18
from torch import nn

model = resnet18(False)
optimizer = torch.optim.SGD(model.parameters(),lr=0.001)
criterion = nn.CrossEntropyLoss()

for epoch in range(2):  # loop over the dataset multiple times    
    for data in trainloader:
        inputs, labels = data # data is a list of [inputs, labels]
        optimizer.zero_grad() # zero the parameter gradients
        outputs = model(inputs)# forward
        loss = criterion(outputs, labels)
        loss.backward() # get grad for weight 
        optimizer.step() # optimize

Так как обычно модель обучают несколько раз с различными параметрами, то логично поместить этот код в функцию:

In [None]:
def train(net, trainloader, epochs,  optimizer, criterion ):
  for epoch in range(epochs):  # loop over the dataset multiple times    
      for i, data in enumerate(trainloader, 0):
          inputs, labels = data # data is a list of [inputs, labels]
          optimizer.zero_grad() # zero the parameter gradients
          outputs = net(inputs) # forward
          loss = criterion(outputs, labels)
          loss.backward() # backward 
          optimizer.step() # optimize

И вызывать ее с параметрами

In [None]:
# Experiment #2
model = resnet18(False)
optimizer = torch.optim.SGD(model.parameters(),lr=0.005)
train(model,trainloader,epochs = 5,optimizer = optimizer,criterion = nn.CrossEntropyLoss())

Однако обычно в процессе совершенствования модели появляется желание усовершенствовать процесс обучения. Например  добавить логгирование и le_scheduler








In [None]:
def train(net,trainloader,epochs,  optimizer, criterion , log, scheduler):
  for epoch in range(epochs):  # loop over the dataset multiple times    
      for i, data in enumerate(trainloader, 0):
          inputs, labels = data # data is a list of [inputs, labels]
          optimizer.zero_grad() # zero the parameter gradients
          outputs = net(inputs)# forward
          loss = criterion(outputs, labels)
          loss.backward() #backward 
          optimizer.step() # optimize
          """ new code """
          log.append(loss.item) 
          scheduler.step(loss)

In [None]:
from torch.optim.lr_scheduler import ReduceLROnPlateau

model = resnet18(False)
optimizer = torch.optim.SGD(model.parameters(),lr=0.005)
log = [] 
scheduler = ReduceLROnPlateau(optimizer) 
criterion = nn.CrossEntropyLoss()
train(model,trainloader,epochs = 5,optimizer = optimizer,criterion = criterion, log = log, scheduler = scheduler)

Такой подход ведет к дублированию кода и разрастанию сигнатуры функции. То есть нарушает принцип Don't repeat yourself [DRY](https://en.wikipedia.org/wiki/Don%27t_repeat_yourself)
Почему дублирование кода это плохо?


1.   Если нужно внести исправление приходится делать это в нескольких местах. Соответвенно риск ошибиться так же возрастает
2.   Что бы понять что поменяпоменялось надо просмотреть весь код целиком

Одним из способов решения этой проблемы является помещение кода в класс


In [None]:
class Trainer:
    def __init__(self, model):
        self.model = model
        self.optimizer = torch.optim.SGD(
            self.model.parameters(), lr=0.03
        )  # Weight update
        self.criterion = nn.CrossEntropyLoss()  # Loss function
        self.epochs = 3

    def __call__(self, train_loader):
        self.model.train()
        for epoch in range(self.epochs):
            for imgs, labels in train_loader:
                self.process_batch(imgs, labels)

    def process_batch(self, imgs, labels):
        self.optimizer.zero_grad()
        out = self.model(imgs)
        loss = self.criterion(out, labels)
        loss.backward()
        self.optimizer.step()
        return loss

In [None]:
model = resnet18(False)
trainer = Trainer(model)
trainer(trainloader)

Теперь если нам понадобиться поменять алгоритм обучения или добавить новый параметр, код можно не переписывать. Например мы хотим логировать и выводить значение лосс:

In [None]:
class EnhancedTrainer(Trainer):
  def __init__(self, model):
    super().__init__(model)

  def __call__(self, train_loader):
        self.model.train()
        for epoch in range(self.epochs):
            loss_history = []
            for imgs, labels in train_loader:
                loss = self.process_batch(imgs, labels)
                loss_history.append(loss.item())
            print(f"Average loss: {torch.mean(loss).item():.4f}")

In [None]:
trainer = EnhancedTrainer(model)
trainer.epochs = 4
trainer(trainloader)

Или добавить LR_Scheduler

In [None]:
class TrainerWithLRScheduler(EnhancedTrainer):
  def __init__(self, model):
    super().__init__(model)
    self.lr_scheduler = ReduceLROnPlateau(self.optimizer) 

  def process_batch(self, imgs, labels):
        loss = super().process_batch(imgs, labels)
        self.lr_scheduler.step(loss)
        return loss  

In [None]:
trainer = TrainerWithLRScheduler(model)
trainer.epochs = 4
trainer(trainloader)

Подобный принцип лежит в основе фреймворка [Lightning](https://www.pytorchlightning.ai/) который мы рекомендуем использовать для работы с реальными данными. Так же как и [Tensorboard](https://pytorch.org/docs/stable/tensorboard.html)

Однако для работы с учебными блокнотами можно использовать заготовку подобного класса в связке с блоком кода для визуализации.

# Trainer - класс для обучения

In [None]:
import torch
from tqdm import tqdm

class Trainer:
    def __init__(self, model, plotter=None, lr=0.03):
        self.model = model
        self.optimizer = torch.optim.SGD(
            self.model.parameters(), lr=0.03
        )  # Weight update
        self.criterion = nn.CrossEntropyLoss()  # Loss function
        # Create new plotter if need
        self.plotter = ProgressPlotter() if plotter is None else plotter
        self.epochs = 10
        self.loss_hist = []
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    def __call__(self, train_loader, val_loader):
        self.model.to(self.device)
        self.model.train() # Enable train mode
        for epoch in tqdm(range(self.epochs)):
            self.loss_hist = []
            correct, total = 0, 0 # for metric calculation
            for imgs, labels in train_loader:
                correct += self.process_batch(imgs, labels)
                total += len(labels)
            
            # Logging
            self.plotter.add_scalar("Loss/train",
                                    torch.tensor(self.loss_hist).mean().item())
            self.plotter.add_scalar(
                "Accuracy/val", self.validate(val_loader)
            )
            self.plotter.add_scalar("Accuracy/train", correct / total)
            self.plotter.display(["Loss/train", "Accuracy/val"])

    def process_batch(self, imgs, labels):
        self.optimizer.zero_grad()
        out = self.model(imgs.to(self.device))
        loss = self.criterion(out, labels.to(self.device))
        loss.backward()
        self.loss_hist.append(loss.item())
        self.optimizer.step()
        return Trainer.get_correct_count(out.cpu(), labels)


    @staticmethod
    def get_correct_count(pred, labels):
      _, predicted = torch.max(pred.data, 1) #shape = batch_size, class_count
      correct_predictions =  predicted.cpu() == labels.cpu() #shape = batch_size
      return correct_predictions.sum().item() # correct_predictions is binary

    """
      Calculate accuracy on val or test dataset
    """
    @torch.inference_mode()  # this annotation disable grad computation
    def validate(self, test_loader):
        correct, total = 0, 0
        for imgs, labels in test_loader:
            pred = self.model(imgs.to(self.device))
            total += labels.size(0)
            correct += Trainer.get_correct_count(pred, labels)
        return correct / total


# ProgressPlotter класс для визуализации процесса обучения

Для тех у кого не работает [Tensorboard](https://pytorch.org/docs/stable/tensorboard.html)


In [None]:
from IPython.display import clear_output
from collections import defaultdict
import matplotlib.pyplot as plt
import numpy as np


class ProgressPlotter:
    """
      title is experiment name e.g. ResNet_SGD
      groups is loging paramener like loss or accuracy, can be set later
    """
    def __init__(self, title="default", groups=None) -> None:
        self._history_dict = defaultdict(dict)
        self.set_title(title)
        self.groups = self.get_groups(groups)

    def get_groups(self, groups):
        if groups is not None:
            return self._history_dict.keys()
        if type(groups) is str:
            groups = [groups]
        return groups

    def set_title(self, title):
        """ Add new experiment to plotter 
        all existing data with same title will be removed"""
        for g in self._history_dict.keys():
            self._history_dict[g][title] = []  # reset data
        self.title = title

    # group e.g. "loss_val" tag e.g. "experiment_1"
    def add_scalar(self, group: str, value, tag=None) -> None:
        tag = self.title if tag is None else tag

        if not tag in self._history_dict[group]:
            self._history_dict[group][tag] = []
        self._history_dict[group][tag].append(value)

    def add_row(self, group: str, value, tag=None) -> None:
        tag = self.title if tag is None else tag
        self._history_dict[group][tag] = value

    def display_keys(self, ax, data):
        # display particular chart
        history_len = 0
        ax.grid()
        for key in data:
            ax.plot(data[key], label=key)
            history_len = max(history_len, len(data[key]))
        if len(data) > 1:
            ax.legend(loc="upper right")
        if history_len < 50:
            ax.set_xlabel("step")
            ax.set_xticks(np.arange(history_len))
            ax.set_xticklabels(np.arange(history_len))

    """
     groups list of keys like [['loss_train','loss_val'],['accuracy']]
     All charts within a group will be plot in the same axis
    """
    def display(self, groups=None):
        clear_output()
        if groups is None:
            groups = self.groups
        n_groups = len(groups)
        fig, ax = plt.subplots(1, n_groups, figsize=(48 // n_groups, 3))
        if n_groups == 1:
            ax = [ax]
        for i, g in enumerate(groups):
            ax[i].set_ylabel(g)
            self.display_keys(ax[i], self.history_dict[g])
        fig.tight_layout()
        plt.show()

    @property # can be accessed without braces e.g. my_plotter.history_dict
    def history_dict(self):
        # store data in format like {"experiment1":{"loss":[0.5,.0.44, ...]}}
        return dict(self._history_dict)

Пример запуска

In [None]:
model = resnet18(False)
trainer = Trainer(model)
trainer.plotter.set_title("Experiment_1")
trainer(trainloader,valloader)