# Nightshade Data Poisoning Algorithm

Implements targeted misclassification using CLIP-guided adversarial perturbations in JAX.

## 1. Install Dependencies

Install JAX with CUDA 12 support and required libraries.

In [None]:
%pip install -q jax[cuda12] jaxlib

%pip install -q flax optax

%pip install -q pillow numpy tqdm

## 2. Mount Google Drive

Connect to Google Drive to load CLIP weights and save results.

In [None]:
import os
from google.colab import drive

drive.mount('/content/drive')

folders = [
    '/content/drive/MyDrive/hope-models/checkpoints',
    '/content/drive/MyDrive/hope-models/exports'
]

for folder in folders:
    os.makedirs(folder, exist_ok=True)

print("Drive mounted and folders ready")

## 3. Import Libraries

Load JAX, Flax, and utility libraries.

In [None]:
import jax
import jax.numpy as jnp
from jax import grad, jit, random

import numpy as np
from PIL import Image
import pickle

from tqdm import tqdm
from flax import linen as nn
from functools import partial

print(f"JAX version: {jax.__version__}")
print(f"Devices: {jax.devices()}")
print(f"Backend: {jax.default_backend()}")

## 4. Helper Functions

Image loading, saving, and visualization utilities.

In [None]:
def load_image(path):
    img = Image.open(path).convert('RGB')
    return jnp.array(img) / 255.0

def save_image(img_array, path):
    img_array = np.clip(np.array(img_array) * 255, 0, 255).astype(np.uint8)
    Image.fromarray(img_array).save(path, quality=95)
    print(f"Saved: {path}")

def get_edge_mask(img):
    gray = jnp.mean(img, axis=-1)

    gx = jnp.abs(gray[1:, :] - gray[:-1, :])
    gy = jnp.abs(gray[:, 1:] - gray[:, :-1])
    gx = jnp.pad(gx, ((0, 1), (0, 0)), mode='edge')
    gy = jnp.pad(gy, ((0, 0), (0, 1)), mode='edge')

    edges = jnp.sqrt(gx**2 + gy**2)
    edges = (edges - edges.min()) / (edges.max() - edges.min() + 1e-8)

    return (0.3 + 0.7 * edges)[..., None]

print("Helper functions defined")

## 5. Load CLIP Data

Load pre-extracted CLIP weights and Nightshade target embeddings from `1_clip_to_jax.ipynb`.

In [None]:
import time

data_path = '/content/drive/MyDrive/hope-models/checkpoints/clip_data.pkl'

max_retries = 3
for attempt in range(max_retries):
    try:
        with open(data_path, 'rb') as f:
            clip_data = pickle.load(f)
        break
    except OSError as e:
        if attempt < max_retries - 1:
            print(f"Drive connection lost. Reconnecting (attempt {attempt + 1}/{max_retries})...")
            from google.colab import drive
            drive.mount('/content/drive', force_remount=True)
            time.sleep(2)
        else:
            raise e

CLIP_MEAN = clip_data['clip_mean']
CLIP_STD = clip_data['clip_std']
CLIP_INPUT_SIZE = clip_data['clip_input_size']

NIGHTSHADE_TARGET_PROMPTS = clip_data['nightshade_target_prompts']
nightshade_target_embeddings_raw = clip_data['nightshade_target_embeddings_base']
nightshade_generic_emb_raw = clip_data['nightshade_generic_emb_base']

target_embeddings = {name: jnp.array(emb) for name, emb in nightshade_target_embeddings_raw.items()}
generic_emb = jnp.array(nightshade_generic_emb_raw)

print(f"Loaded CLIP data")
print(f"Mean: {CLIP_MEAN}")
print(f"Std: {CLIP_STD}")
print(f"Input size: {CLIP_INPUT_SIZE}")
print(f"\nNightshade target embeddings loaded:")
for name, emb in target_embeddings.items():
    print(f"  {name}: {emb.shape}")
print(f"  Generic (clear photo): {generic_emb.shape}")

## 6. Nightshade Algorithm Parameters

**Tuned for invisibility + maximum AI poisoning effect**

Key principles:
- **Low intensity (0.02-0.04)**: Imperceptible to humans
- **High iterations (400-600)**: Subtle but effective changes
- **High perceptual weight (1.5-2.0)**: Preserve visual appearance
- **Targeted misclassification**: Make AI learn wrong class

In [None]:
if 'nightshade_target_prompts' not in clip_data or clip_data['nightshade_target_prompts'] is None:
    NIGHTSHADE_TARGET_PROMPTS = {
        'Dog': "a photo of a dog",
        'Cat': "a photo of a cat",
        'Car': "a photo of a car",
        'Landscape': "a landscape photograph",
        'Person': "a photo of a person",
        'Building': "a photo of a building",
        'Food': "a photo of food",
        'Abstract': "abstract digital art",
    }
    print("Note: Using hardcoded target prompts (re-run 1_clip_to_jax.ipynb to save them)")
else:
    NIGHTSHADE_TARGET_PROMPTS = clip_data['nightshade_target_prompts']

if 'nightshade_generic_prompt' not in clip_data:
    NIGHTSHADE_GENERIC_PROMPT = "a clear photograph"
else:
    NIGHTSHADE_GENERIC_PROMPT = clip_data['nightshade_generic_prompt']

INTENSITY = 0.03 # Very low - almost invisible
ITERATIONS = 500 # High - subtle but effective
LEARNING_RATE = 0.01
PERCEPTUAL_WEIGHT = 1.5 # Strong - preserve human perception
ALPHA_MULTIPLIER = 2.5

PARAMETER_PRESETS = {
    'subtle': {
        'intensity': 0.02,
        'iterations': 600,
        'perceptual_weight': 2.0,
        'description': 'Maximum invisibility, slower but more effective'
    },
    'balanced': {
        'intensity': 0.03,
        'iterations': 500,
        'perceptual_weight': 1.5,
        'description': 'Balance between invisibility and speed'
    },
    'strong': {
        'intensity': 0.04,
        'iterations': 400,
        'perceptual_weight': 1.0,
        'description': 'Stronger poisoning, slightly more visible'
    }
}

ACTIVE_PRESET = 'balanced'

preset = PARAMETER_PRESETS[ACTIVE_PRESET]
INTENSITY = preset['intensity']
ITERATIONS = preset['iterations']
PERCEPTUAL_WEIGHT = preset['perceptual_weight']

print(f"Active preset: {ACTIVE_PRESET}")
print(f"Description: {preset['description']}")
print(f"\nParameters:")
print(f"  Intensity: {INTENSITY} ({int(INTENSITY * 255)}/255 per channel)")
print(f"  Iterations: {ITERATIONS}")
print(f"  Learning rate: {LEARNING_RATE}")
print(f"  Perceptual weight: {PERCEPTUAL_WEIGHT}")
print(f"  Alpha multiplier: {ALPHA_MULTIPLIER}")
print(f"\nAvailable parameter presets: {list(PARAMETER_PRESETS.keys())}")
print(f"Available target classes: {list(NIGHTSHADE_TARGET_PROMPTS.keys())}")

## 7. Define JAX/Flax ViT Architecture

Reuse same CLIP Vision Transformer from previous notebooks.

In [None]:
class CLIPAttention(nn.Module):
    num_heads: int
    @nn.compact
    def __call__(self, x):
        d_model = x.shape[-1]
        qkv = nn.Dense(3 * d_model, name="in_proj")(x)
        q, k, v = jnp.split(qkv, 3, axis=-1)

        def split_heads(t):
            return t.reshape(t.shape[0], t.shape[1], self.num_heads, -1).transpose(0, 2, 1, 3)

        q, k, v = split_heads(q), split_heads(k), split_heads(v)
        scale = (d_model // self.num_heads) ** -0.5
        attn_weights = jax.nn.softmax((q @ k.transpose(0, 1, 3, 2)) * scale, axis=-1)
        out = (attn_weights @ v).transpose(0, 2, 1, 3).reshape(x.shape[0], x.shape[1], -1)

        return nn.Dense(d_model, name="out_proj")(out)

class MLP(nn.Module):
    width: int
    @nn.compact
    def __call__(self, x):
        x = nn.Dense(self.width * 4, name="c_fc")(x)
        x = nn.gelu(x)
        x = nn.Dense(self.width, name="c_proj")(x)

        return x

class ResidualAttentionBlock(nn.Module):
    num_heads: int
    width: int
    @nn.compact
    def __call__(self, x):
        x = x + CLIPAttention(self.num_heads, name="attn")(nn.LayerNorm(name="ln_1")(x))
        x = x + MLP(self.width, name="mlp")(nn.LayerNorm(name="ln_2")(x))

        return x

class VisionTransformer(nn.Module):
    width: int = 768
    layers: int = 12
    heads: int = 12
    patch_size: int = 32
    output_dim: int = 512

    @nn.compact
    def __call__(self, x):
        x = nn.Conv(self.width, (self.patch_size, self.patch_size),
                    strides=(self.patch_size, self.patch_size),
                    padding='VALID', use_bias=False, name="conv1")(x)
        x = x.reshape(x.shape[0], -1, x.shape[-1])
        cls_token = self.param('class_embedding', lambda *args: jnp.zeros((self.width,)))
        pos_embed = self.param('positional_embedding', lambda *args: jnp.zeros((50, self.width)))
        x = jnp.concatenate([jnp.broadcast_to(cls_token, (x.shape[0], 1, self.width)), x], axis=1)
        x = x + pos_embed
        x = nn.LayerNorm(name="ln_pre")(x)
        for i in range(self.layers):
            x = ResidualAttentionBlock(self.heads, self.width, name=f"transformer_resblocks_{i}")(x)
        x = nn.LayerNorm(name="ln_post")(x[:, 0, :])

        return nn.Dense(self.output_dim, use_bias=False, name="proj")(x)

print("ViT architecture defined")

## 8. Weight Conversion Function

Convert PyTorch CLIP weights to Flax format (reused from previous notebooks).

In [None]:
def convert_clip_weights(pt_weights):
    flax_params = {}

    for key, value in pt_weights.items():
        if key.startswith('visual.'):
            key = key[7:]

        value = jnp.array(value)

        if key == 'class_embedding':
            flax_params['class_embedding'] = value
            continue
        elif key == 'positional_embedding':
            flax_params['positional_embedding'] = value
            continue
        elif key == 'proj':
            if value.shape == (512, 768):
                flax_params['proj/kernel'] = value.T
            elif value.shape == (768, 512):
                flax_params['proj/kernel'] = value
            else:
                raise ValueError(f"Unexpected proj shape: {value.shape}")
            continue
        elif key == 'conv1.weight':
            flax_params['conv1/kernel'] = jnp.transpose(value, (2, 3, 1, 0))
            continue

        key = key.replace('transformer.resblocks.', 'transformer_resblocks_').replace('.', '/')

        if 'in_proj_weight' in key:
            flax_params[key.replace('in_proj_weight', 'in_proj/kernel')] = value.T
        elif 'in_proj_bias' in key:
            flax_params[key.replace('in_proj_bias', 'in_proj/bias')] = value
        elif 'weight' in key and 'ln' in key:
            flax_params[key.replace('weight', 'scale')] = value
        elif 'bias' in key and 'ln' in key:
            flax_params[key] = value
        elif 'weight' in key:
            flax_params[key.replace('weight', 'kernel')] = value.T
        else:
            flax_params[key] = value

    nested_params = {}
    for key, value in flax_params.items():
        parts = key.split('/')
        curr = nested_params
        for p in parts[:-1]:
            curr = curr.setdefault(p, {})
        curr[parts[-1]] = value

    return {'params': nested_params}

print("Weight conversion function defined")

## 9. Create CLIP Encoder

Reusable encoder for image → feature extraction.

In [None]:
class CLIPEncoder:
    def __init__(self, weights):
        self.model_vit = VisionTransformer()
        self.variables = convert_clip_weights(weights)

        print("CLIP encoder initialized")

    @partial(jit, static_argnums=(0,))
    def encode_image(self, img):
        img_resized = jax.image.resize(img, (CLIP_INPUT_SIZE, CLIP_INPUT_SIZE, 3), method='bilinear')
        mean = jnp.array([0.48145466, 0.4578275, 0.40821073])
        std = jnp.array([0.26862954, 0.26130258, 0.27577711])
        normalized = (img_resized - mean) / std
        features = self.model_vit.apply(self.variables, normalized[None, ...])

        return features[0] / jnp.linalg.norm(features[0])

clip_encoder = CLIPEncoder(clip_data['base_weights'])
print("Global CLIP encoder ready")

## 10. Nightshade Poisoning Loss Function

**Goal:** Minimize similarity to real class, maximize similarity to target (wrong) class.

**Example:**
- Real image: Dog
- Loss pushes AI to see it as: Cat

**Formula:**

$$\mathcal{L}_{\text{nightshade}} = \text{sim}(\text{img}, \text{source}) - \text{sim}(\text{img}, \text{target})$$

Where:
- `sim(img, source)` = Cosine similarity to real class (e.g., "dog")
- `sim(img, target)` = Cosine similarity to target class (e.g., "cat")

Minimizing this makes image:
- **Less dog-like** (in CLIP space)
- **More cat-like** (in CLIP space)

**Optimization:**
- Gradient descent pushes image features toward target class
- Perceptual loss keeps changes invisible to humans

In [None]:
@jit
def compute_image_features(img):
    return clip_encoder.encode_image(img)

@jit
def nightshade_poisoning_loss(img, source_emb, target_emb):
    img_features = compute_image_features(img)
    sim_source = jnp.dot(img_features, source_emb)
    sim_target = jnp.dot(img_features, target_emb)
    return sim_source - sim_target

print("Nightshade poisoning loss function defined")

## 11. Compiled PGD Step Function

In [None]:
@jit
def nightshade_step_compiled(
    current_img,
    original_img,
    source_emb,
    target_emb,
    variables,
    edge_weight,
    epsilon,
    alpha,
    perceptual_weight
):
    model = VisionTransformer()

    def loss_fn(x):
        resized = jax.image.resize(x, (CLIP_INPUT_SIZE, CLIP_INPUT_SIZE, 3), method='bilinear')
        mean = jnp.array([0.48145466, 0.4578275, 0.40821073])
        std = jnp.array([0.26862954, 0.26130258, 0.27577711])
        normalized = (resized - mean) / std
        features = model.apply(variables, normalized[None, ...])[0]
        features = features / jnp.linalg.norm(features)

        poison_loss = jnp.dot(features, source_emb) - jnp.dot(features, target_emb)
        perceptual_loss = jnp.mean((x - original_img)**2 * (1.5 - edge_weight))

        return poison_loss + perceptual_weight * perceptual_loss * 100

    loss_val, grads = jax.value_and_grad(loss_fn)(current_img)
    next_img = current_img - alpha * jnp.sign(grads * edge_weight)
    delta = jnp.clip(next_img - original_img, -epsilon, epsilon)

    return jnp.clip(original_img + delta, 0.0, 1.0), loss_val

print("Compiled PGD step function defined")

## 12. Nightshade Protection Algorithm

PGD (Projected Gradient Descent) with:
- **Targeted misclassification loss**: Push to wrong class
- **Perceptual loss**: Keep changes invisible to humans
- **Edge-aware perturbation**: Concentrate changes in textured areas

**Process:**
1. Compute loss gradient (direction to mislead AI)
2. Take small step in that direction
3. Clip perturbation to ±epsilon (intensity limit)
4. Repeat for many iterations

In [None]:
def nightshade_protect_image(
    img,
    source_class_name,
    target_class_name,
    intensity,
    iterations,
    perceptual_weight=1.5
):
    source_emb = generic_emb
    target_emb = target_embeddings[target_class_name]

    epsilon = intensity
    alpha = epsilon / iterations * ALPHA_MULTIPLIER

    print(f"\nApplying Nightshade poisoning...")
    print(f"Source (real): Generic image")
    print(f"Target (AI sees): {target_class_name}")
    print(f"Epsilon: {epsilon:.4f} ({int(epsilon * 255)}/255 per channel)")
    print(f"Alpha: {alpha:.6f}")
    print(f"Iterations: {iterations}")
    print(f"Perceptual weight: {perceptual_weight}")
    print(f"\nComputing edge mask...")

    edge_weight = get_edge_mask(img)

    variables = clip_encoder.variables

    print(f"Compiling JIT (first run only)...")

    current_img = img
    current_img, initial_loss = nightshade_step_compiled(
        current_img, img, source_emb, target_emb,
        variables, edge_weight, epsilon, alpha, perceptual_weight
    )
    print(f"JIT compilation complete!")
    print(f"Initial loss: {initial_loss:.4f}")

    losses = [float(initial_loss)]
    for i in tqdm(range(1, iterations), desc="Poisoning", unit="iter"):
        current_img, loss_val = nightshade_step_compiled(
            current_img, img, source_emb, target_emb,
            variables, edge_weight, epsilon, alpha, perceptual_weight
        )
        losses.append(float(loss_val))

    perturbation = current_img - img
    max_change = float(jnp.max(jnp.abs(perturbation)))
    avg_change = float(jnp.mean(jnp.abs(perturbation)))

    print(f"\nNightshade poisoning complete!")
    print(f"Final loss: {losses[-1]:.4f}")
    print(f"Loss improvement: {losses[0] - losses[-1]:.4f}")
    print(f"Max pixel change: {max_change:.4f} ({int(max_change * 255)}/255)")
    print(f"Avg pixel change: {avg_change:.4f} ({int(avg_change * 255)}/255)")
    print(f"\nThis image will poison AI models to classify it as '{target_class_name}'")

    return current_img

print("Nightshade protection function defined")

## 13. Upload Test Image

Upload an image you want to poison.

**Example use cases:**
- Upload dog photo, poison as "Cat" → AI learns dogs are cats
- Upload artwork, poison as "Abstract" → AI misclassifies your style

In [None]:
from google.colab import files
import matplotlib.pyplot as plt

print("Upload the image to poison:")
uploaded = files.upload()

if uploaded:
    test_image_path = list(uploaded.keys())[0]
    try:
        original_img = load_image(test_image_path)
        print(f"Uploaded: {test_image_path}")
        print(f"Size: {original_img.shape}")
        print(f"Range: [{original_img.min():.3f}, {original_img.max():.3f}]")
    except Exception as e:
        print(f"Error loading image: {e}")
        original_img = None
else:
    print("No file uploaded")
    original_img = None

## 14. Run Nightshade Poisoning

Choose target class from: `Dog`, `Cat`, `Car`, `Landscape`, `Person`, `Building`, `Food`, `Abstract`

**Example scenarios:**
- Real: Dog photo → Target: `Cat` → AI learns dog as cat
- Real: Person photo → Target: `Car` → AI learns person as car
- Real: Landscape → Target: `Abstract` → AI misclassifies landscapes

In [None]:
if original_img is None:
    print("No image loaded. Run Step 12 first.")
else:
    # Options: 'Dog', 'Cat', 'Car', 'Landscape', 'Person', 'Building', 'Food', 'Abstract'
    TARGET_CLASS = 'Cat'

    print(f"Available target classes: {list(target_embeddings.keys())}")
    print(f"Selected target: {TARGET_CLASS}")

    try:
        poisoned_img = nightshade_protect_image(
            original_img,
            source_class_name='Generic',
            target_class_name=TARGET_CLASS,
            intensity=INTENSITY,
            iterations=ITERATIONS,
            perceptual_weight=PERCEPTUAL_WEIGHT
        )

        output_path = f'/content/drive/MyDrive/hope-models/exports/nightshade_{TARGET_CLASS.lower()}.jpg'
        save_image(poisoned_img, output_path)

    except Exception as e:
        print(f"Poisoning failed: {e}")
        import traceback
        traceback.print_exc()

## 15. Display Results

Compare original and poisoned images side-by-side.

**What you should see:**
- Images look nearly identical to human eyes
- Max pixel change: 5-10 per channel (out of 255)
- AI will classify poisoned image as target class

In [None]:
if original_img is None:
    print("No original image. Run Step 13 first.")
elif 'poisoned_img' not in dir() or poisoned_img is None:
    print("No poisoned image. Run Step 14 first.")
else:
    fig, axes = plt.subplots(1, 3, figsize=(18, 6))

    axes[0].imshow(np.array(original_img))
    axes[0].set_title('Original Image', fontsize=14)
    axes[0].axis('off')

    axes[1].imshow(np.array(poisoned_img))
    axes[1].set_title(f'Poisoned (Target: {TARGET_CLASS})', fontsize=14)
    axes[1].axis('off')

    diff = np.abs(np.array(poisoned_img) - np.array(original_img)) * 10
    axes[2].imshow(diff)
    axes[2].set_title('Difference (x10 amplified)', fontsize=14)
    axes[2].axis('off')

    plt.tight_layout()
    plt.show()

    max_diff = float(jnp.max(jnp.abs(poisoned_img - original_img)))
    avg_diff = float(jnp.mean(jnp.abs(poisoned_img - original_img)))
    print(f"\nChange statistics:")
    print(f"  Max: {max_diff:.4f} ({int(max_diff * 255)}/255)")
    print(f"  Avg: {avg_diff:.4f} ({int(avg_diff * 255)}/255)")
    print(f"  Human visible: {'Yes' if max_diff > 0.05 else 'No (< 13/255)'}")

## 16. Download Poisoned Image

Download the poisoned image for use in datasets.

**Warning:** This image will cause AI models trained on it to learn incorrect associations.

In [None]:
from google.colab import files

if 'TARGET_CLASS' not in dir():
    print("TARGET_CLASS not defined. Run Step 14 first.")
else:
    output_path = f'/content/drive/MyDrive/hope-models/exports/nightshade_{TARGET_CLASS.lower()}.jpg'

    if os.path.exists(output_path):
        print(f"Downloading: {output_path}")
        files.download(output_path)
    else:
        print("File not found. Run Step 14 first.")

## 17. Batch Protection Function

Poison multiple images at once with same target class.

In [None]:
def nightshade_protect_batch(
    image_paths,
    output_dir,
    target_class_name,
    intensity=0.03,
    iterations=500,
    perceptual_weight=1.5
):
    os.makedirs(output_dir, exist_ok=True)

    for img_path in image_paths:
        print(f"\n{'='*60}")
        print(f"Processing: {img_path}")

        img = load_image(img_path)

        poisoned = nightshade_protect_image(
            img,
            source_class_name='Generic',
            target_class_name=target_class_name,
            intensity=intensity,
            iterations=iterations,
            perceptual_weight=perceptual_weight
        )

        basename = os.path.basename(img_path)
        name, ext = os.path.splitext(basename)
        output_path = os.path.join(output_dir, f"{name}_nightshade_{target_class_name.lower()}{ext}")
        save_image(poisoned, output_path)

    print(f"\n{'='*60}")
    print(f"Batch poisoning complete!")

print("Batch protection function defined")

## 18. Export Model Parameters

Save model data for ONNX/TFLite conversion.

In [None]:
def save_nightshade_model_for_export():
    export_data = {
        'vit_weights': clip_data['base_weights'],

        'presets': {
            'target_embeddings': {name: np.array(emb) for name, emb in target_embeddings.items()},
            'generic_emb': np.array(generic_emb),
        },

        'target_prompts': NIGHTSHADE_TARGET_PROMPTS,
        'generic_prompt': NIGHTSHADE_GENERIC_PROMPT,

        'constants': {
            'clip_mean': CLIP_MEAN,
            'clip_std': CLIP_STD,
            'clip_input_size': CLIP_INPUT_SIZE,
        },

        'parameter_presets': PARAMETER_PRESETS,

        'default_params': {
            'intensity': INTENSITY,
            'iterations': ITERATIONS,
            'learning_rate': LEARNING_RATE,
            'perceptual_weight': PERCEPTUAL_WEIGHT,
            'alpha_multiplier': ALPHA_MULTIPLIER,
        },

        'architecture': {
            'model_type': 'ViT-B/32',
            'width': 768,
            'layers': 12,
            'heads': 12,
            'patch_size': 32,
            'output_dim': 512,
        },

        'algorithm': {
            'name': 'Nightshade',
            'type': 'data_poisoning',
            'description': 'Targeted misclassification for AI training data poisoning',
            'available_targets': list(NIGHTSHADE_TARGET_PROMPTS.keys()),
        }
    }

    export_path = '/content/drive/MyDrive/hope-models/exports/nightshade_model_export.pkl'

    with open(export_path, 'wb') as f:
        pickle.dump(export_data, f)

    file_size = os.path.getsize(export_path) / (1024**2)
    print(f"Nightshade model export data saved")
    print(f"Path: {export_path}")
    print(f"Size: {file_size:.2f} MB")
    print(f"Available targets: {list(NIGHTSHADE_TARGET_PROMPTS.keys())}")
    print(f"Parameter presets: {list(PARAMETER_PRESETS.keys())}")

    return export_path

export_path = save_nightshade_model_for_export()

## 19. Model Information

Display Nightshade model capabilities and parameters.

In [None]:
def print_nightshade_model_info():
    print("=" * 60)
    print("NIGHTSHADE POISONING MODEL INFO")
    print("=" * 60)

    print("\nArchitecture:")
    print(f"  Model: ViT-B/32")
    print(f"  Parameters: ~150M")
    print(f"  Input size: {CLIP_INPUT_SIZE}x{CLIP_INPUT_SIZE}")
    print(f"  Feature dim: 512")

    print("\nTarget Embeddings:")
    for name, emb in target_embeddings.items():
        print(f"  {name}: {emb.shape}")
    print(f"  Generic (source): {generic_emb.shape}")

    print("\nDefault Parameters:")
    print(f"  Intensity: {INTENSITY} ({int(INTENSITY * 255)}/255)")
    print(f"  Iterations: {ITERATIONS}")
    print(f"  Learning rate: {LEARNING_RATE}")
    print(f"  Perceptual weight: {PERCEPTUAL_WEIGHT}")

    print("\nParameter Presets:")
    for name, preset in PARAMETER_PRESETS.items():
        marker = ">" if name == ACTIVE_PRESET else " "
        print(f"  {marker} {name}: {preset['description']}")

    print("\nCapabilities:")
    print(f"  Single image poisoning")
    print(f"  Batch processing")
    print(f"  8 target classes")
    print(f"  Imperceptible to humans (< 10/255 change)")
    print(f"  GPU acceleration (JAX)")
    print(f"  JIT compilation")

    print("\nExport Ready:")
    print(f"  ONNX export compatible")
    print(f"  TFLite export compatible")
    print(f"  Preset embeddings saved")

    print("\n" + "=" * 60)

print_nightshade_model_info()

## Nightshade Algorithm Complete

**Implemented:**
- JAX-based data poisoning
- CLIP-guided targeted misclassification
- PGD optimization with perceptual loss
- Edge-aware adaptive perturbation
- 8 target classes

**Parameters (Optimized for Invisibility + Effectiveness):**
- Intensity: 0.03 (7/255 per channel - imperceptible)
- Iterations: 500 (subtle changes)
- Perceptual weight: 1.5 (strong invisibility)

**Target Classes:**
- Dog, Cat, Car, Landscape, Person, Building, Food, Abstract

**How it works:**
1. Human sees: **Normal image**
2. AI learns: **Wrong class** (e.g., dog → cat)
3. Result: **Poisoned training data**

**Saved to:** `/content/drive/MyDrive/hope-models/exports/`

**Exported:** `nightshade_model_export.pkl`

**Next:** Run `5_export_onnx.ipynb` to convert to ONNX/TFLite for Hope Tauri app