In [None]:
import torchvision
import torchvision.transforms as transforms
import torch
import math
from vgg16 import VGG16



LR = 0.01
BATCH_SIZE = 64
EPOCHS = 250
n_classes = 10

normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])

train_transform = transforms.Compose([
    transforms.ToTensor(),
    normalize])

test_transform = transforms.Compose([
    transforms.ToTensor(),
    normalize])

train_data = torchvision.datasets.CIFAR10('../data/CIFAR10/train',transform=train_transform,train=True,download=True)
test_data = torchvision.datasets.CIFAR10('../data/CIFAR10/test',transform=test_transform,train=False,download=True)

train_loader = torch.utils.data.DataLoader(train_data,batch_size=BATCH_SIZE,shuffle=True,num_workers=4)
test_loader = torch.utils.data.DataLoader(test_data,batch_size=BATCH_SIZE,shuffle=False,num_workers=4)

model = VGG16(n_classes)
model.cuda()
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=LR, momentum=0.9,weight_decay=5e-4)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer)

best_acc = 0


def validate(val_loader, model, criterion):
  
    model.eval()
    correct = 0
    total = 0

    for images, labels in test_loader:
        images = images.cuda()
        outputs = model(images)
        _,predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted.cpu() == labels).sum()
        acc = 100* correct/total
        print("avg acc: %f" % (acc))
        return acc

for epoch in range(EPOCHS):
    avg_loss = 0
    cnt = 0
    acc = 0
    for images, labels in train_loader:
        images = images.cuda()
        labels = labels.cuda()
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        avg_loss += loss.data
        cnt += 1
        print("[E: %d] loss: %f, avg_loss: %f" % (epoch, loss.data, avg_loss/cnt))
        loss.backward()
        optimizer.step()
    val_acc = validate(test_loader, model, criterion)
    if val_acc > best_acc:
          best_acc = val_acc
          torch.save(model.state_dict(), 'vgg16_cifar10.pkl')
    scheduler.step(avg_loss)
