# Export Hope Protection Algorithms to ONNX

This notebook converts the pre-trained JAX models from the previous notebooks into ONNX format for deployment in the Hope Tauri application.

## Data Flow

```
2_noise_algorithm.ipynb     -> noise_model_export.pkl     -> noise_algorithm.onnx
3_glaze_algorithm.ipynb     -> glaze_model_export.pkl     -> glaze_algorithm.onnx
4_nightshade_algorithm.ipynb -> nightshade_model_export.pkl -> nightshade_algorithm.onnx
```

## Input Files

The following files must exist in Google Drive before running this notebook:

```
/content/drive/MyDrive/hope-models/exports/
|-- noise_model_export.pkl
|-- glaze_model_export.pkl
|-- nightshade_model_export.pkl
```

## Output Files

```
/content/drive/MyDrive/hope-models/onnx/
|-- noise_algorithm.onnx
|-- glaze_algorithm.onnx
|-- nightshade_algorithm.onnx
|-- hope_config.json
```

## Prerequisites

Run these notebooks in order before this one:

1. `0_setup_colab.ipynb` - Environment setup
2. `1_clip_to_jax.ipynb` - Extract CLIP weights and embeddings
3. `2_noise_algorithm.ipynb` - Train and export noise model
4. `3_glaze_algorithm.ipynb` - Train and export glaze model
5. `4_nightshade_algorithm.ipynb` - Train and export nightshade model

Each algorithm notebook saves an export file containing:

- ViT weights (PyTorch format)
- Pre-computed embeddings
- Algorithm parameters
- Architecture configuration

## 1. Install Dependencies

Install JAX with CUDA 12 support for GPU acceleration, Flax for neural network definitions, jax2onnx for ONNX conversion, and onnxruntime-gpu for validation on the T4 GPU.

In [None]:
%pip install -q --upgrade jax[cuda12] jaxlib

%pip install -q flax>=0.12.2

%pip install -q jax2onnx>=0.11.2 onnx onnxsim

%pip install -q onnxruntime-gpu

## 2. Mount Google Drive

Connect to Google Drive to access the exported model files and save ONNX outputs.

In [None]:
import os
from google.colab import drive

drive.mount('/content/drive')

EXPORT_DIR = '/content/drive/MyDrive/hope-models/exports'
RAW_ONNX_DIR = os.path.join(EXPORT_DIR, 'onnx-raw')
ONNX_DIR = '/content/drive/MyDrive/hope-models/onnx'

os.makedirs(RAW_ONNX_DIR, exist_ok=True)
os.makedirs(ONNX_DIR, exist_ok=True)

print(f"Export directory: {EXPORT_DIR}")
print(f"Raw ONNX directory: {RAW_ONNX_DIR}")
print(f"ONNX output directory: {ONNX_DIR}")

## 3. Import Libraries

Import JAX for array operations, Flax Linen for model architecture, and ONNX tools for export and validation.

Note: `jax_enable_x64` is enabled to allow float64 operations required by jax2onnx during ONNX conversion. The float64 truncation warnings from jax2onnx's internal plugin validation are harmless and suppressed via `warnings.filterwarnings`.

In [None]:
import jax
jax.config.update("jax_enable_x64", True)

import warnings
warnings.filterwarnings("ignore", message="Explicitly requested dtype float64", category=UserWarning)

import jax.numpy as jnp

import numpy as np
import pickle
import json

from flax import linen as nn
from typing import Dict, Any, Tuple, Callable

from jax2onnx import to_onnx
import onnx
import onnxruntime as ort

print(f"JAX version: {jax.__version__}")
print(f"Devices: {jax.devices()}")
print(f"Backend: {jax.default_backend()}")

## 4. Verify GPU Availability

Confirm that both JAX and ONNX Runtime have access to the T4 GPU for accelerated computation.

In [None]:
print("JAX Devices:")
for device in jax.devices():
    print(f"  {device}")

print(f"\nJAX Backend: {jax.default_backend()}")

print("\nONNX Runtime Providers:")
available_providers = ort.get_available_providers()
for provider in available_providers:
    print(f"  {provider}")

if 'CUDAExecutionProvider' in available_providers:
    print("\nGPU acceleration: ENABLED")
else:
    print("\nGPU acceleration: NOT AVAILABLE")

## 5. Verify Export Files

Check that all required export files from previous notebooks exist before proceeding.

In [None]:
required_files = [
    'noise_model_export.pkl',
    'glaze_model_export.pkl',
    'nightshade_model_export.pkl'
]

print("Checking export files:")
all_exist = True

for filename in required_files:
    filepath = os.path.join(EXPORT_DIR, filename)
    exists = os.path.exists(filepath)
    status = "FOUND" if exists else "MISSING"

    if exists:
        size = os.path.getsize(filepath) / (1024 ** 2)
        print(f"  {filename}: {status} ({size:.2f} MB)")
    else:
        print(f"  {filename}: {status}")
        all_exist = False

if not all_exist:
    raise FileNotFoundError("Missing export files. Run previous notebooks first.")

print("\nAll export files found.")

## 6. Define Vision Transformer Architecture

The Vision Transformer (ViT-B/32) architecture must match exactly what was used in the training notebooks. This is a functional implementation using Flax Linen.

### Architecture Parameters

| Parameter | Value | Description |
|-----------|-------|-------------|
| hidden_dim | 768 | Transformer hidden dimension |
| num_layers | 12 | Number of transformer blocks |
| num_heads | 12 | Number of attention heads |
| patch_size | 32 | Size of image patches |
| output_dim | 512 | Final embedding dimension |

### Attention Mechanism

The scaled dot-product attention is computed as:

$$\text{Attention}(\mathbf{Q}, \mathbf{K}, \mathbf{V}) = \text{softmax}\left(\frac{\mathbf{Q}\mathbf{K}^T}{\sqrt{d_k}}\right)\mathbf{V}$$

Where $d_k = \frac{768}{12} = 64$ is the dimension per attention head.

### 6.1 Multi-Head Attention

The multi-head attention splits the input into multiple heads, applies attention independently, and concatenates the results.

In [None]:
class MultiHeadAttention(nn.Module):
    num_heads: int

    @nn.compact
    def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
        batch_size, seq_len, d_model = x.shape
        head_dim = d_model // self.num_heads

        qkv = nn.Dense(3 * d_model, name="in_proj")(x)
        qkv = qkv.reshape(batch_size, seq_len, 3, self.num_heads, head_dim)
        qkv = jnp.transpose(qkv, (2, 0, 3, 1, 4))
        q, k, v = qkv[0], qkv[1], qkv[2]

        scale = head_dim ** -0.5
        attn = jnp.einsum('bhqd,bhkd->bhqk', q, k) * scale
        attn = jax.nn.softmax(attn, axis=-1)

        out = jnp.einsum('bhqk,bhkd->bhqd', attn, v)
        out = jnp.transpose(out, (0, 2, 1, 3))
        out = out.reshape(batch_size, seq_len, d_model)

        return nn.Dense(d_model, name="out_proj")(out)

### 6.2 Feed-Forward Network

The feed-forward network expands the dimension by 4x, applies GELU activation, then projects back.

$$\text{FFN}(\mathbf{x}) = \text{GELU}(\mathbf{x}\mathbf{W}_1 + \mathbf{b}_1)\mathbf{W}_2 + \mathbf{b}_2$$

In [None]:
class FeedForward(nn.Module):
    hidden_dim: int

    @nn.compact
    def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
        x = nn.Dense(self.hidden_dim * 4, name="c_fc")(x)
        x = nn.gelu(x)
        x = nn.Dense(self.hidden_dim, name="c_proj")(x)
        return x

### 6.3 Transformer Block

Each transformer block applies layer normalization, multi-head attention, and feed-forward network with residual connections.

$$\mathbf{x}' = \mathbf{x} + \text{Attention}(\text{LN}(\mathbf{x}))$$

$$\mathbf{x}'' = \mathbf{x}' + \text{FFN}(\text{LN}(\mathbf{x}'))$$

In [None]:
class TransformerBlock(nn.Module):
    num_heads: int
    hidden_dim: int

    @nn.compact
    def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
        x = x + MultiHeadAttention(self.num_heads, name="attn")(nn.LayerNorm(name="ln_1")(x))
        x = x + FeedForward(self.hidden_dim, name="mlp")(nn.LayerNorm(name="ln_2")(x))
        return x

### 6.4 Vision Transformer

The complete ViT model processes images through patch embedding, transformer blocks, and final projection.

For a 224x224 input with patch size 32:

$$n_{patches} = \left\lfloor\frac{224}{32}\right\rfloor^2 = 7^2 = 49$$

The sequence length is $n_{patches} + 1 = 50$ (including the prepended class token).

### Forward Pass Pipeline

$$\mathbf{x} \in \mathbb{R}^{1 \times 224 \times 224 \times 3} \xrightarrow{\text{Conv}_{32 \times 32}} \mathbb{R}^{1 \times 49 \times 768} \xrightarrow{[\texttt{CLS}; \cdot] + \mathbf{E}_{pos}} \mathbb{R}^{1 \times 50 \times 768} \xrightarrow{\text{Transformer} \times 12} \mathbb{R}^{1 \times 50 \times 768} \xrightarrow{\texttt{CLS} \to \text{proj}} \mathbb{R}^{1 \times 512}$$

In [None]:
class VisionTransformer(nn.Module):
    hidden_dim: int = 768
    num_layers: int = 12
    num_heads: int = 12
    patch_size: int = 32
    output_dim: int = 512

    @nn.compact
    def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
        batch_size = x.shape[0]

        x = nn.Conv(
            self.hidden_dim,
            kernel_size=(self.patch_size, self.patch_size),
            strides=(self.patch_size, self.patch_size),
            padding='VALID',
            use_bias=False,
            name="conv1"
        )(x)

        x = x.reshape(batch_size, -1, self.hidden_dim)
        num_patches = x.shape[1]

        class_embedding = self.param(
            'class_embedding',
            nn.initializers.zeros,
            (self.hidden_dim,)
        )
        class_tokens = jnp.broadcast_to(
            class_embedding[None, None, :],
            (batch_size, 1, self.hidden_dim)
        )
        x = jnp.concatenate([class_tokens, x], axis=1)

        positional_embedding = self.param(
            'positional_embedding',
            nn.initializers.zeros,
            (num_patches + 1, self.hidden_dim)
        )
        x = x + positional_embedding[None, :, :]

        x = nn.LayerNorm(name="ln_pre")(x)

        for i in range(self.num_layers):
            x = TransformerBlock(
                self.num_heads,
                self.hidden_dim,
                name=f"transformer_resblocks_{i}"
            )(x)

        x = nn.LayerNorm(name="ln_post")(x[:, 0, :])
        x = nn.Dense(self.output_dim, use_bias=False, name="proj")(x)

        return x

## 7. Weight Conversion Function

Convert PyTorch CLIP weights to Flax parameter format. The export files store weights in PyTorch format, which requires transformation for use with Flax.

### Transformation Rules

| PyTorch Format | Flax Format | Transformation |
|----------------|-------------|----------------|
| Linear weight | kernel | Transpose |
| Conv weight | kernel | Permute (2,3,1,0) |
| LayerNorm weight | scale | Direct copy |
| bias | bias | Direct copy |

In [None]:
def flatten_key(key: str) -> str:
    if key.startswith('visual.'):
        key = key[7:]
    key = key.replace('transformer.resblocks.', 'transformer_resblocks_')
    key = key.replace('.', '/')
    return key


def transform_param(key: str, value: np.ndarray) -> Tuple[str, jnp.ndarray]:
    value = jnp.array(value)

    if key == 'class_embedding':
        return 'class_embedding', value

    if key == 'positional_embedding':
        return 'positional_embedding', value

    if key == 'proj':
        if value.shape[0] == 512:
            return 'proj/kernel', value.T
        return 'proj/kernel', value

    if key == 'conv1.weight':
        return 'conv1/kernel', jnp.transpose(value, (2, 3, 1, 0))

    key = flatten_key(key)

    if 'in_proj_weight' in key:
        return key.replace('in_proj_weight', 'in_proj/kernel'), value.T

    if 'in_proj_bias' in key:
        return key.replace('in_proj_bias', 'in_proj/bias'), value

    if 'out_proj/weight' in key:
        return key.replace('out_proj/weight', 'out_proj/kernel'), value.T

    if 'out_proj/bias' in key:
        return key.replace('out_proj/bias', 'out_proj/bias'), value

    if 'ln_1/weight' in key:
        return key.replace('ln_1/weight', 'ln_1/scale'), value

    if 'ln_2/weight' in key:
        return key.replace('ln_2/weight', 'ln_2/scale'), value

    if 'ln_pre/weight' in key or key == 'ln_pre/weight':
        return 'ln_pre/scale', value

    if 'ln_pre/bias' in key or key == 'ln_pre/bias':
        return 'ln_pre/bias', value

    if 'ln_post/weight' in key or key == 'ln_post/weight':
        return 'ln_post/scale', value

    if 'ln_post/bias' in key or key == 'ln_post/bias':
        return 'ln_post/bias', value

    if 'c_fc/weight' in key:
        return key.replace('c_fc/weight', 'c_fc/kernel'), value.T

    if 'c_proj/weight' in key:
        return key.replace('c_proj/weight', 'c_proj/kernel'), value.T

    if 'weight' in key and 'ln' in key:
        return key.replace('weight', 'scale'), value

    if 'weight' in key:
        return key.replace('weight', 'kernel'), value.T

    return key, value


def nest_params(flat_params: Dict[str, jnp.ndarray]) -> Dict[str, Any]:
    nested = {}
    for key, value in flat_params.items():
        parts = key.split('/')
        current = nested
        for part in parts[:-1]:
            if part not in current:
                current[part] = {}
            current = current[part]
        current[parts[-1]] = value
    return nested


def convert_pytorch_weights(pytorch_weights: Dict[str, np.ndarray]) -> Dict[str, Any]:
    flat_params = {}

    for key, value in pytorch_weights.items():
        new_key, new_value = transform_param(key, value)
        flat_params[new_key] = new_value

    nested = nest_params(flat_params)
    return {'params': nested}

### 7.1 Debug Weight Conversion

Verify that the weight conversion produces the expected parameter structure.

In [None]:
def debug_weight_structure(params: Dict[str, Any], prefix: str = "") -> None:
    for key, value in params.items():
        full_key = f"{prefix}/{key}" if prefix else key
        if isinstance(value, dict):
            debug_weight_structure(value, full_key)
        else:
            print(f"  {full_key}: {value.shape}")


def verify_model_params(weights: Dict[str, Any]) -> bool:
    model = VisionTransformer()
    dummy_input = jnp.zeros((1, 224, 224, 3))

    try:
        output = model.apply(weights, dummy_input)
        print(f"Model forward pass successful")
        print(f"Output shape: {output.shape}")
        return True
    except Exception as e:
        print(f"Model forward pass failed: {e}")
        return False

## 8. Export File Loading

Define a helper function to load and validate export files from previous notebooks.

In [None]:
def load_export_file(filename: str) -> Dict[str, Any]:
    filepath = os.path.join(EXPORT_DIR, filename)

    with open(filepath, 'rb') as f:
        data = pickle.load(f)

    print(f"Loaded: {filename}")
    print(f"  Keys: {list(data.keys())}")

    if 'constants' in data:
        print(f"  Input size: {data['constants']['clip_input_size']}")

    if 'default_params' in data:
        print(f"  Parameters: {data['default_params']}")

    return data

## 9. Noise Algorithm Export

### 9.1 Load Noise Model Export

Load the noise model export file created by `2_noise_algorithm.ipynb`.

### Export File Structure

```python
noise_model_export.pkl = {
    'vit_weights': dict,
    'presets': {
        'noise_chaos': jnp.ndarray,    # (8, 512)
        'noise_normal': jnp.ndarray    # (3, 512)
    },
    'constants': {
        'clip_mean': np.ndarray,
        'clip_std': np.ndarray,
        'clip_input_size': int
    },
    'default_params': {
        'intensity': float,
        'iterations': int,
        'alpha_multiplier': float
    },
    'architecture': dict
}
```

In [None]:
noise_export = load_export_file('noise_model_export.pkl')

### 9.2 Extract Noise Model Components

Extract weights, embeddings, and parameters from the export file.

In [None]:
noise_weights = convert_pytorch_weights(noise_export['vit_weights'])

noise_constants = noise_export['constants']
noise_mean = jnp.array(noise_constants['clip_mean'])
noise_std = jnp.array(noise_constants['clip_std'])
noise_input_size = noise_constants['clip_input_size']

chaos_embeddings = jnp.array(noise_export['presets']['noise_chaos'])
normal_embeddings = jnp.array(noise_export['presets']['noise_normal'])

noise_params = noise_export['default_params']

print(f"Input size: {noise_input_size}")
print(f"Chaos embeddings shape: {chaos_embeddings.shape}")
print(f"Normal embeddings shape: {normal_embeddings.shape}")
print(f"Parameters: {noise_params}")

### 9.2.1 Validate Noise Model Weights

Verify the converted weights work with the model before proceeding.

In [None]:
print("Validating noise model weights...")
print("\nParameter structure:")
debug_weight_structure(noise_weights['params'])

print("\nTesting forward pass:")
is_valid = verify_model_params(noise_weights)

if not is_valid:
    raise ValueError("Weight conversion failed for noise model")

### 9.3 Define Noise Loss Function

The noise protection algorithm maximizes similarity to chaos embeddings while minimizing similarity to normal embeddings.

### Loss Function

Given an L2-normalized image embedding $\hat{\mathbf{z}} = \frac{\mathbf{z}}{\|\mathbf{z}\|_2}$, the loss is:

$$\mathcal{L}_{noise}(\mathbf{x}) = -\frac{1}{|C|}\sum_{i=1}^{|C|} \hat{\mathbf{z}}^T \mathbf{e}_i^{chaos} + \frac{1}{|N|}\sum_{j=1}^{|N|} \hat{\mathbf{z}}^T \mathbf{e}_j^{normal}$$

Where:
- $\hat{\mathbf{z}} \in \mathbb{R}^{512}$ is the L2-normalized ViT embedding of the input image
- $\mathbf{e}_i^{chaos} \in \mathbb{R}^{512}$ are chaos text embeddings ($|C| = 8$)
- $\mathbf{e}_j^{normal} \in \mathbb{R}^{512}$ are normal text embeddings ($|N| = 3$)

Minimizing $\mathcal{L}_{noise}$ simultaneously increases $\hat{\mathbf{z}}^T \mathbf{e}_i^{chaos}$ (moves toward chaotic semantics) and decreases $\hat{\mathbf{z}}^T \mathbf{e}_j^{normal}$ (moves away from normal semantics).

### CLIP Preprocessing

Images are normalized using CLIP statistics before passing through the ViT:

$$\mathbf{x}_{norm} = \frac{\mathbf{x} - \boldsymbol{\mu}}{\boldsymbol{\sigma}}$$

Where $\boldsymbol{\mu} = (0.48145466, 0.4578275, 0.40821073)$ and $\boldsymbol{\sigma} = (0.26862954, 0.26130258, 0.27577711)$.

This normalization is baked into each ONNX model, so the Rust application only needs to provide pixel values in $[0, 1]$.

In [None]:
def create_noise_loss_fn(
    params: Dict[str, Any],
    mean: jnp.ndarray,
    std: jnp.ndarray,
    chaos_emb: jnp.ndarray,
    normal_emb: jnp.ndarray
) -> Callable:

    model = VisionTransformer()

    def loss_fn(image: jnp.ndarray) -> jnp.ndarray:
        normalized = (image - mean) / std
        features = model.apply(params, normalized)

        features = features / (jnp.linalg.norm(features, axis=-1, keepdims=True) + 1e-8)

        sim_chaos = jnp.mean(jnp.matmul(features, chaos_emb.T))
        sim_normal = jnp.mean(jnp.matmul(features, normal_emb.T))

        return -sim_chaos + sim_normal

    return loss_fn


noise_loss_fn = create_noise_loss_fn(
    noise_weights,
    noise_mean,
    noise_std,
    chaos_embeddings,
    normal_embeddings
)

### 9.4 Test Noise Loss Function

Verify the loss function works correctly before ONNX export.

In [None]:
test_image = jnp.ones((1, noise_input_size, noise_input_size, 3), dtype=jnp.float32) * 0.5
test_loss = noise_loss_fn(test_image)

print(f"Test image shape: {test_image.shape}")
print(f"Test loss value: {test_loss}")

### 9.5 Export Noise Algorithm to ONNX

Convert the noise loss function to ONNX format using jax2onnx.

In [None]:
from jax2onnx import to_onnx

try:
    sample_input = jnp.zeros((1, noise_input_size, noise_input_size, 3), dtype=jnp.float32)

    test_output = noise_loss_fn(sample_input)
    print(f"Pre-export test passed, loss value: {test_output}")

    noise_onnx_path = os.path.join(RAW_ONNX_DIR, 'noise_algorithm.onnx')

    to_onnx(
        noise_loss_fn,
        [(1, noise_input_size, noise_input_size, 3)],
        return_mode="file",
        output_path=noise_onnx_path
    )

    file_size = os.path.getsize(noise_onnx_path) / (1024 ** 2)
    print(f"Exported: {noise_onnx_path}")
    print(f"Size: {file_size:.2f} MB")

except Exception as e:
    print(f"Export failed: {e}")
    import traceback
    traceback.print_exc()

## 10. Glaze Algorithm Export

### 10.1 Load Glaze Model Export

Load the glaze model export file created by `3_glaze_algorithm.ipynb`.

### Export File Structure

```python
glaze_model_export.pkl = {
    'vit_weights': dict,
    'presets': {
        'style_embeddings': dict,      # {style_name: (512,)}
        'source_style_emb': np.ndarray # (512,)
    },
    'style_prompts': dict,
    'source_prompt': str,
    'constants': dict,
    'parameter_presets': dict,
    'default_params': dict,
    'architecture': dict,
    'algorithm': dict
}
```

In [None]:
glaze_export = load_export_file('glaze_model_export.pkl')

### 10.2 Extract Glaze Model Components

Extract weights, style embeddings, and parameters.

In [None]:
glaze_weights = convert_pytorch_weights(glaze_export['vit_weights'])

glaze_constants = glaze_export['constants']
glaze_mean = jnp.array(glaze_constants['clip_mean'])
glaze_std = jnp.array(glaze_constants['clip_std'])
glaze_input_size = glaze_constants['clip_input_size']

style_emb_dict = glaze_export['presets']['style_embeddings']
style_names = list(style_emb_dict.keys())
style_embeddings = jnp.stack([jnp.array(style_emb_dict[name]) for name in style_names])
source_embedding = jnp.array(glaze_export['presets']['source_style_emb'])

glaze_params = glaze_export['default_params']
glaze_presets = glaze_export.get('parameter_presets', {})

print(f"Input size: {glaze_input_size}")
print(f"Style names: {style_names}")
print(f"Style embeddings shape: {style_embeddings.shape}")
print(f"Source embedding shape: {source_embedding.shape}")
print(f"Parameters: {glaze_params}")

### 10.3 Define Glaze Loss Function

The Glaze algorithm performs style cloaking by pushing the image embedding away from the source (realistic photo) toward a target artistic style.

### Loss Function

Given an L2-normalized image embedding $\hat{\mathbf{z}} = \frac{\mathbf{z}}{\|\mathbf{z}\|_2}$, the loss is:

$$\mathcal{L}_{glaze}(\mathbf{x}, i) = \hat{\mathbf{z}}^T \mathbf{e}_{source} - \hat{\mathbf{z}}^T \mathbf{e}_{style}^{(i)}$$

Where:
- $\hat{\mathbf{z}} \in \mathbb{R}^{512}$ is the L2-normalized ViT embedding of the input image
- $\mathbf{e}_{source} \in \mathbb{R}^{512}$ is the CLIP text embedding for "realistic photograph with natural lighting"
- $\mathbf{e}_{style}^{(i)} \in \mathbb{R}^{512}$ is the CLIP text embedding for style $i$
- $i \in \{0, 1, 2, 3, 4\}$ is the style index

Minimizing $\mathcal{L}_{glaze}$ simultaneously decreases $\hat{\mathbf{z}}^T \mathbf{e}_{source}$ (moves away from realistic photo) and increases $\hat{\mathbf{z}}^T \mathbf{e}_{style}^{(i)}$ (moves toward artistic style).

### Style Index Mapping

| Index | Style |
|-------|-------|
| 0 | Abstract |
| 1 | Impressionist |
| 2 | Cubist |
| 3 | Sketch |
| 4 | Watercolor |

In [None]:
def create_glaze_loss_fn(
    params: Dict[str, Any],
    mean: jnp.ndarray,
    std: jnp.ndarray,
    source_emb: jnp.ndarray,
    style_emb: jnp.ndarray
) -> Callable:

    model = VisionTransformer()

    def loss_fn(image: jnp.ndarray, style_index: jnp.ndarray) -> jnp.ndarray:
        normalized = (image - mean) / std
        features = model.apply(params, normalized)

        features = features / (jnp.linalg.norm(features, axis=-1, keepdims=True) + 1e-8)
        features = features[0]

        target_style = style_emb[style_index[0]]

        sim_source = jnp.dot(features, source_emb)
        sim_style = jnp.dot(features, target_style)

        return sim_source - sim_style

    return loss_fn


glaze_loss_fn = create_glaze_loss_fn(
    glaze_weights,
    glaze_mean,
    glaze_std,
    source_embedding,
    style_embeddings
)

### 10.4 Test Glaze Loss Function

In [None]:
test_image = jnp.ones((1, glaze_input_size, glaze_input_size, 3), dtype=jnp.float32) * 0.5
test_style_idx = jnp.array([0], dtype=jnp.int32)
test_loss = glaze_loss_fn(test_image, test_style_idx)

print(f"Test image shape: {test_image.shape}")
print(f"Test style: {style_names[0]}")
print(f"Test loss value: {test_loss}")

### 10.5 Export Glaze Algorithm to ONNX

In [None]:
try:
    sample_image = jnp.zeros((1, glaze_input_size, glaze_input_size, 3), dtype=jnp.float32)
    sample_style_idx = jnp.array([0], dtype=jnp.int32)

    test_output = glaze_loss_fn(sample_image, sample_style_idx)
    print(f"Pre-export test passed, loss value: {test_output}")

    glaze_onnx_path = os.path.join(RAW_ONNX_DIR, 'glaze_algorithm.onnx')

    to_onnx(
        glaze_loss_fn,
        [
            jax.ShapeDtypeStruct((1, glaze_input_size, glaze_input_size, 3), jnp.float32),
            jax.ShapeDtypeStruct((1,), jnp.int32)
        ],
        return_mode="file",
        output_path=glaze_onnx_path
    )

    file_size = os.path.getsize(glaze_onnx_path) / (1024 ** 2)
    print(f"Exported: {glaze_onnx_path}")
    print(f"Size: {file_size:.2f} MB")

except Exception as e:
    print(f"Export failed: {e}")
    import traceback
    traceback.print_exc()

## 11. Nightshade Algorithm Export

### 11.1 Load Nightshade Model Export

Load the nightshade model export file created by `4_nightshade_algorithm.ipynb`.

### Export File Structure

```python
nightshade_model_export.pkl = {
    'vit_weights': dict,
    'presets': {
        'target_embeddings': dict,  # {target_name: (512,)}
        'generic_emb': np.ndarray   # (512,)
    },
    'target_prompts': dict,
    'generic_prompt': str,
    'constants': dict,
    'parameter_presets': dict,
    'default_params': dict,
    'architecture': dict,
    'algorithm': dict
}
```

In [None]:
nightshade_export = load_export_file('nightshade_model_export.pkl')

### 11.2 Extract Nightshade Model Components

In [None]:
nightshade_weights = convert_pytorch_weights(nightshade_export['vit_weights'])

nightshade_constants = nightshade_export['constants']
nightshade_mean = jnp.array(nightshade_constants['clip_mean'])
nightshade_std = jnp.array(nightshade_constants['clip_std'])
nightshade_input_size = nightshade_constants['clip_input_size']

target_emb_dict = nightshade_export['presets']['target_embeddings']
target_names = list(target_emb_dict.keys())
target_embeddings = jnp.stack([jnp.array(target_emb_dict[name]) for name in target_names])
generic_embedding = jnp.array(nightshade_export['presets']['generic_emb'])

nightshade_params = nightshade_export['default_params']
nightshade_presets = nightshade_export.get('parameter_presets', {})

print(f"Input size: {nightshade_input_size}")
print(f"Target names: {target_names}")
print(f"Target embeddings shape: {target_embeddings.shape}")
print(f"Generic embedding shape: {generic_embedding.shape}")
print(f"Parameters: {nightshade_params}")

### 11.3 Define Nightshade Loss Function

The Nightshade algorithm performs targeted data poisoning by making AI models misclassify the image as a different category.

### Loss Function

Given an L2-normalized image embedding $\hat{\mathbf{z}} = \frac{\mathbf{z}}{\|\mathbf{z}\|_2}$, the loss is:

$$\mathcal{L}_{nightshade}(\mathbf{x}, j) = \hat{\mathbf{z}}^T \mathbf{e}_{generic} - \hat{\mathbf{z}}^T \mathbf{e}_{target}^{(j)}$$

Where:
- $\hat{\mathbf{z}} \in \mathbb{R}^{512}$ is the L2-normalized ViT embedding of the input image
- $\mathbf{e}_{generic} \in \mathbb{R}^{512}$ is the CLIP text embedding for "a clear photograph"
- $\mathbf{e}_{target}^{(j)} \in \mathbb{R}^{512}$ is the CLIP text embedding for poison target $j$
- $j \in \{0, 1, \ldots, 7\}$ is the target index

Minimizing $\mathcal{L}_{nightshade}$ simultaneously decreases $\hat{\mathbf{z}}^T \mathbf{e}_{generic}$ (moves away from generic photo) and increases $\hat{\mathbf{z}}^T \mathbf{e}_{target}^{(j)}$ (moves toward the poison target concept).

### Target Index Mapping

| Index | Target |
|-------|--------|
| 0 | Dog |
| 1 | Cat |
| 2 | Car |
| 3 | Landscape |
| 4 | Person |
| 5 | Building |
| 6 | Food |
| 7 | Abstract |

In [None]:
def create_nightshade_loss_fn(
    params: Dict[str, Any],
    mean: jnp.ndarray,
    std: jnp.ndarray,
    generic_emb: jnp.ndarray,
    target_emb: jnp.ndarray
) -> Callable:

    model = VisionTransformer()

    def loss_fn(image: jnp.ndarray, target_index: jnp.ndarray) -> jnp.ndarray:
        normalized = (image - mean) / std
        features = model.apply(params, normalized)

        features = features / (jnp.linalg.norm(features, axis=-1, keepdims=True) + 1e-8)
        features = features[0]

        target = target_emb[target_index[0]]

        sim_generic = jnp.dot(features, generic_emb)
        sim_target = jnp.dot(features, target)

        return sim_generic - sim_target

    return loss_fn


nightshade_loss_fn = create_nightshade_loss_fn(
    nightshade_weights,
    nightshade_mean,
    nightshade_std,
    generic_embedding,
    target_embeddings
)

### 11.4 Test Nightshade Loss Function

In [None]:
test_image = jnp.ones((1, nightshade_input_size, nightshade_input_size, 3), dtype=jnp.float32) * 0.5
test_target_idx = jnp.array([1], dtype=jnp.int32)
test_loss = nightshade_loss_fn(test_image, test_target_idx)

print(f"Test image shape: {test_image.shape}")
print(f"Test target: {target_names[1]}")
print(f"Test loss value: {test_loss}")

### 11.5 Export Nightshade Algorithm to ONNX

In [None]:
try:
    sample_image = jnp.zeros((1, nightshade_input_size, nightshade_input_size, 3), dtype=jnp.float32)
    sample_target_idx = jnp.array([0], dtype=jnp.int32)

    test_output = nightshade_loss_fn(sample_image, sample_target_idx)
    print(f"Pre-export test passed, loss value: {test_output}")

    nightshade_onnx_path = os.path.join(RAW_ONNX_DIR, 'nightshade_algorithm.onnx')

    to_onnx(
        nightshade_loss_fn,
        [
            jax.ShapeDtypeStruct((1, nightshade_input_size, nightshade_input_size, 3), jnp.float32),
            jax.ShapeDtypeStruct((1,), jnp.int32)
        ],
        return_mode="file",
        output_path=nightshade_onnx_path
    )

    file_size = os.path.getsize(nightshade_onnx_path) / (1024 ** 2)
    print(f"Exported: {nightshade_onnx_path}")
    print(f"Size: {file_size:.2f} MB")

except Exception as e:
    print(f"Export failed: {e}")
    import traceback
    traceback.print_exc()

## 12. Consolidate ONNX Models

jax2onnx spills tensors larger than 1 MB to external `.data` files. Since each ViT-B/32 model is ~350 MB (well under protobuf's 2 GB limit), we re-inline all weights into a single `.onnx` file per model for simpler deployment.

Additionally, jax2onnx may export some weight initializers as float16 while the graph operations expect float32, causing `MatMul` type mismatches. We fix these during consolidation.

Raw exports from `onnx-raw/` are consolidated into `onnx/`, then the raw directory is deleted.

In [None]:
from onnx import numpy_helper, TensorProto
import shutil


def fix_float16_initializers(model: onnx.ModelProto) -> int:
    fixed = 0
    for initializer in model.graph.initializer:
        if initializer.data_type == TensorProto.FLOAT16:
            arr = numpy_helper.to_array(initializer).astype(np.float32)
            new_tensor = numpy_helper.from_array(arr, name=initializer.name)
            initializer.CopyFrom(new_tensor)
            fixed += 1

    if fixed > 0:
        del model.graph.value_info[:]

    return fixed


def rename_model_inputs(model: onnx.ModelProto, name_map: Dict[str, str]) -> int:
    renamed = 0
    for inp in model.graph.input:
        if inp.name in name_map:
            old_name = inp.name
            new_name = name_map[old_name]
            for node in model.graph.node:
                for i, node_input in enumerate(node.input):
                    if node_input == old_name:
                        node.input[i] = new_name
            inp.name = new_name
            renamed += 1
            print(f"  Renamed input: {old_name} -> {new_name}")
    return renamed


NOISE_INPUT_NAMES = {'in_0': 'input'}
GLAZE_INPUT_NAMES = {'in_0': 'input', 'in_1': 'style_index'}
NIGHTSHADE_INPUT_NAMES = {'in_0': 'input', 'in_1': 'target_index'}


def consolidate_onnx_model(raw_path: str, output_path: str, input_names: Dict[str, str] = None) -> None:
    model = onnx.load(raw_path, load_external_data=True)

    fixed = fix_float16_initializers(model)
    if fixed > 0:
        print(f"  Fixed {fixed} float16 initializers -> float32")

    if input_names:
        rename_model_inputs(model, input_names)

    onnx.save_model(model, output_path)

    size = os.path.getsize(output_path) / (1024 ** 2)
    print(f"Consolidated: {os.path.basename(output_path)} ({size:.2f} MB)")


print("Consolidating ONNX models into", ONNX_DIR, "...\n")

noise_onnx_path = os.path.join(ONNX_DIR, 'noise_algorithm.onnx')
consolidate_onnx_model(os.path.join(RAW_ONNX_DIR, 'noise_algorithm.onnx'), noise_onnx_path, NOISE_INPUT_NAMES)

glaze_onnx_path = os.path.join(ONNX_DIR, 'glaze_algorithm.onnx')
consolidate_onnx_model(os.path.join(RAW_ONNX_DIR, 'glaze_algorithm.onnx'), glaze_onnx_path, GLAZE_INPUT_NAMES)

nightshade_onnx_path = os.path.join(ONNX_DIR, 'nightshade_algorithm.onnx')
consolidate_onnx_model(os.path.join(RAW_ONNX_DIR, 'nightshade_algorithm.onnx'), nightshade_onnx_path, NIGHTSHADE_INPUT_NAMES)

shutil.rmtree(RAW_ONNX_DIR)
print(f"\nCleaned up raw exports: {RAW_ONNX_DIR}")

## 13. Simplify ONNX Models

Optimize each ONNX model by folding constants and eliminating redundant operations. This reduces model size and improves inference performance.

In [None]:
from onnxsim import simplify

def simplify_onnx_model(model_path: str) -> None:
    original_size = os.path.getsize(model_path) / (1024 ** 2)

    model = onnx.load(model_path)
    del model.graph.value_info[:]
    simplified_model, check = simplify(model)

    if check:
        onnx.save(simplified_model, model_path)
        new_size = os.path.getsize(model_path) / (1024 ** 2)
        reduction = (1 - new_size / original_size) * 100
        print(f"Simplified: {os.path.basename(model_path)}")
        print(f"  Size: {original_size:.2f} MB -> {new_size:.2f} MB")
        print(f"  Reduction: {reduction:.1f}%")
    else:
        print(f"Simplification failed: {os.path.basename(model_path)}")


print("Simplifying ONNX models...\n")
simplify_onnx_model(noise_onnx_path)
print()
simplify_onnx_model(glaze_onnx_path)
print()
simplify_onnx_model(nightshade_onnx_path)

## 14. Validate ONNX Models

Verify each exported model produces correct outputs compared to the original JAX implementation.

### Validation Criteria

The maximum absolute difference between JAX and ONNX outputs should be within numerical tolerance:

$$|\mathcal{L}_{jax} - \mathcal{L}_{onnx}| < 10^{-4}$$

In [None]:
def validate_onnx_model(
    model_path: str,
    jax_fn: Callable,
    sample_inputs: list
) -> bool:

    print(f"Validating: {os.path.basename(model_path)}")

    model = onnx.load(model_path)
    onnx.checker.check_model(model)
    print("  ONNX structure check: PASSED")

    providers = [
        ('CUDAExecutionProvider', {
            'device_id': 0,
            'arena_extend_strategy': 'kNextPowerOfTwo',
            'cudnn_conv_algo_search': 'EXHAUSTIVE',
        }),
        'CPUExecutionProvider'
    ]

    session = ort.InferenceSession(model_path, providers=providers)
    active_provider = session.get_providers()[0]
    print(f"  Execution provider: {active_provider}")

    ort_inputs = {}
    for i, inp in enumerate(session.get_inputs()):
        ort_inputs[inp.name] = np.array(sample_inputs[i])

    onnx_output = session.run(None, ort_inputs)[0]
    jax_output = np.array(jax_fn(*sample_inputs))

    if jax_output.shape == ():
        jax_output = jax_output.reshape(onnx_output.shape)

    max_diff = np.abs(jax_output - onnx_output).max()
    is_valid = max_diff < 1e-4

    print(f"  JAX output: {float(jax_output.flatten()[0]):.6f}")
    print(f"  ONNX output: {float(onnx_output.flatten()[0]):.6f}")
    print(f"  Max difference: {max_diff:.2e}")
    print(f"  Validation: {'PASSED' if is_valid else 'FAILED'}")

    return is_valid

### 14.1 Validate All Models

In [None]:
test_image = jnp.ones((1, 224, 224, 3), dtype=jnp.float32) * 0.5

print("=" * 60)
print("ONNX MODEL VALIDATION")
print("=" * 60)

print("\nNoise Algorithm:")
noise_valid = validate_onnx_model(
    noise_onnx_path,
    noise_loss_fn,
    [test_image]
)

print("\nGlaze Algorithm:")
glaze_valid = validate_onnx_model(
    glaze_onnx_path,
    glaze_loss_fn,
    [test_image, jnp.array([2], dtype=jnp.int32)]
)

print("\nNightshade Algorithm:")
nightshade_valid = validate_onnx_model(
    nightshade_onnx_path,
    nightshade_loss_fn,
    [test_image, jnp.array([3], dtype=jnp.int32)]
)

print("\n" + "=" * 60)
all_valid = noise_valid and glaze_valid and nightshade_valid
print(f"Overall validation: {'PASSED' if all_valid else 'FAILED'}")
print("=" * 60)

## 15. Performance Benchmark

Compare inference speed between CPU and GPU execution providers to verify GPU acceleration is working correctly.

In [None]:
import time

def benchmark_onnx_model(model_path: str, sample_inputs: dict, num_runs: int = 100) -> None:
    print(f"Benchmarking: {os.path.basename(model_path)}")

    session_gpu = ort.InferenceSession(
        model_path,
        providers=['CUDAExecutionProvider']
    )

    for _ in range(10):
        session_gpu.run(None, sample_inputs)

    start = time.perf_counter()
    for _ in range(num_runs):
        session_gpu.run(None, sample_inputs)
    gpu_time = (time.perf_counter() - start) / num_runs * 1000

    session_cpu = ort.InferenceSession(
        model_path,
        providers=['CPUExecutionProvider']
    )

    for _ in range(10):
        session_cpu.run(None, sample_inputs)

    start = time.perf_counter()
    for _ in range(num_runs):
        session_cpu.run(None, sample_inputs)
    cpu_time = (time.perf_counter() - start) / num_runs * 1000

    speedup = cpu_time / gpu_time

    print(f"  GPU (CUDA): {gpu_time:.2f} ms/inference")
    print(f"  CPU: {cpu_time:.2f} ms/inference")
    print(f"  Speedup: {speedup:.1f}x")

In [None]:
test_image_np = np.ones((1, 224, 224, 3), dtype=np.float32) * 0.5

noise_session = ort.InferenceSession(noise_onnx_path, providers=["CPUExecutionProvider"])
noise_input_name = noise_session.get_inputs()[0].name

glaze_session = ort.InferenceSession(glaze_onnx_path, providers=["CPUExecutionProvider"])
glaze_inputs = {
    glaze_session.get_inputs()[0].name: test_image_np,
    glaze_session.get_inputs()[1].name: np.array([0], dtype=np.int32)
}

nightshade_session = ort.InferenceSession(nightshade_onnx_path, providers=["CPUExecutionProvider"])
nightshade_inputs = {
    nightshade_session.get_inputs()[0].name: test_image_np,
    nightshade_session.get_inputs()[1].name: np.array([0], dtype=np.int32)
}

print("=" * 60)
print("PERFORMANCE BENCHMARK")
print("=" * 60)
print()

benchmark_onnx_model(noise_onnx_path, {noise_input_name: test_image_np})
print()
benchmark_onnx_model(glaze_onnx_path, glaze_inputs)
print()
benchmark_onnx_model(nightshade_onnx_path, nightshade_inputs)

## 16. Export Configuration

Save algorithm parameters and metadata as a JSON configuration file for use by the Rust application.

In [None]:
def serialize_value(value):
    if hasattr(value, "tolist"):
        return value.tolist()
    if isinstance(value, (np.generic, jnp.ndarray)):
        return value.item()
    return value


config = {
    'input': {
        'size': int(noise_input_size),
        'format': 'NHWC',
        'dtype': 'float32',
        'range': [0.0, 1.0],
        'preprocessing': 'included_in_model'
    },
    'output': {
        'type': 'scalar',
        'dtype': 'float32',
        'description': 'loss value to minimize via PGD'
    },
    'noise': {
        'file': 'noise_algorithm.onnx',
        'inputs': [
            {'name': 'input', 'shape': [1, 224, 224, 3], 'dtype': 'float32'}
        ],
        'parameters': {k: serialize_value(v) for k, v in noise_params.items()}
    },
    'glaze': {
        'file': 'glaze_algorithm.onnx',
        'inputs': [
            {'name': 'input', 'shape': [1, 224, 224, 3], 'dtype': 'float32'},
            {'name': 'style_index', 'shape': [1], 'dtype': 'int32'}
        ],
        'styles': style_names,
        'parameters': {k: serialize_value(v) for k, v in glaze_params.items()},
        'presets': {k: {pk: serialize_value(pv) for pk, pv in v.items()} for k, v in glaze_presets.items()}
    },
    'nightshade': {
        'file': 'nightshade_algorithm.onnx',
        'inputs': [
            {'name': 'input', 'shape': [1, 224, 224, 3], 'dtype': 'float32'},
            {'name': 'target_index', 'shape': [1], 'dtype': 'int32'}
        ],
        'targets': target_names,
        'parameters': {k: serialize_value(v) for k, v in nightshade_params.items()},
        'presets': {k: {pk: serialize_value(pv) for pk, pv in v.items()} for k, v in nightshade_presets.items()}
    }
}

config_path = os.path.join(ONNX_DIR, 'hope_config.json')
with open(config_path, 'w') as f:
    json.dump(config, f, indent=2)

print(f"Configuration saved: {config_path}")

## 17. Final Verification

List all exported files and verify completeness.

In [None]:
print("=" * 60)
print("EXPORT SUMMARY")
print("=" * 60)

expected_files = [
    'noise_algorithm.onnx',
    'glaze_algorithm.onnx',
    'nightshade_algorithm.onnx',
    'hope_config.json'
]

total_size = 0

for filename in expected_files:
    filepath = os.path.join(ONNX_DIR, filename)

    if os.path.exists(filepath):
        size = os.path.getsize(filepath)
        total_size += size

        if size > 1024 * 1024:
            size_str = f"{size / (1024 ** 2):.2f} MB"
        else:
            size_str = f"{size / 1024:.2f} KB"

        print(f"  {filename}: {size_str}")
    else:
        print(f"  {filename}: MISSING")

print()
print(f"Total size: {total_size / (1024 ** 2):.2f} MB")
print()
print("=" * 60)
print("Export complete")
print("=" * 60)

## 18. Summary

### Exported Files

| File | Inputs | Output | Description |
|------|--------|--------|-------------|
| `noise_algorithm.onnx` | image `(1,224,224,3)` float32 | loss scalar | Chaos/normal similarity loss |
| `glaze_algorithm.onnx` | image `(1,224,224,3)` float32, style\_index `(1,)` int32 | loss scalar | Style cloaking loss |
| `nightshade_algorithm.onnx` | image `(1,224,224,3)` float32, target\_index `(1,)` int32 | loss scalar | Targeted data poisoning loss |
| `hope_config.json` | - | - | Algorithm parameters and metadata |

### Image Input Specification

| Property | Value |
|----------|-------|
| Shape | `(1, 224, 224, 3)` |
| Format | NHWC (batch, height, width, channels) |
| Data type | float32 |
| Range | $[0.0, 1.0]$ |
| Preprocessing | CLIP normalization included in model |

### Loss Functions

All three algorithms share the same backbone (ViT-B/32 CLIP visual encoder) and produce a scalar loss. The Rust/Tauri application minimizes these losses via PGD to generate adversarial perturbations.

**Noise:**

$$\mathcal{L}_{noise} = -\frac{1}{|C|}\sum_{i=1}^{|C|} \hat{\mathbf{z}}^T \mathbf{e}_i^{chaos} + \frac{1}{|N|}\sum_{j=1}^{|N|} \hat{\mathbf{z}}^T \mathbf{e}_j^{normal}$$

**Glaze:**

$$\mathcal{L}_{glaze} = \hat{\mathbf{z}}^T \mathbf{e}_{source} - \hat{\mathbf{z}}^T \mathbf{e}_{style}^{(i)}$$

**Nightshade:**

$$\mathcal{L}_{nightshade} = \hat{\mathbf{z}}^T \mathbf{e}_{generic} - \hat{\mathbf{z}}^T \mathbf{e}_{target}^{(j)}$$

### Projected Gradient Descent (PGD)

The adversarial perturbation is computed iteratively:

$$\delta_{t+1} = \Pi_{\epsilon}\left(\delta_t - \alpha \cdot \text{sign}(\nabla_{\mathbf{x}} \mathcal{L}(\mathbf{x} + \delta_t))\right)$$

$$\mathbf{x}_{adv} = \text{clip}(\mathbf{x} + \delta_T, 0, 1)$$

Where:
- $\delta_t$ is the perturbation at step $t$
- $\Pi_{\epsilon}$ projects onto the $\ell_\infty$ ball: $\|\delta\|_\infty \leq \epsilon$
- $\epsilon$ is the intensity (perturbation budget) from configuration
- $\alpha = \frac{\epsilon \cdot k}{T}$ is the step size, where $k$ is the alpha multiplier and $T$ is the total iterations
- $\text{sign}(\cdot)$ is the element-wise sign function

### Embedding Normalization

Image embeddings are L2-normalized before computing cosine similarity (which reduces to dot product for unit vectors):

$$\hat{\mathbf{z}} = \frac{\mathbf{z}}{\|\mathbf{z}\|_2 + \varepsilon} \quad \Rightarrow \quad \text{sim}(\hat{\mathbf{a}}, \hat{\mathbf{b}}) = \hat{\mathbf{a}}^T \hat{\mathbf{b}} = \cos\theta$$

where $\varepsilon = 10^{-8}$ prevents division by zero.

## 19. Next Steps

The ONNX models are now ready for integration into the Hope Tauri application.

### Files Location

All exported files are saved to: `/content/drive/MyDrive/hope-models/onnx/`

### Download Files

To download the files for local development:

In [None]:
from google.colab import files

print("Files available for download:")
for filename in expected_files:
    filepath = os.path.join(ONNX_DIR, filename)
    if os.path.exists(filepath):
        print(f"  {filepath}")