# Connect 4 vs Q-Network (PyCharm Version)
No input() calls - designed for PyCharm notebooks

In [1]:
import torch
import torch.nn as nn
import numpy as np
from board_processor import BoardProcessor
from feature_generator import FeatureGenerator
from typing import Optional, List

In [2]:
class QNetwork(nn.Module):
    def __init__(self, input_dim, hidden_sizes=(256, 128, 64, 32, 16, 8)):
        super().__init__()
        layers = []
        last_dim = input_dim
        for h in hidden_sizes:
            layers.append(nn.Linear(last_dim, h))
            layers.append(nn.Tanh())
            last_dim = h
        layers.append(nn.Linear(last_dim, 1))
        layers.append(nn.Tanh())
        self.net = nn.Sequential(*layers)
    
    def forward(self, x):
        return self.net(x).squeeze(-1)

In [3]:
# Load model
MODEL_PATH = "qnet_mc_pretrained.pth"
feature_gen = FeatureGenerator()

# Calculate dimensions
_, dummy_features = feature_gen.convolution_feature_gen([[] for _ in range(7)])
FEATURE_DIM = len(dummy_features) * 2  # State-action pairs

print(f"Single state: {len(dummy_features)} features")
print(f"State-action pair: {FEATURE_DIM} features")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = QNetwork(input_dim=FEATURE_DIM).to(device)

checkpoint = torch.load(MODEL_PATH, map_location=device)
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()
scaler = checkpoint['scaler']

print(f"\nModel loaded! Using {device}")

Single state: 69 features
State-action pair: 138 features

Model loaded! Using cpu


In [4]:
board: Optional[BoardProcessor] = None
moves: Optional[List[int]] = None

In [5]:
def init_game(ai_moves_first=False):
    # global board, moves
    print(f"New game! {'AI' if ai_moves_first else 'You'} play{'s' if ai_moves_first else ''} first.")
    print("Columns: 0 1 2 3 4 5 6\n")

    if ai_moves_first:
        # AI makes first move as player 1
        ai_col, q, _ = get_ai_move(board.state_list, ai_player=1, epsilon=0.1)
        moves.append(ai_col)
        board.generate_state_list(moves)
        print(f"AI starts with column {ai_col} (Q={q:.3f}):")
        display_colored()
    else:
        # Show empty board
        for _ in range(6):
            print(". . . . . . .")

In [6]:
def restart():
    global board, moves
    """Initialize a new game"""
    board = BoardProcessor()
    moves = []
    # init_game(ai_moves_first=False)

In [7]:
def check_win(state_list):
   """Check for win using convolution features. Returns 1, -1, or 0"""
   _, features = feature_gen.convolution_feature_gen(state_list)
   if 4 in features:
       return 1
   elif -4 in features:
       return -1
   return 0

In [8]:

# Color codes for Yellow-Green-Blue gradient (7 colors)
COLORS = [
    '\033[93m',   # Yellow (worst)
    '\033[33m',   # Orange-yellow
    '\033[32m',   # Yellow-green
    '\033[92m',   # Green
    '\033[36m',   # Green-cyan
    '\033[96m',   # Cyan
    '\033[94m',   # Blue (best)
]
RED = '\033[91m'
YELLOW = '\033[93m'
RESET = '\033[0m'

In [9]:
def get_q_value(state_list, move, player):
    _, curr_feats = feature_gen.convolution_feature_gen(state_list)
    next_state = [col[:] for col in state_list]
    next_state[move].append(player)
    _, next_feats = feature_gen.convolution_feature_gen(next_state)

    features = np.concatenate([curr_feats, next_feats])
    scaled = scaler.transform([features])

    with torch.no_grad():
        return model(torch.FloatTensor(scaled).to(device)).item()

In [10]:

def get_ai_move(state_list, epsilon=0.1, ai_player = -1):
    """
    Select AI move using epsilon-greedy strategy.
    Returns: (selected_column, best_q_value, q_values_dict)
    """
    valid = [c for c in range(7) if len(state_list[c]) < 6]
    q_values = {}

    # Calculate Q-values for all valid moves
    for col in valid:
        q = get_q_value(state_list, col, ai_player) * ai_player # Negate for AI perspective
        q_values[col] = q

    # Epsilon-greedy selection (structured for easy modification)
    if np.random.random() < epsilon:
        # Exploration: random move
        selected_col = np.random.choice(valid)
    else:
        # Exploitation: best move
        selected_col = max(q_values.keys(), key=lambda k: q_values[k])

    best_q = q_values[selected_col]
    return selected_col, best_q, q_values

In [11]:

def display_colored(q_values=None):
    """Display colored board with Q-value indicators for empty spots"""
    if not board.board_history:
        print("[Board is empty]")
        return

    board_matrix = board.board_history[-1]

    # Map Q-values to colors if provided
    color_map = {}
    if q_values:
        valid_cols = list(q_values.keys())
        if valid_cols:
            # Sort columns by Q-value
            sorted_cols = sorted(valid_cols, key=lambda d: q_values[d])
            # Assign colors based on ranking
            for i, col in enumerate(sorted_cols):
                color_idx = int(i * (len(COLORS) - 1) / max(1, len(sorted_cols) - 1))
                color_map[col] = COLORS[color_idx]

    # Print board
    for r in range(5, -1, -1):
        row_parts = []
        for c in range(7):
            if board_matrix[r, c] == 1:
                row_parts.append(f"{RED}X{RESET}")
            elif board_matrix[r, c] == -1:
                row_parts.append(f"{YELLOW}O{RESET}")
            elif c in color_map:
                # Empty spot with Q-value color
                row_parts.append(f"{color_map[c]}•{RESET}")
            else:
                # Empty spot with no valid move
                row_parts.append(".")
        print(" ".join(row_parts))

In [12]:
def check_game_over():
    """Check for win/draw. Returns (is_over, message)"""
    winner = check_win(board.state_list)
    if winner == 1:
        return True, f"\n🎉 {RED}YOU WIN!{RESET} 🎉\nGame code: {board.moves_code()}"
    elif winner == -1:
        return True, f"\n🤖 {YELLOW}AI WINS!{RESET}\nGame code: {board.moves_code()}"
    elif len(moves) >= 42:
        return True, f"\n🤝 DRAW!\nGame code: {board.moves_code()}"
    return False, None

def play(column):
    """Make a move in the specified column"""
    global moves, board
    move_num = len (moves)
    ai_player = -((move_num % 2) * 2 - 1)
    # Validate move
    valid = [c for c in range(7) if len(board.state_list[c]) < 6]
    if column not in valid:
        print(f"Invalid! Choose from: {valid}")
        return

    # Human move
    moves.append(column)
    board.generate_state_list(moves)
    print(f"\nYou played column {column}:")

    # Get AI's evaluation before showing board
    _, _, q_values = get_ai_move(board.state_list, epsilon=0.0, ai_player=ai_player)  # No exploration for display
    display_colored(q_values)

    # Check for win
    game_over, message = check_game_over()
    if game_over:
        print(message)
        return

    # AI move (with exploration)
    ai_col, q, _ = get_ai_move(board.state_list, epsilon=0.0, ai_player=ai_player)
    moves.append(ai_col)
    board.generate_state_list(moves)
    print(f"\nAI played column {ai_col} (Q={q:.3f}):")
    display_colored()  # No Q-values shown after AI moves

    # Check for AI win
    game_over, message = check_game_over()
    if game_over:
        print(message)
        return

    # Show valid moves
    valid = [c for c in range(7) if len(board.state_list[c]) < 6]
    if valid:
        print(f"\nYour turn. Valid moves: {valid}")

In [13]:
restart()

In [14]:
init_game(True)

New game! AI plays first.
Columns: 0 1 2 3 4 5 6

AI starts with column 3 (Q=0.052):
. . . . . . .
. . . . . . .
. . . . . . .
. . . . . . .
. . . . . . .
. . . [91mX[0m . . .


In [15]:
play(6)


You played column 6:
[33m•[0m [36m•[0m [96m•[0m [94m•[0m [92m•[0m [93m•[0m [32m•[0m
[33m•[0m [36m•[0m [96m•[0m [94m•[0m [92m•[0m [93m•[0m [32m•[0m
[33m•[0m [36m•[0m [96m•[0m [94m•[0m [92m•[0m [93m•[0m [32m•[0m
[33m•[0m [36m•[0m [96m•[0m [94m•[0m [92m•[0m [93m•[0m [32m•[0m
[33m•[0m [36m•[0m [96m•[0m [94m•[0m [92m•[0m [93m•[0m [32m•[0m
[33m•[0m [36m•[0m [96m•[0m [91mX[0m [92m•[0m [93m•[0m [93mO[0m

AI played column 3 (Q=0.105):
. . . . . . .
. . . . . . .
. . . . . . .
. . . . . . .
. . . [91mX[0m . . .
. . . [91mX[0m . . [93mO[0m

Your turn. Valid moves: [0, 1, 2, 3, 4, 5, 6]


In [16]:
play(5)


You played column 5:
[33m•[0m [32m•[0m [92m•[0m [96m•[0m [93m•[0m [94m•[0m [36m•[0m
[33m•[0m [32m•[0m [92m•[0m [96m•[0m [93m•[0m [94m•[0m [36m•[0m
[33m•[0m [32m•[0m [92m•[0m [96m•[0m [93m•[0m [94m•[0m [36m•[0m
[33m•[0m [32m•[0m [92m•[0m [96m•[0m [93m•[0m [94m•[0m [36m•[0m
[33m•[0m [32m•[0m [92m•[0m [91mX[0m [93m•[0m [94m•[0m [36m•[0m
[33m•[0m [32m•[0m [92m•[0m [91mX[0m [93m•[0m [93mO[0m [93mO[0m

AI played column 5 (Q=0.009):
. . . . . . .
. . . . . . .
. . . . . . .
. . . . . . .
. . . [91mX[0m . [91mX[0m .
. . . [91mX[0m . [93mO[0m [93mO[0m

Your turn. Valid moves: [0, 1, 2, 3, 4, 5, 6]


In [17]:
play(5)


You played column 5:
[32m•[0m [36m•[0m [92m•[0m [94m•[0m [33m•[0m [93m•[0m [96m•[0m
[32m•[0m [36m•[0m [92m•[0m [94m•[0m [33m•[0m [93m•[0m [96m•[0m
[32m•[0m [36m•[0m [92m•[0m [94m•[0m [33m•[0m [93m•[0m [96m•[0m
[32m•[0m [36m•[0m [92m•[0m [94m•[0m [33m•[0m [93mO[0m [96m•[0m
[32m•[0m [36m•[0m [92m•[0m [91mX[0m [33m•[0m [91mX[0m [96m•[0m
[32m•[0m [36m•[0m [92m•[0m [91mX[0m [33m•[0m [93mO[0m [93mO[0m

AI played column 3 (Q=-0.039):
. . . . . . .
. . . . . . .
. . . . . . .
. . . [91mX[0m . [93mO[0m .
. . . [91mX[0m . [91mX[0m .
. . . [91mX[0m . [93mO[0m [93mO[0m

Your turn. Valid moves: [0, 1, 2, 3, 4, 5, 6]
