# Connect 4 vs Q-Network (PyCharm Version)
No input() calls - designed for PyCharm notebooks

In [None]:
import torch
import torch.nn as nn
import numpy as np
from board_processor import BoardProcessor
from feature_generator import FeatureGenerator

In [None]:
class QNetwork(nn.Module):
    def __init__(self, input_dim, hidden_sizes=(256, 128, 64, 32, 16, 8)):
        super().__init__()
        layers = []
        last_dim = input_dim
        for h in hidden_sizes:
            layers.append(nn.Linear(last_dim, h))
            layers.append(nn.Tanh())
            last_dim = h
        layers.append(nn.Linear(last_dim, 1))
        layers.append(nn.Tanh())
        self.net = nn.Sequential(*layers)
    
    def forward(self, x):
        return self.net(x).squeeze(-1)

In [None]:
# Load model
MODEL_PATH = "qnet_mc_pretrained.pth"
feature_gen = FeatureGenerator()

# Calculate dimensions
_, dummy_features = feature_gen.convolution_feature_gen([[] for _ in range(7)])
FEATURE_DIM = len(dummy_features) * 2  # State-action pairs

print(f"Single state: {len(dummy_features)} features")
print(f"State-action pair: {FEATURE_DIM} features")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = QNetwork(input_dim=FEATURE_DIM).to(device)

checkpoint = torch.load(MODEL_PATH, map_location=device)
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()
scaler = checkpoint['scaler']

print(f"\nModel loaded! Using {device}")

In [None]:
# Game functions
def get_q_value(state_list, move, player):
    _, curr_feats = feature_gen.convolution_feature_gen(state_list)
    next_state = [col[:] for col in state_list]
    next_state[move].append(player)
    _, next_feats = feature_gen.convolution_feature_gen(next_state)
    
    features = np.concatenate([curr_feats, next_feats])
    scaled = scaler.transform([features])
    
    with torch.no_grad():
        return model(torch.FloatTensor(scaled).to(device)).item()

def get_ai_move(state_list):
    valid = [c for c in range(7) if len(state_list[c]) < 6]
    best_col = valid[0]
    best_q = -float('inf')
    
    for col in valid:
        q = -get_q_value(state_list, col, -1)  # Negate for AI perspective
        if q > best_q:
            best_q = q
            best_col = col
    
    return best_col, best_q

In [None]:
# Initialize game
board = BoardProcessor()
moves = []
print("Game started! You are X, AI is O")
print("Columns: 0 1 2 3 4 5 6\n")
for _ in range(6):
    print(". . . . . . .")

In [None]:
def play(column):
    """Make a move in the specified column"""
    global moves, board
    
    # Validate move
    valid = [c for c in range(7) if len(board.state_list[c]) < 6]
    if column not in valid:
        print(f"Invalid! Choose from: {valid}")
        return
    
    # Human move
    moves.append(column)
    board.generate_state_list(moves)
    print(f"\nYou played column {column}:")
    board.display_board()
    
    # AI move
    ai_col, q = get_ai_move(board.state_list)
    moves.append(ai_col)
    board.generate_state_list(moves)
    print(f"\nAI played column {ai_col} (Q={q:.3f}):")
    board.display_board()

# To play: run play(3) to drop in column 3

In [None]:
play(6)

In [None]:
# Analyze current position
def analyze():
    print("Q-values for your moves:")
    valid = [c for c in range(7) if len(board.state_list[c]) < 6]
    for col in valid:
        q = get_q_value(board.state_list, col, 1)
        print(f"  Column {col}: {q:.4f}")