In [21]:
import jax
from typing import Any, Callable, Sequence, Optional
from jax import lax, random, numpy as jnp
import flax
from flax import linen as nn
import sympy as sy
import matplotlib.pyplot as plt

In [2]:
key = random.PRNGKey(32)



In [3]:
random.uniform(key, (1024, 1))

DeviceArray([[0.57366943],
             [0.3568976 ],
             [0.6369716 ],
             ...,
             [0.7642505 ],
             [0.10045469],
             [0.311921  ]], dtype=float32)

In [63]:
limit_a, limit_b, epsilon = -.1, 1.1, 1e-6
temperature = 2./3.
qz_loga = random.normal(key, (5,))

def cdf_qz(x):
    """Implements the CDF of the 'stretched' concrete distribution"""
    xn = (x - limit_a) / (limit_b - limit_a)
    logits = jnp.log(xn) - jnp.log(1 - xn)
    return lax.clamp(epsilon,
                     nn.sigmoid(logits * temperature - qz_loga),
                     1.-epsilon)

def quantile_concrete(x):
    """Implements the quantile, aka inverse CDF, of the 'stretched' concrete distribution"""
    y = nn.sigmoid((jnp.log(x) - jnp.log(1 - x) + qz_loga) / temperature)
    return y * (limit_b - limit_a) + limit_a

def hard_tanh(x):
    return jnp.where(x > 1, 1, jnp.where(x < 0, 0, x))

def get_eps(key, shape):
    return random.uniform(key, shape, minval=epsilon, maxval=1.-epsilon)

In [64]:
def sample_z(key, shape, sample=True):
    if sample:
        eps = get_eps(key, shape)
        z = quantile_concrete(eps)
        return hard_tanh(z)
    else:
        pi = nn.sigmoid(qz_loga)
        return hard_tanh(pi * (limit_b - limit_a) + limit_a)

In [65]:
sample_z(key, (4,))

TypeError: add got incompatible shapes for broadcasting: (4,), (5,).

In [58]:
def sample_z(self, batch_size, sample=True):
    """Sample the hard-concrete gates for training and use a deterministic value for testing"""
    if sample:
        eps = self.get_eps(self.floatTensor(batch_size, self.dim_z))
        z = self.quantile_concrete(eps).view(batch_size, self.dim_z, 1, 1)
        return F.hardtanh(z, min_val=0, max_val=1)
    else:  # mode
        pi = F.sigmoid(self.qz_loga).view(1, self.dim_z, 1, 1)
        return F.hardtanh(pi * (limit_b - limit_a) + limit_a, min_val=0, max_val=1)

def sample_weights(self):
    z = self.quantile_concrete(self.get_eps(self.floatTensor(self.dim_z))).view(self.dim_z, 1, 1, 1)
    return F.hardtanh(z, min_val=0, max_val=1) * self.weights

def forward(self, input_):
    if self.input_shape is None:
        self.input_shape = input_.size()
    b = None if not self.use_bias else self.bias
    if self.local_rep or not self.training:
        output = F.conv2d(input_, self.weights, b, self.stride, self.padding, self.dilation, self.groups)
        z = self.sample_z(output.size(0), sample=self.training)
        return output.mul(z)
    else:
        weights = self.sample_weights()
        output = F.conv2d(input_, weights, None, self.stride, self.padding, self.dilation, self.groups)
        return output