In [3]:
"""
CNN Performance Analysis Script - FINAL VERSION
Fixed text box with solid background and better positioning
"""

import numpy as np
from tensorflow import keras
import matplotlib.pyplot as plt
import matplotlib.patches as patches

# Load your data and model
print("Loading data and model...")
X_test = np.load('connect4_X_clean.npy').astype(np.float32)
Y_test = np.load('connect4_Y_clean.npy')

# Get test split (same as training)
total_samples = len(X_test)
train_size = int(0.70 * total_samples)
val_size = int(0.15 * total_samples)

X_test_split = X_test[train_size + val_size:]
Y_test_split = Y_test[train_size + val_size:]

print(f"Test set size: {len(X_test_split)}")

# Load model
model = keras.models.load_model('CNN_v2_deep_best.h5')
print("✓ Model loaded")

# Get predictions
print("Generating predictions...")
predictions = model.predict(X_test_split, verbose=1)
predicted_labels = np.argmax(predictions, axis=1)
confidences = np.max(predictions, axis=1)

# Find correct and incorrect predictions
correct_mask = (predicted_labels == Y_test_split)
incorrect_mask = ~correct_mask

print(f"\nOverall accuracy: {correct_mask.mean():.4f}")

# Helper functions
def board_3d_to_2d(board_3d):
    board_2d = np.zeros((6, 7))
    board_2d[board_3d[:, :, 0] == 1] = 1
    board_2d[board_3d[:, :, 1] == 1] = -1
    return board_2d

def count_pieces(board_3d):
    return int(np.sum(board_3d[:, :, 0]) + np.sum(board_3d[:, :, 1]))

def is_interesting_board(board_3d):
    num_pieces = count_pieces(board_3d)
    return 6 <= num_pieces <= 25

# Find interesting boards
high_conf_correct = np.where(correct_mask & (confidences > 0.8))[0]
interesting_high_conf = [i for i in high_conf_correct if is_interesting_board(X_test_split[i])]
interesting_high_conf = sorted(interesting_high_conf, key=lambda i: confidences[i], reverse=True)[:6]

low_conf_correct = np.where(correct_mask & (confidences < 0.35) & (confidences > 0.20))[0]
interesting_low_conf_correct = [i for i in low_conf_correct if is_interesting_board(X_test_split[i])]
interesting_low_conf_correct = sorted(interesting_low_conf_correct, key=lambda i: confidences[i])[:6]

high_conf_incorrect = np.where(incorrect_mask & (confidences > 0.5))[0]
interesting_high_conf_wrong = [i for i in high_conf_incorrect if is_interesting_board(X_test_split[i])]
interesting_high_conf_wrong = sorted(interesting_high_conf_wrong, key=lambda i: confidences[i], reverse=True)[:6]

top2_probs = np.partition(predictions, -2, axis=1)[:, -2:]
prob_diff = top2_probs[:, 1] - top2_probs[:, 0]
close_call_incorrect = np.where(incorrect_mask & (prob_diff < 0.15))[0]
interesting_close_call = [i for i in close_call_incorrect if is_interesting_board(X_test_split[i])]
interesting_close_call = sorted(interesting_close_call, key=lambda i: prob_diff[i])[:6]

# Visualization function with FIXED text box
def visualize_board(board_3d, true_label, pred_label, confidence,
                    title="", predictions_full=None, ax=None):
    if ax is None:
        fig, ax = plt.subplots(1, 1, figsize=(8, 7))

    board_2d = board_3d_to_2d(board_3d)

    # Draw board
    ax.set_xlim(-0.5, 6.5)
    ax.set_ylim(-0.5, 5.5)
    ax.set_aspect('equal')
    ax.invert_yaxis()

    # Draw grid
    for i in range(7):
        ax.axvline(i - 0.5, color='blue', linewidth=2)
    for i in range(6):
        ax.axhline(i - 0.5, color='blue', linewidth=2)

    # Draw pieces
    for row in range(6):
        for col in range(7):
            if board_2d[row, col] == 1:
                circle = plt.Circle((col, row), 0.4, color='red', zorder=10)
                ax.add_patch(circle)
            elif board_2d[row, col] == -1:
                circle = plt.Circle((col, row), 0.4, color='yellow', zorder=10)
                ax.add_patch(circle)

    # Highlight columns
    pred_col = pred_label
    rect = patches.Rectangle((pred_col - 0.5, -0.5), 1, 6,
                             linewidth=3, edgecolor='green',
                             facecolor='none', zorder=5,
                             linestyle='--', label='CNN')
    ax.add_patch(rect)

    true_col = true_label
    rect = patches.Rectangle((true_col - 0.5, -0.5), 1, 6,
                             linewidth=3, edgecolor='orange',
                             facecolor='none', zorder=5,
                             label='MCTS')
    ax.add_patch(rect)

    # Labels
    ax.set_xticks(range(7))
    ax.set_xticklabels(range(7), fontsize=11)
    ax.set_yticks([])
    ax.set_xlabel('Column', fontsize=11)

    # Title
    correct = "✓ CORRECT" if pred_label == true_label else "✗ INCORRECT"
    title_text = f"{title}\n{correct}\nConfidence: {confidence:.1%}"
    ax.set_title(title_text, fontsize=11, fontweight='bold', pad=10)

    # FIXED: Text box at top with solid white background and border
    if predictions_full is not None:
        probs_sorted_idx = np.argsort(predictions_full)[::-1][:3]
        prob_text = "Top 3:\n"
        for idx in probs_sorted_idx:
            marker = "→" if idx == pred_label else " "
            prob_text += f"{marker}Col {idx}: {predictions_full[idx]:.0%}\n"

        # Position at top, solid white background, black border
        ax.text(0.02, 0.98, prob_text, transform=ax.transAxes,
               fontsize=9, verticalalignment='top',
               bbox=dict(boxstyle='round', facecolor='white',
                        edgecolor='black', linewidth=1.5, alpha=1.0),
               zorder=100)  # High zorder to draw on top

    ax.legend(loc='upper right', fontsize=9)

    return ax

# Create visualizations with FIXED text boxes
print("\nCreating visualizations...")

# 1. High confidence correct
fig, axes = plt.subplots(2, 3, figsize=(16, 11))
fig.suptitle('HIGH CONFIDENCE CORRECT - CNN is Very Sure and Right',
             fontsize=15, fontweight='bold', y=0.98)

for idx, i in enumerate(interesting_high_conf):
    ax = axes[idx // 3, idx % 3]
    visualize_board(
        X_test_split[i],
        Y_test_split[i],
        predicted_labels[i],
        confidences[i],
        title=f"Example {idx+1}",
        predictions_full=predictions[i],
        ax=ax
    )

plt.tight_layout(rect=[0, 0, 1, 0.97])
plt.subplots_adjust(hspace=0.35, wspace=0.25)
plt.savefig('cnn_high_confidence_correct_final.png', dpi=150, bbox_inches='tight')
print("✓ Saved: cnn_high_confidence_correct_final.png")
plt.close()

# 2. High confidence incorrect
fig, axes = plt.subplots(2, 3, figsize=(16, 11))
fig.suptitle('HIGH CONFIDENCE INCORRECT - CNN is Very Sure but Wrong',
             fontsize=15, fontweight='bold', y=0.98)

for idx, i in enumerate(interesting_high_conf_wrong):
    ax = axes[idx // 3, idx % 3]
    visualize_board(
        X_test_split[i],
        Y_test_split[i],
        predicted_labels[i],
        confidences[i],
        title=f"Example {idx+1}",
        predictions_full=predictions[i],
        ax=ax
    )

plt.tight_layout(rect=[0, 0, 1, 0.97])
plt.subplots_adjust(hspace=0.35, wspace=0.25)
plt.savefig('cnn_high_confidence_incorrect_final.png', dpi=150, bbox_inches='tight')
print("✓ Saved: cnn_high_confidence_incorrect_final.png")
plt.close()

# 3. Low confidence correct
fig, axes = plt.subplots(2, 3, figsize=(16, 11))
fig.suptitle('LOW CONFIDENCE CORRECT - CNN is Uncertain but Still Right',
             fontsize=15, fontweight='bold', y=0.98)

for idx, i in enumerate(interesting_low_conf_correct):
    ax = axes[idx // 3, idx % 3]
    visualize_board(
        X_test_split[i],
        Y_test_split[i],
        predicted_labels[i],
        confidences[i],
        title=f"Example {idx+1}",
        predictions_full=predictions[i],
        ax=ax
    )

plt.tight_layout(rect=[0, 0, 1, 0.97])
plt.subplots_adjust(hspace=0.35, wspace=0.25)
plt.savefig('cnn_low_confidence_correct_final.png', dpi=150, bbox_inches='tight')
print("✓ Saved: cnn_low_confidence_correct_final.png")
plt.close()

# 4. Close calls
fig, axes = plt.subplots(2, 3, figsize=(16, 11))
fig.suptitle('CLOSE CALLS - Multiple Good Moves Available',
             fontsize=15, fontweight='bold', y=0.98)

for idx, i in enumerate(interesting_close_call):
    ax = axes[idx // 3, idx % 3]
    visualize_board(
        X_test_split[i],
        Y_test_split[i],
        predicted_labels[i],
        confidences[i],
        title=f"Example {idx+1}",
        predictions_full=predictions[i],
        ax=ax
    )

plt.tight_layout(rect=[0, 0, 1, 0.97])
plt.subplots_adjust(hspace=0.35, wspace=0.25)
plt.savefig('cnn_close_calls_final.png', dpi=150, bbox_inches='tight')
print("✓ Saved: cnn_close_calls_final.png")
plt.close()

print("\n" + "="*60)
print("COMPLETE - Final versions with fixed text boxes!")
print("="*60)

Loading data and model...
Test set size: 32297




✓ Model loaded
Generating predictions...
[1m1010/1010[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 3ms/step

Overall accuracy: 0.6654

Creating visualizations...
✓ Saved: cnn_high_confidence_correct_final.png
✓ Saved: cnn_high_confidence_incorrect_final.png
✓ Saved: cnn_low_confidence_correct_final.png
✓ Saved: cnn_close_calls_final.png

COMPLETE - Final versions with fixed text boxes!
