In [1]:
import torch
import torch.nn as nn
import pytorch_lightning as pl
from torchmetrics import Accuracy

  from .autonotebook import tqdm as notebook_tqdm


In [27]:
class MultiLayer( pl.LightningModule ):
    
    def __init__(self, image_shape = (1,28,28), hidden_units = (32,16) ):
        super().__init__()

        # Accuracies.
        self.train_acc = Accuracy(task="multiclass", num_classes=10)
        self.valid_acc = Accuracy(task="multiclass", num_classes=10)
        self.test_acc = Accuracy(task="multiclass", num_classes=10)

        self.training_step_outputs = []

        # Model.
        input_size = image_shape[0] * image_shape[1] * image_shape[2]
        all_layers = [nn.Flatten()]

        for unit in hidden_units:
            all_layers.append(nn.Linear(input_size,unit))
            all_layers.append(nn.ReLU())
            input_size = unit

        all_layers.append( nn.Linear(hidden_units[-1], 10 ) )
        self.Model = nn.Sequential( *all_layers )
        
    
    def forward(self, x):
        return( self.Model(x) )

    def training_step(self, batch, batch_idx):
        x,y = batch
        logits = self(x)                          # Logits are the outputs before activation.
        loss = nn.functional.cross_entropy( logits, y )
        pred = torch.argmax( logits, dim=1 )
        self.train_acc.update( pred, y )
        self.log("Train_loss", loss, prog_bar = True )
        self.training_step_outputs.append(loss)
        return loss
    
    def validation_step(self, batch, batch_idx):
        x,y = batch
        logits = self(x)
        loss = nn.functional.cross_entropy( logits, y )
        pred = torch.argmax( logits, dim =1 )
        self.valid_acc.update( pred, y )
        self.log( "Eval_loss", loss, prog_bar=True )
        self.log( "Eval_acc", self.valid_acc.compute(), prog_bar = True ) 
        return loss

    def test_step(self, batch, batch_idx):
        x,y = batch
        logits = self(x)
        loss = nn.functional.cross_entropy( logits, y )
        pred = torch.argmax( logits, dim = 1)
        self.test_acc.update( pred, y )
        self.log("Test_loss", loss, prog_bar = True )
        self.log("Test_acc", self.test_acc.compute() )
        
    def on_train_epoch_end(self):
        epoch_average = torch.stack(self.training_step_outputs).mean()
        self.log("training_epoch_average", epoch_average)
        self.training_step_outputs.clear()  # free memory


    def configure_optimizers(self):
        optimizer = torch.optim.Adam( lr=0.01, params=self.Model.parameters() )
        return optimizer

In [8]:
from torchvision import transforms
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader
from torch.utils.data import random_split

In [39]:
class MnistDataModule( pl.LightningDataModule ):

    def __init__(self, data_path = "./" ):
        super().__init__()
        self.data_path = data_path
        self.transform = transforms.Compose( [transforms.ToTensor()] )
    
    def prepare_data(self):
        MNIST( root=self.data_path, download=True )

    def setup(self, stage=None):
        mnist_all = MNIST( root = self.data_path, download = False, train= True, transform=self.transform ) 
        self.train, self.val = random_split( mnist_all, [55000,5000], generator=torch.Generator().manual_seed(0) )
        self.test = MNIST( train = False, root = self.data_path, transform=self.transform, download=False )
        
    def train_dataloader(self):
        return DataLoader( self.train, batch_size=64, num_workers=4 )
                         
    def test_dataloader(self):
        return DataLoader( self.test, batch_size=64, num_workers=4 )
                        
    def val_dataloader(self):
        return DataLoader( self.val, batch_size=64, num_workers=4 )

In [40]:
torch.manual_seed(0)
mnist = MnistDataModule()

In [41]:
clf = MultiLayer()

In [42]:
trainer = pl.Trainer( max_epochs = 10 )

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


In [None]:
trainer.fit( datamodule = mnist, model = clf )

In [45]:
trainer.test( datamodule = mnist, model = clf )

Z:\python\Lib\site-packages\pytorch_lightning\trainer\connectors\data_connector.py:419: Consider setting `persistent_workers=True` in 'test_dataloader' to speed up the dataloader worker initialization.


Testing DataLoader 0: 100%|█████████████████████████████████████████████████████████| 157/157 [00:00<00:00, 249.60it/s]


[{'Test_loss': 0.1755753457546234, 'Test_acc': 0.9549301862716675}]