-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Closed
Labels
bugSomething isn't workingSomething isn't workinghelp wantedOpen to be worked onOpen to be worked on
Description
🐛 Bug
pytorch-lightning>=1.4.0 errors before model training for RNNs with DeepSpeed enabled. The same model works as expected for earlier pytorch-lightning==1.3.7.
Specific error: RuntimeError: shape '[768, 1]' is invalid for input of size 1
Using deepspeed==0.5.1 and torch==1.9.0+cu111 in both cases.
To Reproduce
import torch
import torch.nn.functional as F
import pytorch_lightning as pl
from torch import nn
from torch.utils.data import TensorDataset, DataLoader
from deepspeed.ops.adam import DeepSpeedCPUAdam
def create_training_dataloader(
examples: int,
features: int,
sequence_length: int = 1,
output_size: int = 1,
batch_size: int = 32,
):
A = torch.randn(features, output_size)
x = torch.rand((examples, sequence_length, features))
y = x @ A
y = y[:, -1, :]
dataset = TensorDataset(x, y)
dataloader = DataLoader(dataset, batch_size=batch_size, drop_last=True)
return dataloader
class GRUNetwork(pl.LightningModule):
def __init__(
self,
features: int,
output_size: int,
hidden_size: int,
):
super().__init__()
self.rnn = nn.GRU(features, hidden_size, num_layers=1, batch_first=True)
self.output_layer = nn.Linear(hidden_size, output_size)
def forward(self, x):
x, _ = self.rnn(x)
output = self.output_layer(x)
return output
def training_step(self, batch, batch_idx):
x, y = batch
yhat = self(x)
loss = F.mse_loss(yhat, y)
return loss
def configure_optimizers(self):
optimizer = DeepSpeedCPUAdam(self.parameters(), lr=1e-3)
return optimizer
dataloader = create_training_dataloader(examples=10_000, features=16, sequence_length=32)
model = GRUNetwork(features=16, output_size=1, hidden_size=16)
trainer = pl.Trainer(
gpus=1,
max_epochs=5,
precision=16,
checkpoint_callback=False,
logger=False,
plugins="deepspeed_stage_3_offload",
)
trainer.fit(model, dataloader)
This gives the following:
> python debug_deepspeed.py
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
initializing deepspeed distributed: GLOBAL_RANK: 0, MEMBER: 1/1
Enabling DeepSpeed FP16.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [1]
You have not specified an optimizer or scheduler within the DeepSpeed config.Using `configure_optimizers` to define optimizer and scheduler.
Using /home/username/.cache/torch_extensions as PyTorch extensions root...
/home/username/.virtualenvs/deepspeed_env/lib64/python3.6/site-packages/torch/utils/cpp_extension.py:287: UserWarning:
!! WARNING !!
!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
Your compiler (c++) is not compatible with the compiler Pytorch was
built with for this platform, which is g++ on linux. Please
use g++ to to compile your extension. Alternatively, you may
compile PyTorch from source using c++, and then you can also use
c++ to compile your extension.
See https://github.com/pytorch/pytorch/blob/master/CONTRIBUTING.md for help
with compiling PyTorch from source.
!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
!! WARNING !!
platform=sys.platform))
Detected CUDA files, patching ldflags
Emitting ninja build file /home/username/.cache/torch_extensions/cpu_adam/build.ninja...
Building extension module cpu_adam...
Allowing ninja to set a default number of workers... (overridable by setting the environment variable MAX_JOBS=N)
ninja: no work to do.
Loading extension module cpu_adam...
Time to load cpu_adam op: 0.42862367630004883 seconds
Traceback (most recent call last):
File "debug_deepspeed.py", line 68, in <module>
trainer.fit(model, dataloader)
File "/home/username/.virtualenvs/deepspeed_env/lib/python3.6/site-packages/pytorch_lightning/trainer/trainer.py", line 552, in fit
self._run(model)
File "/home/username/.virtualenvs/deepspeed_env/lib/python3.6/site-packages/pytorch_lightning/trainer/trainer.py", line 911, in _run
self._pre_dispatch()
File "/home/username/.virtualenvs/deepspeed_env/lib/python3.6/site-packages/pytorch_lightning/trainer/trainer.py", line 939, in _pre_dispatch
self.accelerator.pre_dispatch(self)
File "/home/username/.virtualenvs/deepspeed_env/lib/python3.6/site-packages/pytorch_lightning/accelerators/accelerator.py", line 104, in pre_dispatch
self.training_type_plugin.pre_dispatch()
File "/home/username/.virtualenvs/deepspeed_env/lib/python3.6/site-packages/pytorch_lightning/plugins/training_type/deepspeed.py", line 369, in pre_dispatch
self.init_deepspeed()
File "/home/username/.virtualenvs/deepspeed_env/lib/python3.6/site-packages/pytorch_lightning/plugins/training_type/deepspeed.py", line 386, in init_deepspeed
self._initialize_deepspeed_train(model)
File "/home/username/.virtualenvs/deepspeed_env/lib/python3.6/site-packages/pytorch_lightning/plugins/training_type/deepspeed.py", line 427, in _initialize_deepspeed_train
dist_init_required=False,
File "/home/username/.virtualenvs/deepspeed_env/lib/python3.6/site-packages/deepspeed/__init__.py", line 141, in initialize
config_params=config_params)
File "/home/username/.virtualenvs/deepspeed_env/lib/python3.6/site-packages/deepspeed/runtime/engine.py", line 182, in __init__
self._configure_distributed_model(model)
File "/home/username/.virtualenvs/deepspeed_env/lib/python3.6/site-packages/deepspeed/runtime/engine.py", line 730, in _configure_distributed_model
self.module.to(self.device)
File "/home/username/.virtualenvs/deepspeed_env/lib/python3.6/site-packages/pytorch_lightning/core/mixins/device_dtype_mixin.py", line 109, in to
return super().to(*args, **kwargs)
File "/home/username/.virtualenvs/deepspeed_env/lib64/python3.6/site-packages/torch/nn/modules/module.py", line 852, in to
return self._apply(convert)
File "/home/username/.virtualenvs/deepspeed_env/lib64/python3.6/site-packages/torch/nn/modules/module.py", line 530, in _apply
module._apply(fn)
File "/home/username/.virtualenvs/deepspeed_env/lib64/python3.6/site-packages/torch/nn/modules/module.py", line 530, in _apply
module._apply(fn)
File "/home/username/.virtualenvs/deepspeed_env/lib64/python3.6/site-packages/torch/nn/modules/rnn.py", line 189, in _apply
self.flatten_parameters()
File "/home/username/.virtualenvs/deepspeed_env/lib64/python3.6/site-packages/torch/nn/modules/rnn.py", line 179, in flatten_parameters
self.batch_first, bool(self.bidirectional))
RuntimeError: shape '[768, 1]' is invalid for input of size 1
Expected behavior
Model should train as normal (as it does in earlier lightning versions).
(DeepSpeed not expected to provide any further benefit in terms of reducing GPU memory usage over regular PyTorch activation checkpointing for this model).
Environment
* CUDA:
- GPU:
- Tesla V100-SXM2-32GB
- Tesla V100-SXM2-32GB
- available: True
- version: 11.1
* Packages:
- numpy: 1.19.5
- pyTorch_debug: False
- pyTorch_version: 1.9.0+cu111
- pytorch-lightning: 1.4.5
- tqdm: 4.62.2
* System:
- OS: Linux
- architecture:
- 64bit
-
- processor: x86_64
- python: 3.6.8
- version: #1 SMP Thu Apr 8 19:01:30 UTC 2021
Additional context
deepspeed=0.5.1
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't workinghelp wantedOpen to be worked onOpen to be worked on