<a href="https://colab.research.google.com/github/Abrazacs/Anylogic_trial/blob/master/lightning_MNIST.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install pytorch-lightning
import pytorch_lightning as pl
import torch
import torch.nn as nn
from torchmetrics import Accuracy

In [12]:
import pytorch_lightning as pl
import torch
import torch.nn as nn
from torchmetrics import Accuracy

class MultyLayerPerceptron(pl.LightningModule):
  def __init__(self, image_shape = (1, 28, 28), hidden_units=(32, 16)):
    super().__init__()
    self.train_acc = Accuracy(task='multiclass', num_classes=10)
    self.valid_acc = Accuracy(task='multiclass', num_classes=10)
    self.test_acc = Accuracy(task='multiclass', num_classes=10)

    input_size = image_shape[0]*image_shape[1]*image_shape[2]
    all_layers = [nn.Flatten()]
    for hidden_unit in hidden_units:
      layer = nn.Linear(input_size, hidden_unit)
      all_layers.append(layer)
      all_layers.append(nn.ReLU())
      input_size = hidden_unit

    all_layers.append(nn.Linear(hidden_units[-1], 10))
    self.layers = nn.Sequential(*all_layers)

  def forward(self, x):
    x = self.layers(x)
    return x

  def training_step(self, batch, batch_idx):
    x, y = batch
    logits = self(x)
    loss = nn.functional.cross_entropy(logits, y)
    preds = torch.argmax(logits, dim=1)
    self.train_acc.update(preds, y)
    self.log("train_loss", loss, prog_bar=True)
    return loss

  def on_train_epoch_end(self):
    self.log("train_acc_epoch", self.train_acc.compute(), prog_bar=True)
    self.train_acc.reset()

  def validation_step(self, batch, batch_idx):
    x, y = batch
    logits = self(x)
    loss = nn.functional.cross_entropy(logits, y)
    preds = torch.argmax(logits, dim=1)
    self.valid_acc.update(preds, y)
    self.log("valid_loss", loss, prog_bar=True)
    return loss

  def on_validation_epoch_end(self):
    self.log("valid_acc_epoch", self.valid_acc.compute(), prog_bar=True)
    self.valid_acc.reset()


  def test_step(self, batch, batch_idx):
    x, y = batch
    logits = self(x)
    loss = nn.functional.cross_entropy(logits, y)
    preds = torch.argmax(logits, dim=1)
    self.test_acc.update(preds, y)
    self.log("test_loss", loss, prog_bar=True)
    return loss

  def on_test_epoch_end(self):
    self.log("test_acc_epoch", self.test_acc.compute(), prog_bar=True)
    self.test_acc.reset()

  def configure_optimizers(self):
    optimizer = torch.optim.Adam(self.parameters(), lr=0.001)
    return optimizer

In [13]:
from torch.utils.data import DataLoader
from torch.utils.data import random_split
from torchvision.datasets import MNIST
from torchvision import transforms

class MnistDataModule(pl.LightningDataModule):
  def __init__(self, data_path='./'):
    super().__init__()
    self.data_path = data_path
    self.transform = transforms.Compose([transforms.ToTensor()])

  def prepare_data(self):
    MNIST(root=self.data_path, download=True)

  def setup(self, stage=None):
    mnist = MNIST(root=self.data_path, train=True, transform=self.transform)
    self.train_set, self.valid_set = random_split(
        mnist, [55000, 5000], generator=torch.Generator().manual_seed(1)
    )
    self.test_set = MNIST(root=self.data_path, train=False, transform=self.transform)

  def train_dataloader(self):
    return DataLoader(self.train_set, batch_size=64, num_workers=4)

  def val_dataloader(self):
    return DataLoader(self.valid_set, batch_size=64, num_workers=4)

  def test_dataloader(self):
    return DataLoader(self.test_set, batch_size=64, num_workers=4)

In [None]:
torch.manual_seed(1)
mnist_dm = MnistDataModule()
mnist_classifier = MultyLayerPerceptron()

if torch.cuda.is_available():
  trainer = pl.Trainer(max_epochs=10, gpus=1)
else:
  trainer = pl.Trainer(max_epochs=10)

trainer.fit(model=mnist_classifier, datamodule=mnist_dm)


In [None]:
!pip install tensorboard
%load_ext tensorboard
%tensorboard --logdir lightning_logs