In [1]:
import numpy as np
import torch
from torch import nn
from torchvision import transforms
import cv2 as cv
from torch.utils.data import Dataset, DataLoader
from typing import *
import sys
import os
import random
from utils import Accumulator, train
from net import resnet18


runtime_path = sys.path[0]
runtime_path


'/home/wakinghours/programming/LiMu-DeepLearning/kaggle/cifar10'

In [2]:
def try_all_GPUS() -> List[torch.device]:
    devices = [torch.device(f"cuda:{i}") for i in range(torch.cuda.device_count())]
    return devices if devices else torch.device("cpu")

devices = try_all_GPUS()


In [3]:
path_join = lambda *args: os.path.join(*args)


class CIFAR10_dataset(Dataset):
    def __init__(self, type_dataset: str = "train", vaild_rate=0.1) -> None:
        super().__init__()
        self.type_dataset = type_dataset
        self.vaild_rate = vaild_rate
        self.transform_train = transforms.Compose([
            transforms.ToTensor(),
            transforms.Resize(42),
            transforms.RandomResizedCrop(32, (0.6, 1.0), ratio=(1.0, 1.0)),
            transforms.RandomHorizontalFlip(),
            transforms.Normalize([0.4914, 0.4822, 0.4465],  # normalize. 归一化.
                                 [0.2023, 0.1994, 0.2010])
        ])
        self.transform_vaild = transforms.Compose([
                    transforms.ToTensor(),
                    transforms.Normalize([0.4914, 0.4822, 0.4465],  # normalize. 归一化.
                                         [0.2023, 0.1994, 0.2010])
                ])
        self.labels_dict = self.parse_csv2label()
        self.classes = ['airplane', 'automobile', 'bird', 'cat',
                        'deer', 'dog', 'frog', 'horse', 'ship', 'truck']

        self.root_path = path_join(runtime_path, "train")
        if "train" == self.type_dataset:
            # generate train and vaild dataset file.
            self.shuffle_train_vaild()
            # only generate in "train" mode.

            with open("./train.txt", "r") as f:
                self.file_path_list = f.readlines()

        elif "vaild" == self.type_dataset:
            with open("./train_vaild.txt", "r") as f:
                self.file_path_list = f.readlines()
        else: # test:
            self.root_path = path_join(runtime_path, "test")

            self.file_path_list = [path_join(self.root_path, file_name) for file_name in os.listdir(self.root_path)]


    def __getitem__(self, index):
        file_path = self.file_path_list[index].strip()
        file_name = file_path.split("/")[-1].split(".")[0]
        img = cv.imread(file_path)
        if "train" == self.type_dataset:
            X = self.transform_train(img)
            return X, self.classes.index(self.labels_dict[file_name])
        elif "vaild" == self.type_dataset:
            X = self.transform_vaild(img)

            return X, self.classes.index(self.labels_dict[file_name])

        else: # test
            X = self.transform_vaild(img)
            return X, file_name

        


    def parse_csv2label(self):
        with open("./trainLabels.csv", "r") as f:
            return {ele[0]: ele[1] for ele in [line.strip().split(',') for line in f.readlines()][1:]}

    def shuffle_train_vaild(self):
        l = len(os.listdir(self.root_path))

        try:
            file_name_list = os.listdir(self.root_path)
            random.shuffle(file_name_list)

            with open(path_join(self.root_path, "../", "train.txt"), "w") as train_file_writer:
                train_file_writer.write(
                    "\n".join([path_join(self.root_path,  file_name)
                              for file_name in file_name_list[0: int(l*(1-self.vaild_rate))]])
                )

            with open(path_join(self.root_path, "../",  "train_vaild.txt"), "w") as train_file_writer:
                train_file_writer.write(
                    "\n".join([path_join(self.root_path,  file_name)
                              for file_name in file_name_list[int(l*(1-self.vaild_rate)):]])
                )

        except Exception as e:
            print("error: ", e)

    def __len__(self):
        return len(self.file_path_list)




def load_CIFAR10_iter(batch_size=64, num_workers=28):
    return DataLoader(
        CIFAR10_dataset("train"),
        batch_size,
        shuffle=True,
        num_workers=num_workers
    ),     DataLoader(
        CIFAR10_dataset("vaild"),
        batch_size,
        shuffle=True,
        num_workers=num_workers
    )




In [4]:
net = resnet18() # 实例化. 
import torchvision
# net = torchvision.models.resnet34(pretrained=True)
# net = torchvision.models.resnet101(pretrained=True)

epoch = 200
batch_size = 512
lr = 1e-1
loss_fn = torch.nn.CrossEntropyLoss()



In [5]:

train_iter, test_iter = load_CIFAR10_iter(batch_size)

train(
    net,
    loss_fn,
    train_iter, test_iter, 
    lr, epoch, 
    50, 10, 0.8, 5e-3,
    "logs/epoch30_testacc0.61_loss1.0_acc0.64.pth"
)



train on:  [device(type='cuda', index=0), device(type='cuda', index=1)]
9 test acc: 0.5320591509342194 train loss: 1.185805307193236 train acc: 0.5836168012835763
