In [0]:
import queue
import torch
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch.utils.data
from torch import optim
from torch import nn
from torch import distributions as dist

from core import *

In [0]:
CONTEXT_SIZE = 117 + 1
SAMPLE_COUNT = 2
AGENT_MEMORY_LEN = 4096

EDIBLE_REWARD = 5.0
POISONOUS_REWARD = -35.0

In [0]:
mushroom_dataset = pd.read_csv('mushrooms.csv')
train_labels = mushroom_dataset['class']
train_labels = train_labels.replace(['p', 'e'],
                                    [POISONOUS_REWARD, EDIBLE_REWARD])
# the features contain missing values (marked as ?)
# these are treated as a different class atm
train_features = pd.get_dummies(mushroom_dataset.drop(['class'], axis=1))
train_features = torch.tensor(train_features.values, dtype=torch.float)
train_labels = torch.tensor(train_labels.values)
trainset = torch.utils.data.TensorDataset(train_features, train_labels)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=1,
                                          shuffle=True, num_workers=1)


In [0]:
class Agent(object):
  
  def __init__(self):
    self.value_estimates = BayesianNN(
        CONTEXT_SIZE, [100, 100, 1],
        [ActivationType.RELU, ActivationType.RELU, ActivationType.NONE])
    self.optimizer = optim.Adam(self.value_estimates.parameters(), lr=0.01)
  
    self.past_plays_context = queue.Queue(maxsize=AGENT_MEMORY_LEN)
    self.past_plays_action = queue.Queue(AGENT_MEMORY_LEN)
    self.past_plays_reward = queue.Queue(AGENT_MEMORY_LEN)
  
  def collected_data_count(self):
    return self.past_plays_context.qsize()
  
  def select_action(self, context, logs=False):
    self.value_estimates.train()
    max_reward = POISONOUS_REWARD - 1
    argmax_action = -1
    for action in range(2):
      expected_reward = 0
      for i in range(SAMPLE_COUNT):
        context_and_action = torch.cat(
            [context, torch.tensor([action], dtype=torch.float)])
        expected_reward += self.value_estimates(context_and_action)
      expected_reward /= SAMPLE_COUNT
      if logs:
        print('Action {} - predicted reward: {}'.format(
            action, expected_reward))
      if expected_reward > max_reward:
        max_reward = expected_reward
        argmax_action = action
    return argmax_action
  
  def update_memory(self, context, action, reward):
    self.past_plays_context.put(context)
    self.past_plays_action.put(action)
    self.past_plays_reward.put(reward)
    if self.past_plays_context.full():
      self.past_plays_context.get()
      self.past_plays_action.get()
      self.past_plays_reward.get()
    
  
  def update_variational_posterior(self, context, action, reward, logs=False):
    features = []
    for context, action in zip(self.past_plays_context, self.past_plays_action):
      features.append(torch.cat(
          [context, torch.tensor([action], dtype=torch.float)]).unsqueeze(0))
    features = torch.cat(features)
    
    rewards = torch.tensor(self.past_plays_reward, dtype=torch.float)
    
    past_plays_set = torch.utils.data.TensorDataset(features, rewards)
    past_plays_loader = torch.utils.data.DataLoader(
        past_plays_set, batch_size=64, shuffle=True, num_workers=1)
    
    for i, data in enumerate(past_plays_loader):
      inputs, labels = data
      # zero the parameter gradients
      self.optimizer.zero_grad()

      # forward + backward + optimize
      loss, _, _ = self.value_estimates.cost_function(
          inputs, labels, num_samples=2, num_batches=len(trainloader))
      loss.backward()
      self.optimizer.step()
      
      if logs:
        print('{}. Loss: {}'.format(i, loss))

In [0]:
class Environment(object):
  
  def __init__(self, agent, dataloader):
    self.agent = agent
    self.dataloader = dataloader
    self.cumulative_regret = 0
  
  def play_round(self, logs=False):
    context, eat_reward = next(iter(self.dataloader))
    selected_action = agent.select_action(context[0], logs)
    if selected_action == 0: #not eat
      if eat_reward == EDIBLE_REWARD:
        if logs:
          print('Mushroom is edible; agent chose to pass.')
        self.cumulative_regret += abs(EDIBLE_REWARD)
      else:
        if logs:
          print('Mushroom is poisonous; agent chose to pass.')
      reward = 0
    else: #eat
      if eat_reward == POISONOUS_REWARD:
        self.cumulative_regret += abs(POISONOUS_REWARD)
        if logs:
          print('Mushroom is poisonous; agent chose to eat.')
      else:
        if logs:
          print('Mushroom is edible; agent chose to eat.')
      reward = eat_reward
    if logs:
      print('Cumulative regret is {}'.format(self.cumulative_regret))
    agent.update_memory(context[0], selected_action, reward)
    if agent.collected_data_count() >= AGENT_MEMORY_LEN:
      agent.update_variational_posterior(context[0], selected_action, reward)

In [0]:
agent = Agent()
env = Environment(agent, trainloader)

In [0]:
for i in range(100000):
  if i == AGENT_MEMORY_LEN:
    print('Started training')
  if (i+1) % 100 == 0:
    print(i)
    env.play_round(logs=True)
  else:
    env.play_round()

99
Action 0 - predicted reward: tensor([16.1083], grad_fn=<DivBackward0>)
Action 1 - predicted reward: tensor([23.0726], grad_fn=<DivBackward0>)
Mushroom is edible; agent chose to eat.
Cumulative regret is 1195.0
199
Action 0 - predicted reward: tensor([22.5331], grad_fn=<DivBackward0>)
Action 1 - predicted reward: tensor([18.3910], grad_fn=<DivBackward0>)
Mushroom is edible; agent chose to pass.
Cumulative regret is 2265.0
299
Action 0 - predicted reward: tensor([20.4924], grad_fn=<DivBackward0>)
Action 1 - predicted reward: tensor([21.5987], grad_fn=<DivBackward0>)
Mushroom is poisonous; agent chose to eat.
Cumulative regret is 3165.0
399
Action 0 - predicted reward: tensor([14.2115], grad_fn=<DivBackward0>)
Action 1 - predicted reward: tensor([19.1123], grad_fn=<DivBackward0>)
Mushroom is poisonous; agent chose to eat.
Cumulative regret is 4480.0
499
Action 0 - predicted reward: tensor([15.9786], grad_fn=<DivBackward0>)
Action 1 - predicted reward: tensor([20.6412], grad_fn=<DivBack