# Optical Transformer Transform Pass

This tutorial provides minimal documentation for the Optical Neural Network (ONN) transform pass and layer classes in MASE.

The optical transformer implementation is based on the [Optical Transformers paper](https://arxiv.org/abs/2302.10360).

## Overview

The ONN transform pass replaces standard PyTorch modules with their optical transformer equivalents:

| Original Module | Optical Equivalent |
|-----------------|--------------------|
| `torch.nn.Linear` | `OtLinear` |
| `LlamaAttention` | `OtLlamaAttention` |

## Requirements

The `mase-triton` package is required for ONN transforms:

```bash
pip install mase-triton
```

In [1]:
import torch
from transformers.models.llama.modeling_llama import LlamaAttention, LlamaConfig

from chop.passes.module.transforms.onn.transform import (
    OtLinear,
    OtLlamaAttention,
    OtTransformConfig,
    optical_transformer_module_transform_pass,
)

  from .autonotebook import tqdm as notebook_tqdm


## Configuration

Use `OtTransformConfig` to configure the optical transform parameters:

| Parameter | Type | Default | Description |
|-----------|------|---------|-------------|
| `q_levels` | int | 256 | Number of quantization levels, $2^n$ for n-bit quantization. |
| `q_lut_min` | float | 0.020040 | Minimum LUT value for quantization |
| `q_smooth_factor` | float | 0.9 | Smoothing factor for statistics updates in the training mode |
| `q_init_seed` | int | 0 | Random seed for initialization (only used in triton kernels) |
| `q_bypass` | bool | False | If True, bypass optical quantization |

In [2]:
# Create default configuration
onn_config = OtTransformConfig.create_default()
print("Default ONN config:", onn_config)

# Customize configuration
onn_config["q_levels"] = 256 # 8-bit quantization
onn_config["q_smooth_factor"] = 0.1
print("Modified ONN config:", onn_config)

Default ONN config: {'q_levels': 256, 'q_lut_min': 0.02004, 'q_smooth_factor': 0.9, 'q_init_seed': 0, 'q_bypass': False}
Modified ONN config: {'q_levels': 256, 'q_lut_min': 0.02004, 'q_smooth_factor': 0.1, 'q_init_seed': 0, 'q_bypass': False}


## OtLinear: Optical Linear Layer

`OtLinear` is the optical equivalent of `torch.nn.Linear`. It applies quantized matrix multiplication that simulates optical computing behavior.

In [3]:
# Create a standard linear layer
linear = torch.nn.Linear(in_features=32, out_features=64)

# Convert to optical linear layer
onn_config = OtTransformConfig.create_default()
linear_onn = OtLinear.from_linear(linear, **onn_config)

# Compare outputs
x = torch.randn(2, 32)
y = linear(x)
y_onn = linear_onn(x)

print(f"Original output shape: {y.shape}")
print(f"Optical output shape: {y_onn.shape}")
print(f"Max absolute difference: {(y - y_onn).abs().max().item():.6f}")

Original output shape: torch.Size([2, 64])
Optical output shape: torch.Size([2, 64])
Max absolute difference: 0.035205


## OtLlamaAttention: Optical Llama Attention

`OtLlamaAttention` replaces the HuggingFace `LlamaAttention` with an optical-aware implementation that uses quantized scaled dot-product attention.

In [5]:
# Setup Llama configuration
model_name = "AICrossSim/clm-60m"
hf_config = LlamaConfig.from_pretrained(model_name)

batch_size = 1
seq_len = 16
head_dim = hf_config.hidden_size // hf_config.num_attention_heads

# Create standard attention layer
attn = LlamaAttention(config=hf_config, layer_idx=0)

# Convert to optical attention
onn_config = OtTransformConfig.create_default()
onn_config["q_levels"] = 512
attn_onn = OtLlamaAttention.from_pretrained(attn, layer_idx=0, **onn_config)

# Test forward pass
pos_emb = torch.ones(batch_size, seq_len, head_dim)
x = 3 * torch.randn(batch_size, seq_len, hf_config.hidden_size)

y, _ = attn(x, (pos_emb, pos_emb), None)
attn_onn.train()  # Enable statistics updates
y_onn, _ = attn_onn(x, (pos_emb, pos_emb), None)

print(f"Original output shape: {y.shape}")
print(f"Optical output shape: {y_onn.shape}")

Original output shape: torch.Size([1, 16, 384])
Optical output shape: torch.Size([1, 16, 384])


## Transform Pass: Network-Level Transformation

Use `optical_transformer_module_transform_pass` to transform an entire network. The pass replaces modules based on name matching.

### Pass Arguments

| Key | Description |
|-----|-------------|
| `by` | Matching mode: `"name"` (exact) or `"regex_name"` (regex pattern) |
| `<layer_name>` | Configuration dict for layers matching the name/pattern |
| `default` | Fallback configuration if no pattern matches |

In [6]:
# Define a simple network with attention and linear layers
class SimpleNetwork(torch.nn.Module):
    def __init__(self, hf_config):
        super().__init__()
        self.attn = LlamaAttention(config=hf_config, layer_idx=0)
        self.linear = torch.nn.Linear(
            in_features=hf_config.hidden_size,
            out_features=hf_config.hidden_size,
        )

    def forward(self, x, pos_emb):
        attn_output, _ = self.attn(x, (pos_emb, pos_emb), None)
        output = self.linear(attn_output)
        return output

network = SimpleNetwork(hf_config)
print("Original network:")
print(network)

Original network:
SimpleNetwork(
  (attn): LlamaAttention(
    (q_proj): Linear(in_features=384, out_features=384, bias=False)
    (k_proj): Linear(in_features=384, out_features=128, bias=False)
    (v_proj): Linear(in_features=384, out_features=128, bias=False)
    (o_proj): Linear(in_features=384, out_features=384, bias=False)
  )
  (linear): Linear(in_features=384, out_features=384, bias=True)
)


In [7]:
# Configure the transform pass with regex patterns
onn_config = OtTransformConfig.create_default()
onn_config["q_levels"] = 512

pass_args = {
    "by": "regex_name",  # Use regex matching
    "attn": onn_config,  # Transform the attention layer
    "linear": onn_config,  # Transform the linear layer
    r"attn\.(q|k|v|o)_proj": onn_config,  # Transform Q/K/V/O projections inside attention
}

# Apply the transform
network_onn = optical_transformer_module_transform_pass(network, pass_args)

print("\nTransformed network:")
print(network_onn)

  (q_proj): Linear(in_features=384, out_features=384, bias=False)
  (k_proj): Linear(in_features=384, out_features=128, bias=False)
  (v_proj): Linear(in_features=384, out_features=128, bias=False)
  (o_proj): Linear(in_features=384, out_features=384, bias=False)
) to OtLlamaAttention(
  (q_proj): Linear(in_features=384, out_features=384, bias=False)
  (k_proj): Linear(in_features=384, out_features=128, bias=False)
  (v_proj): Linear(in_features=384, out_features=128, bias=False)
  (o_proj): Linear(in_features=384, out_features=384, bias=False)
)



Transformed network:
SimpleNetwork(
  (attn): OtLlamaAttention(
    (q_proj): OpticalTransformerLinear(q_bypass=False, q_levels=512, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)
    (k_proj): OpticalTransformerLinear(q_bypass=False, q_levels=512, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)
    (v_proj): OpticalTransformerLinear(q_bypass=False, q_levels=512, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)
    (o_proj): OpticalTransformerLinear(q_bypass=False, q_levels=512, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tenso

In [8]:
# Verify the transformation
print("Verification:")
print(f"  attn is OtLlamaAttention: {isinstance(network_onn.attn, OtLlamaAttention)}")
print(f"  linear is OtLinear: {isinstance(network_onn.linear, OtLinear)}")
print(f"  attn.q_proj is OtLinear: {isinstance(network_onn.attn.q_proj, OtLinear)}")
print(f"  attn.k_proj is OtLinear: {isinstance(network_onn.attn.k_proj, OtLinear)}")
print(f"  attn.v_proj is OtLinear: {isinstance(network_onn.attn.v_proj, OtLinear)}")
print(f"  attn.o_proj is OtLinear: {isinstance(network_onn.attn.o_proj, OtLinear)}")

Verification:
  attn is OtLlamaAttention: True
  linear is OtLinear: True
  attn.q_proj is OtLinear: True
  attn.k_proj is OtLinear: True
  attn.v_proj is OtLinear: True
  attn.o_proj is OtLinear: True


In [9]:
# Test the transformed network
network_onn.train()  # Enable statistics updates

pos_emb = torch.ones(batch_size, seq_len, head_dim)
x = 3 * torch.randn(batch_size, seq_len, hf_config.hidden_size)

y = network(x, pos_emb)
y_onn = network_onn(x, pos_emb)
print(f"Output shape: {y_onn.shape}")
print(f"Max output error: {(y - y_onn).abs().max().item():.6f}")
print(f"Output is finite: {y_onn.isfinite().all().item()}")

Output shape: torch.Size([1, 16, 384])
Max output error: 0.029137
Output is finite: True
