In [1]:
import torch
from torch import nn
import torch.nn.functional as F
import torchvision
from LeNet import LeNet, accuracy, evaluate
import tqdm

In [2]:
from torchsummary import summary

model = LeNet(output_classes=10)

summary(model, (1, 32, 32))

Layer (type:depth-idx)                   Output Shape              Param #
├─Conv2d: 1-1                            [-1, 6, 28, 28]           156
├─ReLU: 1-2                              [-1, 6, 28, 28]           --
├─AvgPool2d: 1-3                         [-1, 6, 14, 14]           --
├─Conv2d: 1-4                            [-1, 16, 10, 10]          2,416
├─ReLU: 1-5                              [-1, 16, 10, 10]          --
├─AvgPool2d: 1-6                         [-1, 16, 5, 5]            --
├─Conv2d: 1-7                            [-1, 120, 1, 1]           48,120
├─ReLU: 1-8                              [-1, 120, 1, 1]           --
├─Linear: 1-9                            [-1, 84]                  10,164
├─Linear: 1-10                           [-1, 10]                  850
Total params: 61,706
Trainable params: 61,706
Non-trainable params: 0
Total mult-adds (M): 0.42
Input size (MB): 0.00
Forward/backward pass size (MB): 0.05
Params size (MB): 0.24
Estimated Total Size (MB): 0.29


Layer (type:depth-idx)                   Output Shape              Param #
├─Conv2d: 1-1                            [-1, 6, 28, 28]           156
├─ReLU: 1-2                              [-1, 6, 28, 28]           --
├─AvgPool2d: 1-3                         [-1, 6, 14, 14]           --
├─Conv2d: 1-4                            [-1, 16, 10, 10]          2,416
├─ReLU: 1-5                              [-1, 16, 10, 10]          --
├─AvgPool2d: 1-6                         [-1, 16, 5, 5]            --
├─Conv2d: 1-7                            [-1, 120, 1, 1]           48,120
├─ReLU: 1-8                              [-1, 120, 1, 1]           --
├─Linear: 1-9                            [-1, 84]                  10,164
├─Linear: 1-10                           [-1, 10]                  850
Total params: 61,706
Trainable params: 61,706
Non-trainable params: 0
Total mult-adds (M): 0.42
Input size (MB): 0.00
Forward/backward pass size (MB): 0.05
Params size (MB): 0.24
Estimated Total Size (MB): 0.29

In [3]:
transforms = torchvision.transforms.Compose([
    torchvision.transforms.Resize((32, 32)),
    torchvision.transforms.ToTensor()
])

In [4]:
train_dataset = torchvision.datasets.FashionMNIST(root='../datasets/', train=True, transform=transforms, download=True)
test_dataset = torchvision.datasets.FashionMNIST(root='../datasets/', train=False, transform=transforms, download=True)

In [5]:
train_dl = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True)
test_dl = torch.utils.data.DataLoader(test_dataset, batch_size=64, shuffle=True)

In [6]:
device='cuda'

In [8]:
def fit(epochs, model, train_loader, val_loader, opt_func=torch.optim.Adam):
    history = []
    optimizer = opt_func(model.parameters(), 6e-5)
    for epoch in range(epochs):
        lrs = []
        loss = 0
        acc = 0
        for batch in tqdm.tqdm(train_loader):
            loss, acc = model.training_step(batch)
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()
        print("Epoch [{}], loss: {:.4f}, acc: {:.4f}".format(epoch, loss, acc))
        result = evaluate(model, val_loader)
        model.epoch_end(epoch, result)
        history.append(result)
    return history

In [10]:
history = fit(5, model, train_dl, test_dl)

100%|█████████████████████████████████████| 938/938 [00:07<00:00, 122.17it/s]


Epoch [0], loss: 0.6212, acc: 0.7500
Epoch [0], val_loss: 0.6254, val_acc: 0.7691


100%|█████████████████████████████████████| 938/938 [00:08<00:00, 117.20it/s]


Epoch [1], loss: 0.6917, acc: 0.7500
Epoch [1], val_loss: 0.5951, val_acc: 0.7745


100%|█████████████████████████████████████| 938/938 [00:08<00:00, 113.21it/s]


Epoch [2], loss: 0.4793, acc: 0.8750
Epoch [2], val_loss: 0.5836, val_acc: 0.7811


100%|█████████████████████████████████████| 938/938 [00:07<00:00, 120.85it/s]


Epoch [3], loss: 0.8935, acc: 0.7500
Epoch [3], val_loss: 0.5764, val_acc: 0.7843


100%|█████████████████████████████████████| 938/938 [00:07<00:00, 120.87it/s]


Epoch [4], loss: 0.8195, acc: 0.6562
Epoch [4], val_loss: 0.5605, val_acc: 0.7888
