Skip to content

Commit

Permalink
vi 4 abms
Browse files Browse the repository at this point in the history
  • Loading branch information
arnauqb committed Nov 14, 2023
1 parent 608eae7 commit 759a9f2
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 1 deletion.
2 changes: 1 addition & 1 deletion blackbirds/infer/vi.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ def loss_aux(params):
continue
loss += float(loss_i)
if type(jacobian) == torch.Tensor:
jacobians_per_rank.append(jacobian.cpu().numpy())
jacobians_per_rank.append(jacobian.detach().cpu().numpy())
else:
jacobians_per_rank.append(jacobian)
indices_per_rank.append(i)
Expand Down
25 changes: 25 additions & 0 deletions blackbirds/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import torch
import normflows as nf


def soft_maximum(a: torch.Tensor, b: torch.Tensor, k: float):
Expand All @@ -25,3 +26,27 @@ def soft_minimum(a: torch.Tensor, b: torch.Tensor, k: float):
- `k`: Hardness.
"""
return -soft_maximum(-a, -b, k)


class Sigmoid(nf.flows.Flow):
def __init__(self, min_values, max_values):
super().__init__()
self.min_values = min_values
self.max_values = max_values

def inverse(self, z):
logz = torch.log(z - self.min_values)
log1mz = torch.log(self.max_values - z)
z = logz - log1mz
sum_dims = list(range(1, z.dim()))
log_det = -torch.sum(logz, dim=sum_dims) - torch.sum(log1mz, dim=sum_dims)
return z, log_det

def forward(self, z):
sum_dims = list(range(1, z.dim()))
ls = torch.sum(torch.nn.functional.logsigmoid(z), dim=sum_dims)
mls = torch.sum(torch.nn.functional.logsigmoid(-z), dim=sum_dims)
lls = torch.sum(torch.log(self.max_values - self.min_values))
log_det = ls + mls + lls
z = self.min_values + (self.max_values - self.min_values) * torch.sigmoid(z)
return z, log_det

0 comments on commit 759a9f2

Please sign in to comment.