In [1]:
import torch
import torchvision
from torch import nn
from d2l import torch as d2l

In [2]:
train_augs = torchvision.transforms.Compose([torchvision.transforms.RandomHorizontalFlip(),
                torchvision.transforms.ToTensor()])
test_augs = train_augs = torchvision.transforms.Compose([torchvision.transforms.ToTensor()])

In [3]:
def load_cifar10(is_train, augs, batch_size):
    dataset = torchvision.datasets.CIFAR10(root='../data', train=is_train,
                transform=augs,download=True)
    dataloader = torch.utils.data.DataLoader(dataset, shuffle=is_train,
                batch_size=batch_size, num_workers=d2l.get_dataloader_workers())
    return dataloader

In [4]:
def accuracy(y_hat, y):
    if len(y_hat.shape) > 1 and y_hat.shape[1] > 1:
        y_hat = y_hat.argmax(axis=1)
    cmp = y_hat.type(y.dtype) == y
    return float(cmp.type(y.dtype).sum())

In [10]:
def evaluate_accuracy(net, data_iter, device=None):
    if isinstance(net, torch.nn.Module):
        net.eval()
        if not device:
            device = next(iter(net.parameters())).device
    
    with torch.no_grad():
        acc, length = 0.0, 0.0
        for X, y in data_iter:
            if isinstance(X, list):
                X = [x.to(device) for x in X]
            else:
                X = X.to(device)
            y = y.to(device)
            acc += accuracy(net(X), y)
            length += len(X)
    return acc/length

In [6]:
def train_batch(net, X, y, loss, trainer, device):
    if isinstance(X, list):
        X = [x.to(device) for x in X]
    else:
        X = X.to(device)
    y = y.to(device)
    net.train()
    trainer.zero_grad()
    pred = net(X)
    l = loss(pred, y)
    l.sum().backward()
    trainer.step()
    train_loss_sum = l.sum()
    train_acc_sum = accuracy(pred, y)
    return train_loss_sum, train_acc_sum

def train(net, train_iter, test_iter, loss,trainer, num_epochs, device):
    net = net.to(device)
    for epoch in range(num_epochs):
        print(f'epoch:{epoch}')
        l, acc = None, None 
        for i, (features, labels) in enumerate(train_iter):
            l, acc = train_batch(net, features, labels, loss, trainer, device)
            l += l
            acc += acc
        print(f'loss:{l/(i+1)}, train_acc:{acc/(i+1)}')
        test_acc = evaluate_accuracy(net, test_iter)
        print(f'test_acc:{test_acc}')

In [7]:
batch_size, device, net = 256, torch.device('cuda:1'), d2l.resnet18(10, 3)

def init_weights(m):
    if type(m) in [nn.Linear, nn.Conv2d]:
        nn.init.xavier_uniform_(m.weight)

net.apply(init_weights)

def train_with_data_aug(train_augs, test_augs, net,batch_size, lr=0.001):
    train_iter = load_cifar10(True, train_augs, batch_size)
    test_iter = load_cifar10(False, test_augs, batch_size)
    loss = nn.CrossEntropyLoss(reduction="none")
    trainer = torch.optim.Adam(net.parameters(), lr)
    train(net, train_iter, test_iter, loss, trainer, 10, device)

In [11]:
train_with_data_aug(train_augs, test_augs, net, batch_size)

Files already downloaded and verified
Files already downloaded and verified
epoch:0
195, loss:124.39905548095703, train_acc:120.0
test_acc:0.7167
epoch:1
195, loss:85.28938293457031, train_acc:132.0
test_acc:0.743
epoch:2
195, loss:60.863609313964844, train_acc:132.0
test_acc:0.717
epoch:3


KeyboardInterrupt: 