In [None]:
#%pip install ipympl
#import scipy

from connect4 import connect
from othello import othello
from mtcs import UTCSearch, MCPlay

import numpy as np

%matplotlib widget
import matplotlib.pyplot as plt
import matplotlib as mpl
mpl.rcParams['toolbar'] = 'None'

from matplotlib import gridspec
import time

import asyncio

def replay(pos,s):
    g = pos.reset()
    for a in pos.moves[:s]:
        g = g.action(a)
    return g


def retract(pos):
    g = replay(pos, max(0,len(pos.moves)-2))
    #g.moves = pos.moves
    g.eval = pos.eval[:g.current-2]
    return g


def back(pos):
    g = replay(pos, max(0, pos.current-1))
    g.moves = pos.moves
    g.eval = pos.eval
    return g

def forward(pos):
    g = replay(pos, pos.current+1)
    g.moves = pos.moves
    g.eval = pos.eval
    return g



def drawprobs(probs):
    ax2.clear()
    a,n,w = zip(*probs)
    an = len(a)
    ax2.bar(np.arange(an),np.array(w)/n)
    ax2.plot([-1,an],[0.5,0.5], color='black',lw=0.5)
    ax2.plot([-1,an],[0,0], color='black',lw=0.5)
    ax2.plot([-1,an],[1,1], color='black',lw=0.5)
    ax2.set_ylim(-0.1,1.1)
    #ax2.set_title(f'{str(sum(n))}')
    ax3.clear()
    ax3.bar(np.arange(len(a)),n,color='red')
    ax2.set_axis_off()
    ax3.set_axis_off()

    refresh()

def drawvals(g):
    ax4.clear()
    ax4.set_ylim(-0.1,1.1)
    ax4.set_axis_off()
    l1 = [v for t,v in g.eval[:g.current] if t == 1]
    l2 = [v for t,v in g.eval[:g.current] if t == -1]
    l = max(len(l1),len(l2),2)
    rx = [0,l-1]
    ax4.plot(rx,[0,0], color='black',lw=0.5,ls='dotted')
    ax4.plot(rx,[0.5,0.5], color='black',lw=0.5,ls='dotted')
    ax4.plot(rx,[1,1], color='black',lw=0.5,ls='dotted')
    ax4.plot(l1, '.-', markersize=10, color='gray')
    ax4.plot(l2, '.-', markersize=10, color='black')

def refresh():
    fig.canvas.draw()
    #fig.canvas.draw_idle()
    #fig.canvas.start_event_loop(0.001)


async def waitCommand():
    global command
    command = None
    while command is None:
        await asyncio.sleep(0.01)
    return command

def on_press(event):
    global command
    command = event.key
    #info(command)

def onclick(event):
    global command, move
    command = 'move'
    #info(command)
    if G == 'connect':
        move = int(round(event.xdata))
    if G == 'othello':
        move = int(round(event.xdata)) + int(round(event.ydata))*games['othello'][1]


async def Human(g):
    global move
    if len(g.valid) == 1:
        print('forced move')
        info('your turn (forced move)')
        time.sleep(3)
        g = g.action(g.valid[0])
        g.eval.append((-g.turn,-1))
        info('')
        return g
    info('your turn')
    while True:
        comm = await waitCommand()

        if comm == 'move':
            g = g.action(move)
            g.eval.append((-g.turn,-1))
            info('')
            refresh()
            return g

        if comm == 'b':
            g = retract(g)
            g.draw(ax)
            drawvals(g)
            refresh()

        if comm == 'left':
            g = back(g)
            g.draw(ax)
            drawvals(g)
            refresh()

        if comm == 'right':
            g = forward(g)
            g.draw(ax)
            drawvals(g)
            refresh()


        if comm == 'x':
            g.resigned = True
            g.draw(ax)
            refresh()
            return g




fig = plt.figure(figsize=(6,6))
fig.canvas.toolbar_visible = False
fig.canvas.header_visible = False
fig.canvas.footer_visible = False
fig.canvas.capture_scroll = False

fig.canvas.mpl_connect('key_press_event', on_press)
fig.canvas.mpl_connect('button_press_event', onclick)

def info(msg):
    plt.suptitle(msg)
    fig.canvas.draw()

gs = gridspec.GridSpec(16, 4)
ax  = plt.subplot(gs[4:16,0:4])
ax2 = plt.subplot(gs[1:4,0:1])
ax3 = plt.subplot(gs[1:4,1:2])
ax4 = plt.subplot(gs[1:4,2:4])
ax4.set_axis_off()
ax2.set_axis_off()
ax3.set_axis_off()
plt.tight_layout(pad=0,h_pad=0,w_pad=0)

color = {1:'white', -1: 'black', 0: 'nobody'}

async def play():
    global g,current_player, next_player, curr_async, next_async
    while True:

        if curr_async:
            g = await current_player(g)
        else:
            g = current_player(g)

        current_player, next_player = next_player, current_player
        curr_async, next_async = next_async, curr_async
        g.draw(ax)
        drawvals(g)
        refresh()

        if g.terminal:
            info(f'{color[g.winner]} ({engine[g.winner]}) wins!')
            break
        if g.resigned:
            info(f'{color[g.turn]} ({engine[g.turn]}) resigned')
            break

games = {'othello': (othello, 8),
         'connect': (connect, 7, 6, 4)}

players = {'mcts':  lambda g: UTCSearch(g, 10000, 10, drawprobs),
           'pure':  lambda g: MCPlay(g, 100),
           'human': lambda g: Human(g)}

isasync = {'mcts':  False,
           'pure':  False,
           'human': True }

In [None]:
W = 'mcts'
B = 'human'
G = 'othello'
white = players[W]
black = players[B]
engine = {1: W, -1: B, 0: 'nobody'}

current_player = white
next_player    = black
curr_async     = isasync[W]
next_async     = isasync[B]

g = games[G]
g = g[0](*g[1:])

g.draw(ax)
info('MTCS')

In [None]:
loop = asyncio.get_event_loop()
loop.create_task(play())