Skip to content

Commit

Permalink
Removed cost wrapper, added reduce wrapper
Browse files Browse the repository at this point in the history
- Removed the cost wrapper as it did not have much of a logical use anymore. See #32.
- Replaced `DeterministicTensor` by `CostTensor`, as its functionality was only to be able to accomodate positive `is_cost` checks.
- Added a experimental`reduce` wrapper that is able to reduce a batched dim without raising an error. Might remove it, as the use case did not actually need it.
- Reworked `storch.nn.b_binary_cross_entropy` to better accomodate the current API.
  • Loading branch information
HEmile committed Mar 5, 2020
1 parent 7a15c8e commit d17c2fb
Show file tree
Hide file tree
Showing 11 changed files with 163 additions and 146 deletions.
34 changes: 13 additions & 21 deletions examples/discrete_vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from torch import nn, optim
from torch.nn import functional as F
from torchvision.utils import save_image
from storch import deterministic, cost, backward
from storch import deterministic, backward
import storch
from torch.distributions import OneHotCategorical, RelaxedOneHotCategorical
from examples.dataloader.data_loader import data_loaders
Expand Down Expand Up @@ -81,18 +81,17 @@ def decode(self, z):
h4 = self.activation(self.fc5(h3))
return self.fc6(h4).sigmoid()

@cost
def KLD(self, p, q):
# see Appendix B from VAE paper:
# Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014
# https://arxiv.org/abs/1312.6114

div = torch.distributions.kl_divergence(p, q)
return div.sum(-1)
kld = torch.distributions.kl_divergence(p, q).sum(-1)
storch.add_cost(kld, "KL-divergence")
return kld

def forward(self, x):
logits = self.encode(x.view(-1, 784))
logits = storch.denote_independent(logits, 0, "data") # Denote the minibatch dimension as being independent
logits = self.encode(x)
logits = logits.reshape(logits.shape[:-1] + (self.latents, 10))

q = OneHotCategorical(logits=logits)
Expand All @@ -108,16 +107,9 @@ def forward(self, x):


# Reconstruction + KL divergence losses summed over all elements and batch
@cost
def loss_function(recon_x, x):
x = x.view(-1, 784)
# TODO: Not going to work: These are now two different data batch links.
# There should be another way to link batch dimensions manually. Possibly requires some different object as key.
x = storch.denote_independent(x, 0, "data")
# print(recon_x, x)
BCE = storch.nn.b_binary_cross_entropy(recon_x, x, reduction="sum")

return BCE
bce = storch.nn.b_binary_cross_entropy(recon_x, x, reduction="sum")
return bce


def train(epoch):
Expand All @@ -128,8 +120,10 @@ def train(epoch):
optimizer.zero_grad()
storch.reset()

# Denote the minibatch dimension as being independent
data = storch.denote_independent(data.view(-1, 784), 0, "data")
recon_batch, KLD, z = model(data)
loss_function(recon_batch, data)
storch.add_cost(loss_function(recon_batch, data), "reconstruction")
cond_log = batch_idx % args.log_interval == 0
cost, loss = backward()
train_loss += loss.item()
Expand All @@ -150,16 +144,14 @@ def _var(t):
squared_diff = (m - mean)**2
sse = squared_diff.sum(0)
return sse.mean()
avg_loss = loss / len(data)
avg_cost = cost.item() / len(data)
variance = _var(grads_logits)
step = 100. * batch_idx / len(train_loader)
global_step = 100 * (epoch - 1) + step
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}\tCost: {:.6f}\t Logits var {:.4E}'.format(
epoch, batch_idx * len(data), len(train_loader.dataset),
step, avg_loss, avg_cost, variance))
writer.add_scalar("train/ELBO", avg_cost, global_step)
writer.add_scalar("train/loss", avg_loss, global_step)
step, loss, cost, variance))
writer.add_scalar("train/ELBO", cost, global_step)
writer.add_scalar("train/loss", loss, global_step)
writer.add_scalar("train/variance", variance, global_step)
avg_train_loss = train_loss / len(train_loader.dataset)
print('====> Epoch: {} Average loss: {:.4f}'.format(
Expand Down
7 changes: 4 additions & 3 deletions storch/__init__.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
from typing import Callable

from .wrappers import deterministic, stochastic, cost, _exception_wrapper, _unpack_wrapper
from .tensor import Tensor, DeterministicTensor, StochasticTensor
from .wrappers import deterministic, stochastic, reduce, _exception_wrapper, _unpack_wrapper
from .tensor import Tensor, CostTensor, StochasticTensor
from .method import *
from .inference import backward, add_cost, reset, denote_independent
from .util import print_graph
from .storch import *
import storch.typing
_debug = True
import storch.nn
_debug = False

from inspect import isclass
from .excluded_init import _excluded_init, _exception_init, _unwrap_only
Expand Down
21 changes: 10 additions & 11 deletions storch/inference.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from storch.tensor import Tensor, StochasticTensor, DeterministicTensor, IndependentTensor
from storch.tensor import Tensor, StochasticTensor, CostTensor, IndependentTensor
import torch
from storch.util import print_graph, reduce_mean
import storch
Expand All @@ -8,9 +8,8 @@
from functools import reduce
from typing import Dict, Optional

_cost_tensors: [DeterministicTensor] = []
_cost_tensors: [CostTensor] = []
_backward_indices: Dict[StochasticTensor, int] = {}
_backward_cost: Optional[DeterministicTensor] = None
_accum_grad: bool = False


Expand All @@ -23,6 +22,8 @@ def denote_independent(tensor: AnyTensor, dim: int, plate_name: str) -> Independ
:param plate_name: Name of the plate. Reused if called again
:return:
"""
if storch.wrappers._context_stochastic or storch.wrappers._context_deterministic > 0:
raise RuntimeError("Cannot create independent tensors within a deterministic or stochastic context.")
if isinstance(tensor, torch.Tensor):
if dim != 0:
tensor = tensor.transpose(dim, 0)
Expand All @@ -40,10 +41,10 @@ def add_cost(cost: Tensor, name: str):
raise ValueError("Can only register cost functions with empty event shapes")
if not name:
raise ValueError("No name provided to register cost node. Make sure to register an unique name with the cost.")
cost.name = name
cost._is_cost = True
cost = CostTensor(cost._tensor, [cost], cost.batch_links, name)
if torch.is_grad_enabled():
storch.inference._cost_tensors.append(cost)
return cost


def _keep_grads_backwards(surrounding_node: Tensor, backwards_tensor: torch.Tensor) -> torch.Tensor:
Expand Down Expand Up @@ -89,16 +90,16 @@ def backward(retain_graph=False, debug=False, print_costs=False, accum_grads=Fal
if print_costs:
print(c.name, ":", avg_cost.item())
total_cost += avg_cost
storch.inference._backward_cost = c
for parent in c.walk_parents(depth_first=False):
if parent.stochastic:
# Instance check here instead of parent.stochastic, as backward methods are only used on these.
if isinstance(parent, StochasticTensor):
stochastic_nodes.add(parent)
if not parent.stochastic or not parent.requires_grad:
if not isinstance(parent, StochasticTensor) or not parent.requires_grad:
continue

# Sum out over the plate dimensions of the parent, so that the shape is the same as the parent but the event shape
mean_cost = c._tensor
c_indices = c.multi_dim_plates()
c_indices = list(c.multi_dim_plates())
for index_p, plate in enumerate(parent.multi_dim_plates()):
index_c = c_indices.index(plate)
if not index_c == index_p:
Expand Down Expand Up @@ -128,8 +129,6 @@ def backward(retain_graph=False, debug=False, print_costs=False, accum_grads=Fal
accum_loss += avg_cost
total_loss += avg_cost

storch.inference._backward_cost = None

if isinstance(accum_loss, torch.Tensor) and accum_loss.requires_grad:
accum_loss.backward()

Expand Down
8 changes: 4 additions & 4 deletions storch/method/baseline.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
from abc import ABC, abstractmethod
import torch
from storch.tensor import StochasticTensor, DeterministicTensor
from storch.tensor import StochasticTensor, CostTensor

class Baseline(ABC, torch.nn.Module):
def __init__(self):
super().__init__()

@abstractmethod
def compute_baseline(self, tensor: StochasticTensor, cost_node: DeterministicTensor,
def compute_baseline(self, tensor: StochasticTensor, cost_node: CostTensor,
costs: torch.Tensor) -> torch.Tensor:
pass

Expand All @@ -18,15 +18,15 @@ def __init__(self, exponential_decay=0.95):
self.register_buffer("exponential_decay", torch.tensor(exponential_decay))
self.register_buffer("moving_average", torch.tensor(0.))

def compute_baseline(self, tensor: StochasticTensor, cost_node: DeterministicTensor, costs: torch.Tensor) -> torch.Tensor:
def compute_baseline(self, tensor: StochasticTensor, cost_node: CostTensor, costs: torch.Tensor) -> torch.Tensor:
avg_cost = costs.mean().detach()
self.moving_average = self.exponential_decay * self.moving_average + (1 - self.exponential_decay) * avg_cost
return self.moving_average


class BatchAverageBaseline(Baseline):
# Uses the means of the other samples
def compute_baseline(self, tensor: StochasticTensor, cost_node: DeterministicTensor,
def compute_baseline(self, tensor: StochasticTensor, cost_node: CostTensor,
costs: torch.Tensor) -> torch.Tensor:
if tensor.n == 1:
raise ValueError("Can only use the batch average baseline if multiple samples are used.")
Expand Down
26 changes: 13 additions & 13 deletions storch/method/method.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from abc import ABC, abstractmethod
from torch.distributions import Distribution, Categorical, OneHotCategorical, Bernoulli, RelaxedOneHotCategorical, RelaxedBernoulli
from storch.tensor import DeterministicTensor, StochasticTensor
from storch.tensor import CostTensor, StochasticTensor
import torch
from typing import Optional, Type, Union, Dict
from storch.util import has_differentiable_path, get_distr_parameters
Expand Down Expand Up @@ -83,7 +83,7 @@ def sample(self, sample_name: str, distr: Distribution, n: int = 1) -> Stochasti

return s_tensor

def _estimator(self, tensor: StochasticTensor, cost_node: DeterministicTensor, costs: torch.Tensor) -> Optional[torch.Tensor]:
def _estimator(self, tensor: StochasticTensor, cost_node: CostTensor, costs: torch.Tensor) -> Optional[torch.Tensor]:
# For docs: costs here is aligned with the StochasticTensor, that's why there's two different things.
self._estimation_triples.append((tensor, cost_node, costs))
return self.estimator(tensor, cost_node, costs)
Expand All @@ -100,10 +100,10 @@ def _sample_tensor(self, distr: Distribution, n: int) -> torch.Tensor:
@abstractmethod
# Estimators should optionally return a torch.Tensor that is going to be added to the total cost function
# In the case of for example reparameterization, None can be returned to denote that no cost function is added
def estimator(self, tensor: StochasticTensor, cost_node: DeterministicTensor, costs: torch.Tensor) -> Optional[torch.Tensor]:
def estimator(self, tensor: StochasticTensor, cost_node: CostTensor, costs: torch.Tensor) -> Optional[torch.Tensor]:
pass

def update_parameters(self, result_triples: [(StochasticTensor, DeterministicTensor, torch.Tensor)]) -> None:
def update_parameters(self, result_triples: [(StochasticTensor, CostTensor, torch.Tensor)]) -> None:
pass


Expand All @@ -124,10 +124,10 @@ def __init__(self, distribution_type: Type[Distribution]):
def _sample_tensor(self, distr: Distribution, n: int) -> torch.Tensor:
return self._method._sample_tensor(distr, n)

def estimator(self, tensor: StochasticTensor, cost_node: DeterministicTensor, costs: torch.Tensor) -> Optional[torch.Tensor]:
def estimator(self, tensor: StochasticTensor, cost_node: CostTensor, costs: torch.Tensor) -> Optional[torch.Tensor]:
return self._method.estimator(tensor, cost_node, costs)

def update_parameters(self, result_triples: [(StochasticTensor, DeterministicTensor, torch.Tensor)]):
def update_parameters(self, result_triples: [(StochasticTensor, CostTensor, torch.Tensor)]):
self._method.update_parameters(result_triples)

class Reparameterization(Method):
Expand All @@ -147,15 +147,15 @@ def _sample_tensor(self, distr: Distribution, n: int) -> StochasticTensor:
"distribution, make sure to use eg GumbelSoftmax.")
return distr.rsample((n,))

def estimator(self, tensor: StochasticTensor, cost_node: DeterministicTensor, costs: torch.Tensor) -> Optional[torch.Tensor]:
def estimator(self, tensor: StochasticTensor, cost_node: CostTensor, costs: torch.Tensor) -> Optional[torch.Tensor]:
if has_differentiable_path(cost_node, tensor):
# There is a differentiable path, so we will just use reparameterization here.
return None
else:
# No automatic baselines
return self._score_method.estimator(tensor, cost_node, costs)

def update_parameters(self, result_triples: [(StochasticTensor, DeterministicTensor, torch.Tensor)]):
def update_parameters(self, result_triples: [(StochasticTensor, CostTensor, torch.Tensor)]):
self._score_method.update_parameters(result_triples)

class GumbelSoftmax(Method):
Expand Down Expand Up @@ -187,14 +187,14 @@ def _sample_tensor(self, distr: DiscreteDistribution, n: int) -> torch.Tensor:
raise ValueError("Using Gumbel Softmax with non-discrete distribution")
return gumbel_distr.rsample((n,))

def estimator(self, tensor: StochasticTensor, cost_node: DeterministicTensor, costs: torch.Tensor) -> Optional[torch.Tensor]:
def estimator(self, tensor: StochasticTensor, cost_node: CostTensor, costs: torch.Tensor) -> Optional[torch.Tensor]:
if has_differentiable_path(cost_node, tensor):
return None
else:
s = ScoreFunction()
return s.estimator(tensor, cost_node, costs)

def update_parameters(self, result_triples: [(StochasticTensor, DeterministicTensor, torch.Tensor)]):
def update_parameters(self, result_triples: [(StochasticTensor, CostTensor, torch.Tensor)]):
if self.training:
self.temperature = torch.max(self.min_temperature, torch.exp(-self.annealing_rate * self.iterations))

Expand All @@ -217,7 +217,7 @@ def __init__(self, baseline_factory: Optional[Union[BaselineFactory, str]] = "mo
def _sample_tensor(self, distr: Distribution, n: int) -> torch.Tensor:
return distr.sample((n, ))

def estimator(self, tensor: StochasticTensor, cost_node: DeterministicTensor, costs: torch.Tensor) -> torch.Tensor:
def estimator(self, tensor: StochasticTensor, cost_node: CostTensor, costs: torch.Tensor) -> torch.Tensor:
log_prob = tensor.distribution.log_prob(tensor._tensor)
# Sum out over the even shape
log_prob = log_prob.sum(dim=list(range(len(tensor.batch_links), len(log_prob.shape))))
Expand All @@ -243,7 +243,7 @@ def _sample_tensor(self, distr: Distribution, n: int) -> torch.Tensor:
# print(distr.batch_shape, distr.event_shape)
# print(support[9, 1])

def estimator(self, tensor: StochasticTensor, cost_node: DeterministicTensor, costs: torch.Tensor) -> Optional[torch.Tensor]:
def estimator(self, tensor: StochasticTensor, cost_node: CostTensor, costs: torch.Tensor) -> Optional[torch.Tensor]:
pass


Expand All @@ -254,5 +254,5 @@ def __init__(self):
def _sample_tensor(self, distr: Distribution, n: int) -> torch.Tensor:
pass

def estimator(self, tensor: StochasticTensor, cost_node: DeterministicTensor, costs: torch.Tensor) -> Optional[torch.Tensor]:
def estimator(self, tensor: StochasticTensor, cost_node: CostTensor, costs: torch.Tensor) -> Optional[torch.Tensor]:
pass
34 changes: 15 additions & 19 deletions storch/nn/losses.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
from typing import Union, List

import storch
import torch
from storch import deterministic
from torch._C import _infer_size
from storch.typing import AnyTensor


def b_binary_cross_entropy(input: AnyTensor, target: torch.Tensor, weight=None, reduction: str = 'mean'):
@deterministic(unwrap=False)
def b_binary_cross_entropy(input: storch.Tensor, target: torch.Tensor, dims: Union[str, List[str]] = None, weight=None, reduction: str = 'mean'):
r"""Function that measures the Binary Cross Entropy in a batched way
between the target and the output.
Expand Down Expand Up @@ -39,35 +41,29 @@ def b_binary_cross_entropy(input: AnyTensor, target: torch.Tensor, weight=None,
>>> loss = b_binary_cross_entropy(F.sigmoid(input), target)
>>> loss.backward()
"""
if isinstance(input, storch.Tensor):
indices = input.event_dim_indices()
if target.size() != input.event_shape:
raise ValueError("Using a target size ({}) that is different to the input size ({}). "
"Please ensure they have the same size.".format(target.size(), input.event_shape))
else:
offset_dim = input.dim() - target.dim()
indices = list(range(input.dim() - target.dim(), input.dim()))
for i in range(target.dim()):
if input.shape[offset_dim + i] != target.shape[i]:
raise ValueError("Input and target are invalid for broadcasting.")


if weight is not None:
new_size = _infer_size(target.size(), weight.size())
weight = weight.expand(new_size)
else:
weight = 1.

if not dims:
dims = []
if isinstance(dims, str):
dims = [dims]

@deterministic
def _loss(input):
def _loss(input, target, weight):
epsilon = 1e-6
input = input + epsilon
return -weight * (target * input.log() + (1. - target) * (1. - input).log())

unreduced = _loss(input)
unreduced = _loss(input, target, weight)
indices = list(unreduced.event_dim_indices()) + dims

if reduction == "mean":
return unreduced.mean(dim=indices)
return storch.mean(unreduced, indices)
elif reduction == "sum":
return unreduced.sum(dim=indices)
return storch.sum(unreduced, indices)
elif reduction == "none":
return unreduced
Loading

0 comments on commit d17c2fb

Please sign in to comment.