- Efficientnet_v2_s из коробки 
- LabelSmoothingCrossEntropy помогло повысить качество
- При подготовке изображений используется нормализация схожая, но не точно такая же как в ImageNet

In [None]:
import os
import numpy as np
from torchvision import datasets, models, transforms
from torch.utils import data
import pandas as pd
import torch, sys, os, pdb
import torch.nn as nn
from PIL import Image
import torch.optim as optim
from sklearn.metrics import balanced_accuracy_score, accuracy_score
from tqdm import tqdm
import copy
import torch.nn.functional as F
from torch.autograd import Variable
import matplotlib.pyplot as plt
from torchvision.models import efficientnet_v2_s, resnet18

In [None]:
import torch
torch.cuda.empty_cache()

In [None]:
device = torch.device('cuda:0')

LR = 0.01
N_EPOCH = 200
batch_size = 128

IMG_SIZE = 32

In [None]:
train_transforms = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.RandomCrop(IMG_SIZE, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    # transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])

val_transforms = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.ToTensor(),
    # transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])

In [None]:
class LabelSmoothingCrossEntropy(nn.Module):
    def __init__(self, classes, smoothing=0.0, dim=-1, weight = None):
        """if smoothing == 0, it's one-hot method
           if 0 < smoothing < 1, it's smooth method
        """
        super(LabelSmoothingCrossEntropy, self).__init__()
        self.confidence = 1.0 - smoothing
        self.smoothing = smoothing
        self.weight = weight
        self.cls = classes
        self.dim = dim

    def forward(self, pred, target):
        assert 0 <= self.smoothing < 1
        pred = pred.log_softmax(dim=self.dim)

        if self.weight is not None:
            pred = pred * self.weight.unsqueeze(0)

        with torch.no_grad():
            true_dist = torch.zeros_like(pred)
            true_dist.fill_(self.smoothing / (self.cls - 1))
            true_dist.scatter_(1, target.data.unsqueeze(1), self.confidence)
        return torch.mean(torch.sum(-true_dist * pred, dim=self.dim))

In [None]:
def train_model(model, loss, optimizer, scheduler, num_epochs):

    best_acc = 0

    for epoch in range(num_epochs):
        print('Epoch {}/{}:'.format(epoch, num_epochs - 1), flush=True)
        gt = []
        net_outputs = []

        # Each epoch has a training and validation phase
        for phase in ['train', 'val']:
            if phase == 'train':
                dataloader = train_dataloader
                model.train()  # Set model to training mode
            else:
                dataloader = val_dataloader
                model.eval()   # Set model to evaluate mode

            running_loss = 0.
            running_acc = 0.

            # Iterate over data.
            for inputs, labels in tqdm(dataloader):
                inputs = inputs.to(device)
                labels = labels.to(device)

                optimizer.zero_grad()

                # forward and backward
                with torch.set_grad_enabled(phase == 'train'):
                    preds = model(inputs)
                    loss_value = loss(preds, labels)
                    preds_class = preds.argmax(dim=1)

                    # backward + optimize only if in training phase
                    if phase == 'train':
                        loss_value.backward()
                        optimizer.step()
                    else:
                        gt.extend(labels.data.cpu().numpy())
                        net_outputs.extend(preds_class.data.cpu().numpy())

                # statistics
                running_loss += loss_value.item()
                running_acc += (preds_class == labels.data).float().mean()

            epoch_loss = running_loss / len(dataloader)
            epoch_acc = running_acc / len(dataloader)

            if phase == 'train':
                print('{} Loss: {:.4f} Acc: {:.4f}'.format(phase, epoch_loss, epoch_acc), flush=True)
                scheduler.step()
            else:
                bacc = balanced_accuracy_score(gt, net_outputs)
                print('{} Loss: {:.4f}, balanced_accuracy_score: {:.4f}, accuracy_score: {:.4f}'.format(phase, epoch_loss, bacc, accuracy_score(gt, net_outputs)), flush=True)

            if phase == 'val' and bacc >= best_acc:
                best_acc = bacc
                best_model_wts = copy.deepcopy(model.state_dict())

    model.load_state_dict(best_model_wts)
    return model, best_acc

In [None]:
train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=train_transforms)
val_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=val_transforms)

train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2)
val_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size, shuffle=True, num_workers=2)

num_classes = len(train_dataset.classes)

In [None]:
net = efficientnet_v2_s(weights='DEFAULT')
num_ftrs = net.classifier[1].in_features
net.classifier[1] = nn.Linear(num_ftrs, num_classes)

net = net.to(device)

In [None]:
# net = resnet18(weights='DEFAULT')
# num_ftrs = net.fc.in_features
# net.fc = nn.Linear(num_ftrs, num_classes)

# net = net.to(device)

In [None]:
# loss = torch.nn.CrossEntropyLoss()
loss = LabelSmoothingCrossEntropy(classes=num_classes, smoothing=0.3, dim=-1, weight = None)

optimizer = optim.SGD(net.parameters(), lr=LR, momentum=0.9)

scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=N_EPOCH, verbose=True)

model, best_acc = train_model(net, loss, optimizer, scheduler, num_epochs=N_EPOCH)

In [None]:
correct = 0
total = 0
with torch.no_grad():
    for data in val_dataloader:
        images, labels = data
        images = images.to(device)
        labels = labels.to(device)
        # calculate outputs by running images through the network
        outputs = model(images)
        # the class with the highest energy is what we choose as prediction
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print(f'Accuracy of test images: {100 * correct // total} %')

In [None]:
# добавить сохранение optimizer
checkpoint = {
    'model': model.state_dict()}
torch.save(checkpoint, 'efficientnet_best_model.pth')