In [14]:
import os
from torch import optim, nn
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor
from torch.utils.data import DataLoader
import torch.nn.functional as F
import pytorch_lightning as pl

# define any number of nn.Modules (or use your current ones)
encoder = nn.Sequential(nn.Linear(28 * 28, 64), nn.ReLU(), nn.Linear(64, 3))
decoder = nn.Sequential(nn.Linear(3, 64), nn.ReLU(), nn.Linear(64, 28 * 28))

# define the LightningModule
class LitAutoEncoder(pl.LightningModule):
    def __init__(self, encoder, decoder):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder

    def training_step(self, batch, batch_idx):
        # train_step defines the train loop.
        # it is independent of forward
        x, y = batch
        x_hat, loss = self.forward(x)
        self.log("train/loss", loss, on_step=True, on_epoch=True, prog_bar=True)
        self.log("train/acc", 0.5, on_step=True, on_epoch=True, prog_bar=True)
        return {"loss": loss, "acc": 0.5, "loss2": loss}
    
    def validation_step(self, batch, batch_idx):
        x, y = batch
        x_hat, loss = self.forward(x)
        self.log("val/loss", loss)
    
    def test_step(self, batch, batch_idx):
        x, y = batch
        x_hat, loss = self.forward(x)
        # self.log("test/loss", loss)

    def forward(self, x):
        x = x.view(x.size(0), -1)
        z = self.encoder(x)
        x_hat = self.decoder(z)
        loss = F.mse_loss(x_hat, x)
        return x_hat, loss
    
    def configure_optimizers(self):
        optimizer = optim.Adam(self.parameters(), lr=1e-3)
        return optimizer
    
    ## HOOKS
    def on_train_batch_start(self, batch, batch_idx):
        if batch_idx == 0:
            print("on_train_batch_start")
            print(type(batch))
            print(batch[0].shape)

    def on_train_batch_end(self, out, batch, batch_idx):
        if batch_idx == 0:
            print("on_train_batch_end")
            print(out)

    def on_train_epoch_end(self, outputs):
        print("on_train_epoch_end")
        print(outputs)


        

In [11]:
# setup data
dataset = MNIST("~/data/", download=True, transform=ToTensor())
train_loader = DataLoader(dataset, num_workers=8, batch_size=128, shuffle=True)
test_loader = DataLoader(dataset, num_workers=8, batch_size=128, shuffle=False)

In [12]:
# DRY RUN
# logger = pl.loggers.TensorBoardLogger("tb_logs", name="autoencoder")
# trainer = pl.Trainer(max_epochs=2, 
#                         accelerator="cpu", devices=1,
#                         logger=logger, fast_dev_run=True)
# trainer.fit(model=autoencoder, 
#     train_dataloaders=train_loader,
    # val_dataloaders=test_loader)

In [15]:
# CREATE MODEL

autoencoder = LitAutoEncoder(encoder, decoder)

# TRAIN FROM SCRATCH

trainer = pl.Trainer(max_epochs=1, 
                        accelerator="mps", devices=1,
                        logger=True, fast_dev_run=False)
trainer.fit(model=autoencoder, 
    train_dataloaders=train_loader,
    val_dataloaders=test_loader)

log_dir = trainer.log_dir

GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs

  | Name    | Type       | Params | Mode 
-----------------------------------------------
0 | encoder | Sequential | 50.4 K | train
1 | decoder | Sequential | 51.2 K | train
-----------------------------------------------
101 K     Trainable params
0         Non-trainable params
101 K     Total params
0.407     Total estimated model params size (MB)
8         Modules in train mode
0         Modules in eval mode


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

/opt/homebrew/Caskroom/miniconda/base/envs/deepl/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:420: Consider setting `persistent_workers=True` in 'val_dataloader' to speed up the dataloader worker initialization.
/opt/homebrew/Caskroom/miniconda/base/envs/deepl/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:420: Consider setting `persistent_workers=True` in 'train_dataloader' to speed up the dataloader worker initialization.


Training: |          | 0/? [00:00<?, ?it/s]

on_train_batch_start
<class 'list'>
torch.Size([128, 1, 28, 28])
on_train_batch_end
{'loss': tensor(0.1338, device='mps:0'), 'acc': 0.5, 'loss2': tensor(0.1338, device='mps:0', grad_fn=<MseLossBackward0>)}


Validation: |          | 0/? [00:00<?, ?it/s]

TypeError: LitAutoEncoder.on_train_epoch_end() missing 1 required positional argument: 'outputs'

In [11]:
# RESUME CHEKPOINT
trainer = pl.Trainer(max_epochs=5, 
                        accelerator="cpu", devices=1,
                        logger=True, fast_dev_run=False)
trainer.fit(model=autoencoder, 
    train_dataloaders=train_loader,
    val_dataloaders=test_loader,
    # ckpt_path= log_dir + "/checkpoints/epoch=2-step=1407.ckpt"
    )

GPU available: True (mps), used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
Restoring states from the checkpoint path at /Users/carlo/.julia/dev/Tsunami/examples/pytorch-lighting-examples/lightning_logs/version_3/checkpoints/epoch=2-step=1407.ckpt

  | Name    | Type       | Params
---------------------------------------
0 | encoder | Sequential | 50.4 K
1 | decoder | Sequential | 51.2 K
---------------------------------------
101 K     Trainable params
0         Non-trainable params
101 K     Total params
0.407     Total estimated model params size (MB)
Restored all states from the checkpoint file at /Users/carlo/.julia/dev/Tsunami/examples/pytorch-lighting-examples/lightning_logs/version_3/checkpoints/epoch=2-step=1407.ckpt


Sanity Checking: 0it [00:00, ?it/s]

Training: 469it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

`Trainer.fit` stopped: `max_epochs=5` reached.


In [29]:
test_results = trainer.test(model=autoencoder, dataloaders=test_loader, verbose=True)

Testing: 0it [00:00, ?it/s]

In [30]:
test_results

[{'test/loss': 0.04163273051381111}]