# MLP Inference with TT-NN

This tutorial demonstrates how to leverage TT-NN with PyTorch for neural network inference tasks. We'll explore two practical applications:

1. **Regression**: Training a neural network to approximate the `sin(x)` function
2. **Classification**: Applying our knowledge to classify handwritten digits from the MNIST dataset

Through these examples, you'll learn how to effectively use TT-NN for tensor operations and accelerated model inference on Tenstorrent hardware.

## Setting Up the Environment

### Required Libraries

To run MLP inference on Tenstorrent devices, we'll import several key libraries:

- **PyTorch (`torch`)**: Core framework for tensor operations and data handling
- **TorchVision (`torchvision`)**: Provides access to the MNIST dataset and image preprocessing utilities
- **Matplotlib (`matplotlib`)**: Visualization tool for plotting regression results
- **TT-NN**: Tenstorrent's neural network library that enables:
  - Hardware-accelerated tensor operations
  - Efficient data layout transformations
  - Optimized layer computations (linear, ReLU, etc.)
- **Loguru (`loguru`)**: Enhanced logging for tracking execution progress and results

In [None]:
import math
import torch
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt  
import ttnn
from loguru import logger

## Open the Device

Create the device to run the program.

In [None]:
# Open Tenstorrent device
device = ttnn.open_device(device_id=0)

## Sine Function Regression

### Overview

In this first task, we'll demonstrate how to:
- Load pre-trained model weights
- Perform inference using TT-NN hardware acceleration
- Validate results against expected outputs

We'll use a neural network trained to approximate the sine function - a classic regression problem that showcases the model's ability to learn non-linear patterns.

### Model Architecture

The pre-trained weights in `mlp_sin.pt` correspond to a simple 3-layer MLP with the following architecture:

```python
import torch.nn as nn

class MLP(nn.Module):
    def __init__(self):
        super().__init__()
        self.l1 = nn.Linear(1, 64)       # Input layer: 1 → 64 neurons
        self.relu1 = nn.ReLU()           # Activation function
        self.l2 = nn.Linear(64, 64)      # Hidden layer: 64 → 64 neurons
        self.relu2 = nn.ReLU()           # Activation function
        self.l3 = nn.Linear(64, 1)       # Output layer: 64 → 1 neuron
    
    def forward(self, x):
        x = self.l1(x)
        x = self.relu1(x)
        x = self.l2(x)
        x = self.relu2(x)
        x = self.l3(x)
        return x
```

### Training Details

This model was trained on input values ranging from `-2π` to `2π`, learning to map each input value to its corresponding sine output.

### Implementation Steps

1. **Generate test data**: Create input values within the trained range
2. **Load weights**: Transfer the pre-trained PyTorch weights to TT-NN layers
3. **Run inference**: Execute the model on Tenstorrent hardware
4. **Visualize results**: Plot predictions against ground truth to verify accuracy

Let's begin by implementing each of these steps:

In [None]:
# Create test data for sin(x) approximation
x_test = torch.linspace(-2 * math.pi, 2 * math.pi, 200).unsqueeze(1)

# Load pretrained weights for the MLP model
sin_weights = torch.load("mlp_sin.pt")
sin_weight_1 = ttnn.from_torch(sin_weights["l1.weight"], dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=device)
sin_bias_1 = ttnn.from_torch(sin_weights["l1.bias"], dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=device)
sin_weight_2 = ttnn.from_torch(sin_weights["l2.weight"], dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=device)
sin_bias_2 = ttnn.from_torch(sin_weights["l2.bias"], dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=device)
sin_weight_3 = ttnn.from_torch(sin_weights["l3.weight"], dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=device)
sin_bias_3 = ttnn.from_torch(sin_weights["l3.bias"], dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=device)
logger.info("Loaded pretrained weights from mlp_sin.pt")

# Prepare weights and biases for TT-NN linear layers
sin_weight_1_final = ttnn.transpose(sin_weight_1, -2, -1)
sin_bias_1_final = ttnn.reshape(sin_bias_1, [1, -1])
sin_weight_2_final = ttnn.transpose(sin_weight_2, -2, -1)
sin_bias_2_final = ttnn.reshape(sin_bias_2, [1, -1])
sin_weight_3_final = ttnn.transpose(sin_weight_3, -2, -1)
sin_bias_3_final = ttnn.reshape(sin_bias_3, [1, -1])

# Run inference on test data using TT-NN
y_pred = []

for i, x in enumerate(x_test):
    if i % 50 == 0:
        logger.info(f"Processing sample {i+1}/{len(x_test)}")
    # Convert input to TT-NN tensor
    x_tt = ttnn.from_torch(x.unsqueeze(0), dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=device)

    # Layer 1: Linear + ReLU
    out1 = ttnn.linear(x_tt, sin_weight_1_final, bias=sin_bias_1_final)
    out1 = ttnn.relu(out1)

    # Layer 2: Linear + ReLU
    out2 = ttnn.linear(out1, sin_weight_2_final, bias=sin_bias_2_final)
    out2 = ttnn.relu(out2)

    # Layer 3: Linear (output, no activation)
    out3 = ttnn.linear(out2, sin_weight_3_final, bias=sin_bias_3_final)

    # Convert TT-NN output back to PyTorch tensor and store prediction
    prediction = ttnn.to_torch(out3).float()
    y_pred.append(float(prediction.cpu().numpy().flatten()[0]))

## Evaluating Model Performance

### Visualizing Regression Results

We've successfully completed inference on our entire test dataset. Now let's assess the model's accuracy by visualizing how well our predictions match the true sine function values.

By plotting both the ground truth (actual sine values) and our model's predictions, we can:
- Visually inspect the quality of the approximation
- Identify any regions where the model struggles
- Confirm that our TT-NN implementation produces accurate results

Let's create a comparison plot to evaluate our regression model:

In [None]:
# Plot true sin(x) and MLP predictions
plt.plot(x_test.numpy(), torch.sin(x_test).numpy(), label="True sin(x)")
plt.plot(x_test.numpy(), y_pred, label="MLP Prediction")
plt.legend()
plt.title("MLP Approximation of sin(x)")
plt.show()

## MNIST Digit Classification

### Overview

Now we'll apply our TT-NN knowledge to a more complex task: classifying handwritten digits from the famous MNIST dataset. This demonstrates how TT-NN handles real-world classification problems with image data.

### Dataset Preparation

The MNIST dataset contains 28×28 grayscale images of handwritten digits (0-9). To use this data with our neural network, we need to:

1. **Load the dataset**: Download MNIST images and labels
2. **Apply transformations**: 
   - Convert images to tensors
   - Normalize pixel values for optimal neural network performance
3. **Create a DataLoader**: Enable efficient batch processing and iteration through the dataset

This preprocessing pipeline ensures our data is properly formatted for inference on Tenstorrent hardware.

In [None]:
# Load MNIST data
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
testset = torchvision.datasets.MNIST(root="./data", train=False, download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=1, shuffle=False)

## Load Pretrained MLP Weights

Load the pretrained MLP weights from a file. Run the following script `train_and_export_mlp.py`

In [None]:
# Load pretrained weights
weights = torch.load("mlp_mnist_weights.pt")
weight_1 = ttnn.from_torch(weights["W1"], dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=device)
bias_1 = ttnn.from_torch(weights["b1"], dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=device)
weight_2 = ttnn.from_torch(weights["W2"], dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=device)
bias_2 = ttnn.from_torch(weights["b2"], dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=device)
weight_3 = ttnn.from_torch(weights["W3"], dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=device)
bias_3 = ttnn.from_torch(weights["b3"], dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=device)
logger.info("Loaded pretrained weights from mlp_mnist_weights.pt")

## Running MNIST Inference

### Implementation Overview

Now we'll implement the inference pipeline for MNIST digit classification using TT-NN. This section demonstrates how to:

- Process image data through our MLP model on Tenstorrent hardware
- Track prediction accuracy
- Handle data format conversions between PyTorch and TT-NN

### Key Steps in the Inference Pipeline

1. **Data Preparation**
   - Flatten 28×28 images into 784-dimensional vectors
   - Convert PyTorch tensors to TT-NN format (bfloat16 precision, TILE_LAYOUT)

2. **Model Execution**
   - Pass data through three fully connected layers
   - Apply ReLU activation after the first two layers
   - Generate logits for 10 digit classes (0-9)

3. **Weight Handling**
   - Transpose weight matrices for TT-NN compatibility
   - Reshape bias vectors to match expected dimensions

4. **Results Processing**
   - Convert outputs back to PyTorch tensors
   - Extract predictions using argmax
   - Compare with ground truth labels
   - Log results and calculate accuracy

### Exercise: Complete the Implementation

Based on the sine regression example, fill in the `TODO` sections to create a working MNIST classifier. Your implementation should:
- Process the first five test samples
- Display predicted vs. actual digit values
- Report the overall accuracy

Let's build the inference loop:

In [None]:
correct = 0
total = 0

# Prepare weights and biases for TT-NN linear layers.
# Transpose weights to match TT-NN's expected shape.
# Reshape biases for broadcasting.
weight_1_final = None # TODO: Add your code here

for i, (image, label) in enumerate(testloader):
    if i >= 5:
        break

    # Flatten image to 1D vector and convert to float32
    image = image.view(1, -1).to(torch.float32)
    
    # Convert PyTorch tensor to TT-NN tensor (bfloat16, TILE_LAYOUT, device)
    image_tt = ttnn.from_torch(image, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=device)
    
    # Layer 1: Linear + ReLU
    # TODO: Add your code here
    
    # Layer 2: Linear + ReLU
    # TODO: Add your code here
    
    # Layer 3: Linear (output logits, no activation)
    out3 = None # TODO: Add your code here

    # Convert TT-NN output back to PyTorch tensor
    prediction = ttnn.to_torch(out3)
    predicted_label = torch.argmax(prediction, dim=1).item()
    
    # Update accuracy counters
    correct += predicted_label == label.item()
    total += 1
    
    logger.info(f"Sample {i+1}: Predicted={predicted_label}, Actual={label.item()}")
    
logger.info(f"\nTT-NN MLP Inference Accuracy: {correct}/{total} = {100.0 * correct / total:.2f}%")


## Conclusion

🎉 **Congratulations!** You've successfully completed two fundamental neural network tasks using TT-NN:

- **Regression**: Approximating the sine function with high accuracy
- **Classification**: Recognizing handwritten digits from the MNIST dataset

Through these examples, you've learned how to:
- Load pre-trained weights into TT-NN models
- Convert data between PyTorch and TT-NN formats
- Execute inference on Tenstorrent hardware
- Evaluate model performance

### Cleaning Up

Before finishing, it's important to properly close the Tenstorrent device to release hardware resources:

In [None]:
ttnn.close_device(device)