In [None]:
import os
from os.path import join, exists
import torch
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
plt.style.use("ggplot")

In [None]:
def sample_gumbel(shape, eps=1e-20):
    U = torch.rand(shape)
    return -torch.log(-torch.log(U + eps) + eps)

def categorical_sampling(pi, shape, eps=1e-20):
    g = sample_gumbel(shape)
    z = torch.eye(shape[1])[(torch.log(pi+eps) + g).max(1)[1]]
    return z

def gumbel_softmax_sampling(pi, shape, tau, eps=1e-20):
    log_pi = torch.log(pi + eps)
    g = sample_gumbel(shape)
    y = F.softmax((log_pi + g)/tau, dim=1)
    return y

In [None]:
plt.title("gumbel sampling")
plt.hist(sample_gumbel(1000), bins=50, normed=True)
plt.ylim(0, 1)
plt.show()

In [None]:
pi = torch.Tensor([0.1, 0.6, 0.1, 0.01, 0.0001, 0.1899])
shape = (100, 6)
plt.title(r"$\pi$", fontsize=20)
plt.bar(np.arange(1, 7), pi)
plt.ylim(0, 1)
plt.show()

In [None]:
categorical_sample_path = "./logs/categorical_samples"
if not exists(categorical_sample_path):
    os.makedirs(categorical_sample_path)
for i, z in enumerate(categorical_sampling(pi, shape)):
    plt.bar(np.arange(1, 7), z, color="orange")
    plt.title("categorical sampling {}".format(i+1))
    plt.ylim(0, 1)
    plt.savefig(join(categorical_sample_path, "{}.png".format(i)))
    plt.show()

In [None]:
def bar_plot_gumbel_softmax(tau):
    gumbel_softmax_sample_path = "./logs/gumbel_softmax_samples_{}".format(tau)
    if not exists(gumbel_softmax_sample_path):
        os.makedirs(gumbel_softmax_sample_path)
    for i, z in enumerate(gumbel_softmax_sampling(pi, shape, tau)):
        plt.bar(np.arange(1, 7), z, color="pink")
        plt.title(r"gumbel softmax sampling {} ($\tau$={})".format(i, tau))
        plt.ylim(0, 1)
        plt.savefig(join(gumbel_softmax_sample_path, "{}.png".format(i)))
        plt.show()

In [None]:
bar_plot_gumbel_softmax(0.1)
bar_plot_gumbel_softmax(0.5)
bar_plot_gumbel_softmax(1)
bar_plot_gumbel_softmax(10)
bar_plot_gumbel_softmax(100)