In [1]:
import logging
#logging.basicConfig(level=logging.DEBUG)

import joblib
from models.feature_extractor import FeatureExtractor
from models.win_rate_models import WinRateSciKitModel
from models.mcts_draft import AllPickDraft, CaptainsModeDraft
from mcts import mcts

In [2]:
feature_extractor = joblib.load('input/feature_extractor.joblib')
linear_svc = joblib.load('input/linear_svc.joblib')

In [3]:
win_rate_model = WinRateSciKitModel(feature_extractor, linear_svc)

In [4]:
x = win_rate_model.prepare_input_vector(([6,7,8,9,10],[1,2,3,4,5]))

In [5]:
initial_state = AllPickDraft(win_rate_model)

In [6]:
print(initial_state.bans, initial_state.radiant_pick, initial_state.dire_pick)
print(initial_state.isTerminal())
pa = initial_state.getPossibleActions()
print(pa)
initial_state = initial_state.takeAction(pa[1])
print(initial_state.bans, initial_state.radiant_pick, initial_state.dire_pick)

frozenset() set() set()
False
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 119, 120, 121, 126, 128, 129]
frozenset() {2} set()


In [7]:
def getOrderedMoves(mcts):
    bestNodes = []
    for action, child in mcts.root.children.items():
        nodeValue = child.totalReward / child.numVisits
        # bestNodes.append((k, nodeValue, child.numVisits))
        bestNodes.append((action, child.numVisits, nodeValue))
    bestNodes.sort(key = lambda x: -x[1])
    return bestNodes

In [8]:
initial_state = AllPickDraft(win_rate_model)

In [9]:
mcts_object = mcts(timeLimit=20_000)
mcts_object.search(initial_state)

42

In [10]:
mcts_object.iterations

16397

In [11]:
getOrderedMoves(mcts_object)

[(42, 1485, 0.8175084175084175),
 (38, 1220, 0.809016393442623),
 (84, 946, 0.7970401691331924),
 (80, 839, 0.7902264600715138),
 (67, 800, 0.7875),
 (92, 790, 0.7873417721518987),
 (108, 778, 0.7866323907455013),
 (70, 765, 0.7856209150326797),
 (32, 743, 0.7833109017496636),
 (36, 671, 0.7779433681073026),
 (77, 338, 0.727810650887574),
 (48, 282, 0.7127659574468085),
 (37, 251, 0.701195219123506),
 (51, 251, 0.701195219123506),
 (4, 206, 0.6796116504854369),
 (61, 178, 0.6629213483146067),
 (1, 174, 0.6609195402298851),
 (22, 166, 0.6566265060240963),
 (110, 161, 0.6521739130434783),
 (60, 155, 0.6516129032258065),
 (3, 134, 0.6268656716417911),
 (5, 114, 0.6052631578947368),
 (73, 114, 0.6052631578947368),
 (82, 109, 0.5963302752293578),
 (29, 102, 0.5882352941176471),
 (64, 102, 0.5882352941176471),
 (62, 100, 0.58),
 (6, 98, 0.5816326530612245),
 (49, 97, 0.5773195876288659),
 (56, 94, 0.574468085106383),
 (55, 91, 0.5714285714285714),
 (95, 90, 0.5666666666666667),
 (9, 87, 0.56