In [9]:
import pygame
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
from matplotlib.lines import Line2D
import matplotlib.patches as patches
import torch
import torch.nn as nn
import torch.nn.functional as F
from scipy import ndimage
import threading
import time
from matplotlib.backends.backend_tkagg import FigureCanvasTkAgg
import tkinter as tk
import queue


class MinimalCNN(nn.Module):
    def __init__(self):
        super(MinimalCNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 4, kernel_size=5, stride=1)
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv2 = nn.Conv2d(4, 8, kernel_size=3, stride = 1)
        self.conv3 = nn.Conv2d(8, 10, kernel_size=4, stride = 1)
        self.dropout = nn.Dropout(0.2)

    def forward(self, x):
        conv1 = self.conv1(x)
        pool1 = self.pool(F.relu(conv1))
        conv2 = self.conv2(pool1)
        pool2 = self.pool(F.relu(conv2))
        conv3 = self.conv3(pool2)
        conv3 = self.dropout(conv3)
        sum = torch.sum(conv3, dim=(2, 3))
        out = sum.view(-1, 10)
        return conv1, pool1, conv2, pool2, conv3, sum, out


def load_model():
    import os
    model = MinimalCNN()
    
    # Try multiple possible paths for the model file
    possible_paths = [
        'mnist_cnn_4-8-10.pth',  # Current directory
        './mnist_cnn_4-8-10.pth',  # Explicit current directory
        os.path.expanduser('~/fpga_acc/CNN_Accelerator/CNN/mnist_cnn_4-8-10.pth'),  # Full path
        'CNN/mnist_cnn_4-8-10.pth',  # Relative from parent directory
    ]
    
    for path in possible_paths:
        if os.path.exists(path):
            print(f"Loading model from: {path}")
            model.load_state_dict(torch.load(path, map_location='cpu'))
            model.eval()
            return model
    
    # If none found, print current directory info for debugging
    print(f"Current working directory: {os.getcwd()}")
    print(f"Files in current directory: {os.listdir('.')}")
    raise FileNotFoundError("Could not find model file in any expected location")
    
    return model


def center_and_normalize_digit(digit_array, target_size=20):
    rows, cols = np.where(digit_array > 0)
    
    if len(rows) == 0:
        return digit_array
    
    min_row, max_row = rows.min(), rows.max()
    min_col, max_col = cols.min(), cols.max()
    
    digit_region = digit_array[min_row:max_row+1, min_col:max_col+1]
    
    height, width = digit_region.shape
    scale = min(target_size / height, target_size / width)
    
    if scale < 1.0:
        new_height = int(height * scale)
        new_width = int(width * scale)
        digit_region = ndimage.zoom(digit_region, (new_height/height, new_width/width))
    
    centered_digit = np.zeros((28, 28))
    
    region_height, region_width = digit_region.shape
    start_row = (28 - region_height) // 2
    start_col = (28 - region_width) // 2
    
    centered_digit[start_row:start_row+region_height, start_col:start_col+region_width] = digit_region
    
    return centered_digit


def calculate_center_of_mass(digit_array):
    rows, cols = np.where(digit_array > 0)
    if len(rows) == 0:
        return 14, 14
    
    weights = digit_array[rows, cols]
    center_row = np.average(rows, weights=weights)
    center_col = np.average(cols, weights=weights)
    
    return center_row, center_col


def fine_tune_centering(digit_array):
    center_row, center_col = calculate_center_of_mass(digit_array)
    
    target_center = 13.5
    shift_row = int(round(target_center - center_row))
    shift_col = int(round(target_center - center_col))
    
    shift_row = max(-10, min(10, shift_row))
    shift_col = max(-10, min(10, shift_col))
    
    if shift_row != 0 or shift_col != 0:
        digit_array = ndimage.shift(digit_array, (shift_row, shift_col), cval=0.0)
    
    return digit_array


def preprocess_drawing(digit_array):
    centered = center_and_normalize_digit(digit_array)
    final = fine_tune_centering(centered)
    
    tensor = torch.FloatTensor(final)
    normalized = (tensor - 0.5) / 0.5
    input_tensor = normalized.unsqueeze(0).unsqueeze(0)
    
    return input_tensor, final


def predict_digit_improved(model, digit_array):
    if np.sum(digit_array) == 0:
        return 0, 0.0, np.ones(10) * 0.1, digit_array, [np.zeros((4, 24, 24)), np.zeros((4, 12, 12)), np.zeros((8, 8, 8)), np.zeros((8, 4, 4)), np.zeros((10, 1, 1)), np.zeros(10)]
    
    input_tensor, processed_array = preprocess_drawing(digit_array)
    
    with torch.no_grad():
        conv1, pool1, conv2, pool2, conv3, sum_out, outputs = model(input_tensor)
        stages = [conv1, pool1, conv2, pool2, conv3, sum_out]
        probabilities = F.softmax(outputs, dim=1)
        confidence, predicted = torch.max(probabilities, 1)
    
    return predicted.item(), confidence.item(), probabilities.squeeze().numpy(), processed_array, stages


def apply_convolution(num_kernels, num_channels, weights, input_tensor):
    output_maps = []
    
    if len(input_tensor.shape) == 3:
        input_tensor = input_tensor.unsqueeze(0)
    
    input_tensor = input_tensor.float()
    weights = weights.float()
    
    for i in range(0, num_kernels):
        kernels = []
        for j in range(0, num_channels):
            if j < input_tensor.shape[1]:
                conv_output = F.conv2d(input_tensor[:, j:j+1, :, :], weights[i:i+1, j:j+1, :, :], stride=1)
                conv_output_np = conv_output.squeeze(0).squeeze(0).detach().cpu().numpy()
                kernels.append(conv_output_np)
            else:
                kernels.append(np.zeros((1, 1)))
        output_maps.append(kernels)
    return output_maps


def apply_conv3_convolution(weights, input_tensor):
    output_maps = []
    
    if len(input_tensor.shape) == 3:
        input_tensor = input_tensor.unsqueeze(0)
    
    input_tensor = input_tensor.float()
    weights = weights.float()
    
    for filter_idx in range(10):
        filter_channels = []
        for ch in range(8):
            if ch < input_tensor.shape[1]:
                conv_output = F.conv2d(input_tensor[:, ch:ch+1, :, :], weights[filter_idx:filter_idx+1, ch:ch+1, :, :], stride=1)
                conv_output_np = conv_output.squeeze(0).squeeze(0).detach().cpu().numpy()
                filter_channels.append(conv_output_np)
            else:
                filter_channels.append(np.zeros((1, 1)))
        output_maps.append(filter_channels)
    return output_maps


class CNNVisualizationWindow:
    def __init__(self, data_queue):
        self.data_queue = data_queue
        self.root = tk.Tk()
        self.root.title("Real-time CNN Visualization - Switchable Conv3 Kernels/Feature Maps")
        
        # Cross-platform window maximization
        try:
            self.root.state('zoomed')  # Windows
        except tk.TclError:
            # Linux/Mac alternative
            self.root.attributes('-zoomed', True)  # Linux
        except tk.TclError:
            # Fallback: set a large window size
            self.root.geometry('1400x800')
        
        self.conv3_showing_features = [False] * 10
        
        try:
            self.model = load_model()
        except FileNotFoundError:
            print("Could not find 'mnist_cnn_4-8-10.pth' model file")
            self.root.destroy()
            return
            
        self.filters1 = []
        self.filters2 = []
        self.filters3 = []
        
        filters1 = self.model.conv1.weight.data
        filters2 = self.model.conv2.weight.data
        filters3 = self.model.conv3.weight.data
        
        for i in range(filters1.size(0)):
            channels = []
            for j in range(1):
                channels.append(filters1[i][j].detach().cpu().numpy())
            self.filters1.append(channels)
        
        for i in range(filters2.size(0)):
            channels = []
            for j in range(4):
                channels.append(filters2[i][j].detach().cpu().numpy())
            self.filters2.append(channels)
        
        for i in range(filters3.size(0)):
            channels = []
            for j in range(8):
                channels.append(filters3[i][j].detach().cpu().numpy())
            self.filters3.append(channels)
        
        self.filter_bias1 = self.model.conv1.bias.data.detach().cpu().numpy()
        self.filter_bias2 = self.model.conv2.bias.data.detach().cpu().numpy()
        self.filter_bias3 = self.model.conv3.bias.data.detach().cpu().numpy()
        
        self.current_digit_array = np.zeros((28, 28))
        self.current_conv3_channels = None
        
        # Store references to dynamic elements that need to be removed/redrawn
        self.p2_lines = []
        self.p2_texts = []
        
        self.setup_matplotlib()
        self.start_update_loop()
        
    def setup_matplotlib(self):
        plt.ioff()
        self.fig = plt.figure(figsize=(34, 10))
        
        self.canvas = FigureCanvasTkAgg(self.fig, master=self.root)
        self.canvas_widget = self.canvas.get_tk_widget()
        self.canvas_widget.pack(fill=tk.BOTH, expand=True)
        
        self.canvas.mpl_connect('button_press_event', self._on_click)
        
        self.setup_plot_elements()
        
    def _on_click(self, event):
        if event.inaxes is None:
            return
            
        clicked_filter = self._get_conv3_filter_from_axis(event.inaxes)
        if clicked_filter is not None:
            self.conv3_showing_features[clicked_filter] = not self.conv3_showing_features[clicked_filter]
            self.update_visualization_data(self.current_digit_array)
    
    def _get_conv3_filter_from_axis(self, clicked_ax):
        for filter_idx, filter_axes in enumerate(self.axes['fc3_kernels']):
            if clicked_ax in filter_axes:
                return filter_idx
        return None

    def clear_p2_elements(self):
        """Clear P2 lines and text elements"""
        for line in self.p2_lines:
            if line in self.fig.lines:
                line.remove()
        for text in self.p2_texts:
            text.remove()
        self.p2_lines.clear()
        self.p2_texts.clear()

    def draw_p2_labels(self):
        """Draw P2 labels and lines to conv3 kernels"""
        self.clear_p2_elements()
        
        for filter_num, filter_kernel_axes in enumerate(self.axes['fc3_kernels']):
            for kernel_idx, ax in enumerate(filter_kernel_axes):
                if kernel_idx < 8:
                    line_y = ax.get_position().y0 + ax.get_position().height / 2
                    line_x_start = ax.get_position().x0 - 0.008  # Reduced from -0.015
                    line_x_end = ax.get_position().x0
                    
                    # Create and store line
                    line = Line2D(
                        [line_x_start, line_x_end], [line_y, line_y],
                        color='purple', linewidth=1.5,
                        transform=self.fig.transFigure, zorder=0
                    )
                    self.fig.add_artist(line)
                    self.p2_lines.append(line)

                    text_x = line_x_start - 0.002  # Reduced from -0.003
                    text_y = line_y
                    p2_label = 'P2_' + str(kernel_idx + 1)
                    formatted_label = '$' + p2_label + '$'
                    
                    # Create and store text
                    text = self.fig.text(text_x, text_y, formatted_label, fontsize=8, 
                                       verticalalignment='center', horizontalalignment='right')
                    self.p2_texts.append(text)
        
    def plot_matrix(self, ax, matrix, title, cmap='Blues', value_format='.1f', show_grid=True, show_numbers=False, change_color=True, font=-1):
        im = ax.imshow(matrix, cmap=cmap, interpolation='nearest')
        ax.set_title(title, fontsize=9, fontweight='bold')

        if show_grid and matrix.shape[0] <= 10 and matrix.shape[1] <= 10:
            ax.set_xticks(np.arange(-.5, matrix.shape[1], 1), minor=True)
            ax.set_yticks(np.arange(-.5, matrix.shape[0], 1), minor=True)
            ax.grid(which='minor', color='w', linestyle='-', linewidth=0.5)
            ax.set_xticks([])
            ax.set_yticks([])
        else:
            ax.set_xticks([])
            ax.set_yticks([])

        if show_numbers and matrix.shape[0] <= 10 and matrix.shape[1] <= 10:
            max_dim = max(matrix.shape[0], matrix.shape[1])
            if font > 0:
                fontsize = font
            elif matrix.shape[0] == 1 and matrix.shape[1] == 1:
                fontsize = 8
            elif max_dim <= 3:
                fontsize = 6
            elif max_dim <= 5:
                fontsize = 5
            else:
                fontsize = 4
            for i in range(matrix.shape[0]):
                for j in range(matrix.shape[1]):
                    value = matrix[i, j]
                    text_color = 'white' if (abs(value) > 0.3 and change_color) else 'black'
                    if isinstance(value, (int, np.integer)):
                        value_str = str(value)
                    else:
                        value_str = "{:.1f}".format(value) if value_format == '.1f' else "{:.2f}".format(value) if value_format == '.2f' else "{:.3f}".format(value)
                    ax.text(j, i, value_str, ha='center', va='center',
                            color=text_color, fontsize=fontsize)
        return ax

    def add_straight_arrow(self, start_x, start_y, end_x, end_y, color='purple', width=3.0):
        arrow = patches.FancyArrowPatch(
            (start_x, start_y), (end_x, end_y),
            arrowstyle='->', color=color, linewidth=width,
            connectionstyle='arc3,rad=0',
            shrinkA=2, shrinkB=2,
            mutation_scale=20,
            transform=self.fig.transFigure, zorder=0
        )
        self.fig.add_artist(arrow)

    def add_straight_line(self, start_x, start_y, end_x, end_y, color='purple', width=3.0):
        line = Line2D(
            [start_x, end_x], [start_y, end_y],
            color=color, linewidth=width,
            transform=self.fig.transFigure, zorder=0
        )
        self.fig.add_artist(line)
        
    def setup_plot_elements(self):
        self.fig.clear()
        
        self.axes = {}
        
        scale = 0.38
        vertical_shift = 0.34
        
        input_left = -0.015
        kernel1_base_left = 0.26 * scale
        feature1_base_left = 0.42 * scale  # Shortened gap from K1 to FM1
        pooling1_base_left = 0.56 * scale  # Moved left from P1 onwards
        kernels2_base_left = 0.76 * scale  # Moved left
        channels2_base_left = 1.01 * scale  # Moved left
        feature2_base_left = 1.24 * scale  # Moved left
        pooling2_base_left = 1.34 * scale  # Moved left
        fc3_base_left = 1.51 * scale       # Moved left
        feature3_base_left = 2.04 * scale  # Moved left
        out_base_left = 2.17 * scale       # Moved left
        prediction_left = 2.44 * scale     # Moved left

        sub_height = 0.12 * scale
        v_spacing = 0.02 * scale
        top_edge = (0.85 * scale + sub_height) + vertical_shift
        bottom_edge = ((0.85 * scale - (3 * 0.3 * scale)) - (sub_height + v_spacing)) + vertical_shift
        true_center = (top_edge + bottom_edge) / 2

        box_height = 0.25 * scale
        centered_bottom = true_center - box_height / 2

        input_width = 0.25 * scale
        self.axes['input'] = self.fig.add_axes([input_left, centered_bottom, input_width, box_height])
        
        self.axes['kernel1'] = []
        self.axes['feature1'] = []
        self.axes['pool1'] = []
        self.axes['kernel2_channels'] = []
        self.axes['channels2'] = []
        self.axes['feature2'] = []
        self.axes['pool2'] = []
        self.axes['fc3_kernels'] = []
        self.axes['feature3'] = []
        self.axes['sum'] = []
        
        for row in range(4):
            left_kernel = kernel1_base_left
            left_feature = feature1_base_left
            left_pooling = pooling1_base_left
            bottom = (0.77 * scale - (row * 0.3 * scale)) + vertical_shift

            ax_filter = self.fig.add_axes([left_kernel, bottom, 0.14 * scale, 0.14 * scale])
            self.axes['kernel1'].append(ax_filter)

            ax_output = self.fig.add_axes([left_feature, bottom, 0.14 * scale, 0.14 * scale])
            self.axes['feature1'].append(ax_output)

            ax_pooling = self.fig.add_axes([left_pooling, bottom, 0.14 * scale, 0.14 * scale])
            self.axes['pool1'].append(ax_pooling)

        for block_num in range(8):
            block_kernel_axes = []
            block_output_axes = []
            block_top = (1.45 * scale - (block_num * 0.3 * scale)) + vertical_shift

            for sub_row in range(2):
                for sub_col in range(2):
                    sub_width = 0.08 * scale      
                    sub_height = 0.08 * scale     
                    h_spacing_kernels = 0.025 * scale
                    h_spacing_features = 0.005 * scale
                    v_spacing = 0.06 * scale

                    left_kernel = kernels2_base_left + sub_col * (sub_width + h_spacing_kernels)
                    left_feature = channels2_base_left + sub_col * (sub_width + h_spacing_features)
                    bottom = block_top - sub_row * (sub_height + v_spacing)

                    ax_filter = self.fig.add_axes([left_kernel, bottom, sub_width, sub_height])
                    block_kernel_axes.append(ax_filter)

                    ax_output = self.fig.add_axes([left_feature, bottom, sub_width, sub_height])
                    block_output_axes.append(ax_output)

            self.axes['kernel2_channels'].append(block_kernel_axes)
            self.axes['channels2'].append(block_output_axes)

        for row in range(8):
            bottom = (1.365 * scale - (row * 0.3 * scale)) + vertical_shift  
            sub_height = 0.12 * scale
            left_feature = feature2_base_left
            left_pooling = pooling2_base_left

            ax_output = self.fig.add_axes([left_feature, bottom, 0.12 * scale, sub_height])
            self.axes['feature2'].append(ax_output)

            ax_pooling = self.fig.add_axes([left_pooling, bottom, 0.12 * scale, sub_height])
            self.axes['pool2'].append(ax_pooling)

        for filter_num in range(10):
            filter_kernel_axes = []
            
            set_spacing = 0.25 * scale
            filter_top = (1.52 * scale - (filter_num * set_spacing)) + vertical_shift
            
            sub_width = 0.10 * scale      
            sub_height = 0.10 * scale     
            h_spacing = 0.025 * scale
            v_spacing = 0.02 * scale
            
            for sub_row in range(2):
                for sub_col in range(4):
                    kernel_idx = sub_row * 4 + sub_col

                    left_kernel = fc3_base_left + sub_col * (sub_width + h_spacing)
                    bottom = filter_top - sub_row * (sub_height + v_spacing)

                    ax_kernel = self.fig.add_axes([left_kernel, bottom, sub_width, sub_height])
                    filter_kernel_axes.append(ax_kernel)

            self.axes['fc3_kernels'].append(filter_kernel_axes)

        for row in range(10):
            set_spacing = 0.25 * scale
            filter_top = (1.52 * scale - (row * set_spacing)) + vertical_shift
            set_center_y = filter_top - 0.01 * scale
            
            sub_height = 0.12 * scale
            bottom = set_center_y - (sub_height / 2)
            
            left_feature = feature3_base_left

            ax_feature = self.fig.add_axes([left_feature, bottom, 0.12 * scale, sub_height])
            self.axes['feature3'].append(ax_feature)

            ax_out = self.fig.add_axes([out_base_left, bottom, 0.12 * scale, sub_height])
            self.axes['sum'].append(ax_out)

        figure_aspect_ratio = 34 / 10
        prediction_bottom = centered_bottom
        prediction_width = box_height * 0.55  # Make it just a tiny bit thinner
        prediction_height = box_height  
        self.axes['prediction'] = self.fig.add_axes([prediction_left, prediction_bottom, prediction_width, prediction_height])

        self.prediction_arrow = None
        
        self.draw_static_elements()
        
        self.update_visualization_data(self.current_digit_array)
        
    def draw_static_elements(self):
        for i, ax in enumerate(self.axes['kernel1']):
            kernel_title = 'K1_' + str(i+1)
            self.plot_matrix(ax, self.filters1[i][0], kernel_title, cmap='RdBu', show_numbers=True)
        
        for block_num, block_axes in enumerate(self.axes['kernel2_channels']):
            for ch, ax in enumerate(block_axes):
                kernel_title = 'K' + str(block_num+1) + 'C' + str(ch+1)
                self.plot_matrix(ax, self.filters2[block_num][ch], kernel_title, cmap='RdBu', show_numbers=True)
        
        self.add_arrows_and_annotations()
        
    def add_arrows_and_annotations(self):
        input_bbox = self.axes['input'].get_position()
        input_right_x = input_bbox.x0 + input_bbox.width - 0.02  
        input_center_y = input_bbox.y0 + input_bbox.height / 2
        
        for i, ax_kernel in enumerate(self.axes['kernel1']):
            kernel_bbox = ax_kernel.get_position()
            kernel_left_x = kernel_bbox.x0 - 0.012  
            kernel_center_y = kernel_bbox.y0 + kernel_bbox.height / 2
            self.add_straight_arrow(input_right_x, input_center_y, kernel_left_x, kernel_center_y, color='purple', width=1.5)

        for i, ax_filter in enumerate(self.axes['kernel1']):
            ax_output = self.axes['feature1'][i]
            kernel_right_x = ax_filter.get_position().x0 + ax_filter.get_position().width + 0.002
            kernel_center_y = ax_filter.get_position().y0 + ax_filter.get_position().height / 2
            feature_left_x = ax_output.get_position().x0 + 0.015  # Extended arrow end
            feature_center_y = ax_output.get_position().y0 + ax_output.get_position().height / 2
            self.add_straight_arrow(kernel_right_x, kernel_center_y, feature_left_x, feature_center_y, color='purple', width=1.5)
            
            bias_value = self.filter_bias1[i]
            bias_matrix = np.array([[bias_value]])
            bias_width = 0.024  
            bias_height = 0.02   
            bias_center_x = ((kernel_right_x + feature_left_x) / 2) + 0.002 - bias_width / 2  # Moved slightly right
            bias_center_y = ((kernel_center_y + feature_center_y) / 2) - bias_height / 2
            ax_bias = self.fig.add_axes([bias_center_x, bias_center_y, bias_width, bias_height])
            self.plot_matrix(ax_bias, bias_matrix, '', show_numbers=True, value_format='.2f', show_grid=False, change_color=False, font=6)

        for i, ax_output in enumerate(self.axes['feature1']):
            ax_pooling = self.axes['pool1'][i]
            feature_right_x = ax_output.get_position().x0 + ax_output.get_position().width - 0.01  
            feature_center_y = ax_output.get_position().y0 + ax_output.get_position().height / 2
            pool_left_x = ax_pooling.get_position().x0 + 0.01  
            pool_center_y = ax_pooling.get_position().y0 + ax_pooling.get_position().height / 2
            self.add_straight_arrow(feature_right_x, feature_center_y, pool_left_x, pool_center_y, color='purple', width=1.5)

        for block_num, block_kernel_axes in enumerate(self.axes['kernel2_channels']):
            for ch, ax in enumerate(block_kernel_axes):
                line_y = ax.get_position().y0 + ax.get_position().height / 2
                line_x_start = ax.get_position().x0 - 0.01  
                line_x_end = ax.get_position().x0
                self.add_straight_line(line_x_start, line_y, line_x_end, line_y, color='purple', width=1.5)

                text_x = line_x_start - 0.004  
                text_y = line_y
                label_text = 'P1_' + str(ch+1)
                label_formatted = '$' + label_text + '$'
                self.fig.text(text_x, text_y, label_formatted, fontsize=8, verticalalignment='center', horizontalalignment='right')

        for block_num, block_kernel_axes in enumerate(self.axes['kernel2_channels']):
            ax_k0 = block_kernel_axes[0]
            ax_k3 = block_kernel_axes[3]
            kernel_block_x = ax_k3.get_position().x0 + ax_k3.get_position().width + 0.005
            kernel_block_y = (ax_k0.get_position().y0 + ax_k3.get_position().y0 + ax_k3.get_position().height) / 2

            block_output_axes = self.axes['channels2'][block_num]
            ax_f0 = block_output_axes[0]
            ax_f3 = block_output_axes[3]
            filter_block_x = ax_f0.get_position().x0 - 0.002
            filter_block_y = (ax_f0.get_position().y0 + ax_f3.get_position().y0 + ax_f3.get_position().height) / 2

            self.add_straight_arrow(kernel_block_x, kernel_block_y, filter_block_x, filter_block_y, color='purple', width=1.5)

        for block_num, block_output_axes in enumerate(self.axes['channels2']):
            ax_c0 = block_output_axes[0]
            ax_c3 = block_output_axes[3]
            channel_block_x = ax_c3.get_position().x0 + ax_c3.get_position().width - 0.005  

            ax_feature = self.axes['feature2'][block_num]
            feature_block_x = ax_feature.get_position().x0 + 0.01  
            feature_block_y = ax_feature.get_position().y0 + ax_feature.get_position().height / 2
            
            channel_block_y = feature_block_y

            self.add_straight_arrow(channel_block_x, channel_block_y, feature_block_x, feature_block_y, color='purple', width=1.5)
            
            bias_value = self.filter_bias2[block_num]
            bias_matrix = np.array([[bias_value]])
            bias_width = 0.024  
            bias_height = 0.02   
            bias_center_x = ((channel_block_x + feature_block_x) / 2) - bias_width / 2
            bias_center_y = ((channel_block_y + feature_block_y) / 2) - bias_height / 2
            ax_bias = self.fig.add_axes([bias_center_x, bias_center_y, bias_width, bias_height])
            self.plot_matrix(ax_bias, bias_matrix, '', show_numbers=True, value_format='.2f', show_grid=False, change_color=False, font=6)
            
            sigma_x = ((channel_block_x + feature_block_x) / 2)  
            sigma_y = ((channel_block_y + feature_block_y) / 2) + 0.012  
            self.fig.text(sigma_x, sigma_y, 'Σ', fontsize=8, verticalalignment='bottom', horizontalalignment='center')

        # CHANGED: Reversed the direction of arrows between FM2_ and P2_
        for i, ax_output in enumerate(self.axes['feature2']):
            ax_pooling = self.axes['pool2'][i]
            feature_right_x = ax_output.get_position().x0 + ax_output.get_position().width + 0.005  # Right edge of FM2_
            feature_center_y = ax_output.get_position().y0 + ax_output.get_position().height / 2
            pool_left_x = ax_pooling.get_position().x0 - 0.005  # Left edge of P2_
            pool_center_y = ax_pooling.get_position().y0 + ax_pooling.get_position().height / 2
            # Reversed: now arrow goes from P2_ to FM2_ (pool_left_x to feature_right_x)
            self.add_straight_arrow(pool_left_x, pool_center_y, feature_right_x, feature_center_y, color='purple', width=1.5)

        # NOTE: P2 labels are now drawn in draw_p2_labels() method

        for i, filter_kernel_axes in enumerate(self.axes['fc3_kernels']):
            ax_feature = self.axes['feature3'][i]
            
            rightmost_kernel = filter_kernel_axes[3]
            fc3_right_x = rightmost_kernel.get_position().x0 + rightmost_kernel.get_position().width - 0.002  # Extended arrow start further left
            
            feature_left_x = ax_feature.get_position().x0 - 0.008  # Closer to FM3 but still with small gap
            feature_center_y = ax_feature.get_position().y0 + ax_feature.get_position().height / 2
            fc3_center_y = feature_center_y
            
            bias_center_x = (fc3_right_x + feature_left_x) / 2 + 0.005
            
            bias_value = self.filter_bias3[i]
            bias_matrix = np.array([[bias_value]])
            bias_width = 0.024  
            bias_height = 0.02   
            bias_box_x = bias_center_x - bias_width / 2
            bias_box_y = ((fc3_center_y + feature_center_y) / 2) - bias_height / 2
            ax_bias = self.fig.add_axes([bias_box_x, bias_box_y, bias_width, bias_height])
            self.plot_matrix(ax_bias, bias_matrix, '', show_numbers=True, value_format='.2f', show_grid=False, change_color=False, font=6)
            
            arrow_end_x = feature_left_x + 0.015  # Extended further to make arrowhead visible
            self.add_straight_arrow(fc3_right_x, fc3_center_y, arrow_end_x, feature_center_y, color='purple', width=1.5)

            sigma_x = bias_center_x
            sigma_y = ((fc3_center_y + feature_center_y) / 2) + 0.012  
            self.fig.text(sigma_x, sigma_y, 'Σ', fontsize=8, verticalalignment='bottom', horizontalalignment='center')

        for i, ax_feature in enumerate(self.axes['feature3']):
            ax_out = self.axes['sum'][i]
            feature_right_x = ax_feature.get_position().x0 + ax_feature.get_position().width - 0.01
            feature_center_y = ax_feature.get_position().y0 + ax_feature.get_position().height / 2
            sum_left_x = ax_out.get_position().x0 + 0.01
            sum_center_y = ax_out.get_position().y0 + ax_out.get_position().height / 2
            self.add_straight_arrow(feature_right_x, feature_center_y, sum_left_x, sum_center_y, color='purple', width=1.5)
        
    def update_visualization_data(self, digit_array):
        self.current_digit_array = digit_array.copy()
        
        predicted_digit, confidence, _, processed_array, stages = predict_digit_improved(self.model, digit_array)
        
        self.axes['input'].clear()
        self.plot_matrix(self.axes['input'], processed_array, 'Input\nDrawing\n(28×28)', show_numbers=False)
        self.axes['input'].set_title('Input\nDrawing\n(28×28)', fontsize=18, fontweight='bold')
        
        if len(stages) >= 6:
            def extract_tensor_data(tensor):
                if hasattr(tensor, 'detach'):
                    tensor = tensor.detach().cpu()
                if hasattr(tensor, 'numpy'):
                    tensor = tensor.numpy()
                if len(tensor.shape) > 3 and tensor.shape[0] == 1:
                    tensor = tensor[0]
                elif len(tensor.shape) > 2 and tensor.shape[0] == 1:
                    tensor = tensor[0]
                elif len(tensor.shape) > 1 and tensor.shape[0] == 1:
                    tensor = tensor[0]
                return tensor
            
            conv1_out = extract_tensor_data(stages[0])
            pool1_out = extract_tensor_data(stages[1])
            conv2_out = extract_tensor_data(stages[2])
            pool2_out = extract_tensor_data(stages[3])
            conv3_out = extract_tensor_data(stages[4])
            sum_out = extract_tensor_data(stages[5])
            
            pool1_tensor = stages[1]
            if not isinstance(pool1_tensor, torch.Tensor):
                pool1_tensor = torch.tensor(pool1_out).float()
            else:
                pool1_tensor = pool1_tensor.float()
            if len(pool1_tensor.shape) == 3:
                pool1_tensor = pool1_tensor.unsqueeze(0)
            
            try:
                channels2_data = apply_convolution(8, 4, self.model.conv2.weight.data, pool1_tensor)
            except Exception as e:
                print("Error in apply_convolution: " + str(e))
                channels2_data = [[np.zeros((8, 8)) for _ in range(4)] for _ in range(8)]
            
            pool2_tensor = stages[3]
            if not isinstance(pool2_tensor, torch.Tensor):
                pool2_tensor = torch.tensor(pool2_out).float()
            else:
                pool2_tensor = pool2_tensor.float()
            if len(pool2_tensor.shape) == 3:
                pool2_tensor = pool2_tensor.unsqueeze(0)
            
            try:
                self.current_conv3_channels = apply_conv3_convolution(self.model.conv3.weight.data, pool2_tensor)
            except Exception as e:
                print("Error in conv3 convolution: " + str(e))
                self.current_conv3_channels = [[np.zeros((1, 1)) for _ in range(8)] for _ in range(10)]
            
            for i in range(min(4, len(self.axes['feature1']))):
                self.axes['feature1'][i].clear()
                if len(conv1_out.shape) >= 3 and i < conv1_out.shape[0]:
                    feature_title = 'FM1_' + str(i+1)
                    self.plot_matrix(self.axes['feature1'][i], conv1_out[i], feature_title, show_numbers=False)
                
                self.axes['pool1'][i].clear()
                if len(pool1_out.shape) >= 3 and i < pool1_out.shape[0]:
                    pool_title = 'P1_' + str(i+1)
                    self.plot_matrix(self.axes['pool1'][i], pool1_out[i], pool_title, cmap='Blues', show_numbers=False)
            
            for block_num, block_axes in enumerate(self.axes['channels2']):
                if block_num < len(channels2_data):
                    for ch, ax in enumerate(block_axes):
                        ax.clear()
                        if ch < len(channels2_data[block_num]):
                            feature_title = 'FM' + str(block_num+1) + 'C' + str(ch+1)
                            self.plot_matrix(ax, channels2_data[block_num][ch], feature_title, show_numbers=False)
            
            for i in range(min(8, len(self.axes['feature2']))):
                self.axes['feature2'][i].clear()
                if len(conv2_out.shape) >= 3 and i < conv2_out.shape[0]:
                    feature_title = 'FM2_' + str(i+1)
                    self.plot_matrix(self.axes['feature2'][i], conv2_out[i], feature_title, show_numbers=False)
                
                self.axes['pool2'][i].clear()
                if len(pool2_out.shape) >= 3 and i < pool2_out.shape[0]:
                    pool_title = 'P2_' + str(i+1)
                    self.plot_matrix(self.axes['pool2'][i], pool2_out[i], pool_title, cmap='Blues', show_numbers=False)
            
            for filter_num, filter_axes in enumerate(self.axes['fc3_kernels']):
                for kernel_idx, ax in enumerate(filter_axes):
                    ax.clear()
                    if kernel_idx < 8:
                        if self.conv3_showing_features[filter_num]:
                            if (self.current_conv3_channels and 
                                filter_num < len(self.current_conv3_channels) and 
                                kernel_idx < len(self.current_conv3_channels[filter_num])):
                                feature_data = self.current_conv3_channels[filter_num][kernel_idx]
                                self.plot_matrix(ax, feature_data, '', 
                                               show_numbers=False, cmap='Blues')
                        else:
                            kernel_data = self.filters3[filter_num][kernel_idx]
                            self.plot_matrix(ax, kernel_data, '', cmap='RdBu', show_numbers=True, font=5)
            
            # IMPORTANT FIX: Redraw P2 labels after updating conv3 kernels
            self.draw_p2_labels()
            
            for i in range(min(10, len(self.axes['feature3']))):
                self.axes['feature3'][i].clear()
                if len(conv3_out.shape) >= 3 and i < conv3_out.shape[0]:
                    feature_title = 'FM3_' + str(i+1)
                    self.plot_matrix(self.axes['feature3'][i], conv3_out[i], feature_title)
                
                self.axes['sum'][i].clear()
                if len(sum_out.shape) >= 1 and i < len(sum_out):
                    value = sum_out[i] if len(sum_out.shape) >= 1 else sum_out
                    out_value = np.array([[value]])
                    sum_title = 'SUM_' + str(i)
                    self.plot_matrix(self.axes['sum'][i], out_value, sum_title, show_numbers=True, value_format='.3f', cmap='Blues', change_color=False)

        self.axes['prediction'].clear()
        self.axes['prediction'].text(0.5, 0.5, str(predicted_digit), ha='center', va='center',
                                   fontsize=55, fontweight='bold', transform=self.axes['prediction'].transAxes)
        self.axes['prediction'].set_xlim(0, 1)
        self.axes['prediction'].set_ylim(0, 1)
        self.axes['prediction'].set_xticks([])
        self.axes['prediction'].set_yticks([])
        self.axes['prediction'].set_title('Prediction', fontsize=18, fontweight='bold')

        for spine in self.axes['prediction'].spines.values():
            spine.set_visible(True)
            spine.set_linewidth(1)

        for text in self.fig.texts[:]:
            if '%' in str(text.get_text()):
                text.remove()
        
        confidence_percentage = confidence * 100
        pred_bbox = self.axes['prediction'].get_position()
        pred_center_x = (pred_bbox.x0 + pred_bbox.x1) / 2
        confidence_y1 = pred_bbox.y0 - 0.03
        confidence_y2 = pred_bbox.y0 - 0.06
        
        confidence_text = "{:.1f}%".format(confidence_percentage)
        self.fig.text(pred_center_x, confidence_y1, 'Confidence:',
                     ha='center', va='top', fontsize=14)
        self.fig.text(pred_center_x, confidence_y2, confidence_text,
                     ha='center', va='top', fontsize=14)
        
        if len(sum_out) > 0:
            if self.prediction_arrow is not None:
                self.prediction_arrow.remove()
                self.prediction_arrow = None
            
            max_index = np.argmax(sum_out) if len(sum_out.shape) >= 1 else 0
            if max_index < len(self.axes['sum']):
                ax_out = self.axes['sum'][max_index]
                line_y_start = ax_out.get_position().y0 + ax_out.get_position().height / 2
                line_y_end = pred_bbox.y0 + pred_bbox.height / 2
                line_x_start = ax_out.get_position().x0 + ax_out.get_position().width  
                line_x_end = pred_bbox.x0 - 0.005  
                
                self.prediction_arrow = patches.FancyArrowPatch(
                    (line_x_start, line_y_start), (line_x_end, line_y_end),
                    arrowstyle='->', color='purple', linewidth=1.5,
                    connectionstyle='arc3,rad=0',
                    shrinkA=2, shrinkB=2,  
                    mutation_scale=20,
                    transform=self.fig.transFigure, zorder=0
                )
                self.fig.add_artist(self.prediction_arrow)
        
        self.canvas.draw()
        
    def start_update_loop(self):
        def update():
            try:
                while True:
                    digit_array = self.data_queue.get_nowait()
                    self.update_visualization_data(digit_array)
            except queue.Empty:
                pass
            
            self.root.after(1, update)  
        
        update()
        
    def run(self):
        self.root.mainloop()


class DrawingWindow:
    def __init__(self, data_queue):
        self.data_queue = data_queue
        
        pygame.init()
        self.WIDTH, self.HEIGHT = 560, 560
        self.BG_COLOR = (255, 255, 255)
        self.running = True
        self.grid_array = []
        self.digit_array = np.zeros((28, 28))
        
        self.window = pygame.display.set_mode((self.WIDTH, self.HEIGHT))
        pygame.display.set_caption("Draw Here - Click Conv3 Sets to Toggle Kernels/Features")
        self.window.fill(self.BG_COLOR)
        
        self.last_update_time = 0
        self.update_interval = 0.001  
        
        self.setup_grid()
        
    def setup_grid(self):
        block_size = int(self.WIDTH/28)
        self.grid_array = []
        
        for i in range(28):
            for j in range(28):
                grid_block = pygame.Rect(i * block_size, j * block_size, block_size, block_size)
                color = (255, 255, 255)
                self.grid_array.append((grid_block, color))
    
    def get_digit_array(self):
        digit = np.zeros((28, 28))
        
        for row in range(28):
            for col in range(28):
                index = col * 28 + row
                rect, color = self.grid_array[index]
                
                if color == (0, 0, 0):
                    digit[row, col] = 0.9
        
        return digit
    
    def clear_grid(self):
        for i in range(len(self.grid_array)):
            rect, _ = self.grid_array[i]
            self.grid_array[i] = (rect, (255, 255, 255))
    
    def draw_rectangles(self):
        for rect, color in self.grid_array:
            pygame.draw.rect(self.window, color, rect)
            pygame.draw.rect(self.window, (200, 200, 200), rect, 1)
    
    def run(self):
        clock = pygame.time.Clock()
        
        print("- Click and drag to draw")
        print("- Press SPACE to clear")
        print("- Press ESC to exit")
        print("- Click on any Conv3 2x4 grid to toggle between kernels and feature maps")
        
        while self.running:
            current_time = time.time()
            
            for event in pygame.event.get():
                if event.type == pygame.QUIT:
                    self.running = False
                    
                if event.type == pygame.KEYDOWN:
                    if event.key == pygame.K_ESCAPE:
                        self.running = False
                    elif event.key == pygame.K_SPACE:
                        self.clear_grid()
                        self.digit_array = np.zeros((28, 28))
                        try:
                            while not self.data_queue.empty():
                                try:
                                    self.data_queue.get_nowait()
                                except queue.Empty:
                                    break
                            self.data_queue.put_nowait(self.digit_array.copy())
                        except queue.Full:
                            pass
            
            if pygame.mouse.get_pressed()[0]:
                mouse_pos = pygame.mouse.get_pos()
                for index, (rect, color) in enumerate(self.grid_array):
                    if rect.collidepoint(mouse_pos):
                        self.grid_array[index] = (rect, (0, 0, 0))
            
            if current_time - self.last_update_time > self.update_interval:
                new_digit_array = self.get_digit_array()
                if not np.array_equal(new_digit_array, self.digit_array):
                    self.digit_array = new_digit_array
                    try:
                        while not self.data_queue.empty():
                            try:
                                self.data_queue.get_nowait()
                            except queue.Empty:
                                break
                        self.data_queue.put_nowait(self.digit_array.copy())
                    except queue.Full:
                        pass
                self.last_update_time = current_time
            
            self.window.fill(self.BG_COLOR)
            self.draw_rectangles()
            pygame.display.flip()
            clock.tick(120)  
        
        pygame.quit()


def main():
    try:
        data_queue = queue.Queue()  
        
        def run_viz_window():
            viz_window = CNNVisualizationWindow(data_queue)
            viz_window.run()
        
        viz_thread = threading.Thread(target=run_viz_window, daemon=True)
        viz_thread.start()
        
        time.sleep(0.5)  
        
        drawing_window = DrawingWindow(data_queue)
        drawing_window.run()
        
    except FileNotFoundError:
        print("Could not find 'mnist_cnn_4-8-10.pth' model file")
        print("Make sure the CNN model is in the same directory")
    except Exception as e:
        print("Error: " + str(e))

if __name__ == "__main__":
    main()

Loading model from: /home/antoni/fpga_acc/CNN_Accelerator/CNN/mnist_cnn_4-8-10.pth


Exception ignored in: <function WeakMethod.__new__.<locals>._cb at 0x7a606d523400>
Traceback (most recent call last):
  File "/usr/lib/python3.10/weakref.py", line 58, in _cb
    if self._alive:
AttributeError: 'NoneType' object has no attribute '_alive'


- Click and drag to draw
- Press SPACE to clear
- Press ESC to exit
- Click on any Conv3 2x4 grid to toggle between kernels and feature maps
