In [381]:
import torch
import torch.nn as nn
import torch.nn.functional as F

from dataclasses import dataclass
from typing import List, Tuple, Optional, Literal, Dict
from random import choice
from copy import deepcopy
import math
from time import time

from check_submission import check_submission
from game_mechanics import (
    State,
    TronEnv,
    choose_move_randomly,
    choose_move_square,
    human_player,
    is_terminal,
    play_tron,
    reward_function,
    rules_rollout,
    transition_function,
    ARENA_WIDTH
)

In [334]:
state.opponent.alive

True

In [358]:
def neighbouring(a, b) -> bool:
    '''
    returns whether two squares are neighbours.
    '''
    return abs(a[0] - b[0]) + abs(a[1] - b[1]) == 1 # and (a < b)

def rollout(state):
    if state.player.alive and not state.opponent.alive:
        return 1
    if not state.player.alive and state.opponent.alive:
        return -1
    if not state.player.alive and not state.opponent.alive:
        return 0
    tiles = [(i,j) for i in range(ARENA_WIDTH) for j in range(ARENA_WIDTH)]
    neighbours = {a: [b for b in tiles if neighbouring(a, b)] for a in tiles}
    
    # create initial snake zones: just around the heads
    heads = [bike.positions[0] for bike in state.bikes]
    snake_zones = {tile: head for head in heads for tile in neighbours[head]}
    
    ## remove initial snake bodies from the neighbour graph
    for bike in state.bikes:
        for tile in bike.positions:
            neighbours_tile = neighbours[tile]
            for neighbouring_tile in neighbours_tile:
                neighbours[neighbouring_tile].remove(tile)
            neighbours.pop(tile)
    
    ## expand the snake zones
    def expand_snake_zones(snake_zones):
        new_snake_zones = deepcopy(snake_zones)
        for tile in neighbours:
            possible_assignments = [snake_zones.get(n, None) for n in neighbours.get(tile, [])]
            possible_assignments = [pa for pa in possible_assignments if pa is not None]
            if len(set(possible_assignments)) == 1:
                new_snake_zones[tile] = possible_assignments[0]
        return new_snake_zones

    new_snake_zones = 
    for _ in range(10):
        snake_zones = expand_snake_zones(snake_zones)
        
    snake_zones = {head: [tile for tile in snake_zones if snake_zones[tile] == head] for head in heads }
    areas = [len(tiles)/ARENA_WIDTH**2 for tiles in snake_zones.values()]
    return (areas[0] - areas[1]) / 4

In [363]:
# class Node:
#     def __init__(self, state: State, last_action):
#         self.state = state
#         self.last_action = last_action
# #         self.is_terminal = is_terminal(last_action, state) if last_action is not None else False
#         # No guarantee that these NODES exist in the MCTS TREE!
#         self.child_states = self._get_possible_children()
#         self.visit_count = 0
#         self.total_reward = 0
            
#     @property
#     def is_terminal(self) -> bool:
#         return len(self.state.bikes) != 2

#     def _get_possible_children(self) -> Dict[int, State]:
#         """Gets the possible children of this node."""
#         if self.is_terminal:
#             return {}
#         children = {}
#         for action in [1,2,3]:
#             state = transition_function(self.state, action)
#             children[action] = state
#         return children
    
class UCTNode():
    def __init__(self, game_state, parent=None, prior=0, value_estimate = 0):
        self.game_state = game_state
        self.is_expanded = False
        self.parent = parent  # Optional[UCTNode]
        self.children = {}  # Dict[move, UCTNode]
        self.prior = prior  # float
        self.total_value = 0  # float
        self.number_visits = 0  # int
        self.value_estimate = value_estimate # float
        
    @property
    def is_terminal(self) -> bool:
        return len(self.game_state.bikes) != 2
    
    @property
    def is_player_to_move(self) -> bool:
        return self.game_state.player_move is None
    
    def _get_possible_children(self) -> Dict[int, State]:
        """Gets the possible children of this node."""
        if self.is_terminal:
            return {}
        children = {}
        for action in [1,2,3]:
            state = transition_function(self.game_state, action)
            children[action] = state
        return children
        
    def Q(self) -> float:
        return self.total_value / (1 + self.number_visits)

    def U(self) -> float:
        return (math.sqrt(self.parent.number_visits)
            * self.value_estimate / (1 + self.number_visits))

    def best_child(self):# -> UCTNode:
        return max(self.children.values(),
                   key=lambda node: node.Q() + node.U())

    def select_leaf(self):# -> UCTNode:
        current = self
        while current.is_expanded:
            current = current.best_child()
        return current
    
    def expand(self):
        self.is_expanded = True
        for move, state in self._get_possible_children().items():
            self.add_child(move, state)
            
    def add_child(self, move, state):
        value_estimate = -1*rollout(state)*self.is_player_to_move
        self.children[move] = UCTNode(
            state, parent=self, value_estimate=value_estimate)
        
    def backup(self, value_estimate):
        current = self
        while current.parent is not None:
            current.number_visits += 1
            current.total_value += (value_estimate *
                self.is_player_to_move)
            current = current.parent

In [364]:
env = TronEnv(choose_move_randomly)
state, reward, done, info = env.reset()

In [365]:
root = UCTNode(state)

In [391]:
time() - start

7.740873098373413

In [393]:
max_time = 0.45
start = time()
while time() - start < max_time:
    leaf = root.select_leaf()
    leaf.expand()
    value_estimate = rollout(leaf.game_state)
    leaf.backup(value_estimate)
print(max(root.children.items(),
        key=lambda item: item[1].number_visits))

(1, <__main__.UCTNode object at 0x7f77c65d66a0>)


In [372]:
root.children

{1: <__main__.UCTNode at 0x7f77c65d66a0>,
 2: <__main__.UCTNode at 0x7f77c59eba30>,
 3: <__main__.UCTNode at 0x7f77c66f2670>}

In [376]:
[c.value_estimate for c in root.children.values()]

[0.10555555555555557, 0.10555555555555557, 0.10555555555555557]

In [None]:
def UCT_search(game_state, num_reads):
    root = UCTNode(game_state)
    for _ in range(num_reads):
        leaf = root.select_leaf()
        leaf.expand()
        leaf.backup(value_estimate)
    return max(root.children.items(),
               key=lambda item: item[1].number_visits)

In [317]:
root = Node(state, None)

In [399]:
env = TronEnv(choose_move_randomly)
state, reward, done, info = env.reset()

In [285]:
possible_moves = [1,2,3]

In [None]:
new_state.ter

In [302]:
new_state = transition_function(state, 1)

In [403]:
new_state = transition_function(state, 1)
new_state = transition_function(new_state, 2)

In [398]:
state, reward, done, info = env.step(1)
if done: print(f'finito, {reward}')

In [291]:
areas(state)

(3, 11)
(3, 12)
(7, 3)
(7, 2)


-0.017777777777777767

In [281]:
new_state.bikes[1].positions

[(3, 11), (2, 11)]

In [282]:
state.bikes[1].positions

[(3, 11), (2, 11)]

In [292]:
areas(new_state)

(3, 10)
(3, 11)
(3, 12)
(6, 3)
(7, 3)
(7, 2)


0.0

In [406]:
tiles = [(i,j) for i in range(ARENA_WIDTH) for j in range(ARENA_WIDTH)]

def neighbouring(a, b) -> bool:
    '''
    returns whether two squares are neighbours.
    '''
    return abs(a[0] - b[0]) + abs(a[1] - b[1]) == 1 # and (a < b)

neighbours = {a: [b for b in tiles if neighbouring(a, b)] for a in tiles}

env = TronEnv(choose_move_randomly)
state, reward, done, info = env.reset()

# create initial snake zones: just around the heads
heads = [bike.positions[0] for bike in state.bikes]
snake_zones = {tile: head for head in heads for tile in neighbours[head]}

start = time()
## remove initial snake bodies from the neighbour graph
for bike in state.bikes:
    for tile in bike.positions:
        print(tile)
        neighbours_tile = neighbours[tile]
        for neighbouring_tile in neighbours_tile:
            neighbours[neighbouring_tile].remove(tile)
        neighbours.pop(tile)
print(f'Removed snake bodies in {time() - start} seconds')
        
        
## expand the snake zones
def expand_snake_zones(snake_zones):
    new_snake_zones = deepcopy(snake_zones)
    for tile in neighbours:
        possible_assignments = [snake_zones.get(n, None) for n in neighbours.get(tile, [])]
        possible_assignments = [pa for pa in possible_assignments if pa is not None]
        if len(set(possible_assignments)) == 1:
            new_snake_zones[tile] = possible_assignments[0]
    return new_snake_zones

start = time()
for _ in range(10):
#     img = np.zeros((15, 15))
#     for tile in snake_zones:
#         if snake_zones[tile] == heads[0]:
#             img[tile] = 0.5
#         else:
#             img[tile] = -0.5
#     plt.imshow(img)
#     plt.show()
    
    snake_zones = expand_snake_zones(snake_zones)
print(f'expanded snake zones in {time() - start} seconds')
    
# reverse from {tile: head} to {head: [tiles]}
snake_zones = {head: [tile for tile in snake_zones if snake_zones[tile] == head] for head in heads }
areas = {head: len(tiles) for head, tiles in snake_zones.items()}

(11, 3)
(10, 3)
(7, 11)
(8, 11)
Removed snake bodies in 0.0004067420959472656 seconds
expanded snake zones in 0.0076904296875 seconds


In [189]:
snake_zones

{(6, 11): (7, 11),
 (7, 10): (7, 11),
 (7, 12): (7, 11),
 (8, 11): (7, 11),
 (2, 3): (3, 3),
 (3, 2): (3, 3),
 (3, 4): (3, 3),
 (4, 3): (3, 3),
 (1, 3): (3, 3),
 (2, 2): (3, 3),
 (2, 4): (3, 3),
 (3, 5): (3, 3),
 (4, 2): (3, 3),
 (4, 4): (3, 3),
 (5, 3): (3, 3),
 (6, 10): (7, 11),
 (6, 12): (7, 11),
 (7, 9): (7, 11),
 (7, 13): (7, 11),
 (8, 10): (7, 11),
 (8, 12): (7, 11),
 (9, 11): (7, 11),
 (0, 3): (3, 3),
 (1, 2): (3, 3),
 (1, 4): (3, 3),
 (2, 1): (3, 3),
 (2, 5): (3, 3),
 (3, 6): (3, 3),
 (4, 1): (3, 3),
 (4, 5): (3, 3),
 (5, 2): (3, 3),
 (5, 4): (3, 3),
 (5, 10): (7, 11),
 (5, 12): (7, 11),
 (6, 3): (3, 3),
 (6, 9): (7, 11),
 (6, 13): (7, 11),
 (7, 8): (7, 11),
 (7, 14): (7, 11),
 (8, 9): (7, 11),
 (8, 13): (7, 11),
 (9, 10): (7, 11),
 (9, 12): (7, 11),
 (10, 11): (7, 11),
 (0, 2): (3, 3),
 (0, 4): (3, 3),
 (1, 1): (3, 3),
 (1, 5): (3, 3),
 (2, 0): (3, 3),
 (2, 6): (3, 3),
 (3, 1): (3, 3),
 (3, 7): (3, 3),
 (4, 0): (3, 3),
 (4, 6): (3, 3),
 (4, 10): (7, 11),
 (4, 12): (7, 11),
 (5

In [150]:
areas

{(7, 7): 64, (3, 7): 44}

In [151]:
len(neighbours)

221

In [119]:
def expand_snake_zones(snake_zones):
    new_snake_zones = deepcopy(snake_zones)
    occupied_tiles = snake_zones.values()[0] | snake_zones.values()[1]
    
    for head in snake_zones:
#         opponent_zone = [v for (k, v) in snake_zones if k != head]
#         print(opponent_zone)
        zone = snake_zones[head]
        for tile in zone:
            new_tiles = {new_tile for new_tile in neighbours[tile] if tile not in occupied_tiles}
            new_snake_zones[head] |= new_tiles
      
    return new_snake_zones

In [116]:
snake_zones

{(11, 3): {(10, 3), (11, 2), (11, 4), (12, 3)},
 (7, 3): {(6, 3), (7, 2), (7, 4), (8, 3)}}

In [120]:
expand_snake_zones(snake_zones)

TypeError: 'dict_values' object is not subscriptable

In [31]:
state, reward, done, info = env.step(1)
snake_zones = get_immeadiate_head_neighbours(state)

## remove new heads of snakes from neighbourhood graph after step
for bike in state.bikes:
    new_head = bike.positions[0]
    neighbours_tile = neighbours[new_head]
    for neighbouring_tile in neighbours_tile:
        neighbours[neighbouring_tile].remove(new_head)
    neighbours.pop(new_head)

In [24]:
for bike in state.bikes:
    print(bike.positions)

[(12, 7), (11, 7), (10, 7)]
[(2, 7), (3, 7), (4, 7)]


In [32]:
snake_zones

{(2, 3): {(1, 3), (2, 2), (2, 4)}, (7, 2): {(6, 2), (7, 1), (8, 2)}}

In [None]:
## assign tiles by distance to every snake:
heads = [bike.positions[0] for bike in state.bikes]
unassigned_tiles = neighbours.keys()
while len(unassigned_tiles) > 0:
    for tile in unassigned_tiles:
        neighbouring_tiles = neighbours[tile]
        ## if neighbouring both zones of control:

{(12, 7): [(12, 6), (12, 8), (13, 7)], (2, 7): [(1, 7), (2, 6), (2, 8)]}

In [28]:
a = {1,2,3}

In [97]:
snake_zones = expand_snake_zones(snake_zones)

[3, 2]
[3, 2]


In [105]:
len(snake_zones[(2,3)])

219

In [None]:


while sum([len(_) for _ in next_step_dfs.values()]) != 0:
    next_step_dfs = {head: set([neighbours[tile] for tile in next_step_dfs[head]]) for head in heads}
    next_step_dfs = {}

In [None]:
def 

In [13]:
env.__dict__

{'opponent_choose_move': <function game_mechanics.choose_move_randomly(state: game_mechanics.State) -> int>,
 '_render': False,
 'verbose': False,
 'game_speed_multiplier': 1.0,
 'starting_positions': [(7, 7),
  (7, 11),
  (3, 3),
  (7, 3),
  (11, 11),
  (11, 7),
  (3, 7),
  (11, 3),
  (3, 11)],
 'score': 0,
 'num_steps_taken': 0,
 'state': State(player=Bike player, opponent=Bike opponent, player_move=None),
 'dead_bikes': [],
 'color_lookup': {'player': (237, 0, 3), 'opponent': (53, 0, 255)}}