In [None]:
## The usual imports
import torch
import torch.nn as nn

## print out the pytorch version used
print(torch.__version__)

In [None]:
import torchvision

g_ws_dir = '../_exp'  # The working directory

mnist_train = torchvision.datasets.MNIST(g_ws_dir, train=True, download=True)
mnist_test = torchvision.datasets.MNIST(g_ws_dir, train=False)

In [None]:
mnist_train.targets.shape


In [None]:
mnist_test.data.shape
# mnist_train.data[0] / 255.0

In [5]:
train_data = mnist_train.data[:55000] / 255.0
train_targets = mnist_train.targets[:55000]
val_data = mnist_train.data[55000:] / 255.0
val_targets = mnist_train.targets[55000:]

In [6]:
import torch
from torch.utils.data import TensorDataset, DataLoader

train_ds = TensorDataset(train_data.view(-1, 1, 28, 28), train_targets)
val_ds = TensorDataset(val_data.view(-1, 1, 28, 28), val_targets)
test_ds = TensorDataset(mnist_test.data.view(-1, 1, 28, 28) / 255.0, mnist_test.targets)

In [7]:
from torch.utils.data import DataLoader

train_dataloader = DataLoader(train_ds, batch_size=128, shuffle=True)
validate_dataloader = DataLoader(val_ds, batch_size=128)
test_dataloader = DataLoader(test_ds, batch_size=128)

In [8]:
import pytorch_lightning as pl

class LitMNIST(pl.LightningModule):

    def __init__(self):
        super().__init__()
        self.model = torch.nn.Sequential(
            torch.nn.Conv2d(1, 32, 3), #in, out, kernel size
            torch.nn.ReLU(),

            torch.nn.Conv2d(32, 32, 3),
            torch.nn.ReLU(),

            torch.nn.MaxPool2d(2),
            torch.nn.Dropout2d(0.25),

            torch.nn.Flatten(),
            torch.nn.Linear(12*12*32, 128), 
            torch.nn.ReLU(),
            torch.nn.Dropout2d(0.5),

            torch.nn.Linear(128, 10)
        )
    
    def forward(self, x):
        return self.model(x)
    
    def _common_step(self, batch, batch_idx):
        x, y = batch
        pred = self.model(x)
        loss_fn = torch.nn.CrossEntropyLoss()
        loss = loss_fn(pred, y)
        return loss
            
    def training_step(self, batch, batch_idx):
        loss = self._common_step(batch, batch_idx)
        self.log('train_loss', loss)
        return loss
    
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.model.parameters())
        return optimizer
    
    def validation_step(self, batch, batch_idx):
        loss = self._common_step(batch, batch_idx)
        self.log('val_loss', loss)
        
    def test_step(self, batch, batch_idx):
        x, y = batch
        pred = self.model(x)
        loss_fn = torch.nn.CrossEntropyLoss()
        loss = loss_fn(pred, y)        
        self.log('test_loss', loss)
        acc = torch.argmax(pred, dim=1).eq(y).sum().item() / len(pred)
        self.log('test_acc', acc)

In [None]:
mnist = LitMNIST()

trainer = pl.Trainer(max_epochs=2, default_root_dir=g_ws_dir)
trainer.fit(mnist, train_dataloader, validate_dataloader)

In [None]:
trainer.test(mnist, test_dataloader)