In [36]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.distributions as distr

In [37]:
device = torch.device('cuda')
num_states = 10  # number of dice's sides
# values: [1, 2, ..., num_states]

theta = torch.randn(num_states, requires_grad=True, device=device)
optimizer = torch.optim.SGD([theta], lr=1e-3, momentum=0.99)

temperature = 1.0
temperature_gamma = 1.0#0.98
temperature_min = 0.1

gumbel_gen = distr.Gumbel(0, 1)
batch_size = 2**14

num_steps = 10000

In [38]:
class Loss:
    def __init__(self, loss_func, str_descr):
        self.loss_func = loss_func
        self.str_descr = str_descr
        
    def __call__(self, x, y=None):
        return self.loss_func(x) if y is None else self.loss_func(x, y)
    
    def __str__(self):
        return self.str_descr

def entropy_loss(probs):
    entropy = -(probs * torch.log(probs)).sum(-1)
    return -(entropy.mean())

def max_mean_loss(probs):
    x = torch.arange(1, probs.shape[-1] + 1).to(probs.device)
    E_x = (x * probs).sum(-1)
    return -E_x.mean()

def max_var_loss(probs):
    x = torch.arange(1, probs.shape[-1] + 1, device=probs.device, dtype=probs.dtype)

    E_x = (x * probs).sum(-1)
    E_x2 = (x**2 * probs).sum(-1)

    variance = E_x2 - E_x**2

    return -variance.mean()

In [39]:
#loss_func = Loss(max_mean_loss, "Mean")
#loss_func = Loss(entropy_loss, "Entropy (in nats)")
loss_func = Loss(max_var_loss, "Variance")

In [40]:
    
for step in range(num_steps):
    gumbel_noise = gumbel_gen.sample((batch_size, num_states)).to(device)
    gumbel_noise= torch.zeros_like(gumbel_noise)
    
    softmax_probs = F.softmax((theta + gumbel_noise) / temperature, dim=-1)
    
    loss = loss_func(softmax_probs)
    
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    
    if step % 100 == 0:
        temperature = max(temperature * temperature_gamma, temperature_min)
    
    if step % 1000 == 0:
        true_probs = F.softmax(theta, dim=-1)
        print(f"Step {step}, {str(loss_func)}: {-loss.item():.3f}")
        print(f"Probabilities: {true_probs.detach().cpu().numpy().round(3)}")

Step 0, Variance: 9.097
Probabilities: [0.09  0.085 0.01  0.221 0.08  0.066 0.054 0.018 0.244 0.131]
Step 1000, Variance: 20.250
Probabilities: [0.5 0.  0.  0.  0.  0.  0.  0.  0.  0.5]
Step 2000, Variance: 20.250
Probabilities: [0.5 0.  0.  0.  0.  0.  0.  0.  0.  0.5]
Step 3000, Variance: 20.250
Probabilities: [0.5 0.  0.  0.  0.  0.  0.  0.  0.  0.5]
Step 4000, Variance: 20.250
Probabilities: [0.5 0.  0.  0.  0.  0.  0.  0.  0.  0.5]
Step 5000, Variance: 20.250
Probabilities: [0.5 0.  0.  0.  0.  0.  0.  0.  0.  0.5]
Step 6000, Variance: 20.250
Probabilities: [0.5 0.  0.  0.  0.  0.  0.  0.  0.  0.5]
Step 7000, Variance: 20.250
Probabilities: [0.5 0.  0.  0.  0.  0.  0.  0.  0.  0.5]
Step 8000, Variance: 20.250
Probabilities: [0.5 0.  0.  0.  0.  0.  0.  0.  0.  0.5]
Step 9000, Variance: 20.250
Probabilities: [0.5 0.  0.  0.  0.  0.  0.  0.  0.  0.5]
