Skip to content

Can't load data with DDP and GPUs=4 #18221

@SyntaxSmith

Description

@SyntaxSmith

Bug description

when I use pl to train my model, this python script can only run when I use one GPU, after I change the trainer's arg "devices=[0] " to "devices=[0,1,2,3]" ,this script can's finish data sanity check.

What version are you seeing the problem on?

v2.0

How to reproduce the bug

PAD_IDX = 1
    # 数据设置
    df = pd.read_csv("group_selfies.txt", sep="\t", header=None)
    train_df, test_df = train_test_split(df, test_size=0.01, random_state=42)

    train_dataset = GSDataset(data=train_df, vocab_file="vocab_gs.txt", max_length=256)
    test_dataset = GSDataset(data=test_df, vocab_file="vocab_gs.txt", max_length=256)
    print(f"train_dataset:{len(train_dataset)},test_dataset:{len(test_dataset)}")
    # 模型设置
    BATCH = 256
    epoch = 150
    ntoken = train_dataset.get_vocab_size()
    d_model = 64
    nhead = 8
    d_hid = 4096
    nlayers = 4
    hparams = {
        "ntoken": ntoken,
        "BATCH": BATCH,
        "epoch": epoch,
        "d_model": d_model,
        "nhead": nhead,
        "d_hid": d_hid,
        "nlayers": nlayers,
        "lr": 1e-3,
        "remarks": """
        decoder-only
        """,
    }

    # 传入模型
    model = GSTransformer(
        train_dataset,
        test_dataset,
        ntoken,
        hparams["d_model"],
        hparams["nhead"],
        hparams["d_hid"],
        hparams["nlayers"],
        lr=1e-3,
        hparams=hparams,
    )
    ## re-pretrain
    # checkpoint = torch.load(
    # "lightning_logs/version_1/checkpoints/epoch=131-step=17028.ckpt"
    # )
    # model.load_state_dict(checkpoint["state_dict"])
    # 设置回调函数
    from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint

    checkpoint_callback = ModelCheckpoint(monitor="val_loss")
    early_stop_callback = pl.callbacks.EarlyStopping(
        monitor="val_loss", min_delta=0.00, patience=10, verbose=False, mode="min"
    )
    lr_monitor = LearningRateMonitor(logging_interval="epoch")
    # Train model
    trainer = pl.Trainer(
        max_epochs=epoch,
        devices=4,
        accelerator="gpu",
        strategy="ddp",
        callbacks=[checkpoint_callback, lr_monitor],
        accumulate_grad_batches=2,
        precision="16-mixed",
        logger=pl.loggers.TensorBoardLogger(
            "lightning_logs/group-selfies", name="gs2gs-recon"
        ),
    )
    trainer.fit(
        model,
        DataLoader(train_dataset, batch_size=BATCH, shuffle=True, num_workers=0),
        DataLoader(test_dataset, batch_size=BATCH * 2, shuffle=False, num_workers=0),
    )

Error messages and logs

# Error messages and logs here please

image
image
image

Environment

Current environment
  • CUDA:
    - GPU:
    - NVIDIA GeForce RTX 4090
    - NVIDIA GeForce RTX 4090
    - NVIDIA GeForce RTX 4090
    - NVIDIA GeForce RTX 4090
    - available: True
    - version: 11.8
  • Lightning:
    - lightning-utilities: 0.9.0
    - pytorch-lightning: 2.0.6
    - torch: 2.0.1
    - torchaudio: 2.0.2
    - torchmetrics: 1.0.1
    - torchvision: 0.15.2
  • Packages:
    - aiohttp: 3.8.3
    - aiosignal: 1.2.0
    - appdirs: 1.4.4
    - asttokens: 2.0.5
    - async-timeout: 4.0.2
    - attrs: 22.1.0
    - backcall: 0.2.0
    - bottleneck: 1.3.5
    - brotlipy: 0.7.0
    - certifi: 2023.7.22
    - cffi: 1.15.1
    - charset-normalizer: 2.0.4
    - click: 8.0.4
    - comm: 0.1.2
    - contourpy: 1.0.5
    - cryptography: 41.0.2
    - cycler: 0.11.0
    - datasets: 2.12.0
    - debugpy: 1.6.7
    - decorator: 5.1.1
    - dill: 0.3.6
    - emmet-core: 0.63.1
    - executing: 0.8.3
    - fastprogress: 1.0.3
    - filelock: 3.9.0
    - fonttools: 4.25.0
    - frozenlist: 1.3.3
    - fsspec: 2023.4.0
    - future: 0.18.3
    - gensim: 4.3.1
    - global-chem: 1.8.1.2
    - gmpy2: 2.1.2
    - greenlet: 2.0.1
    - group-selfies: 1.0.0
    - huggingface-hub: 0.15.1
    - idna: 3.4
    - importlib-metadata: 6.0.0
    - importlib-resources: 5.2.0
    - iprogress: 0.4
    - ipykernel: 6.19.2
    - ipython: 8.12.0
    - jedi: 0.18.1
    - jinja2: 3.1.2
    - joblib: 1.2.0
    - jupyter-client: 8.1.0
    - jupyter-core: 5.3.0
    - kiwisolver: 1.4.4
    - latexcodec: 2.0.1
    - lightning-utilities: 0.9.0
    - markupsafe: 2.1.1
    - matplotlib: 3.7.1
    - matplotlib-inline: 0.1.6
    - mkl-fft: 1.3.6
    - mkl-random: 1.2.2
    - mkl-service: 2.4.0
    - monty: 2023.5.8
    - mp-api: 0.33.3
    - mpmath: 1.3.0
    - msgpack: 1.0.5
    - multidict: 6.0.2
    - multiprocess: 0.70.14
    - munkres: 1.1.4
    - nest-asyncio: 1.5.6
    - networkx: 3.1
    - numexpr: 2.8.4
    - numpy: 1.25.0
    - packaging: 23.0
    - palettable: 3.3.3
    - pandas: 1.5.3
    - parso: 0.8.3
    - pexpect: 4.8.0
    - pickleshare: 0.7.5
    - pillow: 9.4.0
    - pip: 23.2.1
    - platformdirs: 2.5.2
    - plotly: 5.15.0
    - pooch: 1.4.0
    - prompt-toolkit: 3.0.36
    - protobuf: 3.20.3
    - psutil: 5.9.0
    - ptyprocess: 0.7.0
    - pure-eval: 0.2.2
    - pyarrow: 11.0.0
    - pybtex: 0.24.0
    - pycairo: 1.23.0
    - pycparser: 2.21
    - pydantic: 1.10.12
    - pygments: 2.15.1
    - pymatgen: 2023.7.20
    - pyopenssl: 23.2.0
    - pyparsing: 3.0.9
    - pysocks: 1.7.1
    - python-dateutil: 2.8.2
    - pytorch-lightning: 2.0.6
    - pytz: 2022.7
    - pyyaml: 6.0
    - pyzmq: 25.1.0
    - rdkit: 2023.3.2
    - regex: 2022.7.9
    - reportlab: 3.6.12
    - requests: 2.31.0
    - responses: 0.13.3
    - ruamel.yaml: 0.17.32
    - ruamel.yaml.clib: 0.2.7
    - sacremoses: 0.0.43
    - safetensors: 0.3.1
    - scikit-learn: 1.2.2
    - scipy: 1.10.1
    - setuptools: 68.0.0
    - six: 1.16.0
    - smart-open: 6.3.0
    - smilespe: 0.0.3
    - spglib: 2.0.2
    - sqlalchemy: 1.4.39
    - stack-data: 0.2.0
    - sympy: 1.11.1
    - tabulate: 0.9.0
    - tenacity: 8.2.2
    - tensorboardx: 2.2
    - threadpoolctl: 2.2.0
    - tokenizers: 0.13.2
    - torch: 2.0.1
    - torchaudio: 2.0.2
    - torchmetrics: 1.0.1
    - torchvision: 0.15.2
    - tornado: 6.3.2
    - tqdm: 4.65.0
    - traitlets: 5.7.1
    - transformers: 4.31.0
    - triton: 2.0.0
    - typing-extensions: 4.7.1
    - uncertainties: 3.1.7
    - urllib3: 1.26.16
    - wcwidth: 0.2.5
    - wheel: 0.38.4
    - xxhash: 2.0.2
    - yarl: 1.8.1
    - zipp: 3.11.0
  • System:
    - OS: Linux
    - architecture:
    - 64bit
    - ELF
    - processor: x86_64
    - python: 3.9.17
    - release: 5.15.0-60-generic
    - version: No back #66-Ubuntu SMP Fri Jan 20 14:29:49 UTC 2023

More info

I

cc @justusschock @awaelchli

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions