In [17]:
import numpy as np
import torch
import torch.nn as nn
# from torch.distributions.bernoulli import Bernoulli

In [53]:
class ST_Bernoulli(torch.autograd.Function):
    generate_vmap_rule = True

    def __init__(self):
        super().__init__()

    @staticmethod
    def forward(ctx, p):
        result = torch.bernoulli(p)
        ctx.save_for_backward(result, p)
        return result

    @staticmethod
    def backward(ctx, grad_output):
        result, p = ctx.saved_tensors
        ws = torch.ones(result.shape)
        return grad_output * ws

class Bernoulli(torch.autograd.Function):
    generate_vmap_rule = True

    def __init__(self):
        super().__init__()

    @staticmethod
    def forward(ctx, p):
        result = torch.bernoulli(p)
        ctx.save_for_backward(result, p)
        return result

    @staticmethod
    def backward(ctx, grad_output):
        result, p = ctx.saved_tensors
        w_minus = (1.0 / p) / 2 # jump down, averaged for eps > 0 and eps < 0
        w_plus = (1.0 / (1.0 - p)) / 2 # jump up, averaged for eps > 0 and eps < 0
        
        ws = torch.where(result == 1, w_minus, w_plus) # stochastic triple: total grad -> + smoothing rule)
        return grad_output * ws
    
class Binomial(torch.autograd.Function):
    generate_vmap_rule = True

    def __init__(self):
        super().__init__()
        
    @staticmethod
    def forward(ctx, n, p):
        result = torch.distributions.binomial.Binomial(n, p).sample()
        ctx.save_for_backward(result, p, n)
        return result

    @staticmethod
    def backward(ctx, grad_output):
        result, p, n = ctx.saved_tensors
        w_minus = result / p # derivative contributions of unit jump down
        w_plus = (n-result) / (1.0 - p) # derivative contributions of unit jump up
        
        wminus_cont = torch.where(result > 0, w_minus, 0) # smoothing rule says: for each result value: (w_minus + w_plus)/2; (w_minus is 0 for result < 0=n)
        wplus_cont = torch.where(result < n, w_plus, 0)
        
        ws = (wminus_cont + wplus_cont)/2 # smoothing operation -> to account for positive and negative epsilon (depends upon where the previous gradient is coming from)
        
        return None, grad_output * ws


In [28]:
# Example bernoulli
bernoulli = Bernoulli.apply
p = torch.tensor(0.4, requires_grad=True)
r = bernoulli(p)
print(r)
r.backward()
print(p.grad)

tensor(0., grad_fn=<BernoulliBackward>)
tensor(0.8333)


In [31]:
# Example straight-through bernoulli
st_bernoulli = ST_Bernoulli.apply
p = torch.tensor(0.4, requires_grad=True)
r = st_bernoulli(p)
print(r)
r.backward()
print(p.grad)

tensor(1., grad_fn=<ST_BernoulliBackward>)
tensor(1.)


In [49]:
# Example binomial
n, p = torch.tensor(10), torch.tensor(0.4, requires_grad=True)
binomial = Binomial.apply

r = binomial(n, p)
print(r)
r.backward()
print(p.grad)

tensor(5., grad_fn=<BinomialBackward>)
tensor(10.4167)


In [52]:
# Categorical distribution
cate = torch.distributions.categorical.Categorical(torch.tensor([0.2, 0.7, 0.1]))

cate.sample()

tensor(1)

In [55]:
class Categorical(torch.autograd.Function):
    generate_vmap_rule = True
    
    def __init__(self):
        super().__init__()
        
    @staticmethod
    def forward(ctx, probs):
        probs = probs/probs.sum()
        result = torch.distributions.categorical.Categorical(probs)
        ctx.save_for_backward(result, probs)
        return result
    
    @staticmethod
    def backward(ctx, grad_output):
        result, probs = ctx.saved_tensors
        
        ws = []
        for j in range(probs.shape[0]):
            # dx/dp_j
            # smoothing to consider eps > 0 and eps < 0 for wplus and wminus
            wplus = torch.where(result + 1 <= j, 1 / 2*probs[result], 0)
            wminus = torch.where(result - 1 >= j, -1 / 2*probs[result], 0) # -1 used for jump down direction
            
            ws.append(wplus + wminus)
        
        ws = torch.tensor(ws)
            
        return grad_output*ws