-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Closed
Labels
questionFurther information is requestedFurther information is requestedstrategy: ddpDistributedDataParallelDistributedDataParallelver: 2.0.x
Description
Bug description
when I use pl to train my model, this python script can only run when I use one GPU, after I change the trainer's arg "devices=[0] " to "devices=[0,1,2,3]" ,this script can's finish data sanity check.
What version are you seeing the problem on?
v2.0
How to reproduce the bug
PAD_IDX = 1
# 数据设置
df = pd.read_csv("group_selfies.txt", sep="\t", header=None)
train_df, test_df = train_test_split(df, test_size=0.01, random_state=42)
train_dataset = GSDataset(data=train_df, vocab_file="vocab_gs.txt", max_length=256)
test_dataset = GSDataset(data=test_df, vocab_file="vocab_gs.txt", max_length=256)
print(f"train_dataset:{len(train_dataset)},test_dataset:{len(test_dataset)}")
# 模型设置
BATCH = 256
epoch = 150
ntoken = train_dataset.get_vocab_size()
d_model = 64
nhead = 8
d_hid = 4096
nlayers = 4
hparams = {
"ntoken": ntoken,
"BATCH": BATCH,
"epoch": epoch,
"d_model": d_model,
"nhead": nhead,
"d_hid": d_hid,
"nlayers": nlayers,
"lr": 1e-3,
"remarks": """
decoder-only
""",
}
# 传入模型
model = GSTransformer(
train_dataset,
test_dataset,
ntoken,
hparams["d_model"],
hparams["nhead"],
hparams["d_hid"],
hparams["nlayers"],
lr=1e-3,
hparams=hparams,
)
## re-pretrain
# checkpoint = torch.load(
# "lightning_logs/version_1/checkpoints/epoch=131-step=17028.ckpt"
# )
# model.load_state_dict(checkpoint["state_dict"])
# 设置回调函数
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint
checkpoint_callback = ModelCheckpoint(monitor="val_loss")
early_stop_callback = pl.callbacks.EarlyStopping(
monitor="val_loss", min_delta=0.00, patience=10, verbose=False, mode="min"
)
lr_monitor = LearningRateMonitor(logging_interval="epoch")
# Train model
trainer = pl.Trainer(
max_epochs=epoch,
devices=4,
accelerator="gpu",
strategy="ddp",
callbacks=[checkpoint_callback, lr_monitor],
accumulate_grad_batches=2,
precision="16-mixed",
logger=pl.loggers.TensorBoardLogger(
"lightning_logs/group-selfies", name="gs2gs-recon"
),
)
trainer.fit(
model,
DataLoader(train_dataset, batch_size=BATCH, shuffle=True, num_workers=0),
DataLoader(test_dataset, batch_size=BATCH * 2, shuffle=False, num_workers=0),
)Error messages and logs
# Error messages and logs here please
Environment
Current environment
- CUDA:
- GPU:
- NVIDIA GeForce RTX 4090
- NVIDIA GeForce RTX 4090
- NVIDIA GeForce RTX 4090
- NVIDIA GeForce RTX 4090
- available: True
- version: 11.8 - Lightning:
- lightning-utilities: 0.9.0
- pytorch-lightning: 2.0.6
- torch: 2.0.1
- torchaudio: 2.0.2
- torchmetrics: 1.0.1
- torchvision: 0.15.2 - Packages:
- aiohttp: 3.8.3
- aiosignal: 1.2.0
- appdirs: 1.4.4
- asttokens: 2.0.5
- async-timeout: 4.0.2
- attrs: 22.1.0
- backcall: 0.2.0
- bottleneck: 1.3.5
- brotlipy: 0.7.0
- certifi: 2023.7.22
- cffi: 1.15.1
- charset-normalizer: 2.0.4
- click: 8.0.4
- comm: 0.1.2
- contourpy: 1.0.5
- cryptography: 41.0.2
- cycler: 0.11.0
- datasets: 2.12.0
- debugpy: 1.6.7
- decorator: 5.1.1
- dill: 0.3.6
- emmet-core: 0.63.1
- executing: 0.8.3
- fastprogress: 1.0.3
- filelock: 3.9.0
- fonttools: 4.25.0
- frozenlist: 1.3.3
- fsspec: 2023.4.0
- future: 0.18.3
- gensim: 4.3.1
- global-chem: 1.8.1.2
- gmpy2: 2.1.2
- greenlet: 2.0.1
- group-selfies: 1.0.0
- huggingface-hub: 0.15.1
- idna: 3.4
- importlib-metadata: 6.0.0
- importlib-resources: 5.2.0
- iprogress: 0.4
- ipykernel: 6.19.2
- ipython: 8.12.0
- jedi: 0.18.1
- jinja2: 3.1.2
- joblib: 1.2.0
- jupyter-client: 8.1.0
- jupyter-core: 5.3.0
- kiwisolver: 1.4.4
- latexcodec: 2.0.1
- lightning-utilities: 0.9.0
- markupsafe: 2.1.1
- matplotlib: 3.7.1
- matplotlib-inline: 0.1.6
- mkl-fft: 1.3.6
- mkl-random: 1.2.2
- mkl-service: 2.4.0
- monty: 2023.5.8
- mp-api: 0.33.3
- mpmath: 1.3.0
- msgpack: 1.0.5
- multidict: 6.0.2
- multiprocess: 0.70.14
- munkres: 1.1.4
- nest-asyncio: 1.5.6
- networkx: 3.1
- numexpr: 2.8.4
- numpy: 1.25.0
- packaging: 23.0
- palettable: 3.3.3
- pandas: 1.5.3
- parso: 0.8.3
- pexpect: 4.8.0
- pickleshare: 0.7.5
- pillow: 9.4.0
- pip: 23.2.1
- platformdirs: 2.5.2
- plotly: 5.15.0
- pooch: 1.4.0
- prompt-toolkit: 3.0.36
- protobuf: 3.20.3
- psutil: 5.9.0
- ptyprocess: 0.7.0
- pure-eval: 0.2.2
- pyarrow: 11.0.0
- pybtex: 0.24.0
- pycairo: 1.23.0
- pycparser: 2.21
- pydantic: 1.10.12
- pygments: 2.15.1
- pymatgen: 2023.7.20
- pyopenssl: 23.2.0
- pyparsing: 3.0.9
- pysocks: 1.7.1
- python-dateutil: 2.8.2
- pytorch-lightning: 2.0.6
- pytz: 2022.7
- pyyaml: 6.0
- pyzmq: 25.1.0
- rdkit: 2023.3.2
- regex: 2022.7.9
- reportlab: 3.6.12
- requests: 2.31.0
- responses: 0.13.3
- ruamel.yaml: 0.17.32
- ruamel.yaml.clib: 0.2.7
- sacremoses: 0.0.43
- safetensors: 0.3.1
- scikit-learn: 1.2.2
- scipy: 1.10.1
- setuptools: 68.0.0
- six: 1.16.0
- smart-open: 6.3.0
- smilespe: 0.0.3
- spglib: 2.0.2
- sqlalchemy: 1.4.39
- stack-data: 0.2.0
- sympy: 1.11.1
- tabulate: 0.9.0
- tenacity: 8.2.2
- tensorboardx: 2.2
- threadpoolctl: 2.2.0
- tokenizers: 0.13.2
- torch: 2.0.1
- torchaudio: 2.0.2
- torchmetrics: 1.0.1
- torchvision: 0.15.2
- tornado: 6.3.2
- tqdm: 4.65.0
- traitlets: 5.7.1
- transformers: 4.31.0
- triton: 2.0.0
- typing-extensions: 4.7.1
- uncertainties: 3.1.7
- urllib3: 1.26.16
- wcwidth: 0.2.5
- wheel: 0.38.4
- xxhash: 2.0.2
- yarl: 1.8.1
- zipp: 3.11.0 - System:
- OS: Linux
- architecture:
- 64bit
- ELF
- processor: x86_64
- python: 3.9.17
- release: 5.15.0-60-generic
- version: No back #66-Ubuntu SMP Fri Jan 20 14:29:49 UTC 2023
More info
I
Metadata
Metadata
Assignees
Labels
questionFurther information is requestedFurther information is requestedstrategy: ddpDistributedDataParallelDistributedDataParallelver: 2.0.x


