In [None]:
# Import libraries and load MNIST model
import sys
import os
import numpy as np
import matplotlib.pyplot as plt
from ml_dtypes import bfloat16
import aie.iron as iron
from aie.iron.graph import capture_graph

# Add the mnist module to path
sys.path.append('programming_examples/ml/mnist')

# Import our MNIST model
from mnist import MNISTModel

print("📚 Loading MNIST model...")

# Initialize model
model = MNISTModel(dtype=bfloat16, device="npu")
    
# Load weights
model.load_weights_from_dir("mnist_weights")
print("✅ Model loaded successfully!")



input_image = iron.tensor(np.zeros((128, 768), dtype=bfloat16), dtype=bfloat16, device="npu")
input_image[0] = np.zeros((768), dtype=bfloat16)

model.eval()
with capture_graph() as g:
    output = model.forward(input_image)
    g.compile()

graph = g

# Now:
# Set the input_image values: input_image[0][:] = ...
# Then replay: output = g.replay()
    
# Model info
print(f"Model Architecture:")
print(f"  Input: 768 features")
print(f"  Hidden 1: 128 neurons")
print(f"  Hidden 2: 64 neurons") 
print(f"  Output: 32 classes")
print(f"  Total parameters: {sum(p.numel() for p in model.parameters())}")


In [None]:
from ipycanvas import Canvas
import ipywidgets as widgets
from IPython.display import display
import numpy as np
import time

def create_drawing_canvas():
    # Create canvas with sync_image_data enabled
    canvas = Canvas(width=280, height=280, sync_image_data=True)
    
    # Set white background
    canvas.fill_style = 'white'
    canvas.fill_rect(0, 0, 280, 280)
    
    # Drawing state
    drawing = False
    
    def on_mouse_down(x, y):
        nonlocal drawing
        drawing = True
        canvas.fill_style = 'black'
        canvas.fill_circle(x, y, 4)
    
    def on_mouse_move(x, y):
        if drawing:
            canvas.fill_style = 'black'
            canvas.fill_circle(x, y, 4)
    
    def on_mouse_up(x, y):
        nonlocal drawing
        drawing = False
    
    def clear_canvas():
        canvas.fill_style = 'white'
        canvas.fill_rect(0, 0, 280, 280)
        # Reset the dynamic label
        dynamic_label.value = "Draw a digit and click Predict"
    
    # Connect mouse events
    canvas.on_mouse_down(on_mouse_down)
    canvas.on_mouse_move(on_mouse_move)
    canvas.on_mouse_up(on_mouse_up)
    
    # Create buttons
    clear_button = widgets.Button(description='Clear', button_style='warning')
    predict_button = widgets.Button(description='Predict', button_style='success')
    
    # Create dynamic label positioned at top right
    dynamic_label = widgets.HTML(
        value="",
        layout=widgets.Layout(
            position='absolute',
            top='10px',
            right='10px',
            width='300px',
            background='rgba(240, 240, 240, 0.9)',
            padding='10px',
            border_radius='5px',
            border='1px solid #ccc'
        )
    )
    
    def on_clear(b):
        clear_canvas()
    
    def on_predict(b):
        import time
        
        # Update label to show processing
        dynamic_label.value = "Processing..."
        
        # Start timing
        start_time = time.time()
        
        # Save the canvas to a file
        canvas.to_file("debug_canvas.png")
        
        # Load the saved image using PIL
        from PIL import Image
        img = Image.open("debug_canvas.png").convert('L')  # Convert to grayscale
        
        # Resize to 28x28
        img_resized = img.resize((28, 28))
        
        # Convert to numpy array
        img_array = np.array(img_resized) / 255.0
        
        # INVERT the image (black drawing on white background -> white drawing on black background)
        img_array = 1.0 - img_array
        
        # Normalize (MNIST style)
        mean = 0.1307
        std = 0.3081
        img_normalized = (img_array - mean) / std
        
        # Flatten and crop to 768 features
        img_flat = img_normalized.flatten()
        img_cropped = img_flat[:768]
        
        # Save images at each step
        # 1. Original resized image
        Image.fromarray((img_array * 255).astype(np.uint8)).save("debug_step1_resized.png")
        
        # 2. After inversion
        Image.fromarray((img_array * 255).astype(np.uint8)).save("debug_step2_inverted.png")
        
        # 3. After normalization (clamp to 0-1 for visualization)
        img_vis = np.clip((img_normalized + mean) / std, 0, 1)
        Image.fromarray((img_vis * 255).astype(np.uint8)).save("debug_step3_normalized.png")
        
        # 4. Final image for inference (pad back to 28x28 for visualization)
        img_final = np.zeros(784)  # 28x28 = 784
        img_final[:768] = img_cropped  # Fill first 768 with cropped data
        img_final_vis = np.clip((img_final.reshape(28, 28) + mean) / std, 0, 1)
        Image.fromarray((img_final_vis * 255).astype(np.uint8)).save("debug_step4_final.png")
        
        # Update input_image[0][:] with the drawn image
        input_image[0][:] = img_cropped.astype(bfloat16)
        
        num_runs = 100
        
        aie_time = 0
        # Run AIE inference
        for i in range(num_runs):
            aie_start = time.time()
            output = g.replay()
            aie_end = time.time()
            aie_time += (aie_end - aie_start) * 1000
        aie_time /= num_runs
        
        # Get AIE prediction
        logits = output.numpy()[0]
        aie_top2_indices = np.argsort(logits)[-2:][::-1]  # Top 2, highest first
        
        # Load reference PyTorch model and get prediction
        import torch
        import torch.nn as nn
        # Load the reference model
        class ReferenceMNISTModel(nn.Module):
            def __init__(self):
                super().__init__()
                self.fc1 = nn.Linear(768, 128, bias=False)
                self.fc2 = nn.Linear(128, 64, bias=False)
                self.fc3 = nn.Linear(64, 32, bias=False)
                
            def forward(self, x):
                x = torch.relu(self.fc1(x))
                x = torch.relu(self.fc2(x))
                x = self.fc3(x)
                return x
        
        # Load weights from .npy files
        ref_model = ReferenceMNISTModel()
        ref_model.fc1.weight.data = torch.tensor(np.load('mnist_weights/fc1_weight.npy'), dtype=torch.float32)
        ref_model.fc2.weight.data = torch.tensor(np.load('mnist_weights/fc2_weight.npy'), dtype=torch.float32)
        ref_model.fc3.weight.data = torch.tensor(np.load('mnist_weights/fc3_weight.npy'), dtype=torch.float32)
        
        # Get reference prediction
        ref_input = torch.tensor(img_cropped, dtype=torch.float32).unsqueeze(0)
        ref_time = 0
        for i in range(num_runs):
            ref_start = time.time()
            ref_output = ref_model(ref_input)
            ref_end = time.time()
            ref_logits = ref_output.detach().numpy()[0]
            ref_time += (ref_end - ref_start) * 1000
        ref_time /= num_runs
        
        ref_top2_indices = np.argsort(ref_logits)[-2:][::-1]  # Top 2, highest first
        
        # Update dynamic label with results
        dynamic_label.value = f"""
        <div>
            <h4 style="margin: 0 0 10px 0; color: #333;">Prediction Results</h4>
            <p style="margin: 5px 0;"><strong>Top 2 predictions:</strong></p>
            <p style="margin: 5px 0;">NPU: {aie_top2_indices[0]} → {aie_top2_indices[1]} <span style="color: #666;">({aie_time:.1f}ms)</span></p>
            <p style="margin: 5px 0;">GPU: {ref_top2_indices[0]} → {ref_top2_indices[1]} <span style="color: #666;">({ref_time:.1f}ms)</span></p>
        </div>
        """
        
    ## Canvas drawing
    clear_button.on_click(on_clear)
    predict_button.on_click(on_predict)
    
    # Create a container with relative positioning for the absolute positioned label
    container = widgets.HBox([
        widgets.VBox([
            widgets.HTML("<h3>Draw a digit (0-9):</h3>"),
            canvas,
            widgets.HBox([clear_button, predict_button])
        ], layout=widgets.Layout(
            justify_content='center',
            align_items='center',
            width='100%'
        )),
        dynamic_label
    ], layout=widgets.Layout(
        position='relative',
        width='100%'
    ))
    
    display(container)
    
    return canvas

# Create the drawing canvas
drawing_canvas = create_drawing_canvas()