In [1]:
import random

import matplotlib.pyplot as plt
import numpy as np

import torch
import torch.nn as nn
import torch.optim as optim
from torch import autograd
import threading
import concurrent

import math

from util import *

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

import sys
print(f"nogil={getattr(sys.flags, 'nogil', False)}")

"""
NOTE: RUN https://github.com/colesbury/nogil FOR MULTITHREADED (performance not necessarily better)
"""

nogil=False


'\nNOTE: RUN https://github.com/colesbury/nogil FOR MULTITHREADED (performance not necessarily better)\n'

In [None]:
# Memory for better batching

class Memory:
    def __init__(self, width) -> None:
        self.mem_ = None
        self.len_ = 0
    def record(self, obs):
        if self.len_ == 0:
            self.mem_ = obs
        else:
            self.mem_ = torch.cat((self.mem_, obs), dim=0)
        self.len_ += 1
    def recall(self, n_samples):
        if self.len_ == 0:
            return None
        des_len = min(n_samples, self.len_)
        indices = torch.ones(self.mem_.shape[0]).multinomial(des_len, replacement=False)
        return self.mem_[indices]
    def size(self):
        return self.len_
    def clear(self):
        self.len_ = 0

In [3]:
def get_test_data(fname):
    x = torch.tensor(np.loadtxt(fname, delimiter=","), dtype=torch.float)
    return x[:,:-1], x[:,-1]

def plot_db(policy_fn, actions, ranges):
    X = ranges[0]
    Y = ranges[1]
    action_plot = []
    for i in actions:
        action_plot.append([])
    for i in X:
        for j in Y:
            rv = policy_fn(torch.tensor([i,j],dtype=torch.float).unsqueeze(0)).flatten().to(device)
            action_plot[torch.argmax(rv)].append((i.cpu(),j.cpu()))
    for i in range(len(action_plot)):
        action = np.array(action_plot[i])
        if len(action) == 0:
            continue
        plt.scatter(action[:,0], action[:,1], color=("C"+str(i)), label=action)
    plt.show()

def test(x, y, policy_fn, actions=k_2actions, dbs=None):
    correct = 0
    guess_dist = [0] * len(actions)
    for i in range(len(x)):
        state = torch.tensor(x[i]).unsqueeze(0).to(device)
        rv = policy_fn(state).flatten()                      # take the move distribution given by NN

        # todo pick one way to select
        # rv = rv.multinomial(num_samples=1, replacement=True)    # sample from the move distribution
        rv = torch.argmax(rv)

        if rv == y[i]:
            correct += 1
        guess_dist[rv] += 1
    # todo fix
    if dbs is not None:
        # graphing decision boundary
        plot_db(policy_fn, actions, ranges=dbs)
    return correct / len(x), guess_dist


def run_test(data_name, actions, policy_fn, cases=100, dbs=None):
    test_X, test_Y = get_test_data(data_name)
    test_X = test_X.to(device)
    test_Y.reshape(-1, 1)
    test_Y = test_Y.to(device)

    acc, guesses = test(x=test_X[:cases], y=test_Y[:cases],
                        policy_fn=policy_fn, actions=actions, dbs=dbs)
    print("Test Accuracy:", acc)
    print("Guess Distribution:", guesses)
    return acc, guesses


In [4]:
k_cases = 2000
k_dbound_size = 100

dual_file = "test_data/test_simple.csv"

db2 = torch.linspace(2, k_dbound_size, k_dbound_size - 1).to(device)
two_dbs = [db2, db2]

def alt_policy(state, value_fn):
    x = [i(state) for i in k_2actions]
    x = [value_fn(i) for i in x]
    x = torch.tensor(x)
    return x

In [5]:
k_C = 1 / math.sqrt(2)
# k_C = 0.1
k_thread_count_limit = 20
k_core_limit = 5

# def get_train_data(fname):
#     x = np.loadtxt(fname, delimiter=",")
#     return torch.tensor([x[:,2], x[:,2:]], dtype=torch.float)

# def get_nonterm_rwd(mcts):
#     return -mcts.max_depth

# def get_terminal_rwd(terminal_depth, start):
#     return -terminal_depth + torch.linalg.norm(start)

def train_sv(epochs, actions, policy_fn, value_fn, optimizers, fname, batch_size=10):
    k_mem_width = 4    # statex,statey,action,subtree_value
    memory = Memory(k_mem_width)
    loss_fn = Loss()
    # load data into memory
    with open(fname, 'r') as f:
        for line in f.readlines():
            entry = torch.tensor(list(map(int, line.split(','))))
            tree = MCTS(actions, C=k_C, weight=1, value_fn=value_fn)
            g_nodes = tree.generate(entry[0:2].unsqueeze(0).float(), entry[2:])
            for i in range(len(entry) - 2): # go by actions (so we disregard the terminal node)
                memory.record(torch.cat((g_nodes[i].state.reshape(1,2), 
                                         entry[2+i].reshape((1,1)), 
                                         g_nodes[i].subtree_value.reshape(1,1)), dim=1))
    # train off memory
    for t in range(epochs):
        for o in optimizers:
            o.zero_grad()
        batch = memory.recall(batch_size)
        v_out = value_fn(batch[:,:2])
        p_out = policy_fn(batch[:,:2])

        # one-hot encode actions; e.g. convert 3 -> (0,0,0,1)
        action_indices = batch[:,2:-1].to(torch.int64)
        p_target = oh_encode(action_indices, len(actions))
        # p_target = torch.zeros(action_indices.shape[0],len(actions))
        # p_target.scatter_(1, action_indices,1)

        v_target = batch[:,-1]

        loss = loss_fn(v_out.view(v_target.shape), v_target, p_out.view(p_target.shape), p_target)
        loss.backward()

        for o in optimizers:
            o.step()
    
        if (t+1) % 10 == 0:
            print("Epoch:", t+1,"\t\tLoss:",loss.item())


def train_play(epochs, actions, policy_fn, value_fn, optimizers, rand_start_state_fn, comp_limit, batch_size=16):
    history = Memory(3+4)    # [stateX,stateY,value] (probs are sampled probs)
    loss_fn = Loss()
    tot_loss = 0
    for t in range(epochs):

        for o in optimizers:
            o.zero_grad()
        # Repeat the following:
        # 1) run the NN on some random initial state
        # 2) update the NN based off performance in that game
            
        # play out some games
        k_comp_limit = comp_limit(t / epochs)

        # NOTE: ProcessPoolExecutor requires placing called functions/data structs in separate imported file
        payload = [(rand_start_state_fn(), actions, value_fn, k_comp_limit, k_C, k_thread_count_limit) for i in range(batch_size)]

        with concurrent.futures.ProcessPoolExecutor(max_workers=k_core_limit) as executor:
            for rv in zip(executor.map(one_batch, iter(payload))):
                history.record(rv[0])
        # for p in payload:
        #     history.record(one_batch(p))

        # train NN on games just played
        batch = history.recall(batch_size)
        batch_states = batch[:,:2]
        batch_vsampled = batch[:,2]

        loss = loss_fn(value_fn(batch_states).view(batch_vsampled.shape), batch_vsampled)
        loss.backward()
        tot_loss += loss.item() / batch_size

        history.clear()

        for o in optimizers:
            o.step()


        if (t+1) % 10 == 0:
            print("Epoch:", t+1,"\t\tLoss:",tot_loss/10)
            tot_loss = 0
            run_test(dual_file, k_2actions, policy_fn=lambda a: alt_policy(a, value_fn), cases=k_cases, dbs=None)




In [6]:
k_state_upper_lim = 10 # arbitrary
k_comp_limit = int(k_state_upper_lim ** (2))
k_min_comps = int(k_state_upper_lim ** (1.5))

value_fn_2 = ValueNN(2).to(device)
value_optim = optim.SGD(value_fn_2.parameters(), lr=0.00005, momentum=0.9)

def gen_start_state_2a():
    limit = k_state_upper_lim
    return torch.round(torch.rand((1, 2)) * limit + 1).float()

def adaptive_comp_limit(frac_epochs):
    # linearly decrease computation limit as model becomes better over time
    rv = k_comp_limit - (k_comp_limit - k_min_comps) * frac_epochs
    return int(rv)



In [None]:
train_play(epochs=150, actions=k_2actions, policy_fn=None, value_fn=value_fn_2, 
           optimizers=[value_optim], rand_start_state_fn=gen_start_state_2a, 
           comp_limit=adaptive_comp_limit, batch_size=16)

In [None]:
save = True
if save:
    torch.save(value_fn_2.state_dict(), "trained_weights/deep_mcts_2_v_weights.pth")
    # torch.save(policy_fn_2.state_dict(), "trained_weights/deep_mcts_2_p_weights.pth")

In [None]:
k_state_upper_lim = 30 # arbitrary
k_comp_limit = int(k_state_upper_lim ** (7/2))
value_fn_4 = ValueNN(2).to(device)
policy_fn_4 = PolicyNN(2,len(k_4actions)).to(device)
value_optim_4 = optim.Adam(value_fn_4.parameters(), lr=0.00005)
policy_optim_4 = optim.Adam(policy_fn_4.parameters(), lr=0.000005)

def gen_start_state_4a():
    limit = k_state_upper_lim
    return torch.round( (torch.rand((1, 2)) - 0.5) * 2 * limit).float()


In [None]:
train_sv(250, k_4actions, policy_fn=policy_fn_4, value_fn=value_fn_4, optimizers=[value_optim_4,policy_optim_4], fname='train_data/train_mcts.csv', batch_size=16)

Epoch: 10 		Loss: 2.1033401489257812
Epoch: 20 		Loss: 1.556889295578003
Epoch: 30 		Loss: 1.3582837581634521
Epoch: 40 		Loss: 1.4088467359542847
Epoch: 50 		Loss: 1.436020016670227
Epoch: 60 		Loss: 1.379128336906433
Epoch: 70 		Loss: 1.4841262102127075
Epoch: 80 		Loss: 1.3614771366119385
Epoch: 90 		Loss: 1.3785827159881592
Epoch: 100 		Loss: 1.3600499629974365
Epoch: 110 		Loss: 1.389901041984558
Epoch: 120 		Loss: 1.383531928062439
Epoch: 130 		Loss: 1.3663393259048462
Epoch: 140 		Loss: 1.3774583339691162
Epoch: 150 		Loss: 1.2760998010635376
Epoch: 160 		Loss: 1.3536250591278076
Epoch: 170 		Loss: 1.3584808111190796
Epoch: 180 		Loss: 1.3749842643737793
Epoch: 190 		Loss: 1.3323438167572021
Epoch: 200 		Loss: 1.3520457744598389
Epoch: 210 		Loss: 1.310624599456787
Epoch: 220 		Loss: 1.3285794258117676
Epoch: 230 		Loss: 1.3432512283325195
Epoch: 240 		Loss: 1.3051739931106567
Epoch: 250 		Loss: 1.3170291185379028


In [None]:
init_policy_fp = "start_weights/deep_mcts_4_p_weights.pth"
init_value_fp = "start_weights/deep_mcts_4_v_weights.pth"
trained_policy_fp = "trained_weights/deep_mcts_4_p_weights_f.pth"
trained_value_fp = "trained_weights/deep_mcts_4_v_weights_f.pth"

In [None]:

if False:
    torch.save(value_fn_4.state_dict(), init_value_fp)
    torch.save(policy_fn_4.state_dict(), init_policy_fp)

In [None]:
train_play(epochs=100, actions=k_4actions, policy_fn=policy_fn_4, value_fn=value_fn_4, optimizers=[value_optim_4, policy_optim_4], rand_start_state_fn=gen_start_state_4a, comp_limit=k_comp_limit)

In [None]:
# Test trained weights


# value_fn_4.load_state_dict(torch.load(trained_value_fp))
# policy_fn_4.load_state_dict(torch.load(trained_policy_fp))


In [None]:
# test rollout (not good metric at the moment)
def rollout(start, actions, value_fn, comp_limit, k_C, k_thread_count_limit, nogil):
    mcts = MCTS(actions, C=k_C, weight=1, value_fn=value_fn)

    value = mcts.value_fn(start).flatten().to(device)

    start_node = Node(None, start, len(actions), value, 0)

    mcts.run(start_node, comp_limit=comp_limit, max_threads=k_thread_count_limit, nogil=nogil)
    # choose best action
    best = 0
    for i in range(len(start_node.children)):
        if start_node.children[i] is None:
            continue
        if start_node.children[i].subtree_value > start_node.children[best].subtree_value:
            best = i
    return torch.tensor(best)


In [None]:
run_test(dual_file, k_2actions, policy_fn=lambda a: alt_policy(a, value_fn_2), cases=k_cases, dbs=two_dbs)
# ~99% accuracy

In [None]:
quad_file = "test_data/four_directions_cleaner_test.csv"     # thanks, donald

k_cases = 1000

k_dbound_size = 200

db4 = torch.linspace(-k_dbound_size/2, k_dbound_size/2, k_dbound_size+1).to(device)
quad_dbs = [db4, db4]
run_test(quad_file, k_4actions, policy_fn=lambda a: torch.argmax(value_fn_4(a)), cases=k_cases)
            
# run_test(quad_file, k_4actions, policy_fn=lambda a: oh_encode(torch.tensor(determine_action(a.flatten())).view((1,1)),4), cases=k_cases, dbs=quad_dbs)
# # 8% accuracy on Donald test csv

Test Accuracy: 0.261
Guess Distribution: [1000, 0, 0, 0]


  state = torch.tensor(x[i]).unsqueeze(0).to(device)


In [None]:
for i in range(0,200, 10):
    epoch = i + 1
    value_fn_4.load_state_dict(torch.load(trained_value_fp + "_" + str(epoch)))
    acc, guesses = run_test(quad_file, k_4actions, policy_fn=lambda a: alt_policy(a, value_fn_4), cases=k_cases, dbs=quad_dbs)
    print("Epoch {epoch} accuracy: " + str(acc))
