In [None]:
import pygame
import numpy as np
import platform

# Fix matplotlib backend for macOS compatibility
import matplotlib
matplotlib.use('TkAgg')  # Explicitly set backend before importing pyplot

import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
from matplotlib.lines import Line2D
import matplotlib.patches as patches
import torch
import torch.nn as nn
import torch.nn.functional as F
from scipy import ndimage
import threading
import time
from matplotlib.backends.backend_tkagg import FigureCanvasTkAgg
import tkinter as tk
import queue


class MinimalCNN(nn.Module):
    def __init__(self):
        super(MinimalCNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 4, kernel_size=5, stride=1)
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv2 = nn.Conv2d(4, 8, kernel_size=3, stride = 1)
        self.conv3 = nn.Conv2d(8, 10, kernel_size=4, stride = 1)
        self.dropout = nn.Dropout(0.2)

    def forward(self, x):
        conv1 = self.conv1(x)
        pool1 = self.pool(F.relu(conv1))
        conv2 = self.conv2(pool1)
        pool2 = self.pool(F.relu(conv2))
        conv3 = self.conv3(pool2)
        conv3 = self.dropout(conv3)
        sum = torch.sum(conv3, dim=(2, 3))
        out = sum.view(-1, 10)
        return conv1, pool1, conv2, pool2, conv3, sum, out


def load_model():
    model = MinimalCNN()
    model.load_state_dict(torch.load('mnist_cnn_4-8-10.pth', map_location='cpu'))
    model.eval()
    return model


def center_and_normalize_digit(digit_array, target_size=20):
    rows, cols = np.where(digit_array > 0)
    
    if len(rows) == 0:
        return digit_array
    
    min_row, max_row = rows.min(), rows.max()
    min_col, max_col = cols.min(), cols.max()
    
    digit_region = digit_array[min_row:max_row+1, min_col:max_col+1]
    
    height, width = digit_region.shape
    scale = min(target_size / height, target_size / width)
    
    if scale < 1.0:
        new_height = int(height * scale)
        new_width = int(width * scale)
        digit_region = ndimage.zoom(digit_region, (new_height/height, new_width/width))
    
    centered_digit = np.zeros((28, 28))
    
    region_height, region_width = digit_region.shape
    start_row = (28 - region_height) // 2
    start_col = (28 - region_width) // 2
    
    centered_digit[start_row:start_row+region_height, start_col:start_col+region_width] = digit_region
    
    return centered_digit


def calculate_center_of_mass(digit_array):
    rows, cols = np.where(digit_array > 0)
    if len(rows) == 0:
        return 14, 14
    
    weights = digit_array[rows, cols]
    center_row = np.average(rows, weights=weights)
    center_col = np.average(cols, weights=weights)
    
    return center_row, center_col


def fine_tune_centering(digit_array):
    center_row, center_col = calculate_center_of_mass(digit_array)
    
    target_center = 13.5
    shift_row = int(round(target_center - center_row))
    shift_col = int(round(target_center - center_col))
    
    shift_row = max(-10, min(10, shift_row))
    shift_col = max(-10, min(10, shift_col))
    
    if shift_row != 0 or shift_col != 0:
        digit_array = ndimage.shift(digit_array, (shift_row, shift_col), cval=0.0)
    
    return digit_array


def preprocess_drawing(digit_array):
    centered = center_and_normalize_digit(digit_array)
    final = fine_tune_centering(centered)
    
    tensor = torch.FloatTensor(final)
    normalized = (tensor - 0.5) / 0.5
    input_tensor = normalized.unsqueeze(0).unsqueeze(0)
    
    return input_tensor, final


def predict_digit_improved(model, digit_array):
    if np.sum(digit_array) == 0:
        # Return zeros for empty input
        return 0, 0.0, np.ones(10) * 0.1, digit_array, [np.zeros((4, 24, 24)), np.zeros((4, 12, 12)), np.zeros((8, 8, 8)), np.zeros((8, 4, 4)), np.zeros((10, 1, 1)), np.zeros(10)]
    
    input_tensor, processed_array = preprocess_drawing(digit_array)
    
    with torch.no_grad():
        conv1, pool1, conv2, pool2, conv3, sum_out, outputs = model(input_tensor)
        stages = [conv1, pool1, conv2, pool2, conv3, sum_out]
        probabilities = F.softmax(outputs, dim=1)
        confidence, predicted = torch.max(probabilities, 1)
    
    return predicted.item(), confidence.item(), probabilities.squeeze().numpy(), processed_array, stages


def apply_convolution(num_kernels, num_channels, weights, input_tensor):
    output_maps = []
    
    # Ensure input_tensor has the right shape and data type
    if len(input_tensor.shape) == 3:
        input_tensor = input_tensor.unsqueeze(0)  # Add batch dimension
    
    # Ensure both tensors have the same data type
    input_tensor = input_tensor.float()
    weights = weights.float()
    
    for i in range(0, num_kernels):
        kernels = []
        for j in range(0, num_channels):
            # Make sure we have the right number of channels
            if j < input_tensor.shape[1]:
                conv_output = F.conv2d(input_tensor[:, j:j+1, :, :], weights[i:i+1, j:j+1, :, :], stride=1)
                conv_output_np = conv_output.squeeze(0).squeeze(0).detach().cpu().numpy()
                kernels.append(conv_output_np)
            else:
                # If we don't have enough input channels, create zeros
                kernels.append(np.zeros((1, 1)))
        output_maps.append(kernels)
    return output_maps


class CNNVisualizationWindow:
    def __init__(self, data_queue):
        self.data_queue = data_queue
        self.root = tk.Tk()
        self.root.title("Real-time CNN Visualization - Full Layout")
        
        # Cross-platform window management
        try:
            if platform.system() == "Windows":
                self.root.state('zoomed')  # Windows maximize
            elif platform.system() == "Darwin":  # macOS
                # For macOS, use a large window instead of fullscreen to avoid issues
                self.root.geometry("1400x800")
                # Try to center the window
                self.root.update_idletasks()
                width = self.root.winfo_width()
                height = self.root.winfo_height()
                x = (self.root.winfo_screenwidth() // 2) - (width // 2)
                y = (self.root.winfo_screenheight() // 2) - (height // 2)
                self.root.geometry(f'{width}x{height}+{x}+{y}')
            else:  # Linux
                try:
                    self.root.attributes('-zoomed', True)
                except tk.TclError:
                    self.root.geometry("1400x800")
        except Exception as e:
            print(f"Window maximization failed: {e}")
            # Fallback to large window
            self.root.geometry("1400x800")
        
        # Load model and get filters
        try:
            self.model = load_model()
        except FileNotFoundError:
            print("Could not find 'mnist_cnn_4-8-10.pth' model file")
            self.root.destroy()
            return
            
        # Extract filters from all layers
        self.filters1 = []
        self.filters2 = []
        self.filters3 = []
        
        filters1 = self.model.conv1.weight.data
        filters2 = self.model.conv2.weight.data
        filters3 = self.model.conv3.weight.data
        
        # Extract conv1 filters (4 filters, 1 channel each)
        for i in range(filters1.size(0)):
            channels = []
            for j in range(1):  # 1 channel
                channels.append(filters1[i][j].detach().cpu().numpy())
            self.filters1.append(channels)
        
        # Extract conv2 filters (8 filters, 4 channels each)
        for i in range(filters2.size(0)):
            channels = []
            for j in range(4):  # 4 channels
                channels.append(filters2[i][j].detach().cpu().numpy())
            self.filters2.append(channels)
        
        # Extract conv3 filters (10 filters, 8 channels each)
        for i in range(filters3.size(0)):
            channels = []
            for j in range(8):  # 8 channels
                channels.append(filters3[i][j].detach().cpu().numpy())
            self.filters3.append(channels)
        
        # Get bias values
        self.filter_bias1 = self.model.conv1.bias.data.detach().cpu().numpy()
        self.filter_bias2 = self.model.conv2.bias.data.detach().cpu().numpy()
        self.filter_bias3 = self.model.conv3.bias.data.detach().cpu().numpy()
        
        self.setup_matplotlib()
        self.start_update_loop()
        
    def setup_matplotlib(self):
        # Create matplotlib figure that shows everything at once
        plt.ioff()  # Turn off interactive mode
        self.fig = plt.figure(figsize=(22, 10))
        
        # Create tkinter canvas
        self.canvas = FigureCanvasTkAgg(self.fig, master=self.root)
        self.canvas_widget = self.canvas.get_tk_widget()
        self.canvas_widget.pack(fill=tk.BOTH, expand=True)
        
        # Initialize plot elements
        self.setup_plot_elements()
        
    def plot_matrix(self, ax, matrix, title, cmap='Blues', value_format='.1f', show_grid=True, show_numbers=False, change_color=True, font=-1):
        """Scaled plot_matrix function for live visualization"""
        im = ax.imshow(matrix, cmap=cmap, interpolation='nearest')
        ax.set_title(title, fontsize=8, fontweight='bold')

        if show_grid and matrix.shape[0] <= 10 and matrix.shape[1] <= 10:
            ax.set_xticks(np.arange(-.5, matrix.shape[1], 1), minor=True)
            ax.set_yticks(np.arange(-.5, matrix.shape[0], 1), minor=True)
            ax.grid(which='minor', color='w', linestyle='-', linewidth=0.5)
            ax.set_xticks([])
            ax.set_yticks([])
        else:
            ax.set_xticks([])
            ax.set_yticks([])

        if show_numbers and matrix.shape[0] <= 10 and matrix.shape[1] <= 10:
            max_dim = max(matrix.shape[0], matrix.shape[1])
            if font > 0:
                fontsize = font
            elif matrix.shape[0] == 1 and matrix.shape[1] == 1:
                fontsize = 6
            elif max_dim <= 3:
                fontsize = 5
            elif max_dim <= 5:
                fontsize = 4
            else:
                fontsize = 3
            for i in range(matrix.shape[0]):
                for j in range(matrix.shape[1]):
                    value = matrix[i, j]
                    text_color = 'white' if (abs(value) > 0.3 and change_color) else 'black'
                    if isinstance(value, (int, np.integer)):
                        value_str = f"{value}"
                    else:
                        value_str = f"{value:{value_format}}"
                    ax.text(j, i, value_str, ha='center', va='center',
                            color=text_color, fontsize=fontsize)
        return ax
        
    def setup_plot_elements(self):
        """Initialize all plot elements with same layout as static version"""
        self.fig.clear()
        
        # Store axes for updates
        self.axes = {}
        
        scale = 0.4
        vertical_shift = 0.35
        
        input_left = 0.01
        kernel1_base_left = 0.35 * scale
        feature1_base_left = 0.55 * scale
        pooling1_base_left = 0.75 * scale
        kernels2_base_left = 1.0 * scale
        channels2_base_left = 1.30 * scale
        feature2_base_left = 1.50 * scale
        pooling2_base_left = 1.60 * scale
        feature3_base_left = 1.80 * scale
        out_base_left = 2.00 * scale
        prediction_left = 2.20 * scale

        # Calculate true center position
        sub_height = 0.12 * scale
        v_spacing = 0.02 * scale
        top_edge = (0.85 * scale + sub_height) + vertical_shift
        bottom_edge = ((0.85 * scale - (3 * 0.3 * scale)) - (sub_height + v_spacing)) + vertical_shift
        true_center = (top_edge + bottom_edge) / 2

        # Center input and prediction boxes
        box_height = 0.25 * scale
        centered_bottom = true_center - box_height / 2

        # Create input axis
        input_width = 0.25 * scale
        self.axes['input'] = self.fig.add_axes([input_left, centered_bottom, input_width, box_height])
        
        # Create all other axes
        self.axes['kernel1'] = []
        self.axes['feature1'] = []
        self.axes['pool1'] = []
        self.axes['kernel2_channels'] = []
        self.axes['channels2'] = []
        self.axes['feature2'] = []
        self.axes['pool2'] = []
        self.axes['feature3'] = []
        self.axes['sum'] = []
        
        # Layer 1: 4 filters
        for row in range(4):
            left_kernel = kernel1_base_left
            left_feature = feature1_base_left
            left_pooling = pooling1_base_left
            bottom = (0.77 * scale - (row * 0.3 * scale)) + vertical_shift

            # Plot kernel
            ax_filter = self.fig.add_axes([left_kernel, bottom, 0.14 * scale, 0.14 * scale])
            self.axes['kernel1'].append(ax_filter)

            # Plot corresponding feature map
            ax_output = self.fig.add_axes([left_feature, bottom, 0.14 * scale, 0.14 * scale])
            self.axes['feature1'].append(ax_output)

            # Plot pooled feature maps
            ax_pooling = self.fig.add_axes([left_pooling, bottom, 0.14 * scale, 0.14 * scale])
            self.axes['pool1'].append(ax_pooling)

        # Layer 2: kernels and channels in 4 vertical 2x2 blocks
        for block_num in range(8):
            block_kernel_axes = []
            block_output_axes = []
            block_top = (1.45 * scale - (block_num * 0.3 * scale)) + vertical_shift

            for sub_row in range(2):
                for sub_col in range(2):
                    f = sub_row * 2 + sub_col
                    sub_width = 0.08 * scale      
                    sub_height = 0.08 * scale     
                    h_spacing_kernels = 0.04 * scale  
                    h_spacing_features = 0.01 * scale 
                    v_spacing = 0.06 * scale

                    left_kernel = kernels2_base_left + sub_col * (sub_width + h_spacing_kernels)
                    left_feature = channels2_base_left + sub_col * (sub_width + h_spacing_features)
                    bottom = block_top - sub_row * (sub_height + v_spacing)

                    # Plot kernel
                    ax_filter = self.fig.add_axes([left_kernel, bottom, sub_width, sub_height])
                    block_kernel_axes.append(ax_filter)

                    # Plot corresponding feature map
                    ax_output = self.fig.add_axes([left_feature, bottom, sub_width, sub_height])
                    block_output_axes.append(ax_output)

            self.axes['kernel2_channels'].append(block_kernel_axes)
            self.axes['channels2'].append(block_output_axes)

        # Layer 2: feature maps and pooling in 1 column of 8
        for row in range(8):
            bottom = (1.375 * scale - (row * 0.3 * scale)) + vertical_shift
            sub_height = 0.12 * scale
            left_feature = feature2_base_left
            left_pooling = pooling2_base_left

            # Plot feature map
            ax_output = self.fig.add_axes([left_feature, bottom, 0.12 * scale, sub_height])
            self.axes['feature2'].append(ax_output)

            # Plot pooled feature map
            ax_pooling = self.fig.add_axes([left_pooling, bottom, 0.12 * scale, sub_height])
            self.axes['pool2'].append(ax_pooling)

        # Layer 3
        for row in range(10):
            bottom = (1.23 * scale - (row * 0.2 * scale)) + vertical_shift
            sub_height = 0.12 * scale
            left_feature = feature3_base_left

            # Plot conv3 output
            ax_feature = self.fig.add_axes([left_feature, bottom, 0.12 * scale, sub_height])
            self.axes['feature3'].append(ax_feature)

            # Plot output
            ax_out = self.fig.add_axes([out_base_left, bottom, 0.12 * scale, sub_height])
            self.axes['sum'].append(ax_out)

        # Prediction box
        figure_aspect_ratio = 22 / 10
        prediction_bottom = centered_bottom
        prediction_width = (box_height / figure_aspect_ratio) * 1.2
        prediction_height = box_height
        self.axes['prediction'] = self.fig.add_axes([prediction_left, prediction_bottom, prediction_width, prediction_height])

        # Initialize static elements (kernels and bias values)
        self.draw_static_elements()
        
        # Initialize with empty data
        self.update_visualization_data(np.zeros((28, 28)))
        
    def draw_static_elements(self):
        """Draw static elements that don't change (kernels, bias values)"""
        # Draw Layer 1 kernels
        for i, ax in enumerate(self.axes['kernel1']):
            self.plot_matrix(ax, self.filters1[i][0], f'K1_{i+1}', cmap='RdBu', show_numbers=True)
        
        # Draw Layer 2 kernels
        for block_num, block_axes in enumerate(self.axes['kernel2_channels']):
            for ch, ax in enumerate(block_axes):
                self.plot_matrix(ax, self.filters2[block_num][ch], f'K{block_num+1}C{ch+1}', cmap='RdBu', show_numbers=True)
        
        
    def update_visualization_data(self, digit_array):
        """Update the visualization with new data"""
        predicted_digit, confidence, _, processed_array, stages = predict_digit_improved(self.model, digit_array)
        
        # Update input display
        self.axes['input'].clear()
        self.plot_matrix(self.axes['input'], processed_array, 'Input Drawing\n(28×28)', show_numbers=False)
        self.axes['input'].set_title('Input Drawing\n(28×28)', fontsize=16, fontweight='bold')
        
        # Extract stage data with proper dimension handling
        if len(stages) >= 6:
            def extract_tensor_data(tensor):
                if hasattr(tensor, 'detach'):
                    tensor = tensor.detach().cpu()
                if hasattr(tensor, 'numpy'):
                    tensor = tensor.numpy()
                # Remove batch dimension if present
                if len(tensor.shape) > 3 and tensor.shape[0] == 1:
                    tensor = tensor[0]
                elif len(tensor.shape) > 2 and tensor.shape[0] == 1:
                    tensor = tensor[0]
                elif len(tensor.shape) > 1 and tensor.shape[0] == 1:
                    tensor = tensor[0]
                return tensor
            
            conv1_out = extract_tensor_data(stages[0])
            pool1_out = extract_tensor_data(stages[1])
            conv2_out = extract_tensor_data(stages[2])
            pool2_out = extract_tensor_data(stages[3])
            conv3_out = extract_tensor_data(stages[4])
            sum_out = extract_tensor_data(stages[5])
            
            # Calculate conv2 channel outputs
            pool1_tensor = stages[1]
            if not isinstance(pool1_tensor, torch.Tensor):
                pool1_tensor = torch.tensor(pool1_out).float()
            else:
                pool1_tensor = pool1_tensor.float()
            if len(pool1_tensor.shape) == 3:
                pool1_tensor = pool1_tensor.unsqueeze(0)  # Add batch dimension
            
            try:
                channels2_data = apply_convolution(8, 4, self.model.conv2.weight.data, pool1_tensor)
            except Exception as e:
                print(f"Error in apply_convolution: {e}")
                # Create dummy data if convolution fails
                channels2_data = [[np.zeros((8, 8)) for _ in range(4)] for _ in range(8)]
            
            # Update Layer 1 feature maps and pooling
            for i in range(min(4, len(self.axes['feature1']))):
                # Feature map 1
                self.axes['feature1'][i].clear()
                if len(conv1_out.shape) >= 3 and i < conv1_out.shape[0]:
                    self.plot_matrix(self.axes['feature1'][i], conv1_out[i], f'FM1_{i+1}', show_numbers=False)
                
                # Pool 1
                self.axes['pool1'][i].clear()
                if len(pool1_out.shape) >= 3 and i < pool1_out.shape[0]:
                    self.plot_matrix(self.axes['pool1'][i], pool1_out[i], f'P1_{i+1}', cmap='Blues', show_numbers=False)
            
            # Update Layer 2 channel feature maps
            for block_num, block_axes in enumerate(self.axes['channels2']):
                if block_num < len(channels2_data):
                    for ch, ax in enumerate(block_axes):
                        ax.clear()
                        if ch < len(channels2_data[block_num]):
                            self.plot_matrix(ax, channels2_data[block_num][ch], f'FM{block_num+1}C{ch+1}', show_numbers=False)
            
            # Update Layer 2 feature maps and pooling
            for i in range(min(8, len(self.axes['feature2']))):
                # Feature map 2
                self.axes['feature2'][i].clear()
                if len(conv2_out.shape) >= 3 and i < conv2_out.shape[0]:
                    self.plot_matrix(self.axes['feature2'][i], conv2_out[i], f'FM2_{i+1}', show_numbers=False)
                
                # Pool 2
                self.axes['pool2'][i].clear()
                if len(pool2_out.shape) >= 3 and i < pool2_out.shape[0]:
                    self.plot_matrix(self.axes['pool2'][i], pool2_out[i], f'P2_{i+1}', cmap='Blues', show_numbers=False)
            
            # Update Layer 3
            for i in range(min(10, len(self.axes['feature3']))):
                # Feature map 3
                self.axes['feature3'][i].clear()
                if len(conv3_out.shape) >= 3 and i < conv3_out.shape[0]:
                    self.plot_matrix(self.axes['feature3'][i], conv3_out[i], f'FM3_{i+1}')
                
                # Sum output
                self.axes['sum'][i].clear()
                if len(sum_out.shape) >= 1 and i < len(sum_out):
                    value = sum_out[i] if len(sum_out.shape) >= 1 else sum_out
                    out_value = np.array([[value]])
                    self.plot_matrix(self.axes['sum'][i], out_value, f'SUM_{i}', show_numbers=True, value_format='.3f', cmap='Blues', change_color=False)

        # Update prediction
        self.axes['prediction'].clear()
        self.axes['prediction'].text(0.5, 0.5, str(predicted_digit), ha='center', va='center',
                                   fontsize=50, fontweight='bold', transform=self.axes['prediction'].transAxes)
        self.axes['prediction'].set_xlim(0, 1)
        self.axes['prediction'].set_ylim(0, 1)
        self.axes['prediction'].set_xticks([])
        self.axes['prediction'].set_yticks([])
        self.axes['prediction'].set_title('FCN Prediction', fontsize=16, fontweight='bold')

        # Add border
        for spine in self.axes['prediction'].spines.values():
            spine.set_visible(True)
            spine.set_linewidth(1)

        # Add confidence text
        confidence_percentage = confidence * 100
        # Remove existing confidence text
        for text in self.fig.texts[:]:
            if '%' in str(text.get_text()):
                text.remove()
        
        pred_bbox = self.axes['prediction'].get_position()
        pred_center_x = (pred_bbox.x0 + pred_bbox.x1) / 2
        confidence_y = pred_bbox.y0 - 0.05
        
        self.fig.text(pred_center_x, confidence_y, f'Confidence: {confidence_percentage:.1f}%',
                     ha='center', va='top', fontsize=12)
        
        # Redraw
        self.canvas.draw()
        
    def start_update_loop(self):
        """Start the update loop that checks for new data"""
        def update():
            try:
                while True:
                    digit_array = self.data_queue.get_nowait()
                    self.update_visualization_data(digit_array)
            except queue.Empty:
                pass
            
            # Platform-appropriate update intervals
            if platform.system() == "Darwin":  # macOS
                update_interval = 16  # ~60 FPS, gentler on macOS
            else:
                update_interval = 1   # Ultra-fast for other platforms
            
            self.root.after(update_interval, update)
        
        update()
        
    def run(self):
        self.root.mainloop()


class DrawingWindow:
    def __init__(self, data_queue):
        self.data_queue = data_queue
        
        # Initialize pygame
        pygame.init()
        self.WIDTH, self.HEIGHT = 560, 560
        self.BG_COLOR = (255, 255, 255)
        self.running = True
        self.grid_array = []
        self.digit_array = np.zeros((28, 28))
        
        # Initialize pygame window
        self.window = pygame.display.set_mode((self.WIDTH, self.HEIGHT))
        pygame.display.set_caption("Draw Here - Cross-Platform CNN Visualization")
        self.window.fill(self.BG_COLOR)
        
        # Platform-appropriate update intervals
        self.last_update_time = 0
        if platform.system() == "Darwin":  # macOS
            self.update_interval = 0.016  # ~60 FPS for macOS stability
        else:
            self.update_interval = 0.001  # Ultra-fast for other platforms
        
        self.setup_grid()
        
    def setup_grid(self):
        block_size = int(self.WIDTH/28)
        self.grid_array = []
        
        for i in range(28):
            for j in range(28):
                grid_block = pygame.Rect(i * block_size, j * block_size, block_size, block_size)
                color = (255, 255, 255)
                self.grid_array.append((grid_block, color))
    
    def get_digit_array(self):
        digit = np.zeros((28, 28))
        
        for row in range(28):
            for col in range(28):
                index = col * 28 + row
                rect, color = self.grid_array[index]
                
                if color == (0, 0, 0):
                    digit[row, col] = 0.9
        
        return digit
    
    def clear_grid(self):
        for i in range(len(self.grid_array)):
            rect, _ = self.grid_array[i]
            self.grid_array[i] = (rect, (255, 255, 255))
    
    def draw_rectangles(self):
        for rect, color in self.grid_array:
            pygame.draw.rect(self.window, color, rect)
            pygame.draw.rect(self.window, (200, 200, 200), rect, 1)
    
    def run(self):
        # Platform-appropriate FPS
        fps = 120 if platform.system() != "Darwin" else 60
        clock = pygame.time.Clock()
        
        print("Cross-Platform CNN Visualization")
        print(f"Running on: {platform.system()}")
        print("- Click and drag to draw")
        print("- Press SPACE to clear")
        print("- Press ESC to exit")
        print("- Live visualization updates in other window")
        
        while self.running:
            current_time = time.time()
            
            for event in pygame.event.get():
                if event.type == pygame.QUIT:
                    self.running = False
                    
                if event.type == pygame.KEYDOWN:
                    if event.key == pygame.K_ESCAPE:
                        self.running = False
                    elif event.key == pygame.K_SPACE:
                        self.clear_grid()
                        self.digit_array = np.zeros((28, 28))
                        # Send cleared array immediately
                        try:
                            # Clear the queue first to avoid lag
                            while not self.data_queue.empty():
                                try:
                                    self.data_queue.get_nowait()
                                except queue.Empty:
                                    break
                            self.data_queue.put_nowait(self.digit_array.copy())
                        except queue.Full:
                            pass
            
            # Handle drawing
            if pygame.mouse.get_pressed()[0]:
                mouse_pos = pygame.mouse.get_pos()
                for index, (rect, color) in enumerate(self.grid_array):
                    if rect.collidepoint(mouse_pos):
                        self.grid_array[index] = (rect, (0, 0, 0))
            
            # Send updates at platform-appropriate intervals
            if current_time - self.last_update_time > self.update_interval:
                new_digit_array = self.get_digit_array()
                if not np.array_equal(new_digit_array, self.digit_array):
                    self.digit_array = new_digit_array
                    # Send to visualization window - clear queue first to prevent lag
                    try:
                        # Clear old updates to keep only the latest
                        while not self.data_queue.empty():
                            try:
                                self.data_queue.get_nowait()
                            except queue.Empty:
                                break
                        self.data_queue.put_nowait(self.digit_array.copy())
                    except queue.Full:
                        # If queue is full, skip this update
                        pass
                self.last_update_time = current_time
            
            # Draw pygame window
            self.window.fill(self.BG_COLOR)
            self.draw_rectangles()
            pygame.display.flip()
            clock.tick(fps)
        
        pygame.quit()


def main():
    try:
        print(f"Starting Cross-Platform CNN Visualization on {platform.system()}")
        
        # Create communication queue
        data_queue = queue.Queue()
        
        # Start drawing window in separate thread (pygame can handle this)
        def run_drawing_window():
            drawing_window = DrawingWindow(data_queue)
            drawing_window.run()
        
        drawing_thread = threading.Thread(target=run_drawing_window, daemon=True)
        drawing_thread.start()
        
        # Reduced initialization delay for faster startup
        time.sleep(0.5)
        
        # Run visualization window in main thread (tkinter requirement for macOS)
        viz_window = CNNVisualizationWindow(data_queue)
        viz_window.run()
        
    except FileNotFoundError:
        print("Could not find 'mnist_cnn_4-8-10.pth' model file")
        print("Make sure the CNN model is in the same directory")
    except Exception as e:
        print(f"Error: {e}")
        import traceback
        traceback.print_exc()

if __name__ == "__main__":
    main()