In [None]:
%load_ext autoreload
%autoreload 2

import torch
from torch.utils.data import DataLoader, random_split
from torchvision.datasets import MNIST
from torchvision import datasets, transforms

from torch.nn import functional as F

import os

from MNIST_pytorch import MNISTClassifier

from MNIST_lightning import LightningMNISTClassifier
from MNIST_lightning_data import MNISTDataModule

In [None]:
pytorch_model = MNISTClassifier()
lightning_model = LightningMNISTClassifier()

x = torch.Tensor(32, 1, 28, 28)

pt_out = pytorch_model(x)
pl_out = lightning_model(x)

In [None]:
# --------------------
# TRANSFORMS
# --------------------
# prepare transforms standard to MNIST
transform = transforms.Compose([transforms.ToTensor(),
                                transforms.Normalize((0.1307,), (0.3081,))])

# --------------------
# TRANING, VAL DATA
# --------------------
mnist_train = MNIST(os.getcwd(), train=True, download=True)

# train (55,000 images), val split (5,000 images)
mnist_train, mnist_val = random_split(mnist_train, [55000, 5000])

# --------------------
# TEST DATA
# --------------------
mnist_test = MNIST(os.getcwd(), train=False, download=True)

# --------------------
# DATALOADERS
# --------------------
# The dataloaders handle shuffling, batching, etc...
mnist_train = DataLoader(mnist_train, batch_size=64)
mnist_val = DataLoader(mnist_val, batch_size=64)
mnist_test = DataLoader(mnist_test, batch_size=64)

# --------------------
# OPTIMIZER
# --------------------
optimizer = torch.optim.Adam(pytorch_model.parameters(), lr=1e-3)

# --------------------
# LOSS
# --------------------
def cross_entropy_loss(logits, labels):
    return F.nll_loss(logits, labels)

# --------------------
# TRAINING LOOP
# --------------------
num_epochs = 1
for epoch in range(num_epochs):
    
    # TRAINING LOOP
    for train_batch in mnist_train:
        x, y = train_batch
        
        logits = pytorch_model(x)
        loss = cross_entropy_loss(logits, y)
        print('train_loss: ', loss.item())
        
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
    
    # VALIDATION LOOP
    with torch.no_grad():
        val_loss = []
        for val_batch in mnist_val:
            x, y = val_batch
            logits = pytorch_model(x)
            val_loss = cross_entropy_loss(logits, y).item()
            val_loss.append(val_loss)
        
        val_loss = torch.mean(torch.tensor(val_loss))
        print('val_loss: ', val_loss.item())
        
# --------------------
# VALIDATION LOOP
# --------------------

In [None]:
trainer = pl.Trainer()
trainer.fit(lightning_model)