> A deep learning framework.

__Outline__
1. Define a LightningModule
2. Define a dataset
3. Train the model
4. Use the model
5. Visualize training
6. Supercharge training


__1. Define a LightningModule__

In [2]:
import os
from torch import optim, nn, utils, Tensor
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor
import lightning as L

In [4]:
encoder = nn.Sequential(
    nn.Linear(28 * 28, 64), 
    nn.ReLU(), 
    nn.Linear(64,3)
)

decoder = nn.Sequential(
    nn.Linear(3, 64),
    nn.ReLU(),
    nn.Linear(64, 28 * 28)
)


class LitAutoEncoder(L.LightningModule):
    def __init__(self, encoder, decoder):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder
        
    def training_step(self, batch, batch_idx):
        x,y = batch
        x = x.view(x.size(0), -1)
        z = self.encoder(x)
        x_hat = self.decoder(z)
        loss = nn.functional.mse_loss(x_hat,x)
        
        self.log("train_loss", loss)
        return loss
    
    def configure_optimizers(self):
        optimizer = optim.Adam(self.parameters(), lr=1e-3)
        return optimizer

autoencoder = LitAutoEncoder(encoder, decoder)

In [7]:
dataset = MNIST(os.getcwd(), download=True, transform=ToTensor())
trian_loader = utils.data.DataLoader(dataset, batch_size=32)

In [9]:
trainer = L.Trainer(limit_test_batches=100, max_epochs=1)
trainer.fit(model=autoencoder, train_dataloaders=trian_loader)

GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs



  | Name    | Type       | Params
---------------------------------------
0 | encoder | Sequential | 50.4 K
1 | decoder | Sequential | 51.2 K
---------------------------------------
101 K     Trainable params
0         Non-trainable params
101 K     Total params
0.407     Total estimated model params size (MB)
/Users/zhengpanpan/miniconda3/envs/pytorch/lib/python3.12/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:441: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=7` in the `DataLoader` to improve performance.


Training: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=1` reached.


In [13]:
checkpoint = "./lightning_logs/version_2/checkpoints/epoch=0-step=1875.ckpt"
autoencoder = LitAutoEncoder.load_from_checkpoint(checkpoint, encoder=encoder, decoder=decoder)

# choose your trained nn.Module
encoder = autoencoder.encoder
encoder.eval()

# embed 4 fake images!
fake_image_batch = torch.rand(4, 28 * 28, device=autoencoder.device)
embeddings = encoder(fake_image_batch)
print("⚡" * 20, "\nPredictions (4 image embeddings):\n", embeddings, "\n", "⚡" * 20)


⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡ 
Predictions (4 image embeddings):
 tensor([[-0.5026,  0.4083, -0.7412],
        [-0.5384,  0.3476, -0.5428],
        [-0.3514,  0.3526, -0.5581],
        [-0.5821,  0.3804, -0.6357]], device='mps:0',
       grad_fn=<LinearBackward0>) 
 ⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡
