Skip to content

Logging tensors as hparam fails #9022

@ifsheldon

Description

@ifsheldon

🐛 Bug

When a trainer starts fit() and logging hyper parameters, an error occurs if one of hyperparameters is a Tensor.

To Reproduce

import pytorch_lightning as pl
import torch
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader


class Model(pl.LightningModule):
    def __init__(self, some_tensor, num):
        super().__init__()
        self.save_hyperparameters()  # Not OK
        # self.save_hyperparameters(ignore="some_tensor")  # OK
        self.tensor0 = torch.nn.Parameter(some_tensor, requires_grad=False)
        self.tensor1 = torch.nn.Parameter(torch.ones(1))
        self.num = num

    def forward(self, batch):
        return self.tensor0.sum() + batch + self.tensor1

    def training_step(self, batch, batch_idx):
        input, target = batch
        pred = self(input)
        return F.mse_loss(pred, target)

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters())


class RandomData(Dataset):
    def __getitem__(self, item):
        return torch.randn(1)

    def __len__(self):
        return 100


if __name__ == "__main__":
    dataset = RandomData()
    pl.seed_everything(1, workers=True)
    tensor = torch.randn(2, 3)
    model = Model(tensor, 1)
    dataloader = DataLoader(dataset, batch_size=2, num_workers=2)
    trainer = pl.Trainer(gpus=[1, 2],
                         max_steps=10,
                         accelerator="ddp")
    trainer.fit(model, dataloader)

Expected behavior

If I ignore the some_tensor, the trainer works fine, but if I don't, the trainer will give me an error and keep one of my GPU (GPU-2) running at 100%, which is very weird to me.

The error is

Traceback (most recent call last):
  File "/home/liangf/research/SemanticNetMPL/temp/test.py", line 45, in <module>
    trainer.fit(model, dataloader)
  File "/home/liangf/miniconda3/envs/torch/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 553, in fit
    self._run(model)
  File "/home/liangf/miniconda3/envs/torch/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 912, in _run
    self._pre_dispatch()
  File "/home/liangf/miniconda3/envs/torch/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 941, in _pre_dispatch
    self._log_hyperparams()
  File "/home/liangf/miniconda3/envs/torch/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 968, in _log_hyperparams
    self.logger.log_hyperparams(hparams_initial)
  File "/home/liangf/miniconda3/envs/torch/lib/python3.9/site-packages/pytorch_lightning/utilities/distributed.py", line 48, in wrapped_fn
    return fn(*args, **kwargs)
  File "/home/liangf/miniconda3/envs/torch/lib/python3.9/site-packages/pytorch_lightning/loggers/tensorboard.py", line 196, in log_hyperparams
    exp, ssi, sei = hparams(params, metrics)
  File "/home/liangf/miniconda3/envs/torch/lib/python3.9/site-packages/torch/utils/tensorboard/summary.py", line 192, in hparams
    ssi.hparams[k].number_value = v
TypeError: array([0.66135216, 0.2669241 , 0.06167726], dtype=float32) has type numpy.ndarray, but expected one of: int, long, float

This seems to indicate that we cannot save tensors as hyperparameters, but interestingly, there are codes checking if a hyperparameter is a Tensor and converting it to numpy array, see this

     if isinstance(v, torch.Tensor):
            v = make_np(v)[0]
            ssi.hparams[k].number_value = v  # torch/lib/python3.9/site-packages/torch/utils/tensorboard/summary.py
            hps.append(HParamInfo(name=k, type=DataType.Value("DATA_TYPE_FLOAT64")))
            continue

As for the running forever problem, I think it is because the process at rank 0 is died from this error, but the spawned process is not and it runs wild.

Environment

  • CUDA:
    - GPU:
    - NVIDIA RTX A6000
    - NVIDIA RTX A6000
    - NVIDIA RTX A6000
    - NVIDIA RTX A6000
    - available: True
    - version: 11.1
  • Packages:
    - numpy: 1.20.1
    - pyTorch_debug: False
    - pyTorch_version: 1.9.0
    - pytorch-lightning: 1.4.1
    - tqdm: 4.61.2
  • System:
    - OS: Linux
    - architecture:
    - 64bit
    - ELF
    - processor: x86_64
    - python: 3.9.2
    - version: Update github url for new project template #90-Ubuntu SMP Fri Jul 9 22:49:44 UTC 2021

Additional context

I think an old issue may also be caused by this problem, but no one tried to take a closer look at it.

Metadata

Metadata

Assignees

No one assigned

    Labels

    3rd partyRelated to a 3rd-partybugSomething isn't workinghelp wantedOpen to be worked onloggerRelated to the Loggers

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions