In [None]:
import torch
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torch import nn, optim

%run ResNet.ipynb

def main():
    batchsz = 128
    cifar_train = datasets.CIFAR10('../data/cifar', True, transform=transforms.Compose([
        transforms.Resize((32, 32)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ]), download=True)
    cifar_train = DataLoader(cifar_train, batch_size=batchsz, shuffle=True)
    
    cifar_test = datasets.CIFAR10('../data/cifar', False, transform=transforms.Compose([
        transforms.Resize((32, 32)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ]), download=True)
    cifar_test = DataLoader(cifar_test, batch_size=batchsz, shuffle=True)
    
    x, label = iter(cifar_train).next()
    print('x:', x.shape, 'label:',label.shape)
    
    model = ResNet18()
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=1e-3)
    #print(model)
    
    for epoch in range(2):
        model.train()
        for batchidx, (x, label) in enumerate(cifar_train):
            logits = model(x)
            loss = criterion(logits, label)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        
        print(epoch, 'loss:', loss.item())
        
        model.eval()
        
        with torch.no_grad():
            total_correct = 0
            total_num = 0
            for x, label in cifar_test:
                logits = model(x)
                pred = logits.argmax(dim=1)
                correct = torch.eq(pred, label).float().sum().item()
                total_correct += correct
                total_num += x.size(0)
            
            acc = total_correct / total_num
            print(epoch, 'test acc:', acc)
    
if __name__ == "__main__":
    main()

block shape: torch.Size([2, 128, 8, 8])
out shape: torch.Size([2, 10])
Files already downloaded and verified
Files already downloaded and verified
x: torch.Size([128, 3, 32, 32]) label: torch.Size([128])
