[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/real-itu/modern-ai-course/blob/master/lecture-02/lab.ipynb)


# Lab 2 - Adversarial Search

[Connect 4](https://en.wikipedia.org/wiki/Connect_Four) is a classic board game in which 2 players alternate placing markers in columns, and the goal is to get 4 in a row, either horizontally, vertically or diagonally. See the short video below

In [None]:
from IPython.display import YouTubeVideo
YouTubeVideo("ylZBRUJi3UQ")

The game is implemented below. It will play a game where both players take random (legal) actions. The MAX player is represented with a X and the MIN player with an O. The MAX player starts. Execute the code.

In [None]:
import random
from copy import deepcopy
from typing import Sequence
from datetime import datetime

NONE = '.'
MAX = 'X'
MIN = 'O'
COLS = 7
ROWS = 6
N_WIN = 4


class ArrayState:
    def __init__(self, board, heights, n_moves):
        self.board = board
        self.heights = heights
        self.n_moves = n_moves

    @staticmethod
    def init():
        board = [[NONE] * ROWS for _ in range(COLS)]
        return ArrayState(board, [0] * COLS, 0)


def result(state: ArrayState, action: int) -> ArrayState:
    """Insert in the given column."""
    assert 0 <= action < COLS, "action must be a column number"

    if state.heights[action] >= ROWS:
        raise Exception('Column is full')

    player = MAX if state.n_moves % 2 == 0 else MIN

    board = deepcopy(state.board)
    board[action][ROWS - state.heights[action] - 1] = player

    heights = deepcopy(state.heights)
    heights[action] += 1

    return ArrayState(board, heights, state.n_moves + 1)


def actions(state: ArrayState) -> Sequence[int]:
    order = [3,2,4,1,5,0,6]
    return [i for i in order if state.heights[i] < ROWS]

def flipLine(line) -> str:
    for s in line:
        if s == MIN:
            yield MAX
        elif s == MAX:
            yield MIN
        else:
            yield s

def flipLines(lines):
    for line, val in lines:
        yield flipLine(line), val

def heuristic(lines: list[list[str]]) -> float:
    maxScore = 0
    minScore = 0
    highestVal = 0
    for line in lines:
        i = 0
        while i + 4 <= len(line):
            highestVal += 4
            section = line[i:i+4]
            maxCount = section.count(MAX)
            minCount = section.count(MIN)
            if minCount == 0 and maxCount > 1:
                maxScore += maxCount
            if maxCount == 0 and minCount > 1:
                minScore += minCount
            i += 1
            
        
    
    return (maxScore / highestVal) - (minScore / highestVal)





def utility(state: ArrayState) -> float:
    """Get the winner on the current board."""

    board = state.board

    def diagonalsPos():
        """Get positive diagonals, going from bottom-left to top-right."""
        for di in ([(j, i - j) for j in range(COLS)] for i in range(COLS + ROWS - 1)):
            yield [board[i][j] for i, j in di if i >= 0 and j >= 0 and i < COLS and j < ROWS]

    def diagonalsNeg():
        """Get negative diagonals, going from top-left to bottom-right."""
        for di in ([(j, i - COLS + j + 1) for j in range(COLS)] for i in range(COLS + ROWS - 1)):
            yield [board[i][j] for i, j in di if i >= 0 and j >= 0 and i < COLS and j < ROWS]

    lines = board + \
            list(zip(*board)) + \
            list(diagonalsNeg()) + \
            list(diagonalsPos())

    max_win = MAX * N_WIN
    min_win = MIN * N_WIN
    for line in lines:
        str_line = "".join(line)
        if max_win in str_line:
            return 1
        elif min_win in str_line:
            return -1
    lars = heuristic(lines)
    return lars


def terminal_test(state: ArrayState) -> list:
    return state.n_moves >= COLS * ROWS or utility(state) == -1 or utility(state) == 1


def printBoard(state: ArrayState):
    board = state.board
    """Print the board."""
    print('  '.join(map(str, range(COLS))))
    for y in range(ROWS):
        print('  '.join(str(board[x][y]) for x in range(COLS)))
    print()

def playerTurn(state: ArrayState):
    return MAX if state.n_moves % 2 == 0 else MIN


if __name__ == '__main__':
    s = ArrayState.init()
    playRandom = True
    while not terminal_test(s):
        if playerTurn(s) == MIN:
            a = -1
            if playRandom:
                #a = random.choice(actions(s))
                a = minimaxDecisionMin(s)
            else:
                validInput = False
                while(not validInput):
                    a = int(input("Please Input a column!"))
                    validInput = a in actions(s)
                    if(not validInput):
                        print("Not a valid column. Try again!")
        else:
            #a = random.choice(actions(s))
            start = datetime.now()
            a = minimaxDecisionMax(s)
            print("time" + str(datetime.now() - start))
        s = result(s, a)
        printBoard(s)
    print(utility(s))


The last number 0, -1 or 1 is the utility or score of the game. 0 means it was a draw, 1 means MAX player won and -1 means MIN player won.

### Exercise 1

Modify the code so that you can play manually as the MIN player against the random AI.

### Exercise 2

Implement standard minimax with a fixed depth search. Modify the utility function to handle non-terminal positions using heuristics. Find a value for the depth such that moves doesn't take longer than approx. 1s to evaluate. See if you can beat your connect4 AI.

In [None]:
maxDepth = 9
table1 = TranspositionTable()
table2 = TranspositionTable()

def minimaxDecisionMax(state: ArrayState) -> int:
    v = -float("inf")
    bestA = None
    utilities = []
    for a in actions(state):
        newV = minValue(result(state, a), 0, -float("inf"), float("inf"), table1)
        utilities.append(newV)
        if newV > v:
            v = newV
            bestA = a
    print(utilities)
    return bestA

def minimaxDecisionMin(state: ArrayState) -> int:
    v = float("inf")
    bestA = None
    utilities = []
    for a in actions(state):
        newV = maxValue(result(state, a), 0, -float("inf"), float("inf"), table2)
        utilities.append(newV)
        if newV < v:
            v = newV
            bestA = a
    print(utilities)
    return bestA
            
def maxValue(state: ArrayState, depth: int, alpha, beta, table) -> str:
    if terminal_test(state) or depth >= maxDepth:
        return utility(state)
    v = -float("inf")
    for a in actions(state):
        newState = result(state, a)
        lookup = table.get(newState)
        newV = -float("inf")
        if lookup is None:
            newV = minValue(newState, depth + 1, alpha, beta, table)            
            table.put(newState, newV)
        else:
            newV = lookup
        v = max(v, newV)
        if v >= beta: 
            return v
        alpha = max(alpha, v)
    return v
    
def minValue(state: ArrayState, depth: int, alpha, beta, table) -> float:
    if terminal_test(state) or depth >= maxDepth:
        return utility(state)
    v = float("inf")
    for a in actions(state):
        newState = result(state, a)
        lookup = table.get(newState)
        newV = -float("inf")
        if lookup is None:
            newV = maxValue(newState, depth + 1, alpha, beta, table)            
            table.put(newState, newV)
        else:
            newV = lookup
        v = min(v, newV)
        if v <= alpha:
            return v
        beta = min(beta, v)
    return v

### Exercise 3

Add alpha/beta pruning to your minimax. Change your depth so that moves still takes approx 1 second to evaluate. How much deeper can you search? See if you can beat your connect4 AI.

### Exercise 4

Add move ordering. The middle columns are often "better" since there's more winning positions that contain them. Evaluate the moves in this order: [3,2,4,1,5,0,6]. How much deeper can you search now? See if you can beat your connect4 AI

### Exercise 5 - Optional

Improve your AI somehow. Consider 


* Better heuristics
* Faster board representations (look up bitboards)
* Adding a transposition table (see class below)
* Better move ordering

In [None]:
class TranspositionTable:

    def __init__(self, size=1_000_000):
        self.size = size
        self.vals = [None] * size

    def board_str(self, state: ArrayState):
        return ''.join([''.join(c) for c in state.board])

    def put(self, state: ArrayState, utility: float):
        bstr = self.board_str(state)
        idx = hash(bstr) % self.size
        self.vals[idx] = (bstr, utility)

    def get(self, state: ArrayState):
        bstr = self.board_str(state)
        idx = hash(bstr) % self.size
        stored = self.vals[idx]
        if stored is None:
            return None
        if stored[0] == bstr:
            return stored[1]
        else:
            return None