In [1]:
import torch
import torchvision
from torch.utils.data import DataLoader
from torch import nn,optim
from torchvision import datasets,transforms

from utils import Lenet5

In [2]:
def main():
    batchsz=128
    epochs=100
    
    cifar_train=datasets.CIFAR10('./data/cifar',train=True,download=True,
                                transform=transforms.Compose([
                                    transforms.Resize((32,32)),
                                    transforms.ToTensor(),
                                    transforms.Normalize(mean=[0.485,0.456,0.406],
                                                        std=[0.229,0.224,0.225])
                                ]))
    cifar_train=DataLoader(cifar_train,batch_size=batchsz,shuffle=True)
    
    cifar_test=datasets.CIFAR10('./data/cifar',train=False,
                               transform=transforms.Compose([
                                   transforms.Resize((32,32)),
                                   transforms.ToTensor(),
                                   transforms.Normalize(mean=[0.485,0.456,0.406],
                                                       std=[0.229,0.224,0.225])
                               ]))
    cifar_test=DataLoader(cifar_test,batch_size=batchsz,shuffle=True)
    
    
    x,label=iter(cifar_train).next()
    print('x:',x.shape,'label:',label.shape)
    
    
    device=torch.device('cuda')
    model=Lenet5().to(device)
    
    criteon=nn.CrossEntropyLoss().to(device)
    optimizer=optim.Adam(model.parameters(),lr=1e-3)
    print(model)
    
    for epoch in range(epochs):
        
        model.train()
        for batchidx,(x,label) in enumerate(cifar_train):
            x,label=x.to(device),label.to(device)
            
            logits=model(x)
            # logits: [b,10]
            # label: [b]
            # loss: tensor scalar
            loss=criteon(logits,label)
            
            # backprop
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        print('epoch={} loss={}'.format(epoch,loss.item()))
        
        model.eval()
        with torch.no_grad():
            total_correct=0 
            total_num=0 
            for x,label in cifar_test:
                x,label=x.to(device),label.to(device)
                # [b,10]
                logits=model(x)
                pred=logits.argmax(dim=1)
                # [b] vs [b] => scalar tensor
                correct=torch.eq(pred,label).float().sum().item()
                total_correct+=correct
                total_num+=x.size(0)
                # print(correct)
            
            acc=total_correct / total_num
            print('epoch={} test acc={}'.format(epoch,acc))

In [3]:
if __name__=='__main__':
    main()

Files already downloaded and verified
x: torch.Size([128, 3, 32, 32]) label: torch.Size([128])
Lenet5(
  (conv_unit): Sequential(
    (0): Conv2d(3, 16, kernel_size=(5, 5), stride=(1, 1))
    (1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (2): Conv2d(16, 32, kernel_size=(5, 5), stride=(1, 1))
    (3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (fc_unit): Sequential(
    (0): Linear(in_features=800, out_features=32, bias=True)
    (1): ReLU()
    (2): Linear(in_features=32, out_features=10, bias=True)
  )
)
epoch=0 loss=1.2465145587921143
epoch=0 test acc=0.5319
epoch=1 loss=1.4509888887405396
epoch=1 test acc=0.5803
epoch=2 loss=0.9731515049934387
epoch=2 test acc=0.6113
epoch=3 loss=1.387157917022705
epoch=3 test acc=0.6242
epoch=4 loss=1.0898581743240356
epoch=4 test acc=0.6304
epoch=5 loss=0.7639704942703247
epoch=5 test acc=0.6474
epoch=6 loss=0.7387717962265015
epoch=6 test acc=0.6489
epoch=7 loss=0.8267283