In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import torch.utils.data as data
import torchvision.datasets as datasets
from vgg import VGG
from statistics import mean

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
num_workers = 4

In [3]:
trans = {}
trans['train'] = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
trans['test'] = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

In [4]:
Data = {}
Data['train'] = datasets.CIFAR10(
    root='./',train=True,transform=trans['train'],download=True
)
Data['test'] = datasets.CIFAR10(
    root='./',train=False,transform=trans['test'],download=True
)

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./cifar-10-python.tar.gz


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting ./cifar-10-python.tar.gz to ./
Files already downloaded and verified


In [5]:
dataloaders = {x:data.DataLoader(Data[x],batch_size=512,shuffle=True,num_workers=num_workers) for x in ['train','test']}

In [6]:
cfg = {
    'VGG11': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
    'VGG13': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
    'VGG16': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],
    'VGG19': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'],
}

In [7]:
vgg = VGG('VGG16')
vgg = vgg.to(device)

In [8]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(vgg.parameters(), lr=0.01,
                      momentum=0.9, weight_decay=5e-4)
scheduler = optim.lr_scheduler.MultiStepLR(optimizer,[150,250])

In [9]:
def train(init_epoch,num_epochs):
    train_loss = []
    test_acc = []
    best_model = vgg.state_dict()
    for epoch in range(init_epoch,num_epochs):
        print('Epoch:',epoch)
        epoch_train_loss = []
        epoch_test_acc = 0
        vgg.train()
        for x,y in dataloaders['train']:
            x = x.to(device)
            y = y.to(device)
            
            optimizer.zero_grad()
            out = vgg(x)
            loss = criterion(out,y)
            epoch_train_loss.append(loss.item())
            loss.backward()
            optimizer.step()
            
            del x,y
        
        scheduler.step()
            
        epoch_train_loss = mean(epoch_train_loss)
            
        vgg.eval()
        accuracies = 0
        for x,y in dataloaders['test']:
            x = x.to(device)
            y = y.to(device)
            
            out = vgg(x)
            accuracies+= torch.sum(torch.argmax(nn.functional.softmax(out, dim=1), dim=1)==y).item()
            del x,y
            
        epoch_test_acc = accuracies/len(Data['test'])
        
        train_loss.append(epoch_train_loss)
        test_acc.append(epoch_test_acc)
        
        print('Loss:',epoch_train_loss,'Accuracy:',epoch_test_acc)
        
        if epoch==0:
            best_model = vgg.state_dict()
        else:
            if epoch_test_acc == max(test_acc):
                best_model = vgg.state_dict()
        
        torch.save(best_model,'BestModel.pth')
        
    return (train_loss,test_acc)
        

In [10]:
train_loss,test_acc = train(0,300)

Epoch: 0
Loss: 1.5341495920200736 Accuracy: 0.5799
Epoch: 1
Loss: 1.02727736867204 Accuracy: 0.6747
Epoch: 2
Loss: 0.8012745380401611 Accuracy: 0.7082
Epoch: 3
Loss: 0.6588711689929573 Accuracy: 0.7547
Epoch: 4
Loss: 0.5816594374423124 Accuracy: 0.7761
Epoch: 5
Loss: 0.5193708411284855 Accuracy: 0.7627
Epoch: 6
Loss: 0.4778398397017498 Accuracy: 0.7938
Epoch: 7
Loss: 0.43445264867373873 Accuracy: 0.807
Epoch: 8
Loss: 0.40074020928266096 Accuracy: 0.8226
Epoch: 9
Loss: 0.3684662987991255 Accuracy: 0.8309
Epoch: 10
Loss: 0.3425835059309492 Accuracy: 0.838
Epoch: 11
Loss: 0.32344281490968196 Accuracy: 0.8556
Epoch: 12
Loss: 0.29856888554534133 Accuracy: 0.8223
Epoch: 13
Loss: 0.2789645445894222 Accuracy: 0.7922
Epoch: 14
Loss: 0.26242311724594664 Accuracy: 0.8277
Epoch: 15
Loss: 0.252370078952945 Accuracy: 0.8282
Epoch: 16
Loss: 0.24091198204123243 Accuracy: 0.8602
Epoch: 17
Loss: 0.22121956807618237 Accuracy: 0.8483
Epoch: 18
Loss: 0.21023415211512117 Accuracy: 0.8552
Epoch: 19
Loss: 0.2

Epoch: 154
Loss: 0.002878567434928133 Accuracy: 0.9176
Epoch: 155
Loss: 0.002564158876264962 Accuracy: 0.9186
Epoch: 156
Loss: 0.0026902308071635626 Accuracy: 0.9203
Epoch: 157
Loss: 0.0020762632763706985 Accuracy: 0.9193
Epoch: 158
Loss: 0.0021053998266603344 Accuracy: 0.9201
Epoch: 159
Loss: 0.0022897616192717484 Accuracy: 0.9199
Epoch: 160
Loss: 0.0017396779528971078 Accuracy: 0.9214
Epoch: 161
Loss: 0.001758775801507148 Accuracy: 0.9211
Epoch: 162
Loss: 0.001840171814309338 Accuracy: 0.9212
Epoch: 163
Loss: 0.0014718816281838 Accuracy: 0.9219
Epoch: 164
Loss: 0.0015691097535321262 Accuracy: 0.9217
Epoch: 165
Loss: 0.0014315696278164088 Accuracy: 0.9205
Epoch: 166
Loss: 0.0015378559711484276 Accuracy: 0.9213
Epoch: 167
Loss: 0.0014703027875523787 Accuracy: 0.9208
Epoch: 168
Loss: 0.0013748229791918695 Accuracy: 0.92
Epoch: 169
Loss: 0.0012821247233956938 Accuracy: 0.9202
Epoch: 170
Loss: 0.0013851311565044203 Accuracy: 0.9218
Epoch: 171
Loss: 0.001479116309735905 Accuracy: 0.9222
Ep

In [11]:
print('Best model accuracy:',max(test_acc))

Best model accuracy: 0.9235
