In [15]:
import random

import matplotlib.pyplot as plt
import numpy as np

import torch
import torch.nn as nn
import torch.optim as optim
from torch import autograd

import math

from util import *

In [32]:
def terminal(state, k_eps=1e-4):
    for i in state.flatten():
        if abs(i) <= k_eps:
            return True
    return False

class Node:
    def __init__(self, parent, state, n_children, value, depth=0):
        self.state = state
        self.parent = parent
        self.visits = 0
        self.depth = depth
        self.children = [None] * n_children
        self.is_terminal = terminal(self.state)
        # set end game state to best value?
        # todo figure out how to account for number of steps taken
        if self.is_terminal:
            self.value = 1
        else:
            self.value = value
        self.subtree_value = torch.zeros(1)

    def __str__(self):
        return ("State: " + str(self.state) + "; Value: " + str(self.value)
                + "; Subtree Value: " + str(self.subtree_value) + "; Visits:", str(self.visits))

    def is_leaf(self):
        for i in self.state:
            if i is not None:
                return False
        return True


In [None]:
class MCTS:
    def __init__(self, actions, C, weight, value_fn):
        self.actions = actions
        self.k_C = C
        self.k_weight = weight
        self.value_fn = value_fn
        self.max_depth = 0
        self.terminal = None    # None if no terminal state found; terminal Node if found

    def pick_child(self, node):
        # UCT
        t = []
        for i in node.children:
            if i is None:
                continue
            t.append(UCT_fn(i, self.k_C))

        t = torch.tensor(t)

        rvs = torch.squeeze(torch.argwhere(t == torch.max(t)), axis=1)
        if len(rvs) == 0:
            return random.randint(0, len(node.children)-1)
        return int(random.choice(rvs))

    def default_search(self, node):
        """
        If node is fully explored (neither child is None), return True
        Otherwise, initialize value of a random unexplored next state

        :param node: node to search from
        :return: if fully explored, True. Else, value of the random unexplored next state
        """
        possible = []
        for i in range(len(node.children)):
            if node.children[i] is None:
                possible.append(i)
        if len(possible) == 0:
            return True

        i = random.choice(possible)
        # if unexplored or non-terminal, get value
        state = torch.tensor(self.actions[i](node.state.flatten()), dtype=torch.float)
        state = state.reshape(node.state.shape)
        child_val = self.value_fn(state)
        child_val = child_val.flatten()[0]
        if child_val is torch.nan:
            pass    # debugging
        node.children[i] = Node(node, state, len(self.actions), value=child_val, depth=node.depth+1)

        # if new Node is terminal, take it as the tree's terminal if it takes less time to reach than current terminal
        if node.children[i].is_terminal and (self.terminal is None or node.children[i].depth < self.terminal.depth):
                self.terminal = node.children[i]

        if node.children[i].depth > self.max_depth:
            self.max_depth = node.children[i].depth
        return node.children[i]

    def tree_policy(self, node, computations):
        while node.is_terminal is False:
            explored = self.default_search(node)
            if explored is not True:
                return explored, computations + 1
            node = node.children[self.pick_child(node)]
            # node = random.choice(node.children)
        return node, computations + 1

    def mean_prop(self, node):
        """
        Backprop up from a leaf, where subtree_value is the average of a node's rewards and its subtree's rewards

        :param node: of subtree
        """
        node.subtree_value = torch.zeros(1)
        node.subtree_value += node.value
        valid_children = 0
        if not node.is_leaf():
            for i in node.children:
                if i is None:
                    continue
                node.subtree_value += self.k_weight * i.subtree_value
                valid_children += 1
        node.subtree_value /= valid_chlidren + 1
        node.visits += 1
        if node.parent is None:
            return
        self.mean_prop(node.parent)

    def run(self, root, comp_limit=10):
        """
        Shoutout "A Survey of MCTS Methods"
        :param root: the current state
        :param comp_limit: max number of possible future scenarios to compute (carries over)
        :return: index corresponding to best action
        """
        if root.is_terminal:
            return True
        comps = 0
        while comps < comp_limit:
            node, comps = self.tree_policy(root, comps)
            self.mean_prop(node)

        rv = self.pick_child(root)
        
        if False:
            print("root state:", root.state)
            print("child states: ",end="")
            for child in root.children:
                print(child.state, end=",")
            print()
        return rv

In [39]:
class Loss(nn.Module):
    def __init__(self):
        super(Loss, self).__init__()
        
    def forward(self, v_out, p_out, p_sampled, success, steps, init_state):
        """
        Loss function designed to reward successful game completion while taking the least amount of steps possible
        Adapted from:
            - "Mastering the game of Go without human knowledge" (Silver et al)
            - "Discovering faster matrix multiplication algorithms with reinforcement learning" (Fawzi et al)

        :param v_out: the value outputed for the state by NN
        :param p_out: the policy outputed for the state by NN
        :param success: if the game terminated in a success
        :param steps: number of steps taken
        :param init_state: initial state
        :return: loss
        """
        loss = steps
        if success:
            loss -= torch.linalg.norm(init_state)
        loss += torch.square(v_out - int(success)).sum()
        loss += torch.dot(p_sampled, torch.log(p_out))
        # if torch.sum(torch.isnan(loss)) >= 1:
        #     print("v_out:", v_out, "\tp_sampled:", p_sampled, "\tp_out:",p_out)
        return loss


        
class ValueNN(nn.Module):
    def __init__(self, n_actions):
        super(ValueNN, self).__init__()
        self.flatten = nn.Flatten()
        self.stack = nn.Sequential(
            nn.Linear(2, 128),
            nn.ReLU(),
            nn.Linear(128, 128),
            nn.ReLU(),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Linear(64, 16),
            nn.ReLU(),
            nn.Linear(16, 1+n_actions),
        )
        self.value_activation = nn.Sigmoid()
        self.policy_activation = nn.Softmax(dim=0)
        
    def forward(self, x):
        x = self.flatten(x)
        x = self.stack(x).flatten()
        value = self.value_activation(x[0:1]).unsqueeze(0)
        policy = torch.clamp(self.policy_activation(x[1:]).unsqueeze(0),min=1e-8,max=1-(1e-8))
        return torch.cat((value, policy), dim=1)


            
        
    

In [44]:
import math

k_state_upper_lim = 30 # arbitrary
k_C = 1 / math.sqrt(2)
k_comp_limit = int(k_state_upper_lim ** (3/2))

k_actions = (a_subtract, a_swap)

value_NN = ValueNN(len(k_actions))
loss_fn = Loss()
optimizer = optim.SGD(value_NN.parameters(), lr=0.001)




def gen_start_state_2a(limit):
    return torch.round(torch.rand((1, 2)) * limit + 1).float()

def train(epochs):
    for t in range(epochs):
        optimizer.zero_grad()
        # Repeat the following:
        # 1) run the NN on some random initial state
        # 2) update the NN based off performance in that game
        mcts = MCTS(k_actions, C=k_C, weight=1, value_fn=value_NN)
        start = gen_start_state_2a(k_state_upper_lim)

        rv = mcts.value_fn(start).flatten()
        value = rv[0]
        policy = rv[1:]

        start_node = Node(None, start, len(k_actions), value, 0)

        # play out a game
        mcts.run(start_node, comp_limit=k_comp_limit)

        # get attributes of game just played
        steps = mcts.max_depth
        success = mcts.terminal is not None
        if success:
            steps = mcts.terminal.depth

        visits = []
        for i in start_node.children:
            if i is None:
                visits.append(0)
            else:
                visits.append(i.visits)
        visits = torch.tensor(visits, dtype=torch.float)
        p_sampled = visits / torch.sum(visits)
        
        loss = loss_fn(start_node.subtree_value, policy, p_sampled, success, steps, start)
        loss.backward()
        optimizer.step()

        if (t+1) % 10 == 0:
            print("Epoch:", t+1,"\t\tLoss:",loss.item())
    torch.save(value_NN.state_dict(), "deep_mcts_weights.pth")

train(100)
            

Epoch: 10 		Loss: 6.527867317199707
Epoch: 20 		Loss: -23.659456253051758
Epoch: 30 		Loss: 6.172438621520996
Epoch: 40 		Loss: 6.452219009399414
Epoch: 50 		Loss: 6.379735946655273
Epoch: 60 		Loss: 5.2792158126831055
Epoch: 70 		Loss: -26.06336212158203
Epoch: 80 		Loss: 5.5235371589660645
Epoch: 90 		Loss: -0.22800970077514648
Epoch: 100 		Loss: -2.6788597106933594


In [45]:
def get_data(fname):
    x = torch.tensor(np.loadtxt(fname, delimiter=","), dtype=torch.float)
    return x[:,:-1], x[:,-1]

def plot_db(mcts, actions, comp_limit, ranges):
    X = ranges[0]
    Y = ranges[1]
    action_plot = [[] * actions]
    for i in X:
        for j in Y:
            result = mcts.run(
                Node(None, (i,j), n_children=len(actions),value=value_NN([i,j])), comp_limit=comp_limit)
            action_plot[result].append((i,j))
    for i in range(len(action_plot)):
        action = np.array(action_plot[i])
        plt.scatter(action[:,0], action[:,1], color=("C"+str(i)), label=action)
    plt.show()

In [46]:
def test(x, y, C, value_fn, weight=1., comp_limit=10, actions=(a_subtract, a_swap), zero_index=False, dbs=None):
    correct = 0
    mcts = MCTS(actions, C, weight, value_fn)
    guess_dist = [0] * len(actions)
    if zero_index:
        y = y - np.ones(len(y))
    for i in range(len(x)):
        state = torch.tensor(x[i]).unsqueeze(0)
        rv = value_NN(state).flatten()[1:]                      # take the move distribution given by NN
        rv = rv.multinomial(num_samples=1, replacement=True)    # sample from the move distribution
        if rv == y[i]:
            correct += 1
        guess_dist[rv] += 1
        # if (i+1) % 100 == 0:
            # print("epoch", i+1, ":", correct / (i+1))
    # todo fix
    # if dbs is not None:
    #     # graphing decision boundary
    #     plot_db(mcts, actions, comp_limit, ranges=dbs)
    return correct / len(x), guess_dist


def run_test(data_name, actions, C, value_fn, cases=100, lookahead=100, weight=1., zero_index=False, dbs=None):
    test_X, test_Y = get_data(data_name)
    test_Y.reshape(-1, 1)

    acc, guesses = test(test_X[:cases], test_Y[:cases],
                        C, value_fn, weight, comp_limit=lookahead, actions=actions, zero_index=zero_index, dbs=dbs)
    print("Test Accuracy:", acc)
    print("Guess Distribution:", guesses)

In [47]:
k_C = 1 / math.sqrt(2)  # satisfies Hoeffding Ineq (Kocsis and Szepesvari)
k_cases = 2000

k_dbound_size = 100

In [48]:
dual_file = "test_data/test_simple.csv"

db2 = np.linspace(2, k_dbound_size, k_dbound_size - 1)
two_dbs =[db2, db2]

run_test(dual_file, [a_subtract, a_swap], C=k_C, value_fn=value_NN, cases=k_cases, lookahead=10, dbs=two_dbs)
# ~90% accuracy

  state = torch.tensor(x[i]).unsqueeze(0)


Test Accuracy: 0.4895
Guess Distribution: [1, 1999]


In [None]:
# quad_file = "../Donald/four_step_euclidean/four_directions_cleaner_test.csv"     # thanks, donald

# k_C = 1 / math.sqrt(2)  # satisfies Hoeffding Ineq (Kocsis and Szepesvari)
# k_cases = 10

# k_dbound_size = 100

# db4 = np.linspace(-k_dbound_size/2, k_dbound_size/2, k_dbound_size+1)
# quad_dbs = [db4, db4]
# run_test(quad_file, [a_plsy, a_suby, a_plsx, a_subx], k_C, k_cases, lookahead=100, zero_index=False)
# # 8% accuracy on Donald test csv