# Evaluating CNN Performance on RRAM-based In-Memory Computing Accelerators

In this notebook, you will use the [cross-sim](https://github.com/sandialabs/cross-sim) simulator to analyze how the performance of a Convolutional Neural Network (CNN) for MNIST digit recognition is affected when deployed on RRAM-based In-Memory-Computing (IMC) accelerators.

You will:
- Train and evaluate a CNN in standard PyTorch (software-only baseline).
- Evaluate the same trained network using cross-sim to simulate RRAM hardware effects.
- Retrain the network using Hardware-Aware Training (HAT) with cross-sim, then evaluate its performance on simulated hardware.

Finally, you can draw your own digit and see how each network performs on your input!

## 1. Setup and Imports

Let's start by importing the necessary libraries and setting up the environment.

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
from ipycanvas import Canvas
from IPython.display import display

import sys
sys.path.append("../../")
# from cross_sim import CrossSimModel, CrossSimConfig
from simulator import AnalogCore
from simulator import CrossSimParameters
from applications.mvm_params import set_params

## 2. Data Preparation

We will use the MNIST dataset. Let's load and preprocess it.

In [None]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

trainset = torchvision.datasets.MNIST(root='./MNIST', train=True, download=True, transform=transform)
testset = torchvision.datasets.MNIST(root='./MNIST', train=False, download=True, transform=transform)

trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True)
testloader = torch.utils.data.DataLoader(testset, batch_size=1000, shuffle=False)

## 3. Define the CNN Model

We will use a simple CNN suitable for MNIST.

In [None]:
class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, 1)
        self.conv2 = nn.Conv2d(32, 64, 3, 1)
        self.fc1 = nn.Linear(9216, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = torch.relu(self.conv1(x))
        x = torch.relu(self.conv2(x))
        x = torch.flatten(x, 1)
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return x

## 4. Training and Evaluation Functions

Let's define helper functions for training and evaluating the model.

In [None]:
def train(model, device, trainloader, optimizer, criterion, epochs=1):
    model.train()
    for epoch in range(epochs):
        for data, target in trainloader:
            data, target = data.to(device), target.to(device)
            optimizer.zero_grad()
            output = model(data)
            loss = criterion(output, target)
            loss.backward()
            optimizer.step()

def test(model, device, testloader):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for data, target in testloader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            pred = output.argmax(dim=1, keepdim=True)
            correct += pred.eq(target.view_as(pred)).sum().item()
            total += target.size(0)
    return 100. * correct / total

## 5. Case I: PyTorch Training and Inference

Train and evaluate the CNN using only PyTorch (software baseline).

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model_pt = SimpleCNN().to(device)
optimizer = optim.Adam(model_pt.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()

train(model_pt, device, trainloader, optimizer, criterion, epochs=2)
acc_pt = test(model_pt, device, testloader)
print(f'PyTorch (SW) Test Accuracy: {acc_pt:.2f}%')

## 6. Case II: PyTorch Training, CrossSim Inference

Evaluate the PyTorch-trained model using cross-sim to simulate RRAM hardware effects.

In [None]:
# Configure cross-sim for inference
config = CrossSimConfig(device_type='RRAM', noise_std=0.05)
model_cs = CrossSimModel(model_pt, config)

acc_cs = test(model_cs, device, testloader)
print(f'PyTorch Training, CrossSim Inference (HW) Test Accuracy: {acc_cs:.2f}%')

## 7. Case III: CrossSim Training and Inference (Hardware-Aware Training)

Retrain the network using cross-sim to include hardware effects during training (HAT), then evaluate on simulated hardware.

In [None]:
model_hat = SimpleCNN().to(device)
model_hat_cs = CrossSimModel(model_hat, config)
optimizer_hat = optim.Adam(model_hat.parameters(), lr=0.001)

# Hardware-Aware Training
train(model_hat_cs, device, trainloader, optimizer_hat, criterion, epochs=2)
acc_hat = test(model_hat_cs, device, testloader)
print(f'CrossSim Training & Inference (HAT) Test Accuracy: {acc_hat:.2f}%')

## 8. Summary Table

| Case | Training | Inference | Test Accuracy (%) |
|------|----------|-----------|-------------------|
| I    | PyTorch  | PyTorch   | {acc_pt:.2f}      |
| II   | PyTorch  | CrossSim  | {acc_cs:.2f}      |
| III  | CrossSim | CrossSim  | {acc_hat:.2f}     |

Observe how hardware effects degrade accuracy, and how Hardware-Aware Training can help recover it.

## 9. Draw Your Own Digit!

Use the canvas below to draw a digit (0-9). The image will be preprocessed and fed to all three models. See how each model predicts your digit!

In [None]:
canvas = Canvas(width=140, height=140, background_color='black')
canvas.stroke_style = 'white'
canvas.line_width = 12
display(canvas)

def preprocess_canvas(canvas):
    img = np.array(canvas.get_image_data())[..., 0]
    img = img / 255.0
    img = 1.0 - img  # invert
    img = cv2.resize(img, (28, 28), interpolation=cv2.INTER_AREA)
    img = torch.tensor(img, dtype=torch.float32).unsqueeze(0).unsqueeze(0)
    img = (img - 0.1307) / 0.3081
    return img

def predict_digit(img):
    with torch.no_grad():
        img = img.to(device)
        out_pt = model_pt(img)
        out_cs = model_cs(img)
        out_hat = model_hat_cs(img)
        pred_pt = out_pt.argmax(dim=1).item()
        pred_cs = out_cs.argmax(dim=1).item()
        pred_hat = out_hat.argmax(dim=1).item()
    return pred_pt, pred_cs, pred_hat


In [None]:

# Example usage after drawing:
# img = preprocess_canvas(canvas)
# pred_pt, pred_cs, pred_hat = predict_digit(img)
# print(f'PyTorch: {pred_pt}, CrossSim: {pred_cs}, HAT: {pred_hat}')