# CIFAR-10 - Object Recognition in Images
https://www.kaggle.com/competitions/cifar-10/overview

## 引用相关的库

In [None]:
import torch
import tarfile
from torch import nn
from torch.utils.data import Dataset
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import py7zr
import os
from PIL import Image

## 数据预处理，将图片类型转为数字

In [None]:
archive = py7zr.SevenZipFile(r'../input/cifar-10/train.7z', mode='r')
archive.extractall(path='./input')
archive.close()
df = pd.read_csv('../input/cifar-10/trainLabels.csv')
pic_types = df['label'].unique()
print(len(pic_types))
index_list = [i for i in range(len(pic_types))]
class_to_num = dict(zip(pic_types, index_list))
num_to_class = dict(zip(index_list, pic_types))

## 构建数据迭代器

In [None]:
class CIFAR10Dataset(Dataset):
    def __init__(self, csv_path, image_path, transform=None, mode='train', train_ratio=0.8):
        super().__init__()
        self.csv_path = csv_path
        self.image_path = image_path
        self.mode = mode
        self.transform = transform
        self.data_info = pd.read_csv(csv_path, header=None)
        self.data_len = len(self.data_info.index) - 1
        self.train_len = int(self.data_len * train_ratio)
        if self.mode == 'train':
            self.id_arr = np.asarray(self.data_info.iloc[1:self.train_len, 0])
            self.label_arr = np.asarray(self.data_info.iloc[1:self.train_len, 1])
        elif self.mode == 'valid':
            self.id_arr = np.asarray(self.data_info.iloc[self.train_len:, 0])
            self.label_arr = np.asarray(self.data_info.iloc[self.data_len:, 1])
        elif self.mode =='test':
            self.id_arr = np.asarray(self.data_info.iloc[1:, 0])
        self.real_len = len(self.id_arr)
    def __getitem__(self, index):
        if self.mode != 'test':
            single_image_id = self.id_arr[index]
            single_image_path = os.path.join('./input/train', single_image_id, '.png')
            image = Image.open(single_image_path)
            label = self.label_arr[index]
            num_label = class_to_num[label]
            return self.transform(image), num_label
        else:
            single_image_id = self.id_arr[index]
            single_image_path = os.path.join('./input/train', single_image_id, '.png')
            image = Image.open(single_image_path)
            return self.transform(image)
    def __len__(self):
        return self.real_len