In this notebook, we review some properties of the Gumbel distribution. Our primary reference here is __Huijben, I. A., Kool, W., Paulus, M. B., & Van Sloun, R. J. (2022). A review of the gumbel-max trick and its extensions for discrete stochasticity in machine learning. IEEE Transactions on Pattern Analysis and Machine Intelligence, 45(2), 1353-1371.__ The primary goal here is to gain some understanding of the Gumbel-max and Gumbel-softmax tricks.

In [None]:
import numpy as np
import matplotlib.pyplot as plt

In [None]:
class Gumbel:
  ''' Gumbel distribution '''
  
  def __init__(self, mu, beta):
    self.mu = mu
    self.beta = beta

  def cdf(self, x):
    return np.exp( -np.exp( -(x-self.mu)/self.beta ) )

  def icdf(self, u):
    return  -self.beta * np.log( -np.log(u)) + self.mu

  def pdf(self, x):
    return 1/self.beta * np.exp(- (x-self.mu)/self.beta) * self.cdf(x)

  def sample(self, num_samples):
    u = np.random.rand(*num_samples)
    return self.icdf(u)


In [None]:
G = Gumbel(0,1)
x = np.arange(-10,10,0.01)
F = G.cdf(x)
f = G.pdf(x)

u = np.arange(1e-8,1-1e-8,0.001)
Q = G.icdf(u)

fig, axs = plt.subplots(1, 3, figsize=(10,5))
axs[0].plot(x,F,'b', label = 'Gumbel')
axs[0].set_title('CDF')
axs[0].set_xlabel('x')

axs[1].plot(x,f,'b', label = 'Gumbel')
axs[1].set_title('PDF')
axs[1].set_xlabel('x')

axs[2].plot(u,Q,'b', label = 'Gumbel')
axs[2].set_title('Inverse CDF')
axs[2].set_xlabel('u')

plt.show()

In [None]:

g = G.sample( (10000,) )
plt.hist(g, 100, density=True)
plt.plot(x,f,'r', label='pdf')
plt.legend()
plt.show()

In [None]:
class Categorical:
  ''' Unnormalized Categorical distribution, sampled via the Gumbel-max trick '''

  def __init__(self, logits):
    self.logits = logits #unnormalized log probabilities
    # self.theta= np.exp(self.logits)
    # self.probs = self.theta / self.theta.sum()
    # as an exersize, we want to avoid explicitly computing the partition function and resulting probabilities

  def sample(self, num_samples, B = None):
    '''
    sample from the conditional distribution given that index set B has occurred
    '''
    if not B:
      B = [ii for ii in range(self.logits.shape[0])]
    G = Gumbel(0,1)
    n = self.logits[B].shape[0]
    
    H = G.sample((num_samples, n)) + self.logits[B]
    I = H.argmax(axis=1) # ~ Cat(self.probs)
    #M = H.amax(axis=1)   # ~ Gumbel( )
    return I

  
    

In [None]:
n = 10
logits = np.random.randn(n)
theta = np.exp(logits)
probs = theta / theta.sum()

C = Categorical(logits)
c = C.sample(10000)
plt.bar(np.arange(n),probs, color = 'b', label = 'truth')
plt.hist(c, bins = [ii-0.5 for ii in range(n+1)], density = True, 
         color = 'gray', hatch='/', alpha = 0.4, label = 'empirical')
plt.legend()
plt.show()