In [None]:
!pip install pytorch_lightning

In [None]:
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import pytorch_lightning as L
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import MNIST
from sklearn.metrics import confusion_matrix
from sklearn.metrics import confusion_matrix
import seaborn as sns
import matplotlib.pyplot as plt
from IPython.core.display import ProgressBar

In [None]:
mnist_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

In [None]:
mnist_full = MNIST(root='./data', train=True, download=True, transform=mnist_transform)
test_dataset = MNIST(root='./data', train=False, download=True, transform=mnist_transform)

In [None]:
train_size = int(len(mnist_full)*0.7)
val_size = len(mnist_full) - train_size
train_dataset, val_dataset = torch.utils.data.random_split(mnist_full, [train_size, val_size])

In [None]:
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=64,shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

In [None]:
class CNN(L.LightningModule):
  def __init__(self):
    hidden_units = 64
    super().__init__()
    self.model = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Flatten(),
            nn.Dropout(0.6),
            nn.Linear(128 * 3 * 3, hidden_units),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(hidden_units, 10)
        )

  def forward(self, x):
    return self.model(x)

  def training_step(self, batch, batch_idx):
    x, y = batch
    logits = self(x)
    loss = F.cross_entropy(logits, y)
    preds = torch.argmax(logits, dim=1)
    accuracy = torch.sum(preds == y).item()/len(y)
    self.log("train_loss", loss)
    self.log("train_acc", accuracy, prog_bar=True)
    return loss

  def validation_step(self, batch, batch_idx):
    x,y = batch
    logits = self(x)
    loss = F.cross_entropy(logits,y)
    preds = torch.argmax(logits, dim=1)
    accuracy = torch.sum(preds==y).item()/len(y)
    self.log("train_loss", loss)
    self.log("train_acc", accuracy, prog_bar=True)

  def test_step(self, batch, batch_idx):
    x,y = batch
    logits = self(x)
    loss = F.cross_entropy(logits,y)
    preds = torch.argmax(logits, dim=1)
    accuracy = torch.sum(preds==y).item()/len(y)
    self.log('test_loss', loss, prog_bar=True)
    self.log('test_acc', accuracy, prog_bar=True)
    self.test_preds.append(preds.cpu().numpy())
    self.test_labels.append(y.cpu().numpy())

  def configure_optimizers(self):
    optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
    return optimizer

  def on_test_epoch_start(self):
    self.test_preds = []
    self.test_labels = []

In [None]:
model = CNN()

In [None]:
trainer = L.Trainer(max_epochs=5, logger = L.loggers.TensorBoardLogger('logs/', name = 'mnist_model'))

In [None]:
from IPython.core.display import ProgressBar
trainer.fit(model, train_loader, val_loader)

In [None]:
trainer.test(model,test_loader)

In [None]:
# trainer.test(model, test_loader)
preds = np.concatenate(model.test_preds)
labels = np.concatenate(model.test_labels)
conf_matrix = confusion_matrix(labels, preds)
print(conf_matrix)

In [None]:
conf_matrix = confusion_matrix(labels, preds)
plt.figure(figsize=(10, 8))
sns.heatmap(conf_matrix, annot = True, fmt = 'd', cmap = 'Blues', cbar = True)
plt.xlabel('Predicted labels')
plt.ylabel('True labels')
plt.title('Confusion Matrix ')

# Show the plot
plt.show()