In [87]:
import torch
import numpy as np

In [88]:
# helper functions
def mock(num_utterances, num_hypotheses):
  mock_X = np.random.rand(num_utterances, 3, num_hypotheses)
  # make random hypothesis correct
  mock_Y = np.eye(num_hypotheses)[np.random.choice(num_hypotheses, num_utterances)]
  return mock_X, mock_Y

def show_weights(policy_network):
  for name, param in policy_network.named_parameters():
    if param.requires_grad:
      print(name, param.data)

In [89]:
policy = torch.nn.Sequential(
    torch.nn.Linear(3, 64),
    torch.nn.ReLU(),
    torch.nn.Linear(64, 1),
    torch.nn.Softmax(dim=0)
)
lr = 0.005
optim = torch.optim.SGD(policy.parameters(), lr=lr)

In [90]:
def Reinforce(X, Y, policy, epsilon_depth=5):
  for x,y in zip(X,Y):
    obs = torch.Tensor(x.T)
    reward = 0
    ground_truth = np.argmax(y)
    
    # sampling with epsilon depth strategy
    probs = torch.flatten(policy(obs))
    dist = torch.distributions.Categorical(probs=probs)
    epsilon = epsilon_depth
    while epsilon > 0:
      best_action = torch.argmax(probs)
      action = dist.sample().item()
      if action == best_action:
        break
      else:
        epsilon -= 1
    
    # reward inference
    if action == ground_truth:
      reward = 1

    # update policy parameters
    log_prob = dist.log_prob(torch.tensor(action, dtype=torch.int))
    loss = - log_prob*reward
    optim.zero_grad()
    loss.backward()
    optim.step()

In [91]:
mock_X, mock_Y = mock(1000, 6)
Reinforce(mock_X, mock_Y, policy)

In [92]:
show_weights(policy)

0.weight tensor([[-0.0437, -0.0183,  0.0625],
        [ 0.1819,  0.3798,  0.3374],
        [-0.3523,  0.1147, -0.2134],
        [-0.2821,  0.1272,  0.2350],
        [-0.2551, -0.5262, -0.0153],
        [-0.5634, -0.5771,  0.2358],
        [-0.1993, -0.4558, -0.5671],
        [-0.1575,  0.0569,  0.2762],
        [ 0.0601, -0.0345, -0.0258],
        [-0.2387, -0.2367, -0.3468],
        [ 0.0247,  0.4485, -0.3680],
        [ 0.4356, -0.5701,  0.4917],
        [-0.4654,  0.3924, -0.2727],
        [-0.2327,  0.5218, -0.4392],
        [-0.3631, -0.5730, -0.3485],
        [ 0.3869, -0.0382,  0.4705],
        [-0.4228, -0.0621, -0.5355],
        [ 0.2559,  0.5083,  0.5192],
        [ 0.2822,  0.0474,  0.2098],
        [ 0.0037, -0.2619, -0.0835],
        [-0.4801, -0.0810, -0.2359],
        [-0.3143,  0.0900, -0.1902],
        [ 0.1163,  0.4216,  0.1402],
        [-0.1890,  0.1769,  0.5564],
        [ 0.1310,  0.4633,  0.0240],
        [ 0.4634,  0.0628,  0.1605],
        [ 0.1910,  0.5394, -0