In [6]:
import os

import torch
import pytorch_lightning as pl

from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader, random_split, TensorDataset
from torchvision import transforms, datasets
from torchvision.datasets import MNIST

from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint, LearningRateMonitor

In [None]:
model_path = 'models'

In [7]:
class LightningMNISTClassifier(pl.LightningModule):
    def __init__(self, lr_rate):
        super(LightningMNISTClassifier, self).__init__()

        self.layer_1 = torch.nn.Linear(28 * 28, 128)
        self.layer_2 = torch.nn.Linear(128, 256)
        self.layer_3 = torch.nn.Linear(256, 10)
        self.lr_rate = lr_rate

    def forward(self, x):
        batch_size, channels, width, height = x.size()  

        # (b, 1, 28, 28) -> (b, 1*28*28)
        x = x.view(batch_size, -1)

        # layer 1 (b, 1*28*28) -> (b, 128)
        x = self.layer_1(x)
        x = torch.relu(x)

        # layer 2 (b, 128) -> (b, 256)
        x = self.layer_2(x)
        x = torch.relu(x)

        # layer 3 (b, 256) -> (b, 10)
        x = self.layer_3(x)

        # probability distribution
        x = torch.log_softmax(x, dim=1)

        # x = torch.sigmoid(x)
        return x
        
    def cross_entropy_loss(self, logits, labels):
        return F.nll_loss(logits, labels)
    
    def training_step(self, train_batch, batch_idx):
        x, y = train_batch
        logits = self.forward(x)
        loss = self.cross_entropy_loss(logits, y)

        logs = {'train_loss': loss}
        return {'loss': loss, 'log': logs}
    
    def validation_step(self, val_batch, batch_idx):
        x, y = val_batch
        logits = self.forward(x)
        loss = self.cross_entropy_loss(logits, y)

        return {'val_loss': loss}
    
    def test_step(self, val_batch, batch_idx):
        x, y = val_batch
        logits = self.forward(x)
        loss = self.cross_entropy_loss(logits, y)

        return {'test_loss': loss}
    
    def on_validation_epoch_end(self, outputs):
        avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()
        logs = {'val_loss': avg_loss}

        return {'avg_val_loss': avg_loss, 'log': logs}
    
    def on_test_epoch_end(self, outputs):
        avg_loss = torch.stack([x['test_loss'] for x in outputs]).mean()
        logs = {'test_loss': avg_loss}

        return {'avg_test_loss': avg_loss, 'log': logs}
    
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.lr_rate)
        lr_scheduler = {'scheduler': torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.1),
                        'monitor': 'val_loss'}
        
        return [optimizer], [lr_scheduler]

In [8]:
# Custom Callbacks
class MyPrintingCallback(pl.Callback):

    def on_init_start(self, trainer):
        print('Starting to init trainer!')

    def on_init_end(self, trainer):
        print('trainer is init now')

    def on_train_end(self, trainer, pl_module):
        print('do something when training ends')

In [10]:

def prepare_data():
  # transforms for images
  transform=transforms.Compose([transforms.ToTensor(), 
                                transforms.Normalize((0.1307,), (0.3081,))])
    
  # prepare transforms standard to MNIST
  mnist_train = MNIST(os.getcwd(), train=True, download=True, transform=transform)
  mnist_train = [mnist_train[i] for i in range(2200)]
  
  mnist_train, mnist_val = random_split(mnist_train, [2000, 200])

  mnist_test = MNIST(os.getcwd(), train=False, download=True, transform=transform)
  mnist_test = [mnist_test[i] for i in range(3000,4000)]

  return mnist_train, mnist_val, mnist_test

In [11]:
train, val, test = prepare_data()

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Failed to download (trying next):
<urlopen error [SSL: CERTIFICATE_VERIFY_FAILED] certificate verify failed: certificate has expired (_ssl.c:1006)>

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to C:\Users\isaac\MNIST\raw\train-images-idx3-ubyte.gz


100%|██████████| 9912422/9912422 [01:57<00:00, 84483.06it/s] 


Extracting C:\Users\isaac\MNIST\raw\train-images-idx3-ubyte.gz to C:\Users\isaac\MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Failed to download (trying next):
<urlopen error [SSL: CERTIFICATE_VERIFY_FAILED] certificate verify failed: certificate has expired (_ssl.c:1006)>

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to C:\Users\isaac\MNIST\raw\train-labels-idx1-ubyte.gz


100%|██████████| 28881/28881 [00:00<00:00, 223748.58it/s]


Extracting C:\Users\isaac\MNIST\raw\train-labels-idx1-ubyte.gz to C:\Users\isaac\MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Failed to download (trying next):
<urlopen error [SSL: CERTIFICATE_VERIFY_FAILED] certificate verify failed: certificate has expired (_ssl.c:1006)>

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to C:\Users\isaac\MNIST\raw\t10k-images-idx3-ubyte.gz


100%|██████████| 1648877/1648877 [00:00<00:00, 2227017.47it/s]


Extracting C:\Users\isaac\MNIST\raw\t10k-images-idx3-ubyte.gz to C:\Users\isaac\MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Failed to download (trying next):
<urlopen error [SSL: CERTIFICATE_VERIFY_FAILED] certificate verify failed: certificate has expired (_ssl.c:1006)>

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to C:\Users\isaac\MNIST\raw\t10k-labels-idx1-ubyte.gz


100%|██████████| 4542/4542 [00:00<00:00, 4538001.14it/s]


Extracting C:\Users\isaac\MNIST\raw\t10k-labels-idx1-ubyte.gz to C:\Users\isaac\MNIST\raw



In [12]:
train_loader, val_loader, test_loader = DataLoader(train, batch_size=64), DataLoader(val, batch_size=64), DataLoader(test, batch_size=64)


In [13]:
model = LightningMNISTClassifier(lr_rate=1e-3)

# Learning Rate Logger
lr_logger = LearningRateMonitor()
# Set Early Stopping
early_stopping = EarlyStopping('val_loss', mode='min', patience=5)
# saves checkpoints to 'model_path' whenever 'val_loss' has a new min
checkpoint_callback = ModelCheckpoint(filepath=model_path+'mnist_{epoch}-{val_loss:.2f}',
                                      monitor='val_loss', mode='min', save_top_k=3)

trainer = pl.Trainer(max_epochs=30, profiler=True, callbacks=[lr_logger], 
                     early_stop_callback=early_stopping, checkpoint_callback=checkpoint_callback,
                     default_root_dir=model_path) #gpus=1

trainer.fit(model, train_dataloader=train_loader, val_dataloaders=val_loader)

NameError: name 'model_path' is not defined