# import

In [None]:
import torch
from torch import nn
import torch.nn.functional as F
import math
import matplotlib.pyplot as plt
import albumentations as A
from albumentations.pytorch import ToTensorV2
from torch.utils.data import Dataset
import pandas as pd
import numpy as np
from  PIL import Image
import os

# 定义参数

In [None]:
INPUT_PATH = '../input/MyModel'
TRAIN_CSV_PATH = '../input/cassava-leaf-disease-classification/train.csv'
TRAIN_IMAGE_PATH = '../input/cassava-leaf-disease-classification/train_images/'
TEST_IMAGE_PATH = '../input/cassava-leaf-disease-classification/test_images/'
DEVICES = [torch.device(f'cuda:{i}') for i in range(torch.cuda.device_count())]
NUM_EPOCHS = 20
BATCH_SIZE = 16

# tr1.损失函数
1st solution's loss
1. B4: Sigmoid Focal Cross Entropy Loss: is good for class imbalance problems / label smoothing
2. ResNeXt50: Cross Entropy Loss

In [None]:
def sigmoid_focal_cross_entropy(y_hat, y_true, alpha=0.25, gamma=2.0):
    # label smoothing
    def smooth(y, smooth_factor):
        y *= 1 - smooth_factor
        y += smooth_factor / y.shape[1]
        return y

    smooth_factor = 0.1

    if not isinstance(y_true, torch.Tensor):
        y_true = torch.tensor(y_true)
    if not isinstance(y_hat, torch.Tensor):
        y_hat = torch.tensor(y_hat)

    y_true = smooth(y_true, smooth_factor)

    cross_entropy = F.binary_cross_entropy(y_hat, y_true, reduction='none')
    p_t = y_true * y_hat + (1 - y_true) * (1 - y_hat)
    alpha_t = y_true * alpha + (1 - y_true) * (1 - alpha)
    modulating_factor = (1.0 - p_t).pow(gamma)

    return torch.sum(alpha_t * modulating_factor * cross_entropy, dim=-1)


# tr2.learning_rate

In [None]:
def lr_tune(epoch, num_epochs=NUM_EPOCHS):
    lr_start = 1e-6
    lr_max = 2e-4
    lr_final = 1e-6
    lr_warmup_epoch = 4
    lr_sustain_epoch = 0
    lr_decay_epoch = num_epochs - lr_warmup_epoch - lr_sustain_epoch - 1

    if epoch <= lr_warmup_epoch:
        lr = lr_start + (lr_max - lr_start) * (epoch / lr_warmup_epoch) ** 2.5
    elif epoch < lr_warmup_epoch + lr_sustain_epoch:
        lr = lr_max
    else:
        epoch_diff = epoch - lr_warmup_epoch - lr_sustain_epoch
        decay_factor = (epoch_diff / lr_decay_epoch) * math.pi
        decay_factor = (torch.cos(torch.tensor(decay_factor)).numpy() + 1) / 2
        lr = lr_final + (lr_max - lr_final) * decay_factor
    return lr
x = [i for i in range(NUM_EPOCHS)]
y = [lr_tune(i) for i in x]
plt.plot(x, y)

# tr3.albumentations

# tr4.TTA

# 定义数据集

In [None]:
class MyCassavaLeafDataset(Dataset):
    @staticmethod
    def generate_index(num_total, ratio):
        all_index = [i for i in range(num_total)]
        k = ratio * 10
        valid_index = np.arange(0, num_total, k)
        train_index = [i for i in all_index if i not in valid_index]
        return train_index, valid_index

    def __init__(self, csv_path=None, images_path=None, transform=None, mode='train', train_ratio=0.5):
        super().__init__()
        self.transform = transform
        self.mode = mode
        self.images_path = images_path
        self.data_info = pd.read_csv(csv_path)
        self.data_len = self.data_info.shape[0]
        if self.mode == 'train':
            train_index, _ = MyCassavaLeafDataset.generate_index(self.data_len, train_ratio)
            self.image_arr = np.asarray(self.data_info.iloc[train_index, 0])
            self.label_arr = np.asarray(self.data_info.iloc[train_index, 1])
            self.real_len = len(self.image_arr)
        elif self.mode == 'valid':
            _, valid_index = MyCassavaLeafDataset.generate_index(self.data_len, train_ratio)
            self.image_arr = np.asarray(self.data_info.iloc[valid_index, 0])
            self.label_arr = np.asarray(self.data_info.iloc[valid_index, 1])
            self.real_len = len(self.image_arr)

    def __getitem__(self, index):
        if self.mode != 'test':
            single_image_name = self.image_arr[index]
            image = Image.open(os.path.join(self.images_path, single_image_name))
            image = np.array(image)
            label = self.label_arr[index]
            return self.transform(image=image)["image"], label


    def __len__(self):
        return self.real_len

# 定义模型

# 定义训练所需函数

In [None]:
class MyTrainer:
    @staticmethod
    def accurate_count(y_hat, y_true):
        y_hat = y_hat.argmax(axis=1)
        correct_count = 0
        for i in range(len(y_hat)):
            if y_hat[i].type(y_true.dtype) == y_true[i]:
                correct_count += 1
        return float(correct_count)
    @staticmethod
    def calc_valid_acc(model, valid_dataloader):
        model.eval()
        device = next(iter(model.parameters())).device
        test_num = 0
        test_acc_num = 0
        for x, y_true in valid_dataloader:
            if isinstance(x, list):
                x = [x_1.to(device) for x_1 in x]
            else:
                x = x.to(device)
            y_true = y_true.to(device)
            test_num += y_true.shape[0]
            test_acc_num += MyTrainer.accurate_count(model(x), y_true)
        return test_acc_num / test_num

    def __init__(self, optimizer, model, criterion, devices=DEVICES):
        self.optimizer = optimizer
        self.model = model
        self.criterion = criterion
        self.devices = devices
    def train_epoch(self, train_dataloader):
        self.model.train()
        total_loss = 0
        train_num = 0
        train_acc_num = 0
        for batch_idx, (x, y_true) in enumerate(train_dataloader):
            x, y_true = x.to(self.devices[0]), y_true.to(self.devices[0])
            self.optimizer.zero_grad()
            y_hat = self.model(x)
            loss = self.criterion(y_hat, y_true)
            loss.sum().backward()
            self.optimizer.step()
            total_loss += loss.sum()
            train_num += y_true.shape[0]
            train_acc_num += MyTrainer.accurate_count(y_hat, y_true)
        return total_loss / train_num, train_acc_num / train_num
    def train(self, train_dataloader, valid_dataloader, num_epochs=NUM_EPOCHS):
        best_valid_acc = 0
        self.model = nn.DataParallel(self.model ,device_ids=self.devices).to(self.devices[0])
        for epoch in range(num_epochs):
            train_loss, train_acc = MyTrainer.train_epoch(self, train_dataloader)
            valid_acc = MyTrainer.calc_valid_acc(self.model, valid_dataloader)
            if valid_acc > best_valid_acc:
                torch.save(self.model.state_dict(), os.path.join(INPUT_PATH, 'best_model.pth'))
            print(f'epoch{epoch}:train_loss:{train_acc}, train_acc:{train_acc}, valid_acc:{valid_acc}')
    def train_fine_tuning(self, train_dataloader, valid_dataloader, learning_rate, num_epochs=NUM_EPOCHS):
        pass