Skip to content

[Bug] GlobalMutualInformationLoss: bin_centers not registered as buffer — silent gradient tracking + wrong device placement #8866

@Zeesejo

Description

@Zeesejo

Describe the bug

In GlobalMutualInformationLoss.__init__ (monai/losses/image_dissimilarity.py, ~line 210), the bin_centers tensor is assigned as a plain Python attribute:

self.bin_centers = bin_centers[None, None, ...]  # shape (1, 1, num_bins)

This is the same pattern as the bug fixed in #8819 for LocalNormalizedCrossCorrelationLoss. Because bin_centers is not registered via self.register_buffer(...), it has two related problems:

Problem 1: Wrong device placement
When the user calls loss_fn.cuda() or loss_fn.to(device), PyTorch only moves register_buffer-registered tensors. The plain self.bin_centers stays on CPU, causing a RuntimeError: Expected all tensors to be on the same device when the loss is computed on a GPU.

Problem 2: Silent gradient tracking
bin_centers is a constant — it should never accumulate gradients. However, because it is not registered as a buffer, it is not guaranteed to have requires_grad=False across all code paths, and it will not be excluded properly from gradient-related checks.

To Reproduce

import torch
from monai.losses import GlobalMutualInformationLoss

loss_fn = GlobalMutualInformationLoss().cuda()
x = torch.rand(1, 1, 32, 32, 32, device="cuda")
y = torch.rand(1, 1, 32, 32, 32, device="cuda")
print(loss_fn(x, y))  # RuntimeError: device mismatch

Expected behavior

bin_centers should be registered as a buffer so it automatically moves with the module:

# Instead of:
self.bin_centers = bin_centers[None, None, ...]

# Use:
self.register_buffer("bin_centers", bin_centers[None, None, ...])

This mirrors the fix applied to LocalNormalizedCrossCorrelationLoss in PR #8818.

Environment

  • MONAI dev branch (confirmed in monai/losses/image_dissimilarity.py)
  • Affects all MONAI versions where GlobalMutualInformationLoss uses kernel_type="gaussian"

Additional context

This was noted as a potential sibling issue in #8819. I'm happy to open a PR with the fix if this is confirmed.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions