In [3]:
import pandas as pd
import numpy as np
import glob
from tqdm.notebook import tqdm
import cv2
from sklearn.model_selection import train_test_split


from torchvision import transforms, models
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset
import torchvision
from torchvision.models import resnet18
from typing import Callable, Dict, Mapping, Tuple, Optional, Union

from PIL import Image
import matplotlib.pyplot as plt
from IPython.display import clear_output

In [2]:
#from google.colab import drive
#drive.mount('/content/drive')

## Датасет

Прежде чем разбираться с моделями, нам надо в первую очередь разобраться с тем, как грузить датасет. Давайте напишем класс в торче для этого.

In [4]:
# задаем преобразование изображения

train_transform = transforms.Compose([
    #transforms.RandomRotation((90, 90)),
    transforms.ColorJitter(),
    transforms.resize((1440, 2560)),
    transforms.RandomGrayscale(),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225]),
])

valid_transform = transforms.Compose([
    transforms.ToPILImage(),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225]),
    #transforms.ToPILImage()
])

In [5]:
# from google.colab import drive
# drive.mount('/content/drive')

In [None]:
imagenet_data = torchvision.datasets.ImageFolder('dataset', transform=train_transform)


In [None]:
print(np.shape(imagenet_data))

In [None]:
imagenet_data.class_to_idx

Довольно похожие, согласны?

In [None]:
validation_split = .2
shuffle_dataset = True
random_seed = 42

dataset_size = len(imagenet_data)
print(dataset_size)
indices = list(range(dataset_size))
split = int(np.floor(validation_split*dataset_size))
if shuffle_dataset:
  np.random.seed(random_seed)
  np.random.shuffle(indices)
train_indices, val_indices = indices[split:], indices[:split]
train_sampler = torch.utils.data.SubsetRandomSampler(train_indices)
valid_sampler = torch.utils.data.SubsetRandomSampler(val_indices)

train_loader = torch.utils.data.DataLoader(imagenet_data,
                                           batch_size=2,
                                           sampler=train_sampler)
valid_loader = torch.utils.data.DataLoader(imagenet_data,
                                           batch_size=2,
                                           sampler=valid_sampler)
print(len(train_loader), len(valid_loader))

Отлично, теперь давайте напишем пару вспомогательных функций для визуализации и тренировки

## Вспомогательные функции

In [None]:
def plot_history(train_history, val_history, title='loss'):
    plt.figure()
    plt.title('{}'.format(title))
    plt.plot(train_history, label='train', zorder=1)

    points = np.array(val_history)
    steps = list(range(0, len(train_history) + 1, int(len(train_history) / len(val_history))))[1:]

    plt.scatter(steps, val_history, marker='+', s=180, c='orange', label='val', zorder=2)
    plt.xlabel('train steps')

    plt.legend(loc='best')
    plt.grid()

    plt.show()

In [None]:
def train(model, criterion, optimizer, train_dataloader, test_dataloader, NUM_EPOCH=15):
    train_loss_log = []
    val_loss_log = []

    train_acc_log = []
    val_acc_log = []

    for epoch in tqdm(range(NUM_EPOCH)):
        model.train()
        train_loss = 0.
        train_size = 0

        train_pred = 0.

        for imgs, labels in train_dataloader:
            optimizer.zero_grad()
            imgs = imgs.cpu()
            labels = labels.cpu()

            y_pred = model(imgs)

            loss = criterion(y_pred, labels)
            loss.backward()

            train_loss += loss.item()
            train_size += y_pred.size(0)
            train_loss_log.append(loss.data / y_pred.size(0))

            train_pred += (y_pred.argmax(1) == labels).sum()

            optimizer.step()

        train_acc_log.append(train_pred / train_size)

        val_loss = 0.
        val_size = 0

        val_pred = 0.

        model.eval()

        with torch.no_grad():
            for imgs, labels in test_dataloader:

                imgs = imgs.cpu()
                labels = labels.cpu()

                pred = model(imgs)
                loss = criterion(pred, labels)

                val_loss += loss.item()
                val_size += pred.size(0)

                val_pred += (pred.argmax(1) == labels).sum()

        val_loss_log.append(val_loss / val_size)
        val_acc_log.append(val_pred / val_size)

        clear_output()
        plot_history(train_loss_log, val_loss_log, 'loss')

        print('Train loss:', (train_loss / train_size)*100)
        print('Val loss:', (val_loss / val_size)*100)
        print('Train acc:', (train_pred / train_size)*100)
        print('Val acc:', (val_pred / val_size)*100)

    return train_loss_log, train_acc_log, val_loss_log, val_acc_log

## Модель
Все, перейдем к обучению самой модели. Воспользуемся предобученным резнетом и заменим у него классификатор, а потом будем обучать только его.

In [None]:
def get_vgg_19(device: str = 'cpu',
                     ckpt_path: Optional[str] = None
                     ) -> nn.Module:
    model = models.vgg19(True)
    model.classifier = nn.Sequential(nn.Linear(in_features=25088, out_features=4096, bias=True),
                                     nn.ReLU(inplace=True),
                                     nn.Dropout(p=0.5, inplace=False),
                                     nn.Linear(in_features=4096, out_features=4096, bias=True),
                                     nn.ReLU(inplace=True),
                                     nn.Dropout(p=0.5, inplace=False),
                                     nn.Linear(in_features=4096, out_features=182, bias=True)
                                     )
    model = model.to(device)
    if ckpt_path:
        try:
            checkpoint = torch.load(ckpt_path)
            model.load_state_dict(checkpoint)
        except:
            print("Wrong checkpoint")
    return model

In [None]:
model = torch.hub.load('facebookresearch/semi-supervised-ImageNet1K-models', 'resnext50_32x4d_swsl')
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, 8)

model = model.cpu()
criterion = torch.nn.CrossEntropyLoss()

In [None]:
model=get_vgg_19()
num_ftrs = model.classifier[0].in_features
model.classifier[0] = nn.Linear(num_ftrs, 8)
model = model.cpu()
criterion = torch.nn.CrossEntropyLoss()

In [None]:
# optimizer = torch.optim.SGD(model.fc.parameters(), lr=0.0001)
plist = [{'params': model.parameters(), 'lr': 1e-5}]
optimizer = optim.Adam(plist, lr=1e-5)

In [None]:
train_loss_log, train_acc_log, val_loss_log, val_acc_log = train(model,
                                                                 criterion,
                                                                 optimizer,
                                                                 train_loader,
                                                                 valid_loader,
                                                                 5)

In [None]:
from typing import Tuple

class F1Score:
    """
    Class for f1 calculation in Pytorch.
    """

    def __init__(self, average: str = 'weighted'):
        """
        Init.

        Args:
            average: averaging method
        """
        self.average = average
        if average not in [None, 'micro', 'macro', 'weighted']:
            raise ValueError('Wrong value of average parameter')

    @staticmethod
    def calc_f1_micro(predictions: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
        """
        Calculate f1 micro.

        Args:
            predictions: tensor with predictions
            labels: tensor with original labels

        Returns:
            f1 score
        """
        true_positive = torch.eq(labels, predictions).sum().float()
        f1_score = torch.div(true_positive, len(labels))
        return f1_score

    @staticmethod
    def calc_f1_count_for_label(predictions: torch.Tensor,
                                labels: torch.Tensor, label_id: int) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Calculate f1 and true count for the label

        Args:
            predictions: tensor with predictions
            labels: tensor with original labels
            label_id: id of current label

        Returns:
            f1 score and true count for label
        """
        # label count
        true_count = torch.eq(labels, label_id).sum()

        # true positives: labels equal to prediction and to label_id
        true_positive = torch.logical_and(torch.eq(labels, predictions),
                                          torch.eq(labels, label_id)).sum().float()
        # precision for label
        precision = torch.div(true_positive, torch.eq(predictions, label_id).sum().float())
        # replace nan values with 0
        precision = torch.where(torch.isnan(precision),
                                torch.zeros_like(precision).type_as(true_positive),
                                precision)

        # recall for label
        recall = torch.div(true_positive, true_count)
        # f1
        f1 = 2 * precision * recall / (precision + recall)
        # replace nan values with 0
        f1 = torch.where(torch.isnan(f1), torch.zeros_like(f1).type_as(true_positive), f1)
        return f1, true_count

    def __call__(self, predictions: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
        """
        Calculate f1 score based on averaging method defined in init.

        Args:
            predictions: tensor with predictions
            labels: tensor with original labels

        Returns:
            f1 score
        """

        # simpler calculation for micro
        if self.average == 'micro':
            return self.calc_f1_micro(predictions, labels)

        f1_score = 0
        for label_id in range(1, len(labels.unique()) + 1):
            f1, true_count = self.calc_f1_count_for_label(predictions, labels, label_id)

            if self.average == 'weighted':
                f1_score += f1 * true_count
            elif self.average == 'macro':
                f1_score += f1

        if self.average == 'weighted':
            f1_score = torch.div(f1_score, len(labels))
        elif self.average == 'macro':
            f1_score = torch.div(f1_score, len(labels.unique()))

        return f1_score

In [None]:
true_labels = torch.Tensor().cuda()
predictions = torch.Tensor().cuda()

with torch.no_grad():
    for imgs, labels in tqdm(valid_loader):

        imgs = imgs.cuda()
        labels = labels.cuda()

        pred = model(imgs).argmax(dim=1)

        true_labels = torch.cat(
                                    (true_labels, labels)
                                    ,dim=0
                                )
        predictions = torch.cat(
                                    (predictions, pred)
                                    ,dim=0
                                )

f1_metric = F1Score('macro')
f1_metric(predictions, true_labels)

In [None]:
import pickle
import torch
#torch.save(model, 'model.pt')#Сохраняем

In [None]:
model=torch.load('drive/MyDrive/model.pt')#Загружаем

## Посмотрим метрики нашей итоговой модели на валидации

In [None]:
classDict = {0:'Отсутствие клея',
            1:'Одна полоса клея' ,
            2:'Полосы клея смещены к пласти шпона',
            3:'Полосы клея смещены на край',
            4:'Часть уса  без клея',
            5:'Скол по шпону',
            6:'Разошедшаяся трещина',
            7:'Заминание шпона',
            8:'Трещины + участок без клея',
            9:'Запил',
            10:'Скол по сучку + участок без клея',
            11:'Спиливание уса',
            12:'Нет дефектов',
            13:'Не видно шпона'}

In [None]:
def Predict(path):
  image=cv2.imread(path)
  model=torch.load('drive/MyDrive/model.pt')
  image=valid_transform(image)
  if torch.cuda.is_available():
    model=model.cuda()
    image = image.cuda()
    print('Есть Cuda')
  else:
    model=model.cpu()
    image = image.cpu()
    print('Нет Cuda')

  model.eval()
  pred = model(image[None, ...])
  return classDict[pred.argmax().item()]

In [None]:
Predict('1.jpg')