# Define LightningModel
A LightningModule enables your PyTorch nn.Module to play together in complex ways inside the training_step (there is also an optional validation_step and test_step).

In [2]:
import torch
import os
from torch import optim, nn, utils, Tensor
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor
import lightning as L

# define any number of nn.Modules (or use your current ones)
encoder = nn.Sequential(nn.Linear(28 * 28, 64), nn.ReLU(), nn.Linear(64, 3)) 
decoder = nn.Sequential(nn.Linear(3, 64), nn.ReLU(), nn.Linear(64, 28 * 28)) 


# define the LightningModule
class LitAutoEncoder(L.LightningModule):
    def __init__(self, encoder, decoder):
        super().__init__()
        self.encoder = encoder  # define encoder
        self.decoder = decoder  # define decoder

    # the training step is the heart of the training process
    def training_step(self, batch, batch_idx):
        # training_step defines the train loop.  # 中文：training_step定义了训练循环。
        # it is independent of forward  # 中文：它独立于forward
        x, y = batch  # x, y = 批次
        x = x.view(x.size(0), -1)  # flatten the image 
        z = self.encoder(x)  # get the latent vector
        x_hat = self.decoder(z)  # reconstruct the image
        loss = nn.functional.mse_loss(x_hat, x)  # calculate the reconstruction loss
        # Logging to TensorBoard (if installed) by default
        self.log("train_loss", loss)
        return loss

    def configure_optimizers(self):  
        """
        configure_optimizers is the second required method for LightningModules
        it must return a single optimizer or a list of optimizers, and can optionally return
        a scheduler and a dict of schedulers.

        """
        optimizer = optim.Adam(self.parameters(), lr=1e-3)
        return optimizer


# init the autoencoder  # 初始化自动编码器
autoencoder = LitAutoEncoder(encoder, decoder)

# Define a dataset
Lightning supports ANY iterable (DataLoader, numpy, etc…) for the train/val/test/predict splits.

In [3]:
# setup data
dataset = MNIST(os.getcwd(), download=True, transform=ToTensor())  # download the MNIST dataset
train_loader = utils.data.DataLoader(dataset)  # put the dataset into a dataloader

# Train the model
- PyTorch Lightning 的 Trainer 提供了一种简单而强大的方式，可以将任何 LightningModule 与任何数据集结合在一起，
- 提供了一个高级的训练循环抽象，使得用户只需专注于定义模型、数据集以及训练逻辑，而不必担心处理大规模训练、分布式训练等底层细节。

In [4]:
# train the model (hint: here are some helpful Trainer arguments for rapid idea iteration)
trainer = L.Trainer(limit_train_batches=100, max_epochs=1)
trainer.fit(model=autoencoder, train_dataloaders=train_loader)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
d:\data\jhq\AI\torchTutorial\.env\lib\site-packages\lightning\pytorch\trainer\connectors\logger_connector\logger_connector.py:75: Starting from v1.9.0, `tensorboardX` has been removed as a dependency of the `lightning.pytorch` package, due to potential conflicts with other packages in the ML ecosystem. For this reason, `logger=True` will use `CSVLogger` as the default logger, unless the `tensorboard` or `tensorboardX` packages are found. Please `pip install lightning[extra]` or one of them to enable TensorBoard support by default
Missing logger folder: d:\data\jhq\AI\torchTutorial\lightning_logs

  | Name    | Type       | Params
---------------------------------------
0 | encoder | Sequential | 50.4 K
1 | decoder | Sequential | 51.2 K
---------------------------------------
101 K     Trainable params
0         Non-trainable params
101 K    

Epoch 0: 100%|██████████| 100/100 [00:00<00:00, 145.41it/s, v_num=0]

`Trainer.fit` stopped: `max_epochs=1` reached.


Epoch 0: 100%|██████████| 100/100 [00:00<00:00, 141.73it/s, v_num=0]


# Use the model
- Once you’ve trained the model you can export to onnx, torchscript and put it into production or simply load the weights and run predictions.

In [5]:
# load checkpoint
checkpoint = "./lightning_logs/version_0/checkpoints/epoch=0-step=100.ckpt"
autoencoder = LitAutoEncoder.load_from_checkpoint(checkpoint, encoder=encoder, decoder=decoder)

# choose your trained nn.Module
encoder = autoencoder.encoder
encoder.eval()  # set in evaluation mode

# embed 4 fake images!
fake_image_batch = torch.rand(4, 28 * 28, device=autoencoder.device)  # random 4 images
embeddings = encoder(fake_image_batch)
print("⚡" * 20, "\nPredictions (4 image embeddings):\n", embeddings, "\n", "⚡" * 20)

⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡ 
Predictions (4 image embeddings):
 tensor([[ 0.1404, -0.0600, -0.4762],
        [ 0.2734, -0.0961, -0.5693],
        [ 0.2320, -0.1176, -0.6029],
        [ 0.2487, -0.1015, -0.5748]], grad_fn=<AddmmBackward0>) 
 ⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡


# Visualize training
- If you have tensorboard installed, you can use it for visualizing experiments.
- Run this on your commandline and open your browser to http://localhost:6006/

# TUTORIAL 5: TRANSFORMERS AND MULTI-HEAD ATTENTION