In [4]:
import torch
from torch import nn
import pytorch_lightning as pl

In [5]:
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.optim import SGD
from torch.utils.data import random_split, DataLoader

In [6]:
class model(pl.LightningModule):
    def __init__(self):
        super(model, self).__init__()

        # Defining our model architecture
        self.fc1 = nn.Linear(28*28, 256)
        self.fc2 = nn.Linear(256, 128)
        self.out = nn.Linear(128, 10)

        # Defining learning rate
        self.lr = 0.01

        # Defining loss
        self.loss = nn.CrossEntropyLoss()

    def forward(self, x):
        # Defining the forward pass of the model
        batch_size, _, _, _ = x.size()
        x = x.view(batch_size, -1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        return self.out(x)

    def configure_optimizers(self):
        # Defining and returning the optimizer for our model
        # with the defines parameters
        return torch.optim.SGD(self.parameters(), lr = self.lr)

    def training_step(self, train_batch, batch_idx):
        # Defining training steps for our model
        x, y = train_batch
        logits = self.forward(x)
        loss = self.loss(logits, y)
        return loss

    def validation_step(self, valid_batch, batch_idx):
        # Defining validation steps for our model
        x, y = valid_batch
        logits = self.forward(x)
        loss = self.loss(logits, y)

class DataModuleMNIST(pl.LightningDataModule):
    def __init__(self):
        super().__init__()

        # Directory to store MNIST Data
        self.download_dir = ''

        # Defining batch size of our data
        self.batch_size = 32

        # Defining transforms to be applied on the data
        self.transform = transforms.Compose([transforms.ToTensor()])

    def prepare_data(self):
        # Downloading our data
        datasets.MNIST(self.download_dir, train = True, download = True)

        datasets.MNIST(self.download_dir, train = False, download = True)

    def setup(self, stage=None):
        # Loading our data after applying the transforms
        data = datasets.MNIST(self.download_dir, train = True,
                              transform = self.transform)

        self.train_data, self.valid_data = random_split(data, [55000, 5000])

        self.test_data = datasets.MNIST(self.download_dir, train = False,
                                        transform = self.transform)

    def train_dataloader(self):
        # Generating train_dataloader
        return DataLoader(self.train_data,
                          batch_size = self.batch_size)

    def val_dataloader(self):
        # Generating val_dataloader
        return DataLoader(self.valid_data,
                          batch_size = self.batch_size)

    def test_dataloader(self):
        # Generating test_dataloader
        return DataLoader(self.test_data,
                          batch_size = self.batch_size)

In [7]:
clf = model()
mnist = DataModuleMNIST()
trainer = pl.Trainer()
trainer.fit(clf, mnist)

INFO:pytorch_lightning.utilities.rank_zero:GPU available: False, used: False
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs
  rank_zero_warn(


Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9912422/9912422 [00:00<00:00, 107575324.06it/s]


Extracting MNIST/raw/train-images-idx3-ubyte.gz to MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28881/28881 [00:00<00:00, 16605304.16it/s]


Extracting MNIST/raw/train-labels-idx1-ubyte.gz to MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to MNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 1648877/1648877 [00:00<00:00, 39383674.52it/s]

Extracting MNIST/raw/t10k-images-idx3-ubyte.gz to MNIST/raw






Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4542/4542 [00:00<00:00, 6183229.07it/s]


Extracting MNIST/raw/t10k-labels-idx1-ubyte.gz to MNIST/raw



INFO:pytorch_lightning.callbacks.model_summary:
  | Name | Type             | Params
------------------------------------------
0 | fc1  | Linear           | 200 K 
1 | fc2  | Linear           | 32.9 K
2 | out  | Linear           | 1.3 K 
3 | loss | CrossEntropyLoss | 0     
------------------------------------------
235 K     Trainable params
0         Non-trainable params
235 K     Total params
0.941     Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

  rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...")
