In [20]:
import torch
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision import transforms
from torch import nn, optim

%run test.ipynb

if __name__ == "__main__":
    batchsz = 128
    
    cifar_train = datasets.CIFAR10('../data/cifar', True, transform=transforms.Compose([
        transforms.Resize((32, 32)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                            std=[0.229, 0.224, 0.225])
    ]), download=True)
    cifar_train = DataLoader(cifar_train, batch_size=batchsz, shuffle=True)
    
    cifar_test = datasets.CIFAR10('../data/cifar', False, transform=transforms.Compose([
        transforms.Resize((32, 32)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                            std=[0.229, 0.224, 0.225])
    ]), download=True)
    cifar_test = DataLoader(cifar_test, batch_size=batchsz, shuffle=True)
    
    x, label = iter(cifar_train).next()
    print('x:', x.shape, 'label:', label.shape)

    net = Lenet5()
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(net.parameters(), lr=1e-3)
    
    for epoch in range(5):
        net.train()
        
        for batchidx, (x, label) in enumerate(cifar_train):
            logits = net(x)
            loss = criterion(logits, label)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        
        print(epoch, 'loss:', loss.item())
        
        net.eval()
        with torch.no_grad():
            total_correct = 0
            total_num = 0
            
            for x, label in cifar_test:
                logits = net(x)
                pred = logits.argmax(dim=1)
                correct = pred.eq(label).float().sum().item()
                total_correct += correct
                total_num += x.size(0)
                
            print('acc :', total_correct / total_num)

conv out: torch.Size([2, 16, 5, 5])
lenet out: torch.Size([2, 10])
Files already downloaded and verified
Files already downloaded and verified
x: torch.Size([128, 3, 32, 32]) label: torch.Size([128])
conv out: torch.Size([2, 16, 5, 5])
0 loss: 1.7710403203964233
acc : 0.441
1 loss: 1.405535340309143
acc : 0.5043
2 loss: 1.3211902379989624
acc : 0.5181
3 loss: 1.603732705116272
acc : 0.5399
4 loss: 1.3879693746566772
acc : 0.5527
