In [1]:
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt

import pytorch_lightning as pl
import torch.nn.functional as F
from pytorch_lightning import Trainer

In [2]:
##Don't have to worry about:
#Setting model to training/eval
#Using device for gpu support and pushing model to device
#Optimizer zero grad
#Calling backwards function or optimizer step
#Can have automatic learning rate setter
#Can run test batch

In [10]:
input_size = 784 # 28x28
hidden_size = 500 
num_classes = 10
num_epochs = 2
batch_size = 100
learning_rate = 0.001

In [46]:
class LitNeuralNet(pl.LightningModule):
    def __init__(self, input_size, hidden_size, num_classes):
        super(LitNeuralNet, self).__init__()
        self.input_size = input_size
        self.l1 = nn.Linear(input_size, hidden_size) 
        self.relu = nn.ReLU()
        self.l2 = nn.Linear(hidden_size, num_classes)  
    
    def forward(self, x):
        out = self.l1(x)
        out = self.relu(out)
        out = self.l2(out)
        # no activation and no softmax at the end
        return out
    
    def training_step(self, batch, batch_idx):
        images, labels = batch
        images = images.reshape(-1, 28*28)
        
        # Forward pass
        outputs = self(images)
        loss = F.cross_entropy(outputs, labels)
        
        #log to tensorboard
        self.log("train_loss", loss)
        return loss
    
    def validation_step(self, batch, batch_idx):
        images, labels = batch
        images = images.reshape(-1, 28*28)
        
        # Forward pass
        outputs = self(images)
        loss = F.cross_entropy(outputs, labels)
        
        #log to tensorboard
        self.log("val_loss", loss)
        return loss
        
    def configure_optimizers(self):
        return torch.optim.Adam(model.parameters(), lr=learning_rate)  
    
    

In [47]:
train_dataset = torchvision.datasets.MNIST(root="./data", 
    train=True, transform=transforms.ToTensor(), download=True)

train_loader = torch.utils.data.DataLoader(dataset=train_dataset, 
    batch_size=batch_size, num_workers=4, shuffle=False)


val_dataset = torchvision.datasets.MNIST(root="./data", 
    train=False, transform=transforms.ToTensor())

val_loader = torch.utils.data.DataLoader(dataset=val_dataset, 
    batch_size=batch_size, num_workers=4, shuffle=False)

In [None]:
#Use a fast_dev_run
model = LitNeuralNet(input_size, hidden_size, num_classes)

trainer = pl.Trainer(fast_dev_run = True)
trainer.fit(model, train_dataloaders=train_loader, 
            val_dataloaders=val_loader)

In [48]:
model = LitNeuralNet(input_size, hidden_size, num_classes)

trainer = pl.Trainer(max_epochs=num_epochs)
trainer.fit(model, train_dataloaders=train_loader, 
            val_dataloaders=val_loader)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs

  | Name | Type   | Params
--------------------------------
0 | l1   | Linear | 392 K 
1 | relu | ReLU   | 0     
2 | l2   | Linear | 5.0 K 
--------------------------------
397 K     Trainable params
0         Non-trainable params
397 K     Total params
1.590     Total estimated model params size (MB)


Epoch 0:  86%|███████████████████▋   | 600/700 [00:04<00:00, 128.88it/s, loss=0.111, v_num=8]
Validation: 0it [00:00, ?it/s][A
Validation:   0%|                                                    | 0/100 [00:00<?, ?it/s][A
Validation DataLoader 0:   0%|                                       | 0/100 [00:00<?, ?it/s][A
Epoch 0:  86%|███████████████████▋   | 601/700 [00:04<00:00, 125.68it/s, loss=0.111, v_num=8][A
Epoch 0:  86%|███████████████████▊   | 602/700 [00:04<00:00, 125.76it/s, loss=0.111, v_num=8][A
Epoch 0:  86%|███████████████████▊   | 603/700 [00:04<00:00, 125.81it/s, loss=0.111, v_num=8][A
Epoch 0:  86%|███████████████████▊   | 604/700 [00:04<00:00, 125.82it/s, loss=0.111, v_num=8][A
Epoch 0:  86%|███████████████████▉   | 605/700 [00:04<00:00, 125.92it/s, loss=0.111, v_num=8][A
Epoch 0:  87%|███████████████████▉   | 606/700 [00:04<00:00, 126.02it/s, loss=0.111, v_num=8][A
Epoch 0:  87%|███████████████████▉   | 607/700 [00:04<00:00, 126.11it/s, loss=0.111, v_num=8][A

Epoch 1:  94%|████████████████████▌ | 656/700 [00:05<00:00, 129.12it/s, loss=0.0728, v_num=8][A
Epoch 1:  94%|████████████████████▋ | 657/700 [00:05<00:00, 129.17it/s, loss=0.0728, v_num=8][A
Epoch 1:  94%|████████████████████▋ | 658/700 [00:05<00:00, 129.22it/s, loss=0.0728, v_num=8][A
Epoch 1:  94%|████████████████████▋ | 659/700 [00:05<00:00, 129.30it/s, loss=0.0728, v_num=8][A
Epoch 1:  94%|████████████████████▋ | 660/700 [00:05<00:00, 129.37it/s, loss=0.0728, v_num=8][A
Epoch 1:  94%|████████████████████▊ | 661/700 [00:05<00:00, 129.42it/s, loss=0.0728, v_num=8][A
Epoch 1:  95%|████████████████████▊ | 662/700 [00:05<00:00, 129.48it/s, loss=0.0728, v_num=8][A
Validation DataLoader 0:  63%|██████████████████▎          | 63/100 [00:00<00:00, 198.30it/s][A
Epoch 1:  95%|████████████████████▊ | 663/700 [00:05<00:00, 129.52it/s, loss=0.0728, v_num=8][A
Epoch 1:  95%|████████████████████▊ | 664/700 [00:05<00:00, 129.56it/s, loss=0.0728, v_num=8][A
Epoch 1:  95%|████████████████