In [None]:
import math
import random
from graphviz import Digraph

class Connect4:
    ROWS = 6
    COLS = 7

    def __init__(self):
        self.board = [[" " for _ in range(self.COLS)] for _ in range(self.ROWS)]
        self.human = "🔴"
        self.ai = "🟡"
        self.tree_counter = 0

    # --- Display Board ---
    def print_board(self):
        print("\n")
        for row in self.board:
            print("| " + " | ".join(row) + " |")
        print("-" * (self.COLS * 4 + 1))
        print("  " + "   ".join(str(i) for i in range(self.COLS)))

    # --- Drop Piece ---
    def make_move(self, col, player):
        for row in reversed(self.board):
            if row[col] == " ":
                row[col] = player
                return True
        return False

    def undo_move(self, col):
        for r in range(self.ROWS):
            if self.board[r][col] != " ":
                self.board[r][col] = " "
                break

    def available_moves(self):
        return [c for c in range(self.COLS) if self.board[0][c] == " "]

    def check_winner(self, player):
        # Horizontal
        for r in range(self.ROWS):
            for c in range(self.COLS - 3):
                if all(self.board[r][c+i] == player for i in range(4)):
                    return True
        # Vertical
        for r in range(self.ROWS - 3):
            for c in range(self.COLS):
                if all(self.board[r+i][c] == player for i in range(4)):
                    return True
        # Diagonal /
        for r in range(self.ROWS - 3):
            for c in range(self.COLS - 3):
                if all(self.board[r+i][c+i] == player for i in range(4)):
                    return True
        # Diagonal \
        for r in range(3, self.ROWS):
            for c in range(self.COLS - 3):
                if all(self.board[r-i][c+i] == player for i in range(4)):
                    return True
        return False

    def is_full(self):
        return all(self.board[0][c] != " " for c in range(self.COLS))

    def evaluate_window(self, window, player):
        opponent = self.human if player == self.ai else self.ai
        score = 0
        if window.count(player) == 4:
            score += 100
        elif window.count(player) == 3 and window.count(" ") == 1:
            score += 5
        elif window.count(player) == 2 and window.count(" ") == 2:
            score += 2
        if window.count(opponent) == 3 and window.count(" ") == 1:
            score -= 4
        return score

    def score_position(self, player):
        score = 0
        center_array = [self.board[r][self.COLS//2] for r in range(self.ROWS)]
        score += center_array.count(player) * 3

        # Horizontal
        for r in range(self.ROWS):
            row_array = self.board[r]
            for c in range(self.COLS - 3):
                window = row_array[c:c+4]
                score += self.evaluate_window(window, player)

        # Vertical
        for c in range(self.COLS):
            col_array = [self.board[r][c] for r in range(self.ROWS)]
            for r in range(self.ROWS - 3):
                window = col_array[r:r+4]
                score += self.evaluate_window(window, player)

        # Diagonal /
        for r in range(self.ROWS - 3):
            for c in range(self.COLS - 3):
                window = [self.board[r+i][c+i] for i in range(4)]
                score += self.evaluate_window(window, player)

        # Diagonal \
        for r in range(self.ROWS - 3):
            for c in range(self.COLS - 3):
                window = [self.board[r+3-i][c+i] for i in range(4)]
                score += self.evaluate_window(window, player)

        return score

    # --- Minimax with Alpha-Beta Pruning and Tree Tracking ---
    def minimax(self, depth, alpha, beta, maximizing_player):
        valid_moves = self.available_moves()
        node = {"depth": depth, "maximizing": maximizing_player, "alpha": alpha, "beta": beta, "children": []}

        if self.check_winner(self.ai):
            return (None, 1000000, node)
        elif self.check_winner(self.human):
            return (None, -1000000, node)
        elif self.is_full() or depth == 0:
            return (None, self.score_position(self.ai), node)

        if maximizing_player:
            value = -math.inf
            best_col = random.choice(valid_moves)
            for col in valid_moves:
                self.make_move(col, self.ai)
                _, new_score, child = self.minimax(depth - 1, alpha, beta, False)
                self.undo_move(col)
                node["children"].append({"move": col, "score": new_score, "tree": child})
                if new_score > value:
                    value = new_score
                    best_col = col
                alpha = max(alpha, value)
                if alpha >= beta:
                    node["pruned"] = True
                    break
            return best_col, value, node
        else:
            value = math.inf
            best_col = random.choice(valid_moves)
            for col in valid_moves:
                self.make_move(col, self.human)
                _, new_score, child = self.minimax(depth - 1, alpha, beta, True)
                self.undo_move(col)
                node["children"].append({"move": col, "score": new_score, "tree": child})
                if new_score < value:
                    value = new_score
                    best_col = col
                beta = min(beta, value)
                if alpha >= beta:
                    node["pruned"] = True
                    break
            return best_col, value, node

    # --- Graphviz Tree Visualization ---
    def visualize_tree(self, node, graph=None, parent=None, node_id="0"):
        if graph is None:
            graph = Digraph(format="png")
            graph.attr(rankdir="LR", bgcolor="white")
            graph.attr("node", shape="box", style="filled,rounded", fontname="Helvetica", fontsize="10")

        # Color nodes based on type
        fillcolor = "#79b8ff" if node["maximizing"] else "#ff9999"
        label = f"{'MAX' if node['maximizing'] else 'MIN'}\\nα={node['alpha']}\\nβ={node['beta']}"
        if node.get("pruned"):
            label += "\\n✂️ PRUNED"
            fillcolor = "#d0d0d0"
        graph.node(node_id, label, fillcolor=fillcolor)

        # Add child edges
        for i, child in enumerate(node.get("children", [])):
            child_id = f"{node_id}.{i}"
            edge_label = f"Move {child['move']}\\nScore {child['score']}"
            style = "dashed" if child["tree"].get("pruned") else "solid"
            graph.edge(node_id, child_id, label=edge_label, fontsize="8", fontcolor="#555555", style=style)
            self.visualize_tree(child["tree"], graph, node_id, child_id)

        # Render with unique filename for each step
        if parent is None:
            filename = f"minimax_tree_step_{self.tree_counter}"
            graph.render(filename, cleanup=True)
            print(f"\n🌳 Minimax tree saved as '{filename}.png'")
            self.tree_counter += 1  # Increment counter for next tree

    # --- Main Game Loop ---
    def play_game(self):
        print("Welcome to Connect 4!")
        print("You are '🔴' and AI is '🟡'")
        ai_turn = random.choice([True, False])

        while True:
            self.print_board()

            if self.check_winner(self.human):
                print("\n🎉 You win!")
                break
            elif self.check_winner(self.ai):
                print("\n💻 AI wins! Better luck next time.")
                break
            elif self.is_full():
                print("\nIt's a tie!")
                break

            if ai_turn:
                print("\n🤖 AI is thinking...")
                col, score, tree = self.minimax(3, -math.inf, math.inf, True)
                self.make_move(col, self.ai)
                print(f"AI chooses column {col} with score {score}")
                self.visualize_tree(tree)
            else:
                valid = False
                while not valid:
                    try:
                        col = int(input("\nYour move (0–6): "))
                        if col in self.available_moves():
                            self.make_move(col, self.human)
                            valid = True
                        else:
                            print("❌ Column full or invalid. Try again.")
                    except ValueError:
                        print("⚠️ Enter a valid number between 0–6.")
            ai_turn = not ai_turn


# --- Run Game ---
if __name__ == "__main__":
    game = Connect4()
    game.play_game()

Welcome to Connect 4!
You are '🔴' and AI is '🟡'


|   |   |   |   |   |   |   |
|   |   |   |   |   |   |   |
|   |   |   |   |   |   |   |
|   |   |   |   |   |   |   |
|   |   |   |   |   |   |   |
|   |   |   |   |   |   |   |
-----------------------------
  0   1   2   3   4   5   6

Your move (0–6): 3


|   |   |   |   |   |   |   |
|   |   |   |   |   |   |   |
|   |   |   |   |   |   |   |
|   |   |   |   |   |   |   |
|   |   |   |   |   |   |   |
|   |   |   | 🔴 |   |   |   |
-----------------------------
  0   1   2   3   4   5   6

🤖 AI is thinking...
AI chooses column 3 with score 6

🌳 Minimax tree saved as 'minimax_tree_step_0.png'


|   |   |   |   |   |   |   |
|   |   |   |   |   |   |   |
|   |   |   |   |   |   |   |
|   |   |   |   |   |   |   |
|   |   |   | 🟡 |   |   |   |
|   |   |   | 🔴 |   |   |   |
-----------------------------
  0   1   2   3   4   5   6

Your move (0–6): 4


|   |   |   |   |   |   |   |
|   |   |   |   |   |   |   |
|   |   |   |   |   |   |