In [None]:
import os
import random
import matplotlib.pyplot as plt
from torch.nn import functional as F
from PIL import Image
from torch.utils.data import Dataset
import torch
from tqdm import tqdm
import sys
import numpy as np
import math
import json
from sklearn.model_selection import KFold


# build warmup lr
def create_lr_scheduler(optimizer,
                        num_step: int,
                        epochs: int,
                        warmup=True,
                        warmup_epochs=1,
                        warmup_factor=1e-3,
                        end_factor=1e-6):
    assert num_step > 0 and epochs > 0
    if warmup is False:
        warmup_epochs = 0

    def f(x):
        if warmup is True and x <= (warmup_epochs * num_step):
            alpha = float(x) / (warmup_epochs * num_step)
            return warmup_factor * (1 - alpha) + alpha
        else:
            current_step = (x - warmup_epochs * num_step)
            cosine_steps = (epochs - warmup_epochs) * num_step
            return ((1 + math.cos(current_step * math.pi / cosine_steps)) / 2) * (1 - end_factor) + end_factor

    return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=f)

def train_one_epoch_classify(model, optimizer, data_loader, device, epoch, lr_scheduler, weight_decay):
    loss_function = torch.nn.CrossEntropyLoss()
    model.train()
    accu_loss = torch.zeros(1).to(device)  # 累计损失
    accu_num = torch.zeros(1).to(device)  # 累计预测正确的样本数

    sample_num = 0
    weight_decay = 0.0001
    data_loader = tqdm(data_loader, file=sys.stdout)
    step = 0
    for step, data in enumerate(data_loader):
        images, labels = data
        optimizer.zero_grad()
        sample_num += images.shape[0]

        pred = model(images.to(device))
        pred_classes = torch.max(pred, dim=1)[1]
        accu_num += torch.eq(pred_classes, labels.to(device)).sum()

        loss = loss_function(pred, labels.to(device))

        # 添加L2正则化
        l2_regularization = torch.tensor(0.).to(device)
        for param in model.parameters():
            l2_regularization += torch.norm(param, p=2)

        loss += weight_decay * l2_regularization

        loss.backward()
        accu_loss += loss.detach()
        lr = optimizer.param_groups[0]["lr"]
        data_loader.desc = "[train epoch {}] loss: {:.8f}, acc: {:.8f}, lr: {:.5f}".format(
            epoch,
            accu_loss.item() / (step + 1),
            accu_num.item() / sample_num,
            lr
        )

        if not torch.isfinite(loss):
            print('WARNING: 非有限损失值，结束训练 ', loss)
            sys.exit(1)
        optimizer.step()
        # 更新学习率
        lr_scheduler.step()

    return accu_loss.item() / (step + 1), accu_num.item() / sample_num, lr

@torch.no_grad()
def evaluate_classify(model, data_loader, device, epoch):
    loss_function = torch.nn.CrossEntropyLoss()
    model.eval()
    all_pred = []
    all_prob = []
    accu_num = torch.zeros(1).to(device)  # 累计预测正确的样本数
    accu_loss = torch.zeros(1).to(device)  # 累计损失
    global best_acc  # 最好结果
    sample_num = 0
    data_loader = tqdm(data_loader, file=sys.stdout)
    step = 0

    for step, data in enumerate(data_loader):
        images, labels = data
        sample_num += images.shape[0]
        pred = model(images.to(device))
        probs = F.softmax(pred).detach().cpu().numpy()
        data = np.around(probs, 3)
        pred_classes = torch.max(pred, dim=1)[1].cpu()
        accu_num += torch.eq(pred_classes, labels).sum()
        all_pred = all_pred + np.array(pred_classes).tolist()
        all_prob = all_prob + np.array(probs).tolist()
        loss = loss_function(pred, labels.to(device))
        accu_loss += loss

        # val_accurate = accu_num / val_num
        #         if val_accurate > best_acc:
        #             print(val_accurate)
        #             best_acc = val_accurate
        #             torch.save(net.state_dict(), save_path)
        data_loader.desc = "[valid epoch {}] loss: {:.8f}, acc: {:.8f}".format(
            epoch,
            accu_loss.item() / (step + 1),
            accu_num.item() / sample_num
        )
    return accu_loss.item() / (step + 1), accu_num.item() / sample_num, all_pred, all_prob


class MyDataSet(Dataset):

    def __init__(self, images_path: list, images_class: list, transform=None):
        self.images_path = images_path
        self.images_class = images_class
        self.transform = transform

    def __len__(self):
        return len(self.images_path)

    def __getitem__(self, item):
        img = Image.open(self.images_path[item])
        if img.mode != 'RGB':
            img = img.convert("RGB")
        #             raise ValueError("image: {} isn't RGB mode.".format(self.images_path[item]))
        label = self.images_class[item]

        if self.transform is not None:
            img = self.transform(img)

        return img, label

    @staticmethod
    def collate_fn(batch):
        # https://github.com/pytorch/pytorch/blob/67b7e751e6b5931a9f45274653f4f653a4e6cdf6/torch/utils/data/_utils/collate.py
        images, labels = tuple(zip(*batch))

        images = torch.stack(images, dim=0)
        labels = torch.as_tensor(labels)
        return images, labels


def kfold_split_data(root: str, k: int):
    random.seed(1) 
    assert os.path.exists(root), "dataset root: {} does not exist.".format(root)

    flower_class = [cla for cla in os.listdir(root) if os.path.isdir(os.path.join(root, cla))]
    flower_class.sort()
    class_indices = dict((k, v) for v, k in enumerate(flower_class))
    json_str = json.dumps(dict((val, key) for key, val in class_indices.items()), indent=4)
    with open('class_indices.json', 'w') as json_file:
        json_file.write(json_str)

    every_class_num = []
    supported = [".jpg", ".JPG", ".png", ".PNG"]
    all_images = []
    all_labels = []
    for cla in flower_class:
        cla_path = os.path.join(root, cla)
        images = [os.path.join(root, cla, i) for i in os.listdir(cla_path)
                  if os.path.splitext(i)[-1] in supported]
        images.sort()
        image_class = class_indices[cla]
        every_class_num.append(len(images))
        all_images.extend(images)
        all_labels.extend([image_class] * len(images))

    print("{} images were found in the dataset.".format(sum(every_class_num)))
    print("{} images for training.".format(len(all_images)))
    assert len(all_images) > 0, "number of images must be greater than 0."

    kf = KFold(n_splits=k, shuffle=True, random_state=1)
    dataset_splits = []
    for train_index, val_index in kf.split(all_images):
        train_images_path = [all_images[i] for i in train_index]
        train_images_label = [all_labels[i] for i in train_index]
        val_images_path = [all_images[i] for i in val_index]
        val_images_label = [all_labels[i] for i in val_index]
        dataset_splits.append((train_images_path, train_images_label, val_images_path, val_images_label))

    return dataset_splits
