In [None]:
import os
import time
import random

import torch
import cv2 as cv
import numpy as np
import matplotlib.pyplot as plt
import torch.optim as optim
import torch.nn as nn
import seaborn as sns
from pathlib import Path


from torch.utils.data import Dataset, DataLoader
from torch.cuda import amp
from tqdm.autonotebook import tqdm, trange
from torch.optim import lr_scheduler
from torchvision import datasets, models, transforms

In [None]:
FILE = Path(__file__).resolve()
ROOT = FILE.parents[1] # program ROOT

# Создание датасета для обучения и тестирования

In [None]:
class CreateDataset(Dataset):
    """
    Класс загрузки dataset

    :list_classes: список классов.
    :img_path_list: список путей до изображений.
    :transform: список преобразовай dataset.
    :img_list: список изображений.
    """

    def __init__(self, data_frame, transform: transforms.Compose = None):
        
        self.list_classes = data_frame['labels'].to_list()
        self.img_path_list = data_frame['paths'].to_list()
        self.transform = transform
        self.img_list = []

        for path in self.img_path_list:
            img = self.__get_img_by_path(path)
            self.img_list.append(img)

    def __len__(self):
        return len(self.img_list)

    def __getitem__(self, index):
        sample = {'image': self.img_list[index],
                  'target':  self.list_classes[index]}
    
        if self.transform:
            sample["image"] = self.transform(self.img_list[index])

        return sample

    @staticmethod
    def __get_img_by_path(img_path):
        """
        Получение картинки по её пути.
        :img_path: путь до картинки
        :return: картинка, состаящая из массива цифр
        """
        # чтобы картинки считывались и с русским путем
        f = open(img_path, "rb");
        chunk = f.read()
        chunk_arr = np.frombuffer(chunk, dtype=np.uint8)
        img = cv.imdecode(chunk_arr, cv.IMREAD_COLOR)
        img = cv.cvtColor(img, cv.COLOR_BGR2RGB)
        img = np.array(img)
        return img

In [None]:
def create_dataloader(path_file_df):
    df = pd.read_csv(path_file_df)
    data = CreateDataset(df,transform=transforms.Compose([
                            transforms.ToTensor(),
                            transforms.Resize(size=(256,256))]))
    
    data_dl = DataLoader(data, batch_size=4, shuffle=True)
    
    return data_dl

# Обучение и подсчет метрик на обучение и валидации

In [None]:
def count_metrics(data_dl, device, model, count_loss = False):
    loss_func = nn.CrossEntropyLoss() if count_loss else None
    model.eval()
    loss_sum = 0
    correct = 0
    f1_sc = 0
    num = 0

    with torch.no_grad():
        for target in tqdm(data_dl):
            xb, yb = target['image'].to(device),\
                     target['target'].to(device)
            
            if count_loss:
                probs = model(xb)
            else:
                probs = model(xb.float())
            
            if count_loss:
                loss_sum += loss_func(probs, yb).item()

            _, preds = torch.max(probs, axis=-1)
            correct += (preds == yb).sum().item()
            num += len(xb)
            f1_sc = f1_score(probs, yb, average='weighted', num_classes=2)
     
    
    losses = (loss_sum / len(data_dl)) if count_loss else None
    
    accuracies = 100*correct / num
    
    # print("accuracies: ", accuracies)
    # print("loss: ", losses)
    # print("f1: ", f1_sc)
    
    return accuracies, f1_sc, losses


def fit(epochs, model, loss_func, opt, train_dl, valid_dl,device, lr_scale = 0.01):
    """
    Обучение модели.

    :param epochs: количество эпох.
    :param model: модель, для обучения
    .
    :param loss_func: функция потерь сети.
    :param opt: функция оптимизации.
    :param train_dl: обучающая выборка.
    :param valid_dl: валиационная выборка.
    :return: массивы с метриками качества обученной модели.
    """
    nbs = 64  # nominal batch size
    accumulate = max(round(nbs / train_dl.batch_size), 1)  # accumulate loss before optimizing
    
    lf = lambda x: (1 - x / epochs) * (1.0 - lr_scale) + lr_scale # linear
    scheduler = optim.lr_scheduler.LambdaLR(opt, lr_lambda=lf)  # plot_lr_scheduler(optimizer, scheduler, epochs)
    
    train_losses = []
    val_losses = []
    val_accur = []
    val_f1 = []
    best_fitness = 0
    best_acur = 0
    for epoch in range(epochs):
        print("epoch ", epoch)
        model.train()
        loss_sum = 0
        last_opt_step = 0
        for idx, target in enumerate(tqdm(train_dl)):
            xb, yb = target['image'].to(device),\
                     target['target'].to(device)
            
            pred = model(xb)  # forward
            loss = loss_func(pred, yb)  # loss scaled by batch_size
            loss_sum += loss.item()
            loss.backward()

            if idx - last_opt_step >= accumulate:
                opt.step()
                opt.zero_grad()
                last_opt_step = idx

        print("train_loss: ", loss_sum / len(train_dl))
        train_losses.append(loss_sum / len(train_dl))
        
        scheduler.step()

        accuracies_val, f1_sc, losses_val = count_metrics(valid_dl, device,
                                                   model, True)
        print("valid loss: ", losses_val)
        val_losses.append(losses_val)

        print("valid accuracies: ", accuracies_val)
        val_accur.append(accuracies_val)
        
        print("valid f1: ", f1_sc)
        val_f1.append(f1_sc.item())
        
        if val_f1[-1] > best_fitness and accuracies_val > best_acur:
            best_fitness = val_f1[-1]
            best_acur = accuracies_val
       
            torch.save(model.state_dict(), Path(ROOT, 'models', 'model_best.pt'))
        
    return train_losses, val_losses, val_accur, val_f1

In [None]:
def plot_training(train_losses, valid_losses, valid_accuracies):
    """
    Отрисовка графиков после обучения.
    Графики loss-функции на каждой эпохи.
    График точности valid на каждой эпохи.

    :train_losses: значение loss-функции на train каждая эпоха
    :valid_losses: значение loss-функции на valid каждая эпоха
    :valid_accuracies: точность на valid каждая эпоха
    """
    plt.figure(figsize=(12, 9))
    plt.subplot(2, 1, 1)
    plt.xlabel("epoch")
    plt.plot(train_losses, label="train_loss")
    plt.plot(valid_losses, label="valid_loss")
    plt.legend()

    plt.subplot(2, 1, 2)
    plt.xlabel("epoch")
    plt.plot(valid_accuracies, label="valid accuracy")
    plt.legend()