# Level 5: Debug, visualize and find performance bottlenecks

## Debug your model (basic)

In [10]:
import torch
from torch import nn
from torch.utils.data import DataLoader, random_split
import torch.nn.functional as F

from lightning import LightningModule
from lightning import Trainer
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch.callbacks import EarlyStopping
from lightning.pytorch.callbacks import ModelSummary
from lightning.pytorch.loggers import CSVLogger

from torchvision import datasets
import torchvision.transforms as transforms

import os

In [11]:
# Load data sets
transform = transforms.ToTensor()
train_set = datasets.MNIST(root="../data/MNIST", download=True, train=True, transform=transform)
test_set = datasets.MNIST(root="../data/MNIST", download=True, train=False, transform=transform)

# use 20% of training data for validation
train_set_size = int(len(train_set) * 0.8)
valid_set_size = len(train_set) - train_set_size

# split the train set into two
seed = torch.Generator().manual_seed(42)
train_set, valid_set = random_split(train_set, [train_set_size, valid_set_size], generator=seed)

train_loader = DataLoader(train_set, batch_size=128, shuffle=True, num_workers=16, persistent_workers=True, pin_memory=True)
valid_loader = DataLoader(valid_set, batch_size=128, num_workers=16, persistent_workers=True, pin_memory=True)
test_loader = DataLoader(test_set, batch_size=1024, num_workers=16, persistent_workers=True, pin_memory=True)

In [12]:
class Encoder(nn.Module):
    def __init__(self, in_dim=28*28, hidden_nodes_1=64, hidden_nodes_2=64, out_dim=4):
        super().__init__()
        self.ff = nn.Sequential(
            nn.Linear(in_dim, hidden_nodes_1),
            nn.ReLU(),
            nn.Linear(hidden_nodes_1, hidden_nodes_2),
            nn.ReLU(),
            nn.Linear(hidden_nodes_2, out_dim)
        )

    def forward(self, x):
        return self.ff(x)

class Decoder(nn.Module):
    def __init__(self, in_dim=4, hidden_nodes_1=64, hidden_nodes_2=64, out_dim=28*28):
        super().__init__()
        self.ff = nn.Sequential(
            nn.Linear(in_dim, hidden_nodes_1),
            nn.ReLU(),
            nn.Linear(hidden_nodes_1, hidden_nodes_2),
            nn.ReLU(),
            nn.Linear(hidden_nodes_2, out_dim)
        )

    def forward(self, x):
        return self.ff(x)

class LitAutoEncoder(LightningModule):
    def __init__(self, encoder, decoder, lr=1e-5, example_input_array=None):
        super().__init__()
        self.example_input_array = example_input_array
        self.encoder = encoder
        self.decoder = decoder
        self.lr = lr
        self.save_hyperparameters(ignore=["encoder", "decoder"])

    def training_step(self, batch, batch_idx):
        # training_step defines the train loop.
        loss = self._get_loss(batch)
        self.log("train/loss", loss, on_step=True, on_epoch=True)
        self.log("train_loss", loss, prog_bar=True, on_step=False, on_epoch=True)
        return loss

    def validation_step(self, batch, batch_idx):
        # this is the validation loop
        loss = self._get_loss(batch)
        self.log("val_loss", loss, prog_bar=True)

    def test_step(self, batch, batch_idx):
        # this is the test loop
        loss = self._get_loss(batch)
        self.log("test_loss", loss)

    def _get_loss(self, batch):
        x, _ = batch
        x = x.view(x.size(0), -1)
        x_hat = self.forward(x)
        loss = F.mse_loss(x_hat, x)
        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)
        return optimizer

    def forward(self, x):
        z = self.encoder(x)
        x_hat = self.decoder(z)
        return x_hat

In [13]:
model = LitAutoEncoder(
    encoder=Encoder(
        in_dim=28*28,
        hidden_nodes_1=512,
        hidden_nodes_2=256,
        out_dim=100
    ),
    decoder=Decoder(
        in_dim=100,
        hidden_nodes_1=128,
        hidden_nodes_2=256,
        out_dim=28*28
    )
)

In [14]:
logger = CSVLogger(
    save_dir='logs',
    name='debugging_tests',
    version=None,
    prefix='test_'
)

checkpoint_callback = ModelCheckpoint(
    dirpath=os.path.join(logger.log_dir, "checkpoints"),
    filename="autoencoder_best-{epoch:02d}-{val_loss:.3f}",
    monitor="val_loss",    
    mode="min",
    save_top_k=3,     # keep ONLY the best
    save_last=True    # ALSO save last.ckpt
)

early_stop_callback = EarlyStopping(
    monitor="val_loss",
    patience=3,
    verbose=False,
    mode="min"
)

### Run all your model code once quickly
see [here](https://lightning.ai/docs/pytorch/stable/debug/debugging_basic.html#run-all-your-model-code-once-quickly:~:text=x**2%20line.-,Run%20all%20your%20model%20code%20once%20quickly,-If%20you%E2%80%99ve%20ever)

In [15]:
trainer_test = Trainer(
    fast_dev_run=10,
    accelerator="gpu",
    devices=1,
    max_epochs=10
)

ðŸ’¡ Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
Running in `fast_dev_run` mode: will run the requested loop using 10 batch(es). Logging and checkpointing is suppressed.


In [16]:
trainer_test.fit(model, train_loader, valid_loader)
trainer_test.test(model, test_loader)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3,4,5,6,7]

  | Name    | Type    | Params | Mode  | FLOPs
----------------------------------------------------
0 | encoder | Encoder | 558 K  | train | 0    
1 | decoder | Decoder | 247 K  | train | 0    
----------------------------------------------------
806 K     Trainable params
0         Non-trainable params
806 K     Total params
3.226     Total estimated model params size (MB)
14        Modules in train mode
0         Modules in eval mode
0         Total Flops


Epoch 0: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 10/10 [00:00<00:00, 18.25it/s, val_loss=0.113, train_loss=0.110]

`Trainer.fit` stopped: `max_steps=10` reached.


Epoch 0: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 10/10 [00:00<00:00, 18.17it/s, val_loss=0.113, train_loss=0.110]


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3,4,5,6,7]


Testing DataLoader 0: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 10/10 [00:00<00:00, 130.08it/s]
â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€
       Test metric             DataLoader 0
â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€
        test_loss           0.1138618141412735
â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€

[{'test_loss': 0.1138618141412735}]

In [17]:
trainer = Trainer(
    logger=logger,
    callbacks=[checkpoint_callback, early_stop_callback],
    accelerator="gpu",
    devices=1,
    max_epochs=10
)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores


In [18]:
trainer.fit(model, train_loader, valid_loader)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3,4,5,6,7]

  | Name    | Type    | Params | Mode  | FLOPs
----------------------------------------------------
0 | encoder | Encoder | 558 K  | train | 0    
1 | decoder | Decoder | 247 K  | train | 0    
----------------------------------------------------
806 K     Trainable params
0         Non-trainable params
806 K     Total params
3.226     Total estimated model params size (MB)
14        Modules in train mode
0         Modules in eval mode
0         Total Flops


Epoch 0:  88%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–Š | 329/375 [00:05<00:00, 65.30it/s, v_num=2]         


Detected KeyboardInterrupt, attempting graceful shutdown ...


SystemExit: 1

In [None]:
# Access human-readable message
if early_stop_callback.stopping_reason_message:
    print(f"Details: {early_stop_callback.stopping_reason_message}")
else:
    print("Early stopping was not triggered.")

Early stopping was not triggered.


### Shorten the epoch length

In [None]:
# use only 10% of training data and 1% of val data
trainer_1 = Trainer(
    limit_train_batches=0.1,
    limit_val_batches=0.1,
    logger=logger,
    callbacks=[checkpoint_callback, early_stop_callback],
    accelerator="gpu",
    devices=1,
    max_epochs=10
)

# use 10 batches of train and 5 batches of val
trainer_2 = Trainer(
    limit_train_batches=10,
    limit_val_batches=5,
    logger=logger,
    callbacks=[checkpoint_callback, early_stop_callback],
    accelerator="gpu",
    devices=1,
    max_epochs=10
)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores


In [None]:
trainer_1.fit(model, train_loader, valid_loader)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3,4,5,6,7]

  | Name    | Type    | Params | Mode  | FLOPs
----------------------------------------------------
0 | encoder | Encoder | 558 K  | train | 0    
1 | decoder | Decoder | 247 K  | train | 0    
----------------------------------------------------
806 K     Trainable params
0         Non-trainable params
806 K     Total params
3.226     Total estimated model params size (MB)
14        Modules in train mode
0         Modules in eval mode
0         Total Flops


                                                                            

/storage3/DSIP/rriva/tutorials/pl_tutorial/.venv/lib/python3.13/site-packages/lightning/pytorch/loops/fit_loop.py:317: The number of training batches (37) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.


Epoch 9: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 37/37 [00:00<00:00, 88.23it/s, v_num=0, val_loss=0.0318, train_loss=0.0317] 

`Trainer.fit` stopped: `max_epochs=10` reached.


Epoch 9: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 37/37 [00:00<00:00, 66.12it/s, v_num=0, val_loss=0.0318, train_loss=0.0317]


In [None]:
trainer_2.fit(model, train_loader, valid_loader)

/storage3/DSIP/rriva/tutorials/pl_tutorial/.venv/lib/python3.13/site-packages/lightning/fabric/plugins/environments/slurm.py:204: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python /storage3/DSIP/rriva/tutorials/pl_tutorial/.venv/lib ...
/storage3/DSIP/rriva/tutorials/pl_tutorial/.venv/lib/python3.13/site-packages/lightning/pytorch/callbacks/model_checkpoint.py:881: Checkpoint directory /storage3/DSIP/rriva/tutorials/pl_tutorial/basic/logs/debugging_tests/version_0/checkpoints exists and is not empty.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3,4,5,6,7]

  | Name    | Type    | Params | Mode  | FLOPs
----------------------------------------------------
0 | encoder | Encoder | 558 K  | train | 0    
1 | decoder | Decoder | 247 K  | train | 0    
----------------------------------------------------
806 K     Trainable params
0         Non-trainable params
806

                                                                            

/storage3/DSIP/rriva/tutorials/pl_tutorial/.venv/lib/python3.13/site-packages/lightning/pytorch/loops/fit_loop.py:317: The number of training batches (10) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.


Epoch 2: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 10/10 [00:00<00:00, 30.44it/s, v_num=0, val_loss=0.0319, train_loss=0.0313]


### Print LightningModule weights summary
Whenever the `.fit()` function gets called, the Trainer will print the weights summary for the LightningModule.

To add the child modules to the summary add a `ModelSummary`

In [None]:
from lightning.pytorch.callbacks import ModelSummary

trainer = Trainer(
    callbacks=[ModelSummary(max_depth=-1)],
    accelerator="gpu",
    devices=1,
    max_epochs=10
)

/storage3/DSIP/rriva/tutorials/pl_tutorial/.venv/lib/python3.13/site-packages/lightning/fabric/plugins/environments/slurm.py:204: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python /storage3/DSIP/rriva/tutorials/pl_tutorial/.venv/lib ...
ðŸ’¡ Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry.
Trainer already configured with model summary callbacks: [<class 'lightning.pytorch.callbacks.model_summary.ModelSummary'>]. Skipping setting a default `ModelSummary` callback.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores


In [None]:
trainer.fit(model, train_loader, valid_loader)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3,4,5,6,7]


RuntimeError: mat1 and mat2 shapes cannot be multiplied (896x28 and 784x512)

Epoch 0:   0%|          | 0/10 [15:11<?, ?it/s]
Epoch 0:   0%|          | 0/10 [02:43<?, ?it/s]


To print the model summary if `.fit()` is not called:

In [None]:
model = LitAutoEncoder(Encoder(), Decoder())
summary = ModelSummary(model, max_depth=-1)
print(summary)

   | Name         | Type       | Params | Mode  | FLOPs
-------------------------------------------------------------
0  | encoder      | Encoder    | 54.7 K | train | 0    
1  | encoder.ff   | Sequential | 54.7 K | train | 0    
2  | encoder.ff.0 | Linear     | 50.2 K | train | 0    
3  | encoder.ff.1 | ReLU       | 0      | train | 0    
4  | encoder.ff.2 | Linear     | 4.2 K  | train | 0    
5  | encoder.ff.3 | ReLU       | 0      | train | 0    
6  | encoder.ff.4 | Linear     | 260    | train | 0    
7  | decoder      | Decoder    | 55.4 K | train | 0    
8  | decoder.ff   | Sequential | 55.4 K | train | 0    
9  | decoder.ff.0 | Linear     | 320    | train | 0    
10 | decoder.ff.1 | ReLU       | 0      | train | 0    
11 | decoder.ff.2 | Linear     | 4.2 K  | train | 0    
12 | decoder.ff.3 | ReLU       | 0      | train | 0    
13 | decoder.ff.4 | Linear     | 51.0 K | train | 0    
-------------------------------------------------------------
110 K     Trainable params
0        

To turn off the autosummary use:

In [None]:
trainer = Trainer(
    enable_model_summary=False,
    accelerator="gpu",
    devices=1,
    max_epochs=10
)

/storage3/DSIP/rriva/tutorials/pl_tutorial/.venv/lib/python3.13/site-packages/lightning/fabric/plugins/environments/slurm.py:204: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python /storage3/DSIP/rriva/tutorials/pl_tutorial/.venv/lib ...
ðŸ’¡ Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores


In [None]:
trainer.fit(model, train_loader, valid_loader)

/storage3/DSIP/rriva/tutorials/pl_tutorial/.venv/lib/python3.13/site-packages/lightning/fabric/plugins/environments/slurm.py:204: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python /storage3/DSIP/rriva/tutorials/pl_tutorial/.venv/lib ...
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3,4,5,6,7]


Epoch 9: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 375/375 [00:03<00:00, 108.63it/s, v_num=1, val_loss=0.0624, train_loss=0.0628]

`Trainer.fit` stopped: `max_epochs=10` reached.


Epoch 9: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 375/375 [00:03<00:00, 106.10it/s, v_num=1, val_loss=0.0624, train_loss=0.0628]


### Print input output layer dimensions
Another debugging tool is to display the intermediate input- and output sizes of all your layers by setting the `example_input_array` attribute in your LightningModule.

In [None]:
class LitModel(LightningModule):
    def __init__(self, *args, **kwargs):
        super().__init__()
        self.example_input_array = torch.Tensor(32, 1, 28, 28)

In [None]:
model = LitAutoEncoder(
    encoder=Encoder(
        in_dim=28*28,
        hidden_nodes_1=512,
        hidden_nodes_2=256,
        out_dim=100
    ),
    decoder=Decoder(
        in_dim=100,
        hidden_nodes_1=128,
        hidden_nodes_2=256,
        out_dim=28*28
    ),
    example_input_array=torch.Tensor(32, 1, 28*28)
)

In [None]:
trainer = Trainer(
    callbacks=[ModelSummary(max_depth=-1)],
    accelerator="gpu",
    devices=1,
    max_epochs=10
)

/storage3/DSIP/rriva/tutorials/pl_tutorial/.venv/lib/python3.13/site-packages/lightning/fabric/plugins/environments/slurm.py:204: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python /storage3/DSIP/rriva/tutorials/pl_tutorial/.venv/lib ...
ðŸ’¡ Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry.
Trainer already configured with model summary callbacks: [<class 'lightning.pytorch.callbacks.model_summary.ModelSummary'>]. Skipping setting a default `ModelSummary` callback.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores


In [None]:
trainer.fit(model, train_loader, valid_loader)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3,4,5,6,7]

   | Name         | Type       | Params | Mode  | FLOPs  | In sizes     | Out sizes   
--------------------------------------------------------------------------------------------
0  | encoder      | Encoder    | 558 K  | train | 35.7 M | [32, 1, 784] | [32, 1, 100]
1  | encoder.ff   | Sequential | 558 K  | train | 35.7 M | [32, 1, 784] | [32, 1, 100]
2  | encoder.ff.0 | Linear     | 401 K  | train | 25.7 M | [32, 1, 784] | [32, 1, 512]
3  | encoder.ff.1 | ReLU       | 0      | train | 0      | [32, 1, 512] | [32, 1, 512]
4  | encoder.ff.2 | Linear     | 131 K  | train | 8.4 M  | [32, 1, 512] | [32, 1, 256]
5  | encoder.ff.3 | ReLU       | 0      | train | 0      | [32, 1, 256] | [32, 1, 256]
6  | encoder.ff.4 | Linear     | 25.7 K | train | 1.6 M  | [32, 1, 256] | [32, 1, 100]
7  | decoder      | Decoder    | 247 K  | train | 15.8 M | [32, 1, 100] | [32, 1, 784]
8  | decoder.ff   | Sequential | 247 K  | train | 15.8 M | [32, 1, 

Epoch 9: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 375/375 [00:03<00:00, 116.06it/s, v_num=6, val_loss=0.0313, train_loss=0.0321]

`Trainer.fit` stopped: `max_epochs=10` reached.


Epoch 9: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 375/375 [00:03<00:00, 113.48it/s, v_num=6, val_loss=0.0313, train_loss=0.0321]


## Find bottlenecks in your code (basic)
Profiling helps you find bottlenecks in your code by capturing analytics such as how long a function takes or how much memory is used.

In [None]:
import torch
from torch import nn
from torch.utils.data import DataLoader, random_split
import torch.nn.functional as F

from lightning import LightningModule
from lightning import Trainer
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch.callbacks import EarlyStopping
from lightning.pytorch.callbacks import ModelSummary
from lightning.pytorch.loggers import CSVLogger

from torchvision import datasets
import torchvision.transforms as transforms

import os

In [None]:
# Load data sets
transform = transforms.ToTensor()
train_set = datasets.MNIST(root="../data/MNIST", download=True, train=True, transform=transform)
test_set = datasets.MNIST(root="../data/MNIST", download=True, train=False, transform=transform)

# use 20% of training data for validation
train_set_size = int(len(train_set) * 0.8)
valid_set_size = len(train_set) - train_set_size

# split the train set into two
seed = torch.Generator().manual_seed(42)
train_set, valid_set = random_split(train_set, [train_set_size, valid_set_size], generator=seed)

train_loader = DataLoader(train_set, batch_size=128, shuffle=True, num_workers=16, persistent_workers=True, pin_memory=True)
valid_loader = DataLoader(valid_set, batch_size=128, num_workers=16, persistent_workers=True, pin_memory=True)
test_loader = DataLoader(test_set, batch_size=1024, num_workers=16, persistent_workers=True, pin_memory=True)

In [None]:
class Encoder(nn.Module):
    def __init__(self, in_dim=28*28, hidden_nodes_1=64, hidden_nodes_2=64, out_dim=4):
        super().__init__()
        self.ff = nn.Sequential(
            nn.Linear(in_dim, hidden_nodes_1),
            nn.ReLU(),
            nn.Linear(hidden_nodes_1, hidden_nodes_2),
            nn.ReLU(),
            nn.Linear(hidden_nodes_2, out_dim)
        )

    def forward(self, x):
        return self.ff(x)

class Decoder(nn.Module):
    def __init__(self, in_dim=4, hidden_nodes_1=64, hidden_nodes_2=64, out_dim=28*28):
        super().__init__()
        self.ff = nn.Sequential(
            nn.Linear(in_dim, hidden_nodes_1),
            nn.ReLU(),
            nn.Linear(hidden_nodes_1, hidden_nodes_2),
            nn.ReLU(),
            nn.Linear(hidden_nodes_2, out_dim)
        )

    def forward(self, x):
        return self.ff(x)

class LitAutoEncoder(LightningModule):
    def __init__(self, encoder, decoder, lr=1e-5, example_input_array=None):
        super().__init__()
        self.example_input_array = example_input_array
        self.encoder = encoder
        self.decoder = decoder
        self.lr = lr
        self.save_hyperparameters(ignore=["encoder", "decoder"])

    def training_step(self, batch, batch_idx):
        # training_step defines the train loop.
        loss = self._get_loss(batch)
        self.log("train/loss", loss, on_step=True, on_epoch=True)
        self.log("train_loss", loss, prog_bar=True, on_step=False, on_epoch=True)
        return loss

    def validation_step(self, batch, batch_idx):
        # this is the validation loop
        loss = self._get_loss(batch)
        self.log("val_loss", loss, prog_bar=True)

    def test_step(self, batch, batch_idx):
        # this is the test loop
        loss = self._get_loss(batch)
        self.log("test_loss", loss)

    def _get_loss(self, batch):
        x, _ = batch
        x = x.view(x.size(0), -1)
        x_hat = self.forward(x)
        loss = F.mse_loss(x_hat, x)
        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)
        return optimizer

    def forward(self, x):
        z = self.encoder(x)
        x_hat = self.decoder(z)
        return x_hat

In [None]:
model = LitAutoEncoder(
    encoder=Encoder(
        in_dim=28*28,
        hidden_nodes_1=512,
        hidden_nodes_2=256,
        out_dim=100
    ),
    decoder=Decoder(
        in_dim=100,
        hidden_nodes_1=128,
        hidden_nodes_2=256,
        out_dim=28*28
    ),
    example_input_array=torch.Tensor(32, 1, 28*28)
)

In [None]:
logger = CSVLogger(
    save_dir='logs',
    name='find_bottlenecks',
    version=None
)

checkpoint_callback = ModelCheckpoint(
    dirpath=os.path.join(logger.log_dir, "checkpoints"),
    filename="autoencoder_best-{epoch:02d}-{val_loss:.3f}",
    monitor="val_loss",    
    mode="min",
    save_top_k=3,     # keep ONLY the best
    save_last=True    # ALSO save last.ckpt
)

early_stop_callback = EarlyStopping(
    monitor="val_loss",
    patience=3,
    verbose=False,
    mode="min"
)

model_summary = ModelSummary(-1)
print(model_summary)

<lightning.pytorch.callbacks.model_summary.ModelSummary object at 0x7fa33e1cf380>


In [None]:

from lightning.pytorch.callbacks import ModelSummary as MS_Callback
from lightning.pytorch.utilities.model_summary import ModelSummary

In [None]:
model_summary_callback = MS_Callback(max_depth=-1)
model_summary = ModelSummary(model,-1)
print(model_summary)
print(model_summary_callback)

   | Name         | Type       | Params | Mode  | FLOPs  | In sizes     | Out sizes   
--------------------------------------------------------------------------------------------
0  | encoder      | Encoder    | 558 K  | train | 35.7 M | [32, 1, 784] | [32, 1, 100]
1  | encoder.ff   | Sequential | 558 K  | train | 35.7 M | [32, 1, 784] | [32, 1, 100]
2  | encoder.ff.0 | Linear     | 401 K  | train | 25.7 M | [32, 1, 784] | [32, 1, 512]
3  | encoder.ff.1 | ReLU       | 0      | train | 0      | [32, 1, 512] | [32, 1, 512]
4  | encoder.ff.2 | Linear     | 131 K  | train | 8.4 M  | [32, 1, 512] | [32, 1, 256]
5  | encoder.ff.3 | ReLU       | 0      | train | 0      | [32, 1, 256] | [32, 1, 256]
6  | encoder.ff.4 | Linear     | 25.7 K | train | 1.6 M  | [32, 1, 256] | [32, 1, 100]
7  | decoder      | Decoder    | 247 K  | train | 15.8 M | [32, 1, 100] | [32, 1, 784]
8  | decoder.ff   | Sequential | 247 K  | train | 15.8 M | [32, 1, 100] | [32, 1, 784]
9  | decoder.ff.0 | Linear     | 12.9

### Find training loop bottlenecks
The most basic profile measures all the key methods across Callbacks, DataModules and the LightningModule in the training loop.

```
trainer = Trainer(profiler="simple")
```

In [None]:
trainer = Trainer(
    profiler='simple',
    logger=logger,
    callbacks=[checkpoint_callback, early_stop_callback, model_summary_callback],
    accelerator="gpu",
    devices=1,
    max_epochs=100
)

/storage3/DSIP/rriva/tutorials/pl_tutorial/.venv/lib/python3.13/site-packages/lightning/fabric/plugins/environments/slurm.py:204: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python /storage3/DSIP/rriva/tutorials/pl_tutorial/.venv/lib ...
Trainer already configured with model summary callbacks: [<class 'lightning.pytorch.callbacks.model_summary.ModelSummary'>]. Skipping setting a default `ModelSummary` callback.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores


In [None]:
trainer.fit(model, train_loader, valid_loader)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3,4,5,6,7]

   | Name         | Type       | Params | Mode  | FLOPs  | In sizes     | Out sizes   
--------------------------------------------------------------------------------------------
0  | encoder      | Encoder    | 558 K  | train | 35.7 M | [32, 1, 784] | [32, 1, 100]
1  | encoder.ff   | Sequential | 558 K  | train | 35.7 M | [32, 1, 784] | [32, 1, 100]
2  | encoder.ff.0 | Linear     | 401 K  | train | 25.7 M | [32, 1, 784] | [32, 1, 512]
3  | encoder.ff.1 | ReLU       | 0      | train | 0      | [32, 1, 512] | [32, 1, 512]
4  | encoder.ff.2 | Linear     | 131 K  | train | 8.4 M  | [32, 1, 512] | [32, 1, 256]
5  | encoder.ff.3 | ReLU       | 0      | train | 0      | [32, 1, 256] | [32, 1, 256]
6  | encoder.ff.4 | Linear     | 25.7 K | train | 1.6 M  | [32, 1, 256] | [32, 1, 100]
7  | decoder      | Decoder    | 247 K  | train | 15.8 M | [32, 1, 100] | [32, 1, 784]
8  | decoder.ff   | Sequential | 247 K  | train | 15.8 M | [32, 1, 

Epoch 99: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 375/375 [00:03<00:00, 108.79it/s, v_num=4, val_loss=0.0107, train_loss=0.0107]

`Trainer.fit` stopped: `max_epochs=100` reached.


Epoch 99: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 375/375 [00:03<00:00, 104.29it/s, v_num=4, val_loss=0.0107, train_loss=0.0107]


FIT Profiler Report

----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
|  Action                                                                                                                                                               	|  Mean duration (s)	|  Num calls      	|  Total time (s) 	|  Percentage %   	|
----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
|  Total                                                                                                                                                                	|  -       

### Profile the time within every function
`trainer = Trainer(profiler="advanced")`

In [None]:
trainer = Trainer(
    profiler='advanced',
    logger=logger,
    callbacks=[checkpoint_callback, early_stop_callback, model_summary_callback],
    accelerator="gpu",
    devices=1,
    max_epochs=10
)

Trainer already configured with model summary callbacks: [<class 'lightning.pytorch.callbacks.model_summary.ModelSummary'>]. Skipping setting a default `ModelSummary` callback.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores


In [None]:
trainer.fit(model, train_loader, valid_loader)

/storage3/DSIP/rriva/tutorials/pl_tutorial/.venv/lib/python3.13/site-packages/lightning/pytorch/callbacks/model_checkpoint.py:881: Checkpoint directory /storage3/DSIP/rriva/tutorials/pl_tutorial/basic/logs/find_bottlenecks/version_4/checkpoints exists and is not empty.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3,4,5,6,7]

   | Name         | Type       | Params | Mode  | FLOPs  | In sizes     | Out sizes   
--------------------------------------------------------------------------------------------
0  | encoder      | Encoder    | 558 K  | train | 35.7 M | [32, 1, 784] | [32, 1, 100]
1  | encoder.ff   | Sequential | 558 K  | train | 35.7 M | [32, 1, 784] | [32, 1, 100]
2  | encoder.ff.0 | Linear     | 401 K  | train | 25.7 M | [32, 1, 784] | [32, 1, 512]
3  | encoder.ff.1 | ReLU       | 0      | train | 0      | [32, 1, 512] | [32, 1, 512]
4  | encoder.ff.2 | Linear     | 131 K  | train | 8.4 M  | [32, 1, 512] | [32, 1, 256]
5  | encoder.ff.3 | ReLU       | 0      | train | 0      |

                                                                            

Epoch 9: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 375/375 [00:06<00:00, 58.08it/s, v_num=4, val_loss=0.0102, train_loss=0.0102]

`Trainer.fit` stopped: `max_epochs=10` reached.


Epoch 9: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 375/375 [00:06<00:00, 56.73it/s, v_num=4, val_loss=0.0102, train_loss=0.0102]


FIT Profiler Report
Profile stats for: [LightningModule]LitAutoEncoder.configure_callbacks
         7 function calls in 0.000 seconds

   Ordered by: cumulative time

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
        1    0.000    0.000    0.000    0.000 contextlib.py:145(__exit__)
        1    0.000    0.000    0.000    0.000 {built-in method builtins.next}
        1    0.000    0.000    0.000    0.000 profiler.py:56(profile)
        1    0.000    0.000    0.000    0.000 advanced.py:81(stop)
        1    0.000    0.000    0.000    0.000 {method 'disable' of '_lsprof.Profiler' objects}
        1    0.000    0.000    0.000    0.000 {method 'get' of 'dict' objects}
        1    0.000    0.000    0.000    0.000 module.py:970(configure_callbacks)



Profile stats for: [Callback]EarlyStopping{'monitor': 'val_loss', 'mode': 'min'}.setup
         10 function calls in 0.000 seconds

   Ordered by: cumulative time

   ncalls  tottime  percall  cumtime  percall file

If the profiler report becomes too long, you can stream the report to a file:

In [None]:
from lightning.pytorch.profilers import AdvancedProfiler

profiler = AdvancedProfiler(dirpath=".", filename="perf_logs")
trainer = Trainer(
    profiler=profiler,
    logger=logger,
    callbacks=[checkpoint_callback, early_stop_callback, model_summary_callback],
    accelerator="gpu",
    devices=1,
    max_epochs=10
)


/storage3/DSIP/rriva/tutorials/pl_tutorial/.venv/lib/python3.13/site-packages/lightning/fabric/plugins/environments/slurm.py:204: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python /storage3/DSIP/rriva/tutorials/pl_tutorial/.venv/lib ...
Trainer already configured with model summary callbacks: [<class 'lightning.pytorch.callbacks.model_summary.ModelSummary'>]. Skipping setting a default `ModelSummary` callback.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores


In [None]:
trainer.fit(model, train_loader, valid_loader)

/storage3/DSIP/rriva/tutorials/pl_tutorial/.venv/lib/python3.13/site-packages/lightning/fabric/plugins/environments/slurm.py:204: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python /storage3/DSIP/rriva/tutorials/pl_tutorial/.venv/lib ...
/storage3/DSIP/rriva/tutorials/pl_tutorial/.venv/lib/python3.13/site-packages/lightning/pytorch/callbacks/model_checkpoint.py:881: Checkpoint directory /storage3/DSIP/rriva/tutorials/pl_tutorial/basic/logs/find_bottlenecks/version_4/checkpoints exists and is not empty.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3,4,5,6,7]

   | Name         | Type       | Params | Mode  | FLOPs  | In sizes     | Out sizes   
--------------------------------------------------------------------------------------------
0  | encoder      | Encoder    | 558 K  | train | 35.7 M | [32, 1, 784] | [32, 1, 100]
1  | encoder.ff   | Sequential | 558 K  

Epoch 9: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 375/375 [00:08<00:00, 44.37it/s, v_num=4, val_loss=0.0098, train_loss=0.00981] 

`Trainer.fit` stopped: `max_epochs=10` reached.


Epoch 9: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 375/375 [00:08<00:00, 43.48it/s, v_num=4, val_loss=0.0098, train_loss=0.00981]


### Measure accelerator usage
Another helpful technique to detect bottlenecks is to ensure that youâ€™re using the full capacity of your accelerator (GPU/TPU/HPU). This can be measured with the `DeviceStatsMonitor`:

In [None]:
from lightning.pytorch.callbacks import DeviceStatsMonitor
trainer = Trainer(
    profiler=profiler,
    logger=logger,
    callbacks=[checkpoint_callback, early_stop_callback, model_summary_callback, DeviceStatsMonitor()],
    accelerator="gpu",
    devices=1,
    max_epochs=10
)

Trainer already configured with model summary callbacks: [<class 'lightning.pytorch.callbacks.model_summary.ModelSummary'>]. Skipping setting a default `ModelSummary` callback.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores


In [None]:
trainer.fit(model, train_loader, valid_loader)

/storage3/DSIP/rriva/tutorials/pl_tutorial/.venv/lib/python3.13/site-packages/lightning/fabric/plugins/environments/slurm.py:204: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python /storage3/DSIP/rriva/tutorials/pl_tutorial/.venv/lib ...
/storage3/DSIP/rriva/tutorials/pl_tutorial/.venv/lib/python3.13/site-packages/lightning/pytorch/callbacks/model_checkpoint.py:881: Checkpoint directory /storage3/DSIP/rriva/tutorials/pl_tutorial/basic/logs/find_bottlenecks/version_4/checkpoints exists and is not empty.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3,4,5,6,7]

   | Name         | Type       | Params | Mode  | FLOPs  | In sizes     | Out sizes   
--------------------------------------------------------------------------------------------
0  | encoder      | Encoder    | 558 K  | train | 35.7 M | [32, 1, 784] | [32, 1, 100]
1  | encoder.ff   | Sequential | 558 K  

Epoch 9: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 375/375 [00:09<00:00, 41.28it/s, v_num=4, val_loss=0.00944, train_loss=0.00944]

`Trainer.fit` stopped: `max_epochs=10` reached.


Epoch 9: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 375/375 [00:09<00:00, 40.62it/s, v_num=4, val_loss=0.00944, train_loss=0.00944]


In [None]:
import pandas as pd
import os

LOG_DIR = "logs/find_bottlenecks"

# find latest version automatically
versions = sorted(
    [v for v in os.listdir(LOG_DIR) if v.startswith("version_")],
    key=lambda x: int(x.split("_")[1])
)
latest_version = versions[-1]

metrics_path = os.path.join(LOG_DIR, latest_version, "metrics.csv")
df = pd.read_csv(metrics_path)

print(f"Loaded metrics from: {metrics_path}")
print(df.columns)

Loaded metrics from: logs/find_bottlenecks/version_4/metrics.csv
Index(['step',
       'test_-DeviceStatsMonitor.on_train_batch_end/active.all.allocated',
       'test_-DeviceStatsMonitor.on_train_batch_end/active.all.current',
       'test_-DeviceStatsMonitor.on_train_batch_end/active.all.freed',
       'test_-DeviceStatsMonitor.on_train_batch_end/active.all.peak',
       'test_-DeviceStatsMonitor.on_train_batch_end/active.large_pool.allocated',
       'test_-DeviceStatsMonitor.on_train_batch_end/active.large_pool.current',
       'test_-DeviceStatsMonitor.on_train_batch_end/active.large_pool.freed',
       'test_-DeviceStatsMonitor.on_train_batch_end/active.large_pool.peak',
       'test_-DeviceStatsMonitor.on_train_batch_end/active.small_pool.allocated',
       ...
       'test_-DeviceStatsMonitor.on_validation_batch_start/segment.large_pool.peak',
       'test_-DeviceStatsMonitor.on_validation_batch_start/segment.small_pool.allocated',
       'test_-DeviceStatsMonitor.on_validation

In [None]:
import pandas as pd
import os

LOG_DIR = "logs/find_bottlenecks"
version = sorted(os.listdir(LOG_DIR))[-1]
metrics_path = os.path.join(LOG_DIR, version, "metrics.csv")

df = pd.read_csv(metrics_path)

gpu_cols = [
    c for c in df.columns
    if "gpu" in c.lower()
]

gpu_df = df[gpu_cols].dropna(how="all")

summary = {
    "gpu_utilization_mean (%)": gpu_df["device/gpu_utilization"].mean(),
    "gpu_utilization_max (%)": gpu_df["device/gpu_utilization"].max(),
    "gpu_mem_allocated_mean (MB)": gpu_df["device/gpu_memory_allocated"].mean(),
    "gpu_mem_allocated_max (MB)": gpu_df["device/gpu_memory_allocated"].max(),
}

print("\n=== GPU USAGE SUMMARY ===")
for k, v in summary.items():
    print(f"{k:40s}: {v:.2f}")

KeyError: 'device/gpu_utilization'

In [None]:
import matplotlib.pyplot as plt
import pandas as pd
import os

metrics_path = "logs/find_bottlenecks/version_4/metrics.csv"
df = pd.read_csv(metrics_path)

df = df.dropna(subset=["device/gpu_utilization"])

plt.figure(figsize=(10, 4))
plt.plot(df["device/gpu_utilization"])
plt.xlabel("Training step")
plt.ylabel("GPU utilization (%)")
plt.title("GPU Utilization over Training")
plt.grid(True)
plt.tight_layout()
plt.show()

KeyError: ['device/gpu_utilization']

In [None]:
import pandas as pd

df = pd.read_csv("logs/find_bottlenecks/version_4/metrics.csv")

for c in df.columns:
    if "gpu" in c.lower() or "cuda" in c.lower() or 'device' in c.lower():
        print(c)

test_-DeviceStatsMonitor.on_train_batch_end/active.all.allocated
test_-DeviceStatsMonitor.on_train_batch_end/active.all.current
test_-DeviceStatsMonitor.on_train_batch_end/active.all.freed
test_-DeviceStatsMonitor.on_train_batch_end/active.all.peak
test_-DeviceStatsMonitor.on_train_batch_end/active.large_pool.allocated
test_-DeviceStatsMonitor.on_train_batch_end/active.large_pool.current
test_-DeviceStatsMonitor.on_train_batch_end/active.large_pool.freed
test_-DeviceStatsMonitor.on_train_batch_end/active.large_pool.peak
test_-DeviceStatsMonitor.on_train_batch_end/active.small_pool.allocated
test_-DeviceStatsMonitor.on_train_batch_end/active.small_pool.current
test_-DeviceStatsMonitor.on_train_batch_end/active.small_pool.freed
test_-DeviceStatsMonitor.on_train_batch_end/active.small_pool.peak
test_-DeviceStatsMonitor.on_train_batch_end/active_bytes.all.allocated
test_-DeviceStatsMonitor.on_train_batch_end/active_bytes.all.current
test_-DeviceStatsMonitor.on_train_batch_end/active_bytes.

## Track and Visualize Experiments (basic)

In [1]:
import torch
from torch import nn
from torch.utils.data import DataLoader, random_split
import torch.nn.functional as F

from lightning import LightningModule
from lightning import Trainer
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch.callbacks import EarlyStopping
from lightning.pytorch.callbacks import ModelSummary
from lightning.pytorch.loggers import CSVLogger

from torchvision import datasets
import torchvision.transforms as transforms

import os

In [2]:
# Load data sets
transform = transforms.ToTensor()
train_set = datasets.MNIST(root="../data/MNIST", download=True, train=True, transform=transform)
test_set = datasets.MNIST(root="../data/MNIST", download=True, train=False, transform=transform)

# use 20% of training data for validation
train_set_size = int(len(train_set) * 0.8)
valid_set_size = len(train_set) - train_set_size

# split the train set into two
seed = torch.Generator().manual_seed(42)
train_set, valid_set = random_split(train_set, [train_set_size, valid_set_size], generator=seed)

train_loader = DataLoader(train_set, batch_size=128, shuffle=True, num_workers=16, persistent_workers=True, pin_memory=True)
valid_loader = DataLoader(valid_set, batch_size=128, num_workers=16, persistent_workers=True, pin_memory=True)
test_loader = DataLoader(test_set, batch_size=1024, num_workers=16, persistent_workers=True, pin_memory=True)

### Track metrics
Metric visualization is the most basic but powerful way of understanding how your model is doing throughout the model development process.
To track a metric, simply use the `self.log` method available inside the LightningModule

In [3]:
class Encoder(nn.Module):
    def __init__(self, in_dim=28*28, hidden_nodes_1=64, hidden_nodes_2=64, out_dim=4):
        super().__init__()
        self.ff = nn.Sequential(
            nn.Linear(in_dim, hidden_nodes_1),
            nn.ReLU(),
            nn.Linear(hidden_nodes_1, hidden_nodes_2),
            nn.ReLU(),
            nn.Linear(hidden_nodes_2, out_dim)
        )

    def forward(self, x):
        return self.ff(x)

class Decoder(nn.Module):
    def __init__(self, in_dim=4, hidden_nodes_1=64, hidden_nodes_2=64, out_dim=28*28):
        super().__init__()
        self.ff = nn.Sequential(
            nn.Linear(in_dim, hidden_nodes_1),
            nn.ReLU(),
            nn.Linear(hidden_nodes_1, hidden_nodes_2),
            nn.ReLU(),
            nn.Linear(hidden_nodes_2, out_dim)
        )

    def forward(self, x):
        return self.ff(x)

class LitAutoEncoder(LightningModule):
    def __init__(self, encoder, decoder, lr=1e-5, example_input_array=None):
        super().__init__()
        self.example_input_array = example_input_array
        self.encoder = encoder
        self.decoder = decoder
        self.lr = lr
        self.save_hyperparameters(ignore=["encoder", "decoder"])

    def training_step(self, batch, batch_idx):
        # training_step defines the train loop.
        loss = self._get_loss(batch)
        self.log("train/loss", loss, on_step=True, on_epoch=True)
        self.log("train_loss", loss, prog_bar=True, on_step=False, on_epoch=True)
        return loss

    def validation_step(self, batch, batch_idx):
        # this is the validation loop
        loss = self._get_loss(batch)
        self.log("val_loss", loss, prog_bar=True)

    def test_step(self, batch, batch_idx):
        # this is the test loop
        loss = self._get_loss(batch)
        self.log("test_loss", loss, prog_bar=True)

    def _get_loss(self, batch):
        x, _ = batch
        x = x.view(x.size(0), -1)
        x_hat = self.forward(x)
        loss = F.mse_loss(x_hat, x)
        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)
        return optimizer

    def forward(self, x):
        z = self.encoder(x)
        x_hat = self.decoder(z)
        return x_hat

In [4]:
model = LitAutoEncoder(
    encoder=Encoder(
        in_dim=28*28,
        hidden_nodes_1=512,
        hidden_nodes_2=256,
        out_dim=100
    ),
    decoder=Decoder(
        in_dim=100,
        hidden_nodes_1=128,
        hidden_nodes_2=256,
        out_dim=28*28
    ),
    example_input_array=torch.Tensor(32, 1, 28*28)
)

In [5]:
logger = CSVLogger(
    save_dir='logs',
    name='find_bottlenecks',
    version=None
)

checkpoint_callback = ModelCheckpoint(
    dirpath=os.path.join(logger.log_dir, "checkpoints"),
    filename="autoencoder_best-{epoch:02d}-{val_loss:.3f}",
    monitor="val_loss",    
    mode="min",
    save_top_k=3,     # keep ONLY the best
    save_last=True    # ALSO save last.ckpt
)

early_stop_callback = EarlyStopping(
    monitor="val_loss",
    patience=3,
    verbose=False,
    mode="min"
)

model_summary_callback = ModelSummary(-1)
print(model_summary_callback)

<lightning.pytorch.callbacks.model_summary.ModelSummary object at 0x7fc11f815400>


In [6]:
from lightning.pytorch.loggers import TensorBoardLogger

logger = TensorBoardLogger("tb_logs", name="my_model")


In [7]:
trainer = Trainer(
    #profiler='simple',
    logger=logger,
    callbacks=[checkpoint_callback, early_stop_callback, model_summary_callback],
    accelerator="gpu",
    devices=1,
    max_epochs=100
)

/storage3/DSIP/rriva/tutorials/pl_tutorial/.venv/lib/python3.13/site-packages/lightning/fabric/plugins/environments/slurm.py:204: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python /storage3/DSIP/rriva/tutorials/pl_tutorial/.venv/lib ...
Trainer already configured with model summary callbacks: [<class 'lightning.pytorch.callbacks.model_summary.ModelSummary'>]. Skipping setting a default `ModelSummary` callback.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores


In [8]:
trainer.fit(model, train_loader, valid_loader)
trainer.test(model, test_loader)

You are using a CUDA device ('NVIDIA A40') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3,4,5,6,7]



   | Name         | Type       | Params | Mode  | FLOPs  | In sizes     | Out sizes   
--------------------------------------------------------------------------------------------
0  | encoder      | Encoder    | 558 K  | train | 35.7 M | [32, 1, 784] | [32, 1, 100]
1  | encoder.ff   | Sequential | 558 K  | train | 35.7 M | [32, 1, 784] | [32, 1, 100]
2  | encoder.ff.0 | Linear     | 401 K  | train | 25.7 M | [32, 1, 784] | [32, 1, 512]
3  | encoder.ff.1 | ReLU       | 0      | train | 0      | [32, 1, 512] | [32, 1, 512]
4  | encoder.ff.2 | Linear     | 131 K  | train | 8.4 M  | [32, 1, 512] | [32, 1, 256]
5  | encoder.ff.3 | ReLU       | 0      | train | 0      | [32, 1, 256] | [32, 1, 256]
6  | encoder.ff.4 | Linear     | 25.7 K | train | 1.6 M  | [32, 1, 256] | [32, 1, 100]
7  | decoder      | Decoder    | 247 K  | train | 15.8 M | [32, 1, 100] | [32, 1, 784]
8  | decoder.ff   | Sequential | 247 K  | train | 15.8 M | [32, 1, 100] | [32, 1, 784]
9  | decoder.ff.0 | Linear     | 12.

Epoch 99: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 375/375 [00:05<00:00, 72.77it/s, v_num=1, val_loss=0.0109, train_loss=0.0109]

`Trainer.fit` stopped: `max_epochs=100` reached.


Epoch 99: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 375/375 [00:05<00:00, 70.61it/s, v_num=1, val_loss=0.0109, train_loss=0.0109]


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3,4,5,6,7]


Testing DataLoader 0: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 10/10 [00:00<00:00, 159.08it/s]
â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€
       Test metric             DataLoader 0
â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€
        test_loss          0.010606293566524982
â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”

[{'test_loss': 0.010606293566524982}]