## Блок кода для обучения

In [None]:
import torch
from tqdm import tqdm

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


def get_correct_count(pred, labels):
    _, predicted = torch.max(pred.data, 1)
    return (predicted.cpu() == labels.cpu()).sum().item()


@torch.inference_mode()  # this annotation disable grad computation
def validate(model, test_loader, device="cpu"):
    correct, total = 0, 0
    for imgs, labels in test_loader:
        pred = model(imgs.to(device))
        total += labels.size(0)
        correct += get_correct_count(pred, labels)
    return correct / total


class Trainer:
    def __init__(self, model, plotter=None, lr=0.03):
        self.model = model
        self.optimizer = torch.optim.SGD(
            self.model.parameters(), lr=0.03
        )  # Weight update
        self.criterion = nn.CrossEntropyLoss()  # Loss function
        self.plotter = ProgressPlotter() if plotter is None else plotter
        self.epochs = 25
        self.loss_hist = []

    def __call__(self, train_loader, val_loader, epochs=10):
        global device
        print("Using device:", device)
        self.model.to(device)
        self.model.train()
        for epoch in tqdm(range(self.epochs)):
            self.loss_hist = []
            correct, total = 0, 0
            for imgs, labels in train_loader:
                correct += self.process_batch(imgs, labels)
                total += len(labels)
            self.plotter.add_scalar("Loss/train", np.mean(self.loss_hist))
            self.plotter.add_scalar(
                "Accuracy/val", validate(self.model, val_loader, device=device)
            )
            self.plotter.add_scalar("Accuracy/train", correct / total)
            self.plotter.display(["Loss/train", "Accuracy/val"])

    def process_batch(self, imgs, labels):
        self.optimizer.zero_grad()
        out = self.model(imgs.to(device))
        loss = self.criterion(out, labels.to(device))
        loss.backward()
        self.loss_hist.append(loss.item())
        self.optimizer.step()
        return get_correct_count(out.cpu(), labels)

В силу нашей ненависти к TB копируем костыль для визуализации из прошлых заданий.

In [None]:
from IPython.display import clear_output
from collections import defaultdict
import matplotlib.pyplot as plt
import numpy as np


class ProgressPlotter:
    def __init__(self, title="default", groups=None) -> None:
        self._history_dict = defaultdict(dict)
        self.set_title(title)
        self.groups = self.get_groups(groups)

    def get_groups(self, groups):
        if groups is not None:
            return self._history_dict.keys()
        if type(groups) is str:
            groups = [groups]
        return groups

    def set_title(self, title):
        for g in self._history_dict.keys():
            self._history_dict[g][title] = []  # reset data
        self.title = title

    # group e.g. "loss_val" tag e.g. "experiment_1"
    def add_scalar(self, group: str, value, tag=None) -> None:
        tag = self.title if tag is None else tag

        if not tag in self._history_dict[group]:
            self._history_dict[group][tag] = []
        self._history_dict[group][tag].append(value)

    def add_row(self, group: str, value, tag=None) -> None:
        tag = self.title if tag is None else tag
        self._history_dict[group][tag] = value

    def display_keys(self, ax, data):
        history_len = 0
        ax.grid()
        for key in data:
            ax.plot(data[key], label=key)
            history_len = max(history_len, len(data[key]))
        if len(data) > 1:
            ax.legend(loc="upper right")
        if history_len < 50:
            ax.set_xlabel("step")
            ax.set_xticks(np.arange(history_len))
            ax.set_xticklabels(np.arange(history_len))

    """
     groups list of keys like [['loss_train','loss_val'],['accuracy']]
     All charts within a group will be plot in the same axis
  """

    def display(self, groups=None):
        clear_output()
        if groups is None:
            groups = self.groups
        n_groups = len(groups)
        fig, ax = plt.subplots(1, n_groups, figsize=(48 // n_groups, 3))
        if n_groups == 1:
            ax = [ax]
        for i, g in enumerate(groups):
            ax[i].set_ylabel(g)
            self.display_keys(ax[i], self.history_dict[g])
        fig.tight_layout()
        plt.show()

    @property
    def history_dict(self):
        return dict(self._history_dict)