In [1]:
from torch import optim
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor
from tqdm import tqdm

from models.lenet import LeNet

### Utility Functions

In [2]:
def load_mnist_dataset():
    training_data = datasets.FashionMNIST(
        root='data',
        train=True,
        download=True,
        transform=ToTensor()
    )
    test_data = datasets.FashionMNIST(
        root='data',
        train=False,
        download=True,
        transform=ToTensor()
    )
    return training_data, test_data


### LeNet

In [3]:
# Hyper parameters
batch_size = 100
num_epochs = 100

# Convert this to function
training_data, test_data = load_mnist_dataset()
training_loader = DataLoader(training_data, batch_size=batch_size, shuffle=True)
model = LeNet(in_shape=0, out_shape=10)
optimizer = optim.Adam(model.parameters())

# Convert to use GPU
model.train(True)
loss_per_epoch = []
for _ in tqdm(range(num_epochs)):
    total_loss = 0
    for images, labels in training_loader:
        optimizer.zero_grad()

        pred = model(images)

        loss = model.loss(pred, labels)

        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    loss_per_epoch.append(total_loss)
    print('Epoch Loss: ', total_loss)


  1%|          | 1/100 [00:08<13:54,  8.42s/it]

Epoch Loss:  831.49112200737


  2%|▏         | 2/100 [00:16<13:27,  8.24s/it]

Epoch Loss:  430.64438703656197


  3%|▎         | 3/100 [00:25<13:44,  8.50s/it]

Epoch Loss:  358.7724778652191


  3%|▎         | 3/100 [00:30<16:20, 10.11s/it]


KeyboardInterrupt: 