## Подготовка данных

Перед началом надо скачать датасет с болезнями клубники отсюда:
https://www.kaggle.com/usmanafzaal/strawberry-disease-detection-dataset

Положить папку `train` и распаковать следующей командой:

In [None]:
!unzip train.zip

## Обучение

In [None]:
%matplotlib inline
%config InlineBackend.figure_format = 'retina'
import os
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

import torch
import torch.nn.functional as F

from sklearn.metrics import precision_score, recall_score, f1_score
from sklearn.metrics import classification_report
from sklearn.preprocessing import OneHotEncoder
from torch import nn
from torch import optim

from torchvision import datasets, transforms, models
from tqdm.notebook import tqdm

Загрузка данных:

In [None]:
data_dir = './train'
def load_split_train_test(datadir, valid_size = .2, batch_size=32):
    # для тренировочного набора делаем аугментацию данных
    train_transforms = transforms.Compose([
        transforms.Resize(224),
        transforms.ColorJitter(0.1, 0.1, 0.1),
        transforms.RandomAffine(degrees=(-30, 30),
                                translate=(0.2, 0.2),
                                scale=(0.8, 1.5),
                                shear=(-15, 15),
                                fill=(int(0.485 * 255), 
                                      int(0.456 * 255), 
                                      int(0.406 * 255))),
        transforms.RandomVerticalFlip(),
        transforms.RandomHorizontalFlip(),
        transforms.RandomPerspective(),
        transforms.ToTensor(),
        transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))])

    # для тестового набора просто масштабируем и нормируем
    test_transforms = transforms.Compose([
        transforms.Resize(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))])
    # открываем датасет
    train_data = datasets.ImageFolder(datadir,       
                    transform=train_transforms)
    test_data = datasets.ImageFolder(datadir,
                    transform=test_transforms)
    num_train = len(train_data)
    indices = list(range(num_train))
    split = int(np.floor(valid_size * num_train))
    
    # делим на train и test для проверки качества
    np.random.shuffle(indices)
    
    from torch.utils.data.sampler import SubsetRandomSampler
    train_idx, test_idx = indices[split:], indices[:split]
    train_sampler = SubsetRandomSampler(train_idx)
    test_sampler = SubsetRandomSampler(test_idx)
    
    # загрузка датасета
    trainloader = torch.utils.data.DataLoader(train_data,
                   sampler=train_sampler, batch_size=batch_size, num_workers=6)
    testloader = torch.utils.data.DataLoader(test_data,
                   sampler=test_sampler, batch_size=batch_size, num_workers=6)
    return trainloader, testloader


batch_size = 16
trainloader, testloader = load_split_train_test(data_dir, .2)
print(trainloader.dataset.classes)

In [None]:
# смотрим на стандартные метрики классификации
def calculate_metrics(pred, target, threshold=0.5):
    pred = np.array(pred > threshold, dtype=float)
    return {'micro/precision': precision_score(y_true=target, y_pred=pred, average='micro'),
            'micro/recall': recall_score(y_true=target, y_pred=pred, average='micro'),
            'micro/f1': f1_score(y_true=target, y_pred=pred, average='micro'),
            'macro/precision': precision_score(y_true=target, y_pred=pred, average='macro'),
            'macro/recall': recall_score(y_true=target, y_pred=pred, average='macro'),
            'macro/f1': f1_score(y_true=target, y_pred=pred, average='macro'),
            }


# мультиклассовые метрики
def calculate_metrics_multiclass(pred, target, threshold=0.5):
    metrics_dict = {}
    for i in range(pred.shape[1]):
        metrics_dict[trainloader.dataset.classes[i]] = calculate_metrics(pred[:, i], target[:, i], threshold)
    return metrics_dict

# Модель

In [None]:
def checkpoint_save(model, optimizer, save_path, epoch):
    torch.save({
        'model': model.state_dict(),
        'optimizer': optimizer.state_dict(),
        'epoch': epoch,
    }, save_path)

def checkpoint_load(model, optimizer, load_path):
    checkpoint = torch.load(load_path)
    model.load_state_dict(checkpoint['model'])
    optimizer.load_state_dict(checkpoint['optimizer'])
    print('Loaded')
    
    
os.makedirs('checkpoints/', exist_ok=True)

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

Стандартная модель с замененным классификатором в конце:

In [None]:
class Resnext50(nn.Module):
    def __init__(self, n_classes):
        super().__init__()
        resnet = models.resnext50_32x4d(pretrained=True)
        resnet.fc = nn.Sequential(
            nn.Dropout(p=0.2),
            nn.Linear(in_features=resnet.fc.in_features, out_features=n_classes)
        )
        self.base_model = resnet
        self.sigm = nn.Sigmoid()

    def forward(self, x):
        return self.base_model(x)

    
# Initialize the model
model = Resnext50(len(trainloader.dataset.classes))

# Switch model to the training mode
model.train()
model.to(device)
print('Ok')

Обучение:

In [None]:
max_epoch_number = 64
learning_rate = 1e-4
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
criterion = nn.BCEWithLogitsLoss()

In [None]:
epoch = 0
iteration = 0
test_freq = 16
save_freq = 16

for epoch in tqdm(range(max_epoch_number)):
    batch_losses = []
    model.train()
    for imgs, targets in tqdm(trainloader, leave=False):
        imgs, target_labels = imgs.to(device), targets.to(device)
        targets = torch.nn.functional.one_hot(target_labels, num_classes=7)

        optimizer.zero_grad()

        model_result = model(imgs)
        loss = criterion(model_result, targets.float())

        batch_loss_value = loss.item()
        loss.backward()
        optimizer.step()

        batch_losses.append(batch_loss_value)
        
    if epoch % test_freq == 0:
        model.eval()
        with torch.no_grad():
            model_result = []
            targets = []
            for imgs, batch_targets in tqdm(testloader, leave=False):
                imgs = imgs.to(device)
                batch_targets = torch.nn.functional.one_hot(batch_targets, num_classes=7)
                model_batch_result = model(imgs)
                model_batch_result = torch.sigmoid(model_batch_result)
                
                model_result.append(model_batch_result.cpu().numpy())
                targets.append(batch_targets.cpu().numpy())
        model_result = np.concatenate(model_result)
        targets = np.concatenate(targets)
        result_metrics = calculate_metrics_multiclass(model_result, targets)
        print("epoch:{:2d} iter:{:3d}".format(epoch, iteration))
        display(pd.DataFrame(result_metrics))


    loss_value = np.mean(batch_losses)
    print("epoch:{:2d} iter:{:3d} train: loss:{:.3f}".format(epoch, iteration, loss_value))
    if epoch % save_freq == 0:
        save_path = f'checkpoints/ep_{epoch:02d}'
        checkpoint_save(model, optimizer, save_path, epoch)

Сохраняем модель:

In [None]:
torch.save(model.cpu().state_dict(), 'strawdisease.pt')