- https://huggingface.co/docs/accelerate/concept_guides/fsdp1_vs_fsdp2

In [4]:
from packaging import version
import torch

In [7]:
version.parse(torch.__version__), version.parse(torch.__version__) >= version.parse("2.6")

(<Version('2.6.0+cu124')>, True)

In [5]:
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
if version.parse(torch.__version__) >= version.parse("2.6"):
    from torch.distributed.fsdp import CPUOffloadPolicy, FSDPModule, MixedPrecisionPolicy, fully_shard
elif version.parse(torch.__version__) >= version.parse("2.4"):
    from torch.distributed._composable.fsdp import CPUOffloadPolicy, FSDPModule, MixedPrecisionPolicy, fully_shard
else:
    fully_shard, MixedPrecisionPolicy, FSDPModule, CPUOffloadPolicy = None, None, None, None

In [9]:
def fsdp_version(model):
    if isinstance(model, FSDP):
        return 1
    elif isinstance(model, FSDPModule):
        return 2
    else:
        return 0

## fsdp 1 vs. fsdp 2

In [10]:
from IPython.display import Image

In [11]:
Image(url='https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/accelerate/fsdp1.png',
      width=400)

In [13]:
Image(url='https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/accelerate/fsdp2.png',
      width=400)

- FSDP 1 (传统FSDP)：
    - `FlatParameter` 抽象：将多个参数展平连接成一个1D张量
    - 固定参数结构：使用`FlatParameter`改变了原始参数的结构
- FSDP 2 (组合式FSDP)：
    - 组合模式：使用 fully_shard() 函数应用到特定模块
    - 去除`FlatParameter`（使用 DTensor，distributed tensor）：不再使用参数展平，直接复制参数进行通信
        - In the image above, the tensors were sharded across the `1st dimension` (columns 按列分) for the sake of fitting the image on the screen, in reality, they are sharded across the `0th dimension`（按 rows 分） as stated above
    - 保持原始结构：保留模型的原始参数结构