-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Closed
Labels
3rd partyRelated to a 3rd-partyRelated to a 3rd-partybugSomething isn't workingSomething isn't workinghelp wantedOpen to be worked onOpen to be worked onloggerRelated to the LoggersRelated to the Loggers
Description
🐛 Bug
When a trainer starts fit() and logging hyper parameters, an error occurs if one of hyperparameters is a Tensor.
To Reproduce
import pytorch_lightning as pl
import torch
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
class Model(pl.LightningModule):
def __init__(self, some_tensor, num):
super().__init__()
self.save_hyperparameters() # Not OK
# self.save_hyperparameters(ignore="some_tensor") # OK
self.tensor0 = torch.nn.Parameter(some_tensor, requires_grad=False)
self.tensor1 = torch.nn.Parameter(torch.ones(1))
self.num = num
def forward(self, batch):
return self.tensor0.sum() + batch + self.tensor1
def training_step(self, batch, batch_idx):
input, target = batch
pred = self(input)
return F.mse_loss(pred, target)
def configure_optimizers(self):
return torch.optim.Adam(self.parameters())
class RandomData(Dataset):
def __getitem__(self, item):
return torch.randn(1)
def __len__(self):
return 100
if __name__ == "__main__":
dataset = RandomData()
pl.seed_everything(1, workers=True)
tensor = torch.randn(2, 3)
model = Model(tensor, 1)
dataloader = DataLoader(dataset, batch_size=2, num_workers=2)
trainer = pl.Trainer(gpus=[1, 2],
max_steps=10,
accelerator="ddp")
trainer.fit(model, dataloader)Expected behavior
If I ignore the some_tensor, the trainer works fine, but if I don't, the trainer will give me an error and keep one of my GPU (GPU-2) running at 100%, which is very weird to me.
The error is
Traceback (most recent call last):
File "/home/liangf/research/SemanticNetMPL/temp/test.py", line 45, in <module>
trainer.fit(model, dataloader)
File "/home/liangf/miniconda3/envs/torch/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 553, in fit
self._run(model)
File "/home/liangf/miniconda3/envs/torch/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 912, in _run
self._pre_dispatch()
File "/home/liangf/miniconda3/envs/torch/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 941, in _pre_dispatch
self._log_hyperparams()
File "/home/liangf/miniconda3/envs/torch/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 968, in _log_hyperparams
self.logger.log_hyperparams(hparams_initial)
File "/home/liangf/miniconda3/envs/torch/lib/python3.9/site-packages/pytorch_lightning/utilities/distributed.py", line 48, in wrapped_fn
return fn(*args, **kwargs)
File "/home/liangf/miniconda3/envs/torch/lib/python3.9/site-packages/pytorch_lightning/loggers/tensorboard.py", line 196, in log_hyperparams
exp, ssi, sei = hparams(params, metrics)
File "/home/liangf/miniconda3/envs/torch/lib/python3.9/site-packages/torch/utils/tensorboard/summary.py", line 192, in hparams
ssi.hparams[k].number_value = v
TypeError: array([0.66135216, 0.2669241 , 0.06167726], dtype=float32) has type numpy.ndarray, but expected one of: int, long, floatThis seems to indicate that we cannot save tensors as hyperparameters, but interestingly, there are codes checking if a hyperparameter is a Tensor and converting it to numpy array, see this
if isinstance(v, torch.Tensor):
v = make_np(v)[0]
ssi.hparams[k].number_value = v # torch/lib/python3.9/site-packages/torch/utils/tensorboard/summary.py
hps.append(HParamInfo(name=k, type=DataType.Value("DATA_TYPE_FLOAT64")))
continueAs for the running forever problem, I think it is because the process at rank 0 is died from this error, but the spawned process is not and it runs wild.
Environment
- CUDA:
- GPU:
- NVIDIA RTX A6000
- NVIDIA RTX A6000
- NVIDIA RTX A6000
- NVIDIA RTX A6000
- available: True
- version: 11.1 - Packages:
- numpy: 1.20.1
- pyTorch_debug: False
- pyTorch_version: 1.9.0
- pytorch-lightning: 1.4.1
- tqdm: 4.61.2 - System:
- OS: Linux
- architecture:
- 64bit
- ELF
- processor: x86_64
- python: 3.9.2
- version: Update github url for new project template #90-Ubuntu SMP Fri Jul 9 22:49:44 UTC 2021
Additional context
I think an old issue may also be caused by this problem, but no one tried to take a closer look at it.
Metadata
Metadata
Assignees
Labels
3rd partyRelated to a 3rd-partyRelated to a 3rd-partybugSomething isn't workingSomething isn't workinghelp wantedOpen to be worked onOpen to be worked onloggerRelated to the LoggersRelated to the Loggers