# Connect Four AI 
With Min-Max Trees

In [None]:
Here 

In [1]:
import numpy as np
import tabulate as tb
from random import randrange
from IPython.display import clear_output as clear
from IPython.display import display as display
from time import sleep
import session_info
session_info.show()

ModuleNotFoundError: No module named 'session_info'

In [3]:
init_state = {"board" : np.full(fill_value="_", shape=(8,8)), "turn" : True, "winner" : None}

In [4]:
def copy_state(state):
    return {"board" : state["board"].copy(), "turn" : state["turn"], "winner" : state["winner"]}

In [5]:
def print_state(state) :
    if state["winner"] is not None:
        statement_comp = "won" if state["winner"] else "lost"
        print(f"Mundobot {statement_comp}...")
        statement_player = "lost" if state["winner"] else "won"
        print(f"You {statement_player}!")
    table = tb.tabulate(state["board"], tablefmt="fancy_grid")
    print(table)

In [6]:
def turns_till_access(state, i, j):
    cd = 0
    while i+cd < 8:
        if state["board"][i+cd][j] != "_":
            return cd
        cd += 1
    return cd

In [7]:
def vertical_line(state, char, i, j):
    
    if char != state["board"][i][j]:
        return None
    
    length = 1
    accessible = False
    turn_time = float("inf")
    
    cd = 1
    while i+cd < 8:
        if char == state["board"][i+cd][j]:
            length += 1
            cd += 1
        else:
            if "_" == state["board"][i+cd][j]:
                accessible = True
                turn_time = min(turn_time, turns_till_access(state, i+cd, j))
            break
    
    cd = -1
    while i+cd >= 0:
        if char == state["board"][i+cd][j]:
            length += 1
            cd -= 1
        else:
            if "_" == state["board"][i+cd][j]:
                accessible = True
                turn_time = min(turn_time, turns_till_access(state, i+cd, j))
            break
            
    if accessible or length >= 4:
        return (length, turn_time)
    else:
        return None

In [8]:
def horizontal_line(state, char, i, j):
    
    if char != state["board"][i][j]:
        return None
    
    length = 1
    accessible = False
    turn_time = float("inf")
    
    cd = 1
    while j+cd < 8:
        if char == state["board"][i][j+cd]:
            length += 1
            cd += 1
        else:
            if "_" == state["board"][i][j+cd]:
                accessible = True
                turn_time = min(turn_time, turns_till_access(state, i, j+cd))
            break
    
    cd = -1
    while j+cd >= 0:
        if char == state["board"][i][j+cd]:
            length += 1
            cd -= 1
        else:
            if "_" == state["board"][i][j+cd]:
                accessible = True
                turn_time = min(turn_time, turns_till_access(state, i, j+cd))
            break
            
    if accessible or length >= 4:
        return (length, turn_time)
    else:
        return None

In [9]:
def up_slant_line(state, char, i, j):
     
    if char != state["board"][i][j]:
        return None
    
    length = 1
    accessible = False
    turn_time = float("inf")
    
    cd = 1
    while i+cd < 8 and j+cd < 8:
        if char == state["board"][i+cd][j+cd]:
            length += 1
            cd += 1
        else:
            if "_" == state["board"][i+cd][j+cd]:
                accessible = True
                turn_time = min(turn_time, turns_till_access(state, i+cd, j+cd))
            break
    
    cd = -1
    while i+cd >= 0 and j+cd >= 0:
        if char == state["board"][i+cd][j+cd]:
            length += 1
            cd -= 1
        else:
            if "_" == state["board"][i+cd][j+cd]:
                accessible = True
                turn_time = min(turn_time, turns_till_access(state, i+cd, j+cd))
            break
    
    if accessible or length >= 4:
        return (length, turn_time)
    else:
        return 0
    


In [10]:
def down_slant_line(state, char, i, j):
     
    if char != state["board"][i][j]:
        return None
    
    length = 1
    accessible = False
    turn_time = float("inf")
    
    cd = 1
    while i+cd < 8 and j-cd >= 0:
        if char == state["board"][i+cd][j-cd]:
            length += 1
            cd += 1
        else:
            if "_" == state["board"][i+cd][j-cd]:
                accessible = True
                turn_time = min(turn_time, turns_till_access(state, i+cd, j-cd))
            break
    
    cd = -1
    while i+cd >= 0 and j-cd < 8:
        if char == state["board"][i+cd][j-cd]:
            length += 1
            cd -= 1
        else:
            if "_" == state["board"][i+cd][j-cd]:
                accessible = True
                turn_time = min(turn_time, turns_till_access(state, i+cd, j-cd))
            break
            
    if accessible or length >= 4:
        return (length, turn_time)
    else:
        return None
    


In [11]:
LINE_FUNCS = [vertical_line, horizontal_line, up_slant_line, down_slant_line]

In [12]:
def play_state(state, col):
    if state["board"][0][col] != "_" or state["winner"] is not None:
        return False
    else:
        for i in range(7, -1, -1):
            if state["board"][i][col] == "_":
                char = "X" if state["turn"] else "O"
                state["board"][i][col] = char
                state["winner"] = None
                for func in LINE_FUNCS:
                    line_info = func(state, char, i, col)
                    if line_info:
                        if line_info[0] >= 4:
                            state["winner"] = state["turn"]
                            break               
                state["turn"] = not state["turn"]
                
                return True
                

In [13]:
# test = copy_state(init_state)
# for i in range(0, 30):
#     play_state(test, randrange(0,8))

# print_state(test)

def score_player(state, char): 
    
    line_dict = {"1":[], "2":[], "3":[], "+":[]}

    for i in range(0, 8):
        for j in range(0, 8):
            for func in LINE_FUNCS:
                line_info = func(state, char, i, j)
                if line_info:
                    length, turn_time = line_info
                    if length >= 4:
                        line_dict["+"] += [turn_time]
                    else:
                        line_dict[str(length)] += [turn_time]
    if line_dict["+"]:
        return float("inf")
    else:
        value = lambda i : (-1/8) * (i-1) + 1
        sum1 = sum([value(i) for i in line_dict["1"]])
        sum2 = sum([value(i) for i in line_dict["2"]])/2
        sum3 = sum([value(i) for i in line_dict["3"]])/3
        return 0.5*sum1 + 1*sum2 + 1.5*sum3

In [14]:
def minimax(state, depth, alpha, beta):
    
    if depth == 0 or state["winner"] is not None:
        return (score_player(state, "X") - score_player(state, "O"), -1)
    
    best_turn = 0
    if state["turn"]:
        curr_max = float("-inf")
        for col in range(0, 8):
            next_state = copy_state(state)
            is_playable = play_state(next_state, col)
            if is_playable:
                poss_max, next_best_turn = minimax(next_state, depth-1, alpha, beta)
                best_turn = best_turn if curr_max >= poss_max else col
                curr_max = max(curr_max, poss_max)
                alpha = max(alpha, poss_max)
                if beta <= alpha:
                    break
        return (curr_max, best_turn)
    else:
        curr_min = float("inf")
        for col in range(0, 8):
            next_state = copy_state(state)
            is_playable = play_state(next_state, col)
            if is_playable:
                poss_min, next_best_turn = minimax(next_state, depth-1, alpha, beta)
                best_turn = best_turn if curr_min <= poss_min else col
                curr_min = min(curr_min, poss_min)
                beta = min(beta, poss_min)
                if beta <= alpha:
                    break
        return (curr_min, best_turn)
        
                

In [15]:
def depth_prompt():
    try:
        res = input("What level would you like to challenge: ")
        if int(res) in range(1, 7):
            return int(res)
        else:
            raise Exception
    except Exception:
        print("Please input a number 1-6...")
        return depth_prompt()

def starting_prompt():
    res = input("Would you like to go first: ")
    if res == "yes":
        return True
    if res == "no":
        return False
    print("Please answer either 'yes' or 'no'...")
    return starting_prompt()
    
def turn_prompt():
    try:
        res = input("Your turn: ")
        if res == "stop":
            return -1
        if res == "restart":
            return -2
        if int(res) in range(1, 9):
            return int(res)-1
    except Exception:
        print("Please type in valid input — a number between 1 and 8, 'stop', or 'restart'.")
        turn_prompt()
        
def restart_prompt():
    res = input("Would you like to restart: ")
    if res == "yes":
        return True
    if res == "no":
        return False
    print("Please answer either 'yes' or 'no'...")
    return restart_prompt()
        

In [16]:
def main():
    
    curr_state = copy_state(init_state)
    
    print("Welcome to Mundobot's connect four!")
    print("You can play levels 1-6.")
    print("If you beat level 6 you will get Mundobot's hard earned respect. :/ ")
    
    print("")
    depth = depth_prompt()
    
    print("")
    if starting_prompt():
        curr_state["turn"] = False
    print("")
    
    print("Lets begin! :>")
    sleep(1)
    
    clear()
    
    while True:
        
        if curr_state["winner"] is not None:
            clear()
            print_state(curr_state)
            break
        
        elif curr_state["turn"]:
            print("Waiting for computer... *beep boop baap*")
            best_score, next_turn = minimax(curr_state, depth, float("-inf"), float("inf"))
            play_state(curr_state, next_turn)
        
        else:
            clear()
            print_state(curr_state)
            player_input = turn_prompt()
            if player_input == -1:
                print("Game ended.")
                break
            if player_input == -2:
                clear()
                main()
                break
            next_turn = player_input
            is_playable = play_state(curr_state, next_turn)
            if not is_playable:
                print("That column is full...")
                sleep(1)
                continue
            clear()
            print_state(curr_state)
    
    if (restart_prompt()):
        clear()
        main()
    else:
        print("GG! :D")
    

In [None]:
main()

╒═══╤═══╤═══╤═══╤═══╤═══╤═══╤═══╕
│ _ │ _ │ _ │ _ │ _ │ _ │ _ │ _ │
├───┼───┼───┼───┼───┼───┼───┼───┤
│ _ │ _ │ _ │ _ │ _ │ _ │ _ │ _ │
├───┼───┼───┼───┼───┼───┼───┼───┤
│ _ │ _ │ _ │ _ │ _ │ _ │ _ │ _ │
├───┼───┼───┼───┼───┼───┼───┼───┤
│ _ │ _ │ _ │ _ │ _ │ _ │ _ │ _ │
├───┼───┼───┼───┼───┼───┼───┼───┤
│ _ │ _ │ _ │ _ │ _ │ _ │ _ │ _ │
├───┼───┼───┼───┼───┼───┼───┼───┤
│ _ │ _ │ _ │ _ │ _ │ _ │ _ │ _ │
├───┼───┼───┼───┼───┼───┼───┼───┤
│ _ │ _ │ X │ _ │ _ │ _ │ _ │ _ │
├───┼───┼───┼───┼───┼───┼───┼───┤
│ _ │ _ │ O │ _ │ _ │ _ │ _ │ _ │
╘═══╧═══╧═══╧═══╧═══╧═══╧═══╧═══╛
Your turn: 
Please type in valid input — a number between 1 and 8, 'stop', or 'restart'.
