https://lightning.ai/docs/fabric/stable/

https://lightning.ai/docs/fabric/stable/api/fabric_args.html

https://lightning.ai/docs/fabric/stable/api/fabric_methods.html

In [36]:
import torch
from torch import nn
from torch import optim
from torch.utils.data import Dataset, DataLoader
from lightning_fabric import Fabric
from lightning_fabric.loggers import CSVLogger, TensorBoardLogger
from torchmetrics.functional import accuracy
from tqdm import tqdm
from pathlib import Path

In [37]:
epochs      = 5
batch_size  = 100
in_features = 10
num_classes = 5
data_len    = 10000

In [38]:
model = nn.Linear(in_features, num_classes)

In [39]:
optimizer = optim.AdamW(model.parameters(), lr=0.001, betas=(0.9, 0.999), weight_decay=1e-2,)

In [40]:
lr_sche = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs, eta_min=0.001*100)

In [41]:
loss_fn = nn.CrossEntropyLoss()

In [42]:
class Dataset(Dataset):
    def __init__(self) -> None:
        super().__init__()
        self.x = torch.randn(data_len, in_features)
        self.y = torch.randint(0, num_classes, (data_len,))

    def __getitem__(self, index):
        return self.x[index], self.y[index]

    def __len__(self):
        return len(self.x)

In [43]:
train_datalaoder = DataLoader(dataset=Dataset(), batch_size=batch_size, shuffle=True, num_workers=0, pin_memory=True)
val_datalaoder   = DataLoader(dataset=Dataset(), batch_size=batch_size, shuffle=False, num_workers=0, pin_memory=True)

# Fabric

In [44]:
output_dir = Path("fabric_checkpoint")
output_dir

WindowsPath('fabric_checkpoint')

## logger

In [45]:
loggers = [
    CSVLogger(root_dir=output_dir, name = "", version="",),
    TensorBoardLogger(root_dir=output_dir, name = "", version = "",),
]

In [46]:
fabric = Fabric(
    accelerator="gpu",  # "cpu", "gpu", "tpu", "auto"
    # strategy="ddp",   # "dp", "ddp", "ddp_spawn", "xla", "deepspeed", "fsdp", "auto"
    devices=1,          # "auto", -1: run on all GPUs
    precision="32-true",# ("transformer-engine", "transformer-engine-float16", "16-true", "16-mixed", "bf16-true", "bf16-mixed", "32-true", "64-true")
    loggers=loggers,
)
fabric.launch()

In [47]:
fabric.device

device(type='cuda', index=0)

## local_rank

> node
物理节点，就是一台机器，节点内部可以有多个GPU(一台机器有多卡)。

> rank & local_rank
>
> 用于表示进程的序号，用于进程间通信。每一个进程对应了一个rank。
>
> rank=0的进程就是master进程。
>
> local_rank： rank是指在整个分布式任务中进程的序号；local_rank是指在一台机器上(一个node上)进程的相对序号，例如机器一上有0,1,2,3,4,5,6,7，机器二上也有0,1,2,3,4,5,6,7。local_rank在node之间相互独立。
>
> 单机多卡时，rank就等于local_rank

> nnodes
>
> 物理节点数量

> node_rank
>
> 物理节点的序号

> nproc_per_node
>
> 每个物理节点上面进程的数量。

> group
>
> 进程组。默认只有一个组

> world size 全局的并行数
>
> 全局（一个分布式任务）中，rank的数量。
>
> 每个node包含16个GPU，且nproc_per_node=8，nnodes=3，机器的node_rank=5，请问world_size是多少？
>
> 答案：world_size = 3*8 = 24


```yaml
# 一共有12个rank, nnodes=3, nproc_per_node=4,每个节点都对应一个node_rank

machine0:
    node_rank: 0
        GPU0:
            rank: 0
            local_rank: 0
        GPU1:
            rank: 1
            local_rank: 1
        GPU2:
            rank: 2
            local_rank: 2
        GPU3:
            rank: 3
            local_rank: 3

machine1:
    node_rank: 1
        GPU0:
            rank: 4
            local_rank: 0
        GPU1:
            rank: 5
            local_rank: 1
        GPU2:
            rank: 6
            local_rank: 2
        GPU3:
            rank: 7
            local_rank: 3

machine2:
    node_rank: 2
        GPU0:
            rank: 8
            local_rank: 0
        GPU1:
            rank: 9
            local_rank: 1
        GPU2:
            rank: 10
            local_rank: 2
        GPU3:
            rank: 11
            local_rank: 3
```

In [48]:
print(fabric.local_rank)     # 获取进程 fabric.local_rank==0 代表主进程,相当于 accelerate.is_main_process
print(fabric.node_rank)
print(fabric.global_rank)
print(fabric.is_global_zero) # Whether this rank is rank zero.

0
0
0
True


## setup

In [49]:
model, optimizer = fabric.setup(model, optimizer)
train_datalaoder = fabric.setup_dataloaders(train_datalaoder)
val_datalaoder   = fabric.setup_dataloaders(val_datalaoder)

## clip gradients

In [50]:
fabric.clip_gradients(module=model, optimizer=optimizer, max_norm=1, norm_type=2)
torch.nn.utils.clip_grad.clip_grad_norm_(parameters=model.parameters(), max_norm=1, norm_type=2)

tensor(0.)

In [51]:
# fabric.clip_gradients(module=model, optimizer=optimizer, clip_val=0.1)
# torch.nn.utils.clip_grad.clip_grad_value_(parameters=model.parameters(), clip_value=0.1)

## train loop

In [52]:
for epoch in range(1, epochs + 1):
    # train
    model.train()
    with tqdm(total = len(train_datalaoder), desc=f"{epoch}/{epochs}", disable = fabric.local_rank != 0) as pbar:
        all_predictions = []
        all_targets = []
        all_losses = []
        for x, y in train_datalaoder:
            optimizer.zero_grad()
            y_pred: torch.Tensor = model(x) # with automatic autocast https://lightning.ai/docs/fabric/stable/api/fabric_methods.html#autocast
            loss: torch.Tensor = loss_fn(y_pred, y)
            fabric.backward(loss)   # replace loss.backward()
            fabric.clip_gradients(  # 梯度裁剪
                module=model,
                optimizer=optimizer,
                clip_val=None,  # 按照值裁剪
                max_norm=1.0,   # 按照梯度裁剪
                norm_type=2.0,
            )
            optimizer.step()

            pbar.set_postfix({"train/loss": f"{loss.item():.4f}"})
            pbar.update(1)

            # 获取所有数据上的预测值和真实值,用来验证
            all_pred, all_tar, all_loss = fabric.all_gather((y_pred, y, loss))
            all_predictions.append(all_pred)
            all_targets.append(all_tar)
            all_losses.append(all_loss)

        train_acc = accuracy(
            preds=torch.cat(all_predictions, dim=0),
            target=torch.cat(all_targets, dim=0),
            task="multiclass",
            num_classes=num_classes,
        )
        train_avg_loss = torch.mean(torch.tensor(all_losses))
        pbar.set_postfix({"train/acc" :f"{train_acc.item():.4f}"})

    lr_sche.step()

    # val
    model.eval()
    with tqdm(total = len(val_datalaoder), desc=f"{epoch}/{epochs}", disable = fabric.local_rank != 0) as pbar:
        all_predictions = []
        all_targets = []
        all_losses = []
        for x, y in val_datalaoder:
            with torch.inference_mode():
                y_pred: torch.Tensor = model(x)
            loss: torch.Tensor = loss_fn(y_pred, y)

            pbar.set_postfix({"val/loss": f"{loss.item():.4f}"})
            pbar.update(1)

            # 获取所有数据上的预测值和真实值,用来验证
            all_pred, all_tar, all_loss = fabric.all_gather((y_pred, y, loss))
            all_predictions.append(all_pred)
            all_targets.append(all_tar)
            all_losses.append(all_loss)

        val_acc = accuracy(
            preds=torch.cat(all_predictions, dim=0),
            target=torch.cat(all_targets, dim=0),
            task="multiclass",
            num_classes=num_classes,
        )
        val_avg_loss = torch.mean(torch.tensor(all_losses))
        pbar.set_postfix({"val/acc" :f"{val_acc.item():.4f}"})

    # fabric log
    # fabric.log(name="val/acc", value=val_avg_loss.item(), step=epoch)
    fabric.log_dict(
        metrics={
            "train/acc": train_acc.item(),
            "train/loss": train_avg_loss.item(),
            "val/acc": val_acc.item(),
            "val/loss": val_avg_loss.item(),
        },
        step=epoch,
    )

    # 本地主进程才保存
    if fabric.is_global_zero:
        # like torch.distributed.barrier, wait for all processes to enter this call.
        fabric.barrier()
        # save
        # You should pass the model and optimizer objects directly into the dictionary so Fabric can unwrap them and automatically retrieve their state-dict.
        fabric.save(
            path=output_dir / "fabric.last.pth",
            state={
                "model": model,
                "optimizer": optimizer,
                "lr_sche": lr_sche,
            }
        )

fabric.logger.finalize("training finish")
fabric.print("training finish")

1/5: 100%|██████████| 100/100 [00:00<00:00, 315.46it/s, train/acc=0.2002]
1/5: 100%|██████████| 100/100 [00:00<00:00, 531.92it/s, val/acc=0.1997]
2/5: 100%|██████████| 100/100 [00:00<00:00, 315.46it/s, train/acc=0.2037] 
2/5: 100%|██████████| 100/100 [00:00<00:00, 563.21it/s, val/acc=0.2016]
3/5: 100%|██████████| 100/100 [00:00<00:00, 323.63it/s, train/acc=0.1959]
3/5: 100%|██████████| 100/100 [00:00<00:00, 568.19it/s, val/acc=0.1980]
4/5: 100%|██████████| 100/100 [00:00<00:00, 337.24it/s, train/acc=0.2047]
4/5: 100%|██████████| 100/100 [00:00<00:00, 570.18it/s, val/acc=0.2093]
5/5: 100%|██████████| 100/100 [00:00<00:00, 334.95it/s, train/acc=0.1936]
5/5: 100%|██████████| 100/100 [00:00<00:00, 571.43it/s, val/acc=0.2077]

training finish





## fabric.load等同torch.load

In [53]:
fabric.load(output_dir / "fabric.last.pth")

{'model': OrderedDict([('weight',
               tensor([[-0.1348, -0.0855, -0.0954,  0.0417, -0.3190,  0.1403, -0.1463, -0.0071,
                        -0.1843, -0.1384],
                       [-0.2316, -0.0180,  0.0514,  0.3167, -0.1499,  0.0330,  0.1179,  0.0268,
                         0.0168, -0.0062],
                       [-0.0115,  0.0963,  0.0759,  0.0937,  0.0938, -0.0601, -0.1356,  0.2082,
                        -0.1564, -0.0013],
                       [ 0.0964, -0.0995,  0.0425,  0.0931, -0.0945, -0.0394,  0.0584, -0.0102,
                         0.0834, -0.1264],
                       [-0.0601,  0.0548, -0.0289, -0.0111,  0.0597,  0.0963, -0.0984,  0.2177,
                         0.0635,  0.1728]])),
              ('bias',
               tensor([ 0.1981,  0.0042,  0.1461, -0.2195, -0.0454]))]),
 'optimizer': {'state': {0: {'step': tensor(500.),
    'exp_avg': tensor([[ 0.0004,  0.0075, -0.0126,  0.0096,  0.0147,  0.0048, -0.0042,  0.0065,
             -0.0097,  0.

In [54]:
torch.load(output_dir / "fabric.last.pth")

{'model': OrderedDict([('weight',
               tensor([[-0.1348, -0.0855, -0.0954,  0.0417, -0.3190,  0.1403, -0.1463, -0.0071,
                        -0.1843, -0.1384],
                       [-0.2316, -0.0180,  0.0514,  0.3167, -0.1499,  0.0330,  0.1179,  0.0268,
                         0.0168, -0.0062],
                       [-0.0115,  0.0963,  0.0759,  0.0937,  0.0938, -0.0601, -0.1356,  0.2082,
                        -0.1564, -0.0013],
                       [ 0.0964, -0.0995,  0.0425,  0.0931, -0.0945, -0.0394,  0.0584, -0.0102,
                         0.0834, -0.1264],
                       [-0.0601,  0.0548, -0.0289, -0.0111,  0.0597,  0.0963, -0.0984,  0.2177,
                         0.0635,  0.1728]], device='cuda:0')),
              ('bias',
               tensor([ 0.1981,  0.0042,  0.1461, -0.2195, -0.0454], device='cuda:0'))]),
 'optimizer': {'state': {0: {'step': tensor(500.),
    'exp_avg': tensor([[ 0.0004,  0.0075, -0.0126,  0.0096,  0.0147,  0.0048, -0.0042, 