In [1]:
import torch
from torch import nn
import torch.nn.functional as F
import torchvision
from VGG import VGG, accuracy, evaluate

In [2]:
x = torch.randn(32, 3, 224, 224)
model = VGG(output_classes=10)

out = model(x)
out.shape

torch.Size([32, 10])

In [3]:
transforms = torchvision.transforms.Compose([
    torchvision.transforms.Resize((224, 224)),
    torchvision.transforms.ToTensor()
])

In [4]:
train_dataset = torchvision.datasets.CIFAR10(root='../datasets/', train=True, transform=transforms, download=True)
test_dataset = torchvision.datasets.CIFAR10(root='../datasets/', train=False, transform=transforms, download=True)

Files already downloaded and verified
Files already downloaded and verified


In [5]:
train_dl = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True)
test_dl = torch.utils.data.DataLoader(test_dataset, batch_size=64, shuffle=True)

In [6]:
def fit(epochs, model, train_loader, val_loader, opt_func=torch.optim.Adam):
    history = []
    optimizer = opt_func(model.parameters(), 6e-5)
    for epoch in range(epochs):
        lrs = []
        loss = 0
        acc = 0
        for batch in tqdm.tqdm(train_loader):
            loss, acc = model.training_step(batch)
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()
        print("Epoch [{}], loss: {:.4f}, acc: {:.4f}".format(epoch, loss, acc))
        result = evaluate(model, val_loader)
        model.epoch_end(epoch, result)
        history.append(result)
    return history

In [7]:
device = 'cuda'

In [8]:
model = model.to(device)

In [9]:
import tqdm
history = fit(5, model, train_dl, test_dl)

100%|██████████| 782/782 [06:59<00:00,  1.87it/s]


Epoch [0], loss: 1.7942, acc: 0.4375
Epoch [0], val_loss: 1.0827, val_acc: 0.6152


100%|██████████| 782/782 [06:59<00:00,  1.86it/s]


Epoch [1], loss: 0.8865, acc: 0.6875
Epoch [1], val_loss: 0.7883, val_acc: 0.7194


100%|██████████| 782/782 [06:59<00:00,  1.86it/s]


Epoch [2], loss: 0.1224, acc: 1.0000
Epoch [2], val_loss: 0.6814, val_acc: 0.7632


100%|██████████| 782/782 [06:59<00:00,  1.86it/s]


Epoch [3], loss: 0.2944, acc: 0.8750
Epoch [3], val_loss: 0.5997, val_acc: 0.7958


100%|██████████| 782/782 [06:59<00:00,  1.86it/s]


Epoch [4], loss: 0.3271, acc: 0.8750
Epoch [4], val_loss: 0.5394, val_acc: 0.8150
