In [1]:
import jax
from typing import Any, Callable, Sequence, Optional
from jax import lax, random, numpy as jnp
import flax
from flax import linen as nn
import sympy as sy
import matplotlib.pyplot as plt

In [3]:
import sys
sys.path.append("..")
from eql.l0_dense import L0Dense

In [62]:
l = L0Dense(3,drop_rate=0.99)

In [63]:
key = random.PRNGKey(0)
key, k1, k2 = random.split(key, 3)

In [64]:
x = random.normal(key, (3, 5))

In [65]:
params = l.init({'l0':k1, 'params': k2}, x)

In [66]:
k2, k3 = random.split(k2)
l.apply(params, x, deterministic=False, rngs={'l0': k2})

DeviceArray([[0., 0., 0.],
             [0., 0., 0.],
             [0., 0., 0.]], dtype=float32)

In [67]:
l.apply(params, rngs={'l0': k2}, method=l.l0_reg)

DeviceArray(0.714367, dtype=float32)

In [60]:
limit_a, limit_b, epsilon = -.1, 1.1, 1e-6


def init_qz_loga(drop_rate, stddev=1e-2, dtype: Any = jnp.float_) -> Callable:
        mean  = jnp.log(1-drop_rate) - jnp.log(drop_rate)
        def init(key, shape, dtype=jnp.float_):
            return random.normal(key, shape, dtype)*stddev + mean
        return init
    
def hard_tanh(x):
    return jnp.where(x > 1, 1, jnp.where(x < 0, 0, x))

class L0Dense(nn.Module):
    features: int
    drop_rate: float = 0.5
    temperature: float = 2./3.
    kernel_init: Callable = nn.initializers.lecun_normal()
    bias_init: Callable = nn.initializers.zeros
        
    def quantile_concrete(self, x, qz_loga):
        """Implements the quantile, aka inverse CDF, of the 'stretched' concrete distribution"""
        y = nn.sigmoid((jnp.log(x) - jnp.log(1 - x) + qz_loga) / self.temperature)
        return y * (limit_b - limit_a) + limit_a
    
    
    def cdf_qz(self, x):
        qz_loga = self.variables['params']['qz_loga']
        xn = (x - limit_a) / (limit_b - limit_a)
        logits = jnp.log(xn) - jnp.log(1 - xn)
        return lax.clamp(epsilon,
            nn.sigmoid(logits * self.temperature - qz_loga),
            1.-epsilon)
    
    def sample_mask(self, qz_loga, rng):
        shape = qz_loga.shape
        eps = random.uniform(rng, shape, minval=epsilon, maxval=1.-epsilon)
        z = self.quantile_concrete(eps, qz_loga)
        mask = hard_tanh(z)
        return mask
    
    def l0_reg(self):
        return  jnp.sum((1 - self.cdf_qz(0)))
    
    @nn.compact
    def __call__(self, inputs, deterministic: Optional[bool] = False):
        kernel = self.param('kernel',
                        self.kernel_init, # Initialization function
                        (inputs.shape[-1], self.features))  # shape info.
        
        qz_loga = self.param('qz_loga',
                             init_qz_loga(self.drop_rate),
                             (inputs.shape[-1],))
        
        if deterministic:
            pi = nn.sigmoid(qz_loga)
            z = hard_tanh(pi * (limit_b - limit_a) + limit_a)
            mask = jnp.broadcast_to(z[:,None], kernel.shape)
            kernel = kernel*mask
        else:
            rng = self.make_rng('l0')
            mask = self.sample_mask(qz_loga, rng)
            mask = jnp.broadcast_to(mask[:,None], kernel.shape)
            kernel = kernel*mask
            
        y = lax.dot_general(inputs, kernel,
                        (((inputs.ndim - 1,), (0,)), ((), ())),) # TODO Why not jnp.dot?
        
        bias = self.param('bias', self.bias_init, (self.features,))
        
        y = y + bias
        return y

key1, key2, key3 = random.split(random.PRNGKey(0), 3)
x = random.uniform(key1, (10,))

model = L0Dense(features=1, drop_rate=0.9)
params = model.init({'l0': key2, 'params': key3}, x)
y = model.apply(params, x, deterministic=False, rngs={'l0': key3})

#print('initialized parameters:\n', params)
print('output:\n', y)

output:
 [0.12372394]


In [63]:
model.apply(params, rngs={'l0': key3}, method=model.l0_reg)

DeviceArray(3.5409908, dtype=float32)

In [23]:
rrr, _ = random.split(rrr, 2)
y = model.apply(params, x, deterministic=False, rngs={'l0': rrr})
print(y)

[[0.        ]
 [0.        ]
 [0.        ]
 [0.7264574 ]
 [0.9645642 ]
 [0.45473173]
 [0.        ]
 [0.        ]
 [0.        ]
 [0.        ]]
[[ 0.17365593]
 [-0.04896441]
 [ 0.08915189]
 [ 0.07911363]]


In [203]:
key = random.PRNGKey(32)

In [115]:
e=random.normal(key1, (2,))

In [7]:
random.uniform(key, (1024, 1))

DeviceArray([[0.57366943],
             [0.3568976 ],
             [0.6369716 ],
             ...,
             [0.7642505 ],
             [0.10045469],
             [0.311921  ]], dtype=float32)

In [4]:
temperature = 2./3.
qz_loga = random.normal(key, (5,))

def cdf_qz(x):
    """Implements the CDF of the 'stretched' concrete distribution"""
    xn = (x - limit_a) / (limit_b - limit_a)
    logits = jnp.log(xn) - jnp.log(1 - xn)
    return lax.clamp(epsilon,
                     nn.sigmoid(logits * temperature - qz_loga),
                     1.-epsilon)

def quantile_concrete(x):
    """Implements the quantile, aka inverse CDF, of the 'stretched' concrete distribution"""
    y = nn.sigmoid((jnp.log(x) - jnp.log(1 - x) + qz_loga) / temperature)
    return y * (limit_b - limit_a) + limit_a



def get_eps(key, shape):
    return random.uniform(key, shape, minval=epsilon, maxval=1.-epsilon)

NameError: name 'key' is not defined

In [9]:
def sample_z(key, shape, sample=True):
    if sample:
        eps = get_eps(key, shape)
        z = quantile_concrete(eps)
        return hard_tanh(z)
    else:
        pi = nn.sigmoid(qz_loga)
        return hard_tanh(pi * (limit_b - limit_a) + limit_a)
    
def sample_weights(key, weights):
    z = quantile_concrete(get_eps(key, (5,)))
    print(z)
    return hard_tanh(z).T*weights

In [10]:
weights = random.normal(key, (5, 5))
weights

DeviceArray([[ 0.0735873 ,  0.4324186 , -0.8828158 ,  0.14094852,
               0.373134  ],
             [-0.33430755,  0.18413313, -0.66062707,  1.2581732 ,
               0.8564375 ],
             [-1.4437462 , -0.3401676 ,  0.4056625 , -0.38791317,
              -0.83877325],
             [-0.5023728 ,  0.8090666 , -1.788563  , -2.4703    ,
               0.45676166],
             [ 1.0493878 , -0.35580957,  1.109334  ,  0.73159033,
               0.86589175]], dtype=float32)

In [11]:
sample_weights(key, weights)

[ 1.0299255  -0.09998897  0.6923099   0.50926715  0.69103235]


DeviceArray([[ 0.0735873 ,  0.        , -0.6111821 ,  0.07178045,
               0.25784767],
             [-0.33430755,  0.        , -0.45735866,  0.6407463 ,
               0.591826  ],
             [-1.4437462 , -0.        ,  0.28084418, -0.19755143,
              -0.57961947],
             [-0.5023728 ,  0.        , -1.2382399 , -1.2580426 ,
               0.31563708],
             [ 1.0493878 , -0.        ,  0.7680029 ,  0.37257493,
               0.5983592 ]], dtype=float32)

In [12]:
sample_z(key, (5,), sample=True)

DeviceArray([1.        , 0.        , 0.6923099 , 0.50926715, 0.69103235],            dtype=float32)

In [13]:
def sample_z(self, batch_size, sample=True):
    """Sample the hard-concrete gates for training and use a deterministic value for testing"""
    if sample:
        eps = self.get_eps(self.floatTensor(batch_size, self.dim_z))
        z = self.quantile_concrete(eps).view(batch_size, self.dim_z, 1, 1)
        return F.hardtanh(z, min_val=0, max_val=1)
    else:  # mode
        pi = F.sigmoid(self.qz_loga).view(1, self.dim_z, 1, 1)
        return F.hardtanh(pi * (limit_b - limit_a) + limit_a, min_val=0, max_val=1)

def sample_weights(self):
    z = self.quantile_concrete(self.get_eps(self.floatTensor(self.dim_z))).view(self.dim_z, 1, 1, 1)
    return F.hardtanh(z, min_val=0, max_val=1) * self.weights

def forward(self, input_):
    if self.input_shape is None:
        self.input_shape = input_.size()
    b = None if not self.use_bias else self.bias
    if self.local_rep or not self.training:
        output = F.conv2d(input_, self.weights, b, self.stride, self.padding, self.dilation, self.groups)
        z = self.sample_z(output.size(0), sample=self.training)
        return output.mul(z)
    else:
        weights = self.sample_weights()
        output = F.conv2d(input_, weights, None, self.stride, self.padding, self.dilation, self.groups)
        return output