In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import torch.utils.data as data
import torchvision.datasets as datasets
from vgg import VGG
from statistics import mean

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
num_workers = 4

In [3]:
trans = {}
trans['train'] = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
trans['test'] = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

In [4]:
Data = {}
Data['train'] = datasets.CIFAR10(
    root='./',train=True,transform=trans['train'],download=True
)
Data['test'] = datasets.CIFAR10(
    root='./',train=False,transform=trans['test'],download=True
)

Files already downloaded and verified
Files already downloaded and verified


In [5]:
dataloaders = {x:data.DataLoader(Data[x],batch_size=512,shuffle=True,num_workers=num_workers) for x in ['train','test']}

In [6]:
cfg = {
    'VGG11': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
    'VGG13': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
    'VGG16': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],
    'VGG19': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'],
}

In [7]:
vgg = VGG('VGG16')
vgg = vgg.to(device)

In [8]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(vgg.parameters(), lr=0.01,
                      momentum=0.9, weight_decay=5e-4)
scheduler = optim.lr_scheduler.MultiStepLR(optimizer,[150,250])

In [9]:
def train(init_epoch,num_epochs):
    train_loss = []
    test_acc = []
    best_model = vgg.state_dict()
    for epoch in range(init_epoch,num_epochs):
        print('Epoch:',epoch)
        epoch_train_loss = []
        epoch_test_acc = 0
        vgg.train()
        for x,y in dataloaders['train']:
            x = x.to(device)
            y = y.to(device)
            
            optimizer.zero_grad()
            out = vgg(x)
            loss = criterion(out,y)
            epoch_train_loss.append(loss.item())
            loss.backward()
            optimizer.step()
            
            del x,y
        
        scheduler.step()
            
        epoch_train_loss = mean(epoch_train_loss)
            
        vgg.eval()
        accuracies = 0
        for x,y in dataloaders['test']:
            x = x.to(device)
            y = y.to(device)
            
            out = vgg(x)
            accuracies+= torch.sum(torch.argmax(nn.functional.softmax(out, dim=1), dim=1)==y).item()
            del x,y
            
        epoch_test_acc = accuracies/len(Data['test'])
        
        train_loss.append(epoch_train_loss)
        test_acc.append(epoch_test_acc)
        
        print('Loss:',epoch_train_loss,'Accuracy:',epoch_test_acc)
        
        if epoch==0:
            best_model = vgg.state_dict()
        else:
            if epoch_test_acc == max(test_acc):
                best_model = vgg.state_dict()
        
        torch.save(best_model,'BestModel.pth')
        
    return (train_loss,test_acc)
        

In [10]:
train_loss,test_acc = train(0,300)

Epoch: 0
Loss: 1.5460122349310894 Accuracy: 0.5019
Epoch: 1
Loss: 1.037760588587547 Accuracy: 0.6414
Epoch: 2
Loss: 0.7891671882600201 Accuracy: 0.7015


KeyboardInterrupt: 