In [2]:
import torch
from torch import nn
from torch.utils.data import DataLoader, random_split
from torchvision.datasets import MNIST
import os
from torchvision import datasets, transforms
import pytorch_lightning as pl
from torch.nn import functional as F

# Data

In [11]:
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]) # these are the mean and variance of pixels when taken from the entire dataset
mnist_train = MNIST(os.getcwd(), train = True, download = True, transform=transform)
mnist_train, mnist_val = random_split(mnist_train, [55000, 5000])

mnist_test =  MNIST(os.getcwd(), train = False, download = True, transform=transform)

mnist_train = DataLoader(mnist_train, batch_size = 64)
mnist_val = DataLoader(mnist_val, batch_size = 64)
mnist_test = DataLoader(mnist_test, batch_size = 64)

##### Lightning has a data module to make sure that every model implemented follows the SAME structure

In [28]:
class MNISTDataModule(pl.LightningDataModule):

    def setup(self, stage):
        # image transforms
        transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
        mnist_train = MNIST(os.getcwd(), train = True, download = True, transform=transform)
        self.mnist_train, self.mnist_val = random_split(mnist_train, [55000, 5000])

        self.mnist_test =  MNIST(os.getcwd(), train = False, download = True, transform=transform)
    
    def train_dataloader(self):
        return DataLoader(self.mnist_train, batch_size = 64)
    
    def val_dataloader(self):
        return DataLoader(self.mnist_val, batch_size = 64)
    
    def test_dataloader(self):
        return DataLoader(self.mnist_test, batch_size = 64)

# Models

In [13]:
class MNISTClassifier(nn.Module):

    def __init__(self):
        super(MNISTClassifier, self).__init__()

        self.layer_1 = torch.nn.Linear(28* 28, 128)
        self.layer_2 = torch.nn.Linear(128, 256)
        self.layer_3 = torch.nn.Linear(256, 10)
    
    def forward(self, x):
        batch_size, channels, width, height = x.size()

        # (b, 1, 28, 28) --> (b, 28 *28)
        x = x.view(batch_size, -1)

        x = self.layer_1(x)
        x = torch.relu(x)
        x = self.layer_2(x)
        x = torch.relu(x)
        x = self.layer_3(x)
        x = torch.log_softmax(x, dim = 1)
        return x
        

In [21]:
class LightningMNISTClassifier(pl.LightningModule):

    def __init__(self):
        super(LightningMNISTClassifier, self).__init__()

        self.layer_1 = torch.nn.Linear(28* 28, 128)
        self.layer_2 = torch.nn.Linear(128, 256)
        self.layer_3 = torch.nn.Linear(256, 10)
    
    def forward(self, x):
        batch_size, channels, width, height = x.size()

        # (b, 1, 28, 28) --> (b, 28 *28)
        x = x.view(batch_size, -1)

        x = self.layer_1(x)
        x = torch.relu(x)
        x = self.layer_2(x)
        x = torch.relu(x)
        x = self.layer_3(x)
        x = torch.log_softmax(x, dim = 0)
        return x
    
    def configure_optimizers(self):
        optmizer = torch.optim.Adam(self.parameters(), lr = 1e-3)
        return optmizer
    
    def cross_entropy_loss(self, logits, labels):
        return F.nll_loss(logits, labels)
    
    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self.forward(x)
        loss = self.cross_entropy_loss(logits, y)
        self.log('train loss', loss)
        return loss
    
    def validation_step(self, batch, batch_idx):
        x, y = batch
        logits = self.forward(x)
        loss = self.cross_entropy_loss(logits, y)
        self.log('val_loss', loss)

In [15]:
pytorch_model = MNISTClassifier()
lightning_model = LightningMNISTClassifier()

x = torch.Tensor(32, 1, 28, 28)

pt_out = pytorch_model(x)
pl_out = lightning_model(x)

tensor([[ 0.0652, -0.0154,  0.0236, -0.0407, -0.0659, -0.0504, -0.0175, -0.0093,
         -0.0093, -0.0564],
        [ 0.0652, -0.0154,  0.0236, -0.0407, -0.0659, -0.0504, -0.0175, -0.0093,
         -0.0093, -0.0564],
        [ 0.0652, -0.0154,  0.0236, -0.0407, -0.0659, -0.0504, -0.0175, -0.0093,
         -0.0093, -0.0564],
        [ 0.0652, -0.0154,  0.0236, -0.0407, -0.0659, -0.0504, -0.0175, -0.0093,
         -0.0093, -0.0564],
        [ 0.0652, -0.0154,  0.0236, -0.0407, -0.0659, -0.0504, -0.0175, -0.0093,
         -0.0093, -0.0564],
        [ 0.0652, -0.0154,  0.0236, -0.0407, -0.0659, -0.0504, -0.0175, -0.0093,
         -0.0093, -0.0564],
        [ 0.0652, -0.0154,  0.0236, -0.0407, -0.0659, -0.0504, -0.0175, -0.0093,
         -0.0093, -0.0564],
        [ 0.0652, -0.0154,  0.0236, -0.0407, -0.0659, -0.0504, -0.0175, -0.0093,
         -0.0093, -0.0564],
        [ 0.0652, -0.0154,  0.0236, -0.0407, -0.0659, -0.0504, -0.0175, -0.0093,
         -0.0093, -0.0564],
        [ 0.0652, -

# Optimizer and loss can all get put in the lightning class for pytorch lightning

In [16]:
pytorch_model = MNISTClassifier()
optimizer = torch.optim.Adam(pytorch_model.parameters(), lr = 1e-3)

In [17]:
def cross_entropy_loss(logits, labels):
    return F.nll_loss(logits, labels)

# Training

### Pytorch

In [19]:
num_epochs = 2
for epoch in range(num_epochs):
    for train_batch in mnist_train:
        x, y = train_batch
        logits = pytorch_model(x)
        loss = cross_entropy_loss(logits, y)
        print("train loss: ", loss)

        loss.backward()

        optimizer.step()
        optimizer.zero_grad()
    
    with torch.no_grad():
        val_loss = []
        for val_batch in mnist_val:
            x, y = val_batch
            logits = pytorch_model(x)
            val_loss.append(cross_entropy_loss(logits, y).item())
        
        val_loss = torch.mean(torch.tensor(val_loss))
        print('val_loss: ', val_loss.item())

        

train loss:  tensor(0.1313, grad_fn=<NllLossBackward0>)
train loss:  tensor(0.1242, grad_fn=<NllLossBackward0>)
train loss:  tensor(0.1617, grad_fn=<NllLossBackward0>)
train loss:  tensor(0.0693, grad_fn=<NllLossBackward0>)
train loss:  tensor(0.1786, grad_fn=<NllLossBackward0>)
train loss:  tensor(0.0784, grad_fn=<NllLossBackward0>)
train loss:  tensor(0.1269, grad_fn=<NllLossBackward0>)
train loss:  tensor(0.1411, grad_fn=<NllLossBackward0>)
train loss:  tensor(0.1289, grad_fn=<NllLossBackward0>)
train loss:  tensor(0.0343, grad_fn=<NllLossBackward0>)
train loss:  tensor(0.0923, grad_fn=<NllLossBackward0>)
train loss:  tensor(0.0655, grad_fn=<NllLossBackward0>)
train loss:  tensor(0.1785, grad_fn=<NllLossBackward0>)
train loss:  tensor(0.1756, grad_fn=<NllLossBackward0>)
train loss:  tensor(0.0702, grad_fn=<NllLossBackward0>)
train loss:  tensor(0.1614, grad_fn=<NllLossBackward0>)
train loss:  tensor(0.1275, grad_fn=<NllLossBackward0>)
train loss:  tensor(0.1417, grad_fn=<NllLossBack

### Pytorch lightning

In [29]:
model = LightningMNISTClassifier()
mnist_data = MNISTDataModule()
trainer = pl.Trainer()
trainer.fit(model, mnist_data)

GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
/Users/aryomanpatel/Desktop/Coding Stuff/ML/MnistPytorchLightning/env/lib/python3.11/site-packages/pytorch_lightning/loops/utilities.py:73: `max_epochs` was not set. Setting it to 1000 epochs. To train without an epoch limit, set `max_epochs=-1`.

  | Name    | Type   | Params
-----------------------------------
0 | layer_1 | Linear | 100 K 
1 | layer_2 | Linear | 33.0 K
2 | layer_3 | Linear | 2.6 K 
-----------------------------------
136 K     Trainable params
0         Non-trainable params
136 K     Total params
0.544     Total estimated model params size (MB)


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

/Users/aryomanpatel/Desktop/Coding Stuff/ML/MnistPytorchLightning/env/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=9` in the `DataLoader` to improve performance.


Sanity Checking DataLoader 0:   0%|          | 0/2 [00:00<?, ?it/s]tensor([[ 0.0433, -0.0713, -0.0622,  0.0418, -0.1709,  0.1440,  0.0725,  0.0611,
          0.0058, -0.0647],
        [ 0.0254, -0.1241, -0.0871, -0.1186, -0.1578,  0.0114,  0.0701,  0.1487,
          0.0422,  0.0071],
        [ 0.0207, -0.1525, -0.0498, -0.1664, -0.1659, -0.0971,  0.0867,  0.0483,
          0.0550, -0.0056],
        [-0.0467, -0.1364, -0.1380, -0.0714, -0.2185, -0.1705,  0.0010,  0.0354,
          0.0508,  0.0005],
        [-0.0026, -0.1412, -0.1482, -0.1407, -0.1955, -0.0162,  0.0615,  0.1363,
          0.0833, -0.0115],
        [-0.0638, -0.1570, -0.0526, -0.0739, -0.1687, -0.0850,  0.0935,  0.0562,
          0.1724, -0.0249],
        [-0.0779, -0.0845, -0.0186,  0.0299, -0.2920, -0.1757,  0.0204,  0.0935,
          0.1362, -0.0647],
        [ 0.0050, -0.0914,  0.0085, -0.1254, -0.0990,  0.0686, -0.0032,  0.1529,
          0.1036, -0.0768],
        [-0.1146, -0.1558, -0.1408, -0.1012, -0.1013,  0.0297

/Users/aryomanpatel/Desktop/Coding Stuff/ML/MnistPytorchLightning/env/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=9` in the `DataLoader` to improve performance.


Epoch 0:   0%|          | 0/860 [00:00<?, ?it/s] tensor([[-1.7346e-02, -1.0278e-01, -3.1378e-02, -7.4272e-02, -1.6011e-01,
         -4.9602e-02,  1.0983e-02,  1.7006e-01,  1.1776e-01,  1.1393e-02],
        [ 6.0930e-03, -1.6356e-01, -1.2953e-01,  9.0006e-03, -1.5327e-01,
          6.1658e-03,  4.5208e-02,  1.5152e-02,  8.5461e-02, -1.6925e-01],
        [ 2.1044e-02, -4.2154e-02, -4.4681e-02, -5.5241e-02, -1.1005e-01,
          5.6133e-02,  7.1856e-02,  1.2106e-01, -9.9681e-03,  1.2572e-02],
        [-5.1636e-02, -1.8122e-01,  8.7784e-02, -1.1618e-01, -2.3924e-01,
         -1.5834e-01,  1.0383e-01,  5.5011e-02,  1.6022e-01, -5.3401e-02],
        [-8.5789e-02, -2.4634e-01, -2.6506e-02, -1.6633e-01, -3.1059e-01,
         -1.8463e-01,  1.3521e-02,  2.2511e-02,  1.8146e-01, -1.0528e-01],
        [ 7.0104e-02, -1.5040e-01, -1.4090e-02, -7.0608e-02, -2.4198e-02,
          4.4614e-02,  6.9578e-02,  1.3279e-01,  9.6076e-02, -7.0748e-02],
        [-4.8055e-03, -7.6676e-02, -1.7464e-01, -8.8620e-

/Users/aryomanpatel/Desktop/Coding Stuff/ML/MnistPytorchLightning/env/lib/python3.11/site-packages/pytorch_lightning/trainer/call.py:54: Detected KeyboardInterrupt, attempting graceful shutdown...
