In [1]:
import time
import torch
from torch import nn, optim
from torch.utils.data import Dataset, DataLoader
import torchvision
from PIL import Image

import sys

sys.path.append("../..")
import d2lzh_pytorch as d2l
# from apps.chapter import d2lzh_pytorch as d2l

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [2]:
def load_cifar10(is_train, augs, batch_size, root="~/Datasets/CIFAR"):
    dataset = torchvision.datasets.CIFAR10(root=root,
                                           train=is_train,
                                           transform=augs,
                                           download=True
                                           )
    num_workers = 0 if sys.platform.startswith('win32') else 4
    return DataLoader(dataset, batch_size=batch_size, shuffle=is_train, num_workers=num_workers)

In [3]:
def train_with_data_aug(train_augs, test_augs, lr=0.001):
    batch_size, net = 256, d2l.resnet18(10)
    num_epochs = 5
    
    optimizer = torch.optim.Adam(net.parameters(), lr=lr)
    loss = torch.nn.CrossEntropyLoss()

    train_iter = load_cifar10(True, train_augs, batch_size)
    test_iter = load_cifar10(False, test_augs, batch_size)

    # 训练
    d2l.train(train_iter, test_iter,
              net,
              loss,
              optimizer,
              device,
              num_epochs=num_epochs
             )

In [4]:
def run():
    flip_aug = torchvision.transforms.Compose([
        torchvision.transforms.RandomHorizontalFlip(),
        torchvision.transforms.ToTensor()
    ])
    no_aug = torchvision.transforms.Compose([
        torchvision.transforms.ToTensor()
    ])

    # 训练
    train_with_data_aug(flip_aug,
                        no_aug)
    pass

In [5]:
run()

Files already downloaded and verified
Files already downloaded and verified
training on  cuda
epoch 1, loss 1.3863, train acc 0.497, test acc 0.424, time 4.3 sec
epoch 2, loss 0.5001, train acc 0.647, test acc 0.566, time 3.5 sec
epoch 3, loss 0.2817, train acc 0.702, test acc 0.655, time 3.5 sec
epoch 4, loss 0.1865, train acc 0.737, test acc 0.666, time 3.5 sec
epoch 5, loss 0.1334, train acc 0.768, test acc 0.703, time 3.5 sec
