## 1. torchvision自带数据集

In [8]:
from torchvision import transforms, datasets
from torch.utils import data

In [9]:
transform = transforms.Compose([transforms.RandomResizedCrop(224),
                                transforms.RandomHorizontalFlip(),
                                transforms.ToTensor(),
                                transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
                               ])





In [10]:
src = '../../study/data/cifar'
train = datasets.CIFAR10(root=src, train=True, download=False, transform=transform)
test = datasets.CIFAR10(root=src, train=False, download=False, transform=transform)

In [11]:
batch_size = 32
train_loader = data.DataLoader(train, batch_size=batch_size, shuffle=True)
test_loader = data.DataLoader(test, batch_size=batch_size, shuffle=False)

## 2. ImageFolder读取本地数据

In [24]:
from torchvision import transforms, datasets
from torch.utils import data

In [25]:
transform = {
    # Train uses data augmentation
    'train': transforms.Compose([
                transforms.RandomResizedCrop(size=256, scale=(0.8, 1.0)),
                transforms.RandomRotation(degrees=15),
                transforms.ColorJitter(),
                transforms.RandomHorizontalFlip(),
                transforms.CenterCrop(size=224),  # Image net standards
                transforms.ToTensor(),
                transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])  # Imagenet standards
              ]),
    # Validation does not use augmentation
    'valid': transforms.Compose([
                 transforms.Resize(size=256),
                 transforms.CenterCrop(size=224),
                 transforms.ToTensor(),
                 transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
            ]),

    'test': transforms.Compose([
                 transforms.Resize(size=256),
                 transforms.CenterCrop(size=224),
                 transforms.ToTensor(),
                 transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
            ]),
  }

In [26]:
src = '../../study/data/caltech101/data'
img_data = {
        'train': datasets.ImageFolder(src+'/train', transform=transform['train']),
        'valid': datasets.ImageFolder(src+'/valid', transform=transform['valid']),
        'test': datasets.ImageFolder(src+'/test', transform=transform['test'])
       }

In [27]:
batch_size = 32
loader = {
         'train': data.DataLoader(img_data['train'], batch_size=batch_size, shuffle=True),
         'valid':data.DataLoader(img_data['valid'], batch_size=batch_size, shuffle=False),
         'test':data.DataLoader(img_data['test'], batch_size=batch_size, shuffle=False)
        }

## 3. 自定义数据

In [33]:
import torch
from torch.utils.data import Dataset, DataLoader

In [34]:
class Mydata(Dataset):
    def __init__(self, x, y):
        self.x = x
        self.y = y

    def __getitem__(self, index):
        return self.x[index], self.y[index]

    def __len__(self):
        return self.x.shape[0]  #self.x.size(0)

In [38]:
x = torch.arange(10).view((5, 2))
y = torch.arange(5)
print(x)
print(y)

data = Mydata(x, y)
data_loader = DataLoader(data,
                             batch_size=2,
                             shuffle=True,
                             num_workers=0
                             )
for x, y in data_loader:
    print(x, y)

tensor([[0, 1],
        [2, 3],
        [4, 5],
        [6, 7],
        [8, 9]])
tensor([0, 1, 2, 3, 4])
tensor([[2, 3],
        [6, 7]]) tensor([1, 3])
tensor([[0, 1],
        [8, 9]]) tensor([0, 4])
tensor([[4, 5]]) tensor([2])
