In [None]:
# import tqdm.notebook as tqdm

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import pytorch_lightning as pl
from torch.utils.data import DataLoader, random_split
from torchvision.datasets import MNIST
import torchvision.transforms as T
from torchmetrics import Accuracy

In [None]:
class ImageVNN(pl.LightningModule):
    def __init__(self, image_shape: tuple[int], hidden_layers_n: tuple[int], classes_n: int):
        super().__init__()
        # PL attributes
        self.train_acc = Accuracy()
        self.valid_acc = Accuracy()
        self.test_acc = Accuracy()
        # Model
        layers = [nn.Flatten()]
        input_size = image_shape[0] * image_shape[1] * image_shape[2] 
        for n in hidden_layers_n:
            layers.append(nn.Sequential(
                                nn.Linear(input_size, n), 
                                nn.LeakyReLU()))
            input_size = n
        layers.append(nn.Linear(input_size, classes_n))
        self.model = nn.Sequential(*layers)
        
    def forward(self, x, *_, **__):
        return self.model(x)
    
    def training_step(self, batch, *_, **__):
        x, y = batch
        outputs = self(x)
        loss = F.cross_entropy(outputs, y)
        predictions = outputs.argmax(axis=1)
        self.train_acc.update(predictions, y)
        self.log("train_loss", loss, prog_bar=True)
        return loss
    
    def training_epoch_end(self, *_, **__):
        self.log("train_acc", self.train_acc.compute())
        
    def validation_step(self, batch, *_, **__):
        x, y = batch
        outputs = self(x)
        loss = F.cross_entropy(outputs, y)
        predictions = outputs.argmax(axis=1)
        self.valid_acc.update(predictions, y)
        self.log("validation_loss", loss, prog_bar=True)
        return loss
    
    def validation_epoch_end(self, outs, *_, **__):
        self.log("valididation_acc", self.valid_acc.compute())
        
    def test_step(self, batch, *_, **__):
        x, y = batch
        outputs = self(x)
        loss = F.cross_entropy(outputs, y)
        predictions = outputs.argmax(axis=1)
        self.test_acc.update(predictions, y)
        self.log("test_loss", loss, prog_bar=True)
        return loss
    
    def test_epoch_end(self, *_, **__):
        self.log("test_acc", self.test_acc.compute())
        
    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(self.parameters(), lr=0.01)
        return [optimizer], []

In [16]:
class MNISTDataModule(pl.LightningDataModule):
    def __init__(self, path='./mnist', download=True, batch_size=265):
        super().__init__()
        self.path = path
        self.download = download
        self.batch_size = batch_size
        self.transformer = T.Compose([T.ToTensor()])
        
    def prepare_data(self):
        MNIST(root=self.path, download=self.download)
        
    def setup(self, *_, **__):
        train_val_data = MNIST(root=self.path, train=True, transform=self.transformer)
        self.train_data, self.val_data = random_split(train_val_data, [55_000, 5_000])
        self.test_data = MNIST(root=self.path, train=False, transform=self.transformer)
    
    def train_dataloader(self):
        return DataLoader(self.train_data, batch_size=self.batch_size, num_workers=8)
    
    def val_dataloader(self):
        return DataLoader(self.val_data, batch_size=self.batch_size, num_workers=8)
    
    def test_dataloader(self):
        return DataLoader(self.test_data, batch_size=self.batch_size, num_workers=8)

In [None]:
model = ImageVNN((1, 28, 28), (32, 16), 10)
trainer = pl.Trainer(max_epochs=10)
data_model = MNISTDataModule(download=False)

In [None]:
trainer.fit(model, data_model);

In [17]:
trainer.test(model, MNISTDataModule(download=False))

Testing: 0it [00:00, ?it/s]

[{'test_loss': 0.13443806767463684, 'test_acc': 0.9670000076293945}]