In [3]:
#from pynq.overlays.base import BaseOverlay
#rom pynq.lib.video import *
#from pynq import Overlay, Xlnk


#allocator = Xlnk()
#overlay = Overlay("connect_four.bit")

#help(overlay.video)

from pynq import Overlay, Xlnk
from pynq.overlays.base import BaseOverlay
from pynq.lib.video import *

allocator = Xlnk()
ol = Overlay("connect_four.bit")


hdmi_in = ol.video.hdmi_in
hdmi_out = ol.video.hdmi_out

hdmi_in.configure(PIXEL_RGBA)
hdmi_out.configure(hdmi_in.mode, PIXEL_RGBA)

<contextlib._GeneratorContextManager at 0x30ec2b70>

In [4]:
hdmi_in.start()
hdmi_out.start()

<contextlib._GeneratorContextManager at 0x2ffdfab0>

In [5]:
hdmi_in.tie(hdmi_out)

In [6]:
# includes for the minimax
from random import randint
import time # measure time performance

# setup for interfaces
from pynq import MMIO
corner_detection_reference_axilite = MMIO(0x83C30000, 0x10000)

In [7]:
class States_cache:
    def __init__(self):
        self.cached_states = {}

    def cache_state(self, state_grid, evaluation):
        self.cached_states[state_grid] = evaluation

    def already_cached(self, state_grid):
        return state_grid in self.cached_states

    def get_cached_value(self, state_grid):
        return self.cached_states[state_grid]

    def clear(self):
        self.cached_states.clear()


class Game_state:
    def __init__(self, grid = None):
        if grid == None:
            self.grid    = [0 for _ in range(42)]
            self.heights = [0 for _ in range(7)]
        else:
            self.grid = grid
            self.heights = [0 for _ in range(7)]
            for col in range(7):
                for row in range(6):
                    if self.grid[self.ij_to_idx(row,col)] != 0:
                        self.heights[col] = (6 - row)
                        break

    def col_to_idx(self, col):
        return (7 * (6 - self.heights[col])) + col
    
    def ij_to_idx(self, i, j):
        return (i * 7) + j

    def idx_to_ij(self, idx):
        return (int(idx / 7), idx % 7)

    def can_insert_coin(self, col):
        return (self.heights[col] <= 5)
    
    def insert_coin(self, pos, color):
        self.heights[pos] += 1 # the order of this two operations is critical
        self.grid[self.col_to_idx(pos)] = color

    def remove_coin(self, pos):
        self.grid[self.col_to_idx(pos)] = 0 # the order of this two operations is critical
        self.heights[pos] -= 1

    def check_connect4(self, f1, f2, i, j):
        tmp = 1
        for k in range(1, 4): # [1,2,3]
            if self.grid[self.ij_to_idx( i + k*f1, j + k*f2 )] == self.grid[self.ij_to_idx(i,j)]:
                tmp += 1
            else:
                break
        return (tmp == 4)

    def is_win(self):
        for i in range(6):
            for j in range(7):
                if self.grid[self.ij_to_idx(i,j)] != 0:
                    # horizontal
                    if (j + 3 < 7) and self.check_connect4(0, 1, i, j):
                        return True
                    # vertical
                    if (i + 3 < 6) and self.check_connect4(1, 0, i, j):
                        return True
                    # oblique up
                    if (i + 3 < 6 and j + 3 < 7) and self.check_connect4(1, 1, i, j):
                        return True
                    # oblique down
                    if (i - 3 >= 0 and j + 3 < 7) and self.check_connect4(-1, 1, i, j):
                        return True
        return False

    def count_connected(self, f1, f2, i, j):
        tmp = 0
        for k in range(1, 4): # [1,2,3]
            if i + k*f1 >= 0 and i + k*f1 < 6 and j + k*f2 >= 0 and j + k*f2 < 7 and \
               self.grid[self.ij_to_idx( i + k*f1, j + k*f2 )] == self.grid[self.ij_to_idx(i,j)]:
                tmp += 1
            else:
                break
        return tmp
    
    def is_win_fast(self, col):
        i, j = self.idx_to_ij(self.col_to_idx(col))
        # horizontal
        if(self.count_connected(0, 1, i, j) + self.count_connected(0, -1, i, j) + 1 >= 4):
            return True
        # vertical
        if(self.count_connected(1, 0, i, j) + self.count_connected(-1, 0, i, j) + 1 >= 4):
            return True
        # oblique 
        if(self.count_connected(1, 1, i, j) + self.count_connected(-1, -1, i, j) + 1 >= 4):
            return True
        if(self.count_connected(-1, 1, i, j) + self.count_connected(1, -1, i, j) + 1 >= 4):
            return True
        return False

    def count_connect3(self, f1, f2, i, j):
        tmp = 1
        for k in range(1, 3): # [1, 2]
            if self.grid[self.ij_to_idx( i + k*f1, j + k*f2 )] == self.grid[self.ij_to_idx(i,j)]:
                tmp += 1
            elif (self.grid[self.ij_to_idx( i + k*f1, j + k*f2 )] != self.grid[self.ij_to_idx(i,j)]) and (self.grid[self.ij_to_idx( i + k*f1, j + k*f2 )] != 0):
                tmp = 0
                break
        
        if tmp == 3:
            if self.grid[self.ij_to_idx(i,j)] == 1: # player 1
                return 1
            else:                                   # player 2
                return -1
        return 0
    
    def evaluate(self):
        tot_count = 0
        for i in range(6):
            for j in range(7):
                if self.grid[self.ij_to_idx(i,j)] != 0:
                    # horizontal
                    if j + 2 < 7:
                        tot_count += self.count_connect3(0, 1, i, j)
                    # vertical
                    if i + 2 < 6:
                        tot_count += self.count_connect3(1, 0, i, j)
                    # oblique up
                    if i + 2 < 6 and j + 2 < 7:
                        tot_count += self.count_connect3(1, 1, i, j)
                    # oblique down
                    if i - 2 >= 0 and j + 2 < 7:
                        tot_count += self.count_connect3(-1, 1, i, j)
        return tot_count
    
    def print(self):
        idx = 0
        char = ['-', 'X', 'O']
        for _ in range(6):
            for __ in range(7):
                print(char[self.grid[idx]], end=' ')
                idx += 1
            print()
        print("_____________")
        for i in range(7):
            print(i, end=" ")
        print("\n")

    def print_from_grid(self, grid):
        idx = 0
        char = ['-', 'X', 'O']
        for _ in range(6):
            for __ in range(7):
                print(char[grid[idx]], end=' ')
                idx += 1
            print()
        print("_____________")
        for i in range(7):
            print(i, end=" ")
        print("\n")


class Minimax_agent:
    def __init__(self, max_depths, default_depth):
        self.INF = 99999999
        self.WIN = 10000
        self.CONNECT3 = 50
        self.DISCOUNT = 1
        self.PLAYER1 = 1
        self.PLAYER2 = 2
        self.columns = [3, 4, 2, 5, 1, 6, 0]
        self.turn = 0
        self.max_depths = max_depths
        self.default_depth = default_depth 
        self.turn_times = []
        self.cache = States_cache()

    def print_turn_times(self):
        for i, time in enumerate(self.turn_times):
            print("%d. time: %.4f\t\t | depth: %d" % (i+1, time, self.get_max_depth(i+1)))

    def discount(self, val, depth):
        if val > 0:
            return -self.DISCOUNT * depth + val
        elif val < 0:
            return self.DISCOUNT * depth + val
        else:
            return val

    def get_max_depth(self, turn, print_val=False):
        max_depth = self.default_depth
        if turn in self.max_depths.keys():
            max_depth = self.max_depths[turn]
        if print_val:
            print("[with depth ", str(max_depth) + "]")
        return max_depth

    def find_trivial_move(self, state):
        # first check for win
        for col in self.columns:
            if not state.can_insert_coin(col):
                continue
            state.insert_coin(col, self.PLAYER1)
            if state.is_win_fast(col):
                state.remove_coin(col)
                return col
            state.remove_coin(col)

        # second check for not lose
        # check if the opponent would win in one move
        for col in self.columns:
            if not state.can_insert_coin(col):
                continue
            state.insert_coin(col, self.PLAYER2)
            if state.is_win_fast(col):
                state.remove_coin(col)
                return col
            state.remove_coin(col)

        # nothing found
        return -1

    def max_node(self, state, depth, alpha, beta):
        state_grid = tuple(state.grid)
        if self.cache.already_cached(state_grid):
            return self.cache.get_cached_value(state_grid)

        if(depth == self.max_depth):
            return self.discount(self.CONNECT3 * state.evaluate(), depth)

        max_val = -self.INF
        for col in self.columns:
            if not state.can_insert_coin(col):
                continue

            state.insert_coin(col, self.PLAYER1) # insert coin and go down with recursion
            move_val = 0
            # chek if this move lead to a win
            if state.is_win_fast(col):
                move_val = self.discount(self.WIN, depth + 1)
                #state.remove_coin(col)
                #return self.discount(self.WIN, depth + 1)
            else: # if not, we have to calculate everything normally
                move_val = self.min_node(state, depth + 1, alpha, beta)
            max_val = max(max_val, move_val)
            self.cache.cache_state(tuple(state.grid), move_val)

            state.remove_coin(col) # remove coin to get the previous state

            if max_val > beta:
                return max_val
            alpha = max(alpha, max_val)

        if(max_val == -self.INF): # the game already ended with a tie and no moves have been done here
            return 0
        return self.discount(max_val, depth)

    def min_node(self, state, depth, alpha, beta):
        state_grid = tuple(state.grid)
        if self.cache.already_cached(state_grid):
            return self.cache.get_cached_value(state_grid)

        if(depth == self.max_depth):
            return self.discount(self.CONNECT3 * state.evaluate(), depth)

        min_val = self.INF
        for col in self.columns:
            if not state.can_insert_coin(col):
                continue

            state.insert_coin(col, self.PLAYER2) # insert coin and go on with recursion   

            move_val = 0
            # chek if this move lead to a win
            if state.is_win_fast(col):
                move_val = self.discount(-self.WIN, depth)
                #state.remove_coin(col)
                #return self.discount(-self.WIN, depth)
            else: # if not, we have to calculate everything normally
                move_val = self.max_node(state, depth + 1, alpha, beta)
            min_val = min(move_val, min_val)
            self.cache.cache_state(tuple(state.grid), move_val)

            state.remove_coin(col) # remove coin to get the previous state

            if min_val < alpha:
                return min_val
            beta = min(beta, min_val)

        if(min_val == self.INF): # the game already ended with a tie and no moves have been done here
            return 0
        return self.discount(min_val, depth)

    def get_move(self, state):
        self.turn += 1

        # try trivial moves
        move = self.find_trivial_move(state)
        if move != -1:
            print("trivial move: ", move)
            return move

        # minimax
        print("getting values for moves: ", end="")

        # get max depth
        self.max_depth = self.get_max_depth(self.turn, True)

        start = time.time()

        max_val = -self.INF
        alpha = -self.INF
        best_moves = []
        for col in self.columns:
            if not state.can_insert_coin(col):
                continue
            tmp = 0
            state.insert_coin(col, self.PLAYER1)
            if state.is_win_fast(col):
                tmp = self.WIN
            else:
                tmp = self.min_node(state, 1, alpha, self.INF)
            state.remove_coin(col)
            
            print(tmp, end=" ", flush = True)
            
            if tmp > max_val:
                best_moves.clear()
                max_val = tmp
                alpha = tmp
                best_moves.append(col)
            elif tmp == max_val:
                best_moves.append(col)

        end = time.time()
        print("\nTime elapsed:", end - start)
        self.turn_times.append(end - start)

        # clear the cache
        self.cache.clear()

        return best_moves[randint(0, len(best_moves)-1)]

In [32]:
centers = []
corners = []

def detect_corners():
    # call the code on the FPGA
    # no parameters has to be written
    # read the four values as a tuple corners = (r1, c1, r2, c2)
    
    # MOCK
    #corners = (23, 147, 489, 955)
    #f = hdmi_in.readframe() 
    #hdmi_out.writeframe(f)
    #return corners
    # END MOCK
    
    start = time.time()
    
    in_frame = hdmi_in.readframe()
    out_frame = hdmi_out.newframe()
    corner_detection_reference_axilite.write(0x10, in_frame.physical_address)  # in_data
    corner_detection_reference_axilite.write(0x18, out_frame.physical_address) # out_data
    corner_detection_reference_axilite.write(0x20, 1280) # w
    corner_detection_reference_axilite.write(0x28, 720)  # h
    corner_detection_reference_axilite.write(0x30, 0)    # mode
    corner_detection_reference_axilite.write(0x78, 90*150)  # black threshold
    
    corner_detection_reference_axilite.write(0x00, 0x01) # start
    while (corner_detection_reference_axilite.read(0) & 0x4) == 0:
        #print(corner_detection_reference_axilite.read(0))
        pass
    hdmi_out.writeframe(out_frame)
    
    end = time.time()
    
    corners = (corner_detection_reference_axilite.read(0x40), \
               corner_detection_reference_axilite.read(0x50), \
               corner_detection_reference_axilite.read(0x60), \
               corner_detection_reference_axilite.read(0x70))
    
    print("Time elapsed for a frame: ", end - start)
    return corners

def setup_corners():
    global corners
    while True:
        corners = detect_corners()
        print(corners)
        # display square
        f = hdmi_in.readframe()
        for i in range(2):
            r = corners[2*i]
            c = corners[2*i + 1]
            print("corner:",r,c)
            for j in range(10):
                for k in range(10):
                    (red,green,blue,a) = f[r+j-5, c+k-5]
                    red = 0
                    blue = 0
                    green = 255
                    f[r+j-5, c+k-5] = (red,green,blue,a)
        hdmi_out.writeframe(f)

        ready = str(input("Are the corners ok? Y/N"))
        if ready in ["y", "Y"]:
            return corners
        else:
            print("Detecting again")
            
def calculate_centers():
    global corners
    r1, c1, r2, c2 = corners
    centers = []
    delta_y = (r2 - r1) / 12
    delta_x = (c2 - c1) / 14
    for i in range(6):
        for j in range(7):
            centers.append(2 * delta_x * j + delta_x + c1)
            centers.append(2 * delta_y * i + delta_y + r1)
    
def get_grid_state():
    global corners
    # call the code on the FPGA
    # pass the centers coordinates
    # recieve a list of colors as colors = [...]
    # 0 -> blank
    # 1 -> red  (computer)
    # 2 -> green (human)
    # 3 -> error (not used)
    
    # MOCK
    #colors = [0,0,0,0,2,0,0, \
    #          0,1,0,0,2,0,0, \
    #          1,2,0,0,1,0,0, \
    #          1,2,1,0,1,0,0, \
    #          2,2,1,2,2,0,0, \
    #          1,1,2,1,1,1,0]
    #f = hdmi_in.readframe() 
    #hdmi_out.writeframe(f)
    # END MOCK
    start = time.time()
    
    in_frame = hdmi_in.readframe()
    out_frame = hdmi_out.newframe()
    corner_detection_reference_axilite.write(0x10, in_frame.physical_address)  # in_data
    corner_detection_reference_axilite.write(0x18, out_frame.physical_address) # out_data
    corner_detection_reference_axilite.write(0x20, 1280) # w
    corner_detection_reference_axilite.write(0x28, 720)  # h
    corner_detection_reference_axilite.write(0x30, 1)    # mode
    corner_detection_reference_axilite.write(0x1d0, 110)  # colour threshold
    
    corner_detection_reference_axilite.write(0x38, corners[0])
    corner_detection_reference_axilite.write(0x48, corners[1])
    corner_detection_reference_axilite.write(0x58, corners[2])
    corner_detection_reference_axilite.write(0x68, corners[3])
    
    corner_detection_reference_axilite.write(0x00, 0x01) # start
    while (corner_detection_reference_axilite.read(0) & 0x4) == 0:
        #print(corner_detection_reference_axilite.read(0))
        pass
    hdmi_out.writeframe(out_frame)
    
    end = time.time()
    
    colours = []
    word_address = 0x80
    for _ in range(42):
        word = corner_detection_reference_axilite.read(word_address) 
        colours.append(word)
        word_address += 0x8
    
    #print("Time elapsed for a frame: ", end - start)
    
    return colours

def get_stable_state():
    min_same_states = 40
    count = 0
    stable_state = get_grid_state()
    while count < min_same_states:
        new_state = get_grid_state()
        if new_state != stable_state:
            count = 0
            stable_state = new_state
        else:
            count += 1
    return stable_state

def show_move_on_screen(col, curr_state, centers):
    row = 6 - curr_state.heights[col]
    center_col = centers[2 * (row * 6 + col)]
    center_row = centers[2 * (row * 6 + col) + 1]
    # display square
    start = time.time()
    while time.time() - start < 2:
        f = hdmi_in.readframe()
        print("center:",center_row,center_col)
        for j in range(16):
            for k in range(16):
                f[center_row+j-8, center_col+k-8] = (0,255,0,1)
        hdmi_out.writeframe(f)
    

# main functions

def setup():
    global centers
    global corners
    corners = setup_corners()
    centers = calculate_centers()
    #return corners
    #return centers
    
def play():
    # setup player
    max_depths = {1: 1, 2: 4, 3: 5, 4: 5, 5: 5, 6: 5, 7: 5, 8: 6, 9: 6, 10: 6, 11: 6, 12: 7, 13: 7, 14: 8, 15: 9}
    default_depth = 12
    agent = Minimax_agent(max_depths, default_depth)
    previous_stable_state = [0 for _ in range(42)] 
    
    # remove this when the game starts from scratch
    #agent.turn = 10

    while True:
        # waiting for human move
        print("Waiting for a stable state (HUMAN move)")
        detected_state = get_stable_state()
        print("Detected a stable state")
        curr_state = Game_state(detected_state)
        
        if(detected_state != previous_stable_state):
            curr_state.print()
            previous_stable_state = detected_state
            # human did the move
            if curr_state.is_win():
                input("Human won")
                continue
            
            agent.turn += 1
            # calculating the move for the agent
            pc_move = agent.get_move(curr_state)
            
            print("\n===========================\nInsert coin at column: " \
                  + str(pc_move) \
                  + "\n===========================\n")
            
            #show_move_on_screen(pc_move, curr_state)
            
            # wait until there is a new stable state. Human has to insert the coin for the computer
            while(detected_state == previous_stable_state):
                print("Waiting for the COMPUTER coin to be inserted")
                detected_state = get_stable_state()
                print("Detected a stable state")
                curr_state = Game_state(detected_state)
                #curr_state.print()
             
            print("Accepted as new state, coin inserted")
            previous_stable_state = detected_state
            curr_state = Game_state(detected_state)
            curr_state.print()
            if curr_state.is_win():
                input("Computer won")
                continue
            
        else:
            print("Same state as before, ignored")

In [33]:
setup()
play()

Time elapsed for a frame:  0.5904459953308105
(48, 342, 626, 964)
corner: 48 342
corner: 626 964
Are the corners ok? Y/Ny
Waiting for a stable state (HUMAN move)
Detected a stable state
- - - - - - - 
- - - - - - - 
- - - - - - - 
- - - - - - - 
- - - - - - - 
- - - - - - - 
_____________
0 1 2 3 4 5 6 

Same state as before, ignored
Waiting for a stable state (HUMAN move)
Detected a stable state
- - - - - - - 
- - - - - - - 
- - - - - - - 
- - - - - - - 
- - - - - - - 
- - - - - - - 
_____________
0 1 2 3 4 5 6 

Same state as before, ignored
Waiting for a stable state (HUMAN move)
Detected a stable state
- - - - - - - 
- - - - - - - 
- - - - - - - 
- - - - - - - 
- - - - - - - 
- - - - - - - 
_____________
0 1 2 3 4 5 6 

Same state as before, ignored
Waiting for a stable state (HUMAN move)
Detected a stable state
- - - - - - - 
- - - - - - - 
- - - - - - - 
- - - - - - - 
- - - - - - - 
- - - X - - - 
_____________
0 1 2 3 4 5 6 

getting values for moves: [with depth  4]
0 40 40 0 

KeyboardInterrupt: 

In [None]:
# close connections
hdmi_in.close()
hdmi_out.close()

In [None]:
"""
SETUP
import overlay and do the setup stuffs
start HDMI IO

call the corner detection function -> 2 corners
highlight corners and ask if it is ok, if not, call it again

we have the 2 corners
calculate all the centers -> 42 * 2 values


GAME
call the color detection function -> 42 colors array
while we don't have the same state 500 times
elaborate the state and wait until there is 1 extra opponet's coin

check opposite won

pass the state to minimax and get the move

display the column where we want to put the coin
detect the colors -> 42 colors
"""