In [None]:
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from matplotlib.patches import FancyBboxPatch
import numpy as np

def plot_cnn_architecture():
    """
    Create a visual diagram of the CNN architecture for music genre classification
    """
    fig, ax = plt.subplots(1, 1, figsize=(16, 10))
    
    # Define colors for different layer types
    colors = {
        'input': '#E8F4FD',
        'conv': '#4CAF50',
        'batch_norm': '#FF9800',
        'maxpool': '#2196F3',
        'flatten': '#9C27B0',
        'dense': '#F44336',
        'dropout': '#795548',
        'output': '#FF5722'
    }
    
    # Layer information: (name, type, input_shape, output_shape, position)
    layers = [
        ('Input\n(600, 13, 1)', 'input', (600, 13, 1), (600, 13, 1), 0),
        ('Conv2D\n32 filters\n(3×3)', 'conv', (600, 13, 1), (598, 11, 32), 1),
        ('BatchNorm', 'batch_norm', (598, 11, 32), (598, 11, 32), 2),
        ('MaxPool2D\n(2×2)', 'maxpool', (598, 11, 32), (299, 5, 32), 3),
        ('Conv2D\n64 filters\n(3×3)', 'conv', (299, 5, 32), (297, 3, 64), 4),
        ('BatchNorm', 'batch_norm', (297, 3, 64), (297, 3, 64), 5),
        ('MaxPool2D\n(2×2)', 'maxpool', (297, 3, 64), (148, 1, 64), 6),
        ('Flatten', 'flatten', (148, 1, 64), (9472,), 7),
        ('Dense\n128 units\nReLU', 'dense', (9472,), (128,), 8),
        ('Dropout\n0.5', 'dropout', (128,), (128,), 9),
        ('Dense\n10 units\nSoftmax', 'output', (128,), (10,), 10)
    ]
    
    # Position parameters
    box_width = 1.2
    box_height = 0.8
    spacing = 1.5
    start_x = 0.5
    y_center = 2
    
    # Draw layers
    for i, (name, layer_type, input_shape, output_shape, pos) in enumerate(layers):
        x = start_x + pos * spacing
        
        # Create fancy box for each layer
        if layer_type == 'input':
            box = FancyBboxPatch(
                (x - box_width/2, y_center - box_height/2),
                box_width, box_height,
                boxstyle="round,pad=0.1",
                facecolor=colors[layer_type],
                edgecolor='black',
                linewidth=2
            )
        elif layer_type == 'output':
            box = FancyBboxPatch(
                (x - box_width/2, y_center - box_height/2),
                box_width, box_height,
                boxstyle="round,pad=0.1",
                facecolor=colors[layer_type],
                edgecolor='black',
                linewidth=2
            )
        else:
            box = FancyBboxPatch(
                (x - box_width/2, y_center - box_height/2),
                box_width, box_height,
                boxstyle="round,pad=0.05",
                facecolor=colors[layer_type],
                edgecolor='black',
                linewidth=1
            )
        
        ax.add_patch(box)
        
        # Add layer name
        ax.text(x, y_center, name, ha='center', va='center', 
                fontsize=9, fontweight='bold', wrap=True)
        
        # Add output shape below the box
        shape_text = f"Output: {output_shape}"
        ax.text(x, y_center - 0.6, shape_text, ha='center', va='center', 
                fontsize=7, style='italic', color='darkblue')
        
        # Draw arrows between layers
        if i < len(layers) - 1:
            arrow_start_x = x + box_width/2
            arrow_end_x = start_x + (pos + 1) * spacing - box_width/2
            ax.annotate('', 
                       xy=(arrow_end_x, y_center), 
                       xytext=(arrow_start_x, y_center),
                       arrowprops=dict(arrowstyle='->', lw=2, color='black'))
    
    # Add title and labels
    ax.set_title('CONEqNet: CNN Architecture for Music Genre Classification', 
                fontsize=16, fontweight='bold', pad=20)
    
    # Add legend
    legend_elements = []
    for layer_type, color in colors.items():
        if layer_type in [lt[1] for lt in layers]:
            legend_elements.append(patches.Patch(color=color, label=layer_type.replace('_', ' ').title()))
    
    ax.legend(handles=legend_elements, loc='upper right', bbox_to_anchor=(1, 0.95))
    
    # Add architecture summary box
    summary_text = """Architecture Summary:
• Input: MFCC features (600 time steps × 13 coefficients)
• Feature Extraction: 2 Convolutional blocks with BatchNorm
• Classification: 2 Dense layers with Dropout
• Output: 10 music genres (softmax)"""
    
    ax.text(0.02, 0.02, summary_text, transform=ax.transAxes, 
            fontsize=10, verticalalignment='bottom',
            bbox=dict(boxstyle="round,pad=0.5", facecolor='lightgray', alpha=0.8))
    
    # Set axis properties
    ax.set_xlim(-0.5, start_x + len(layers) * spacing)
    ax.set_ylim(0.5, 3.5)
    ax.set_aspect('equal')
    ax.axis('off')
    
    plt.tight_layout()
    return fig

def plot_detailed_architecture():
    """
    Create a more detailed vertical diagram showing parameter counts and layer details
    """
    fig, ax = plt.subplots(1, 1, figsize=(12, 14))
    
    # Layer details with parameter counts (approximate)
    layer_details = [
        ('Input Layer', 'MFCC Features\n(600, 13, 1)', 0, '#E8F4FD'),
        ('Conv2D Layer 1', '32 filters, 3×3 kernel\nReLU activation\nParams: ~320', 896, '#4CAF50'),
        ('BatchNormalization', 'Normalize activations\nParams: ~128', 128, '#FF9800'),
        ('MaxPooling2D', '2×2 pool size\nParams: 0', 0, '#2196F3'),
        ('Conv2D Layer 2', '64 filters, 3×3 kernel\nReLU activation\nParams: ~18,496', 18496, '#4CAF50'),
        ('BatchNormalization', 'Normalize activations\nParams: ~256', 256, '#FF9800'),
        ('MaxPooling2D', '2×2 pool size\nParams: 0', 0, '#2196F3'),
        ('Flatten', 'Reshape to 1D\nParams: 0', 0, '#9C27B0'),
        ('Dense Layer 1', '128 neurons\nReLU activation\nParams: ~1,212,544', 1212544, '#F44336'),
        ('Dropout', 'Rate: 0.5\nParams: 0', 0, '#795548'),
        ('Output Layer', '10 neurons (genres)\nSoftmax activation\nParams: ~1,290', 1290, '#FF5722')
    ]
    
    # Calculate positions
    y_positions = np.linspace(10, 0, len(layer_details))
    box_height = 0.8
    box_width = 8
    
    total_params = sum(detail[2] for detail in layer_details)
    
    for i, (layer_name, description, params, color) in enumerate(layer_details):
        y = y_positions[i]
        
        # Draw layer box
        box = FancyBboxPatch(
            (-box_width/2, y - box_height/2),
            box_width, box_height,
            boxstyle="round,pad=0.1",
            facecolor=color,
            edgecolor='black',
            linewidth=1.5
        )
        ax.add_patch(box)
        
        # Add layer name (bold, larger)
        ax.text(-box_width/2 + 0.2, y + 0.1, layer_name, 
                fontsize=12, fontweight='bold', va='center')
        
        # Add description
        ax.text(-box_width/2 + 0.2, y - 0.15, description, 
                fontsize=10, va='center')
        
        # Add parameter count on the right
        if params > 0:
            ax.text(box_width/2 - 0.2, y, f"Parameters: {params:,}", 
                    fontsize=10, ha='right', va='center', 
                    bbox=dict(boxstyle="round,pad=0.2", facecolor='white', alpha=0.8))
        
        # Draw arrows between layers
        if i < len(layer_details) - 1:
            ax.annotate('', 
                       xy=(0, y_positions[i+1] + box_height/2), 
                       xytext=(0, y - box_height/2),
                       arrowprops=dict(arrowstyle='->', lw=2, color='darkblue'))
    
    # Add title
    ax.set_title('CONEqNet: Detailed Layer Architecture', 
                fontsize=16, fontweight='bold', pad=20)
    
    # Add total parameters
    ax.text(0, -1, f'Total Parameters: {total_params:,}', 
            fontsize=14, fontweight='bold', ha='center',
            bbox=dict(boxstyle="round,pad=0.5", facecolor='lightblue', alpha=0.8))
    
    # Set axis properties
    ax.set_xlim(-5, 5)
    ax.set_ylim(-1.5, 11)
    ax.axis('off')
    
    plt.tight_layout()
    return fig

if __name__ == "__main__":
    # Generate both diagrams
    print("Generating CNN architecture diagrams...")
    
    # Create horizontal flow diagram
    fig1 = plot_cnn_architecture()
    plt.figure(fig1.number)
    plt.savefig('cnn_architecture_horizontal.png', dpi=300, bbox_inches='tight')
    plt.show()
    
    # Create detailed vertical diagram
    fig2 = plot_detailed_architecture()
    plt.figure(fig2.number)
    plt.savefig('cnn_architecture_detailed.png', dpi=300, bbox_inches='tight')
    plt.show()
    
    print("Diagrams saved as 'cnn_architecture_horizontal.png' and 'cnn_architecture_detailed.png'")

Model loaded from ./saved_models/CONEqNet.keras
You must install pydot (`pip install pydot`) and install graphviz (see instructions at https://graphviz.gitlab.io/download/) for plot_model to work.
Model diagram saved to ./CONEqNet_schema.png
