Skip to content

Commit

Permalink
black
Browse files Browse the repository at this point in the history
  • Loading branch information
alexhernandezgarcia committed May 7, 2024
1 parent 4267acb commit 7e9cc44
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 13 deletions.
21 changes: 9 additions & 12 deletions gflownet/proxy/corners.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,19 +40,16 @@ def min(self):
return self._min

def __call__(self, states: TensorType["batch", "state_dim"]) -> TensorType["batch"]:
return (
self.mulnormal_norm
* torch.exp(
-0.5
* (
torch.diag(
return self.mulnormal_norm * torch.exp(
-0.5
* (
torch.diag(
torch.tensordot(
torch.tensordot(
torch.tensordot(
(torch.abs(states) - self.mu_vec), self.cov_inv, dims=1
),
(torch.abs(states) - self.mu_vec).T,
dims=1,
)
(torch.abs(states) - self.mu_vec), self.cov_inv, dims=1
),
(torch.abs(states) - self.mu_vec).T,
dims=1,
)
)
)
Expand Down
2 changes: 1 addition & 1 deletion gflownet/proxy/tetris.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def setup(self, env=None):
@property
def norm(self):
if self.normalize:
return (self.height * self.width)
return self.height * self.width
else:
return 1.0

Expand Down

0 comments on commit 7e9cc44

Please sign in to comment.