Skip to content

Commit

Permalink
Warn user against using torch tensors as arguments of random variables (
Browse files Browse the repository at this point in the history
facebookresearch#1639)

Summary:
Pull Request resolved: facebookresearch#1639

Bean Machine uses the hash value of the arguments of random variables to identify them, which means that foo(1) and foo(tensor(1)) are considered two different random variables. Further in PyTorch, tensors are hashed by memory address instead of by value, so we can have hash(tensor(1)) != hash(tensor(1)). Therefore, it’s not recommended to use tensors as indices of random variables.
In this change,
1. In `rv_identifier.py` we identify if tensors are used as arguments to RVs and warn the user against its use.
2. Added a test case to `rv_identifier_test.py` to check if the warning is triggered correctly when the user provides a tensor instead of a primitive argument.

Differential Revision: D39169577

fbshipit-source-id: 487719d5fd66de2936e7e1f6d8885f6edb5498ac
  • Loading branch information
AishwaryaSivaraman authored and facebook-github-bot committed Aug 31, 2022
1 parent 163f54a commit 13afc25
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 0 deletions.
2 changes: 2 additions & 0 deletions pyproject.toml
Expand Up @@ -24,6 +24,8 @@ filterwarnings = [
"default:pandas.Int64Index is deprecated *:FutureWarning",
# functorch 0.1.0 imports deprecated _stateless module
"default:The `torch.nn.utils._stateless` code is deprecated*:DeprecationWarning",
# BM warns against using torch tensors as arguments of random variables
"default:PyTorch tensors are hashed by memory address instead of value.*:UserWarning",
]

[tool.usort]
Expand Down
11 changes: 11 additions & 0 deletions src/beanmachine/ppl/model/rv_identifier.py
Expand Up @@ -3,9 +3,12 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import warnings
from dataclasses import dataclass
from typing import Callable, Tuple

import torch


@dataclass(eq=True, frozen=True)
class RVIdentifier:
Expand All @@ -17,6 +20,14 @@ class RVIdentifier:
wrapper: Callable
arguments: Tuple

def __post_init__(self):
for arg in self.arguments:
if torch.is_tensor(arg):
warnings.warn(
"PyTorch tensors are hashed by memory address instead of value. "
"Therefore, it is not recommended to use tensors as indices of random variables."
)

def __str__(self):
return str(self.function.__name__) + str(self.arguments)

Expand Down
23 changes: 23 additions & 0 deletions src/beanmachine/ppl/model/tests/rv_identifier_test.py
Expand Up @@ -5,6 +5,7 @@

import pickle
import unittest
import warnings

import beanmachine.ppl as bm
import torch
Expand Down Expand Up @@ -39,6 +40,28 @@ def foo(self):
def __eq__(self, other):
return isinstance(other, RVIdentifierTest.SampleModelWithEq)

class SampleModelWithIndex:
@bm.random_variable
def foo(self, u: int):
return dist.Normal(torch.tensor(0.0), torch.tensor(1.0))

def test_indexed_model_rv_identifier(self):
model = self.SampleModelWithIndex()

# RVs indexed using primitives should not show a user warning
with warnings.catch_warnings():
warnings.simplefilter("error")
model.foo(1)

# RVs indexed using tensors should show a user warning
with self.assertWarns(UserWarning) as context:
model.foo(torch.tensor(1))
self.assertEqual(
"PyTorch tensors are hashed by memory address instead of value. "
"Therefore, it is not recommended to use tensors as indices of random variables.",
str(context.warning),
)

def test_pickle_unbound_rv_identifier(self):
original_foo_key = foo()
foo_bytes = pickle.dumps(foo())
Expand Down

0 comments on commit 13afc25

Please sign in to comment.