# Adversarial Noise Protection Algorithm

Implements CLIP-guided adversarial perturbations in JAX

## 1. Install Dependencies

In [None]:
%pip install -q jax[cuda12] jaxlib
%pip install -q flax optax
%pip install -q pillow numpy tqdm

## 2. Mount Google Drive

In [None]:
import os
from google.colab import drive

drive.mount('/content/drive')

folders = [
    '/content/drive/MyDrive/hope-models/checkpoints',
    '/content/drive/MyDrive/hope-models/exports'
]

for folder in folders:
    os.makedirs(folder, exist_ok=True)

print("Drive mounted and folders ready")

## 3. Import Libraries

In [None]:
import jax
import jax.numpy as jnp

from jax import grad, jit, random
import numpy as np

from PIL import Image
import pickle
from tqdm import tqdm

print(f"JAX version: {jax.__version__}")
print(f"Devices: {jax.devices()}")
print(f"Backend: {jax.default_backend()}")

## 4. Helper Functions

In [None]:
def load_image(path):
    img = Image.open(path).convert('RGB')
    return jnp.array(img) / 255.0

def save_image(img_array, path):
    img_array = np.clip(np.array(img_array) * 255, 0, 255).astype(np.uint8)
    Image.fromarray(img_array).save(path, quality=95)

    print(f"Saved: {path}")

## 5. Load CLIP Data

Load pre-extracted weights and embeddings

In [None]:
data_path = '/content/drive/MyDrive/hope-models/checkpoints/clip_data.pkl'

with open(data_path, 'rb') as f:
    clip_data = pickle.load(f)

CLIP_MEAN = clip_data['clip_mean']
CLIP_STD = clip_data['clip_std']
CLIP_INPUT_SIZE = clip_data['clip_input_size']

chaos_embeddings = jnp.array(clip_data['chaos_embeddings_base'])
normal_embeddings = jnp.array(clip_data['normal_embeddings_base'])

print(f"Loaded CLIP data")
print(f"Mean: {CLIP_MEAN}")
print(f"Std: {CLIP_STD}")
print(f"Input size: {CLIP_INPUT_SIZE}")
print(f"Chaos embeddings: {chaos_embeddings.shape}")
print(f"Normal embeddings: {normal_embeddings.shape}")

## 6. Base Algorithm Parameters

Based on original Hope implementation

In [None]:
INTENSITY = 0.06
ITERATIONS = 200
LEARNING_RATE = 0.01
PERCEPTUAL_WEIGHT = 0.5

print(f"Intensity: {INTENSITY} ({int(INTENSITY * 255)}/255 per channel)")
print(f"Iterations: {ITERATIONS}")
print(f"Learning rate: {LEARNING_RATE}")
print(f"Perceptual weight: {PERCEPTUAL_WEIGHT}")

## 7. Define JAX/Flax ViT Architecture

Flax/Linen implementation of CLIP's Vision Transformer (ViT-B/32)

In [None]:
from flax import linen as nn

class CLIPAttention(nn.Module):
    num_heads: int
    @nn.compact
    def __call__(self, x):
        d_model = x.shape[-1]
        qkv = nn.Dense(3 * d_model, name="in_proj")(x)
        q, k, v = jnp.split(qkv, 3, axis=-1)

        def split_heads(t):
            return t.reshape(t.shape[0], t.shape[1], self.num_heads, -1).transpose(0, 2, 1, 3)

        q, k, v = split_heads(q), split_heads(k), split_heads(v)
        scale = (d_model // self.num_heads) ** -0.5
        attn_weights = jax.nn.softmax((q @ k.transpose(0, 1, 3, 2)) * scale, axis=-1)
        out = (attn_weights @ v).transpose(0, 2, 1, 3).reshape(x.shape[0], x.shape[1], -1)
        return nn.Dense(d_model, name="out_proj")(out)

class MLP(nn.Module):
    width: int
    @nn.compact
    def __call__(self, x):
        x = nn.Dense(self.width * 4, name="c_fc")(x)
        x = nn.gelu(x)
        x = nn.Dense(self.width, name="c_proj")(x)
        return x

class ResidualAttentionBlock(nn.Module):
    num_heads: int
    width: int
    @nn.compact
    def __call__(self, x):
        x = x + CLIPAttention(self.num_heads, name="attn")(nn.LayerNorm(name="ln_1")(x))
        x = x + MLP(self.width, name="mlp")(nn.LayerNorm(name="ln_2")(x))
        return x

class VisionTransformer(nn.Module):
    width: int = 768
    layers: int = 12
    heads: int = 12
    patch_size: int = 32
    output_dim: int = 512
    @nn.compact
    def __call__(self, x):
        x = nn.Conv(self.width, (self.patch_size, self.patch_size),
                    strides=(self.patch_size, self.patch_size),
                    padding='VALID', use_bias=False, name="conv1")(x)
        x = x.reshape(x.shape[0], -1, x.shape[-1])
        cls_token = self.param('class_embedding', lambda *args: jnp.zeros((self.width,)))
        pos_embed = self.param('positional_embedding', lambda *args: jnp.zeros((50, self.width)))
        x = jnp.concatenate([jnp.broadcast_to(cls_token, (x.shape[0], 1, self.width)), x], axis=1)
        x = x + pos_embed
        x = nn.LayerNorm(name="ln_pre")(x)
        for i in range(self.layers):
            x = ResidualAttentionBlock(self.heads, self.width, name=f"transformer_resblocks_{i}")(x)
        x = nn.LayerNorm(name="ln_post")(x[:, 0, :])
        return nn.Dense(self.output_dim, use_bias=False, name="proj")(x)

print("ViT architecture defined (CLIPAttention, MLP, ResidualAttentionBlock, VisionTransformer)")

## 8. Weight Conversion Function

Convert PyTorch CLIP weights to Flax parameter format

In [None]:
def convert_clip_weights(pt_weights):
    flax_params = {}

    for key, value in pt_weights.items():
        if key.startswith('visual.'):
            key = key[7:]

        value = jnp.array(value)

        if key == 'class_embedding':
            flax_params['class_embedding'] = value
            continue

        elif key == 'positional_embedding':
            flax_params['positional_embedding'] = value
            continue

        elif key == 'proj':
            if value.shape == (512, 768):
                flax_params['proj/kernel'] = value.T
            elif value.shape == (768, 512):
                flax_params['proj/kernel'] = value
            else:
                raise ValueError(f"Unexpected proj shape: {value.shape}")
            print(f"proj: input shape {pt_weights['proj'].shape} -> output shape {flax_params['proj/kernel'].shape}")
            continue

        elif key == 'conv1.weight':
            flax_params['conv1/kernel'] = jnp.transpose(value, (2, 3, 1, 0))
            continue

        key = key.replace('transformer.resblocks.', 'transformer_resblocks_').replace('.', '/')

        if 'in_proj_weight' in key:
            flax_params[key.replace('in_proj_weight', 'in_proj/kernel')] = value.T

        elif 'in_proj_bias' in key:
            flax_params[key.replace('in_proj_bias', 'in_proj/bias')] = value

        elif 'weight' in key and 'ln' in key:
            flax_params[key.replace('weight', 'scale')] = value

        elif 'bias' in key and 'ln' in key:
            flax_params[key] = value

        elif 'weight' in key:
            flax_params[key.replace('weight', 'kernel')] = value.T

        else:
            flax_params[key] = value

    nested_params = {}
    for key, value in flax_params.items():
        parts = key.split('/')
        curr = nested_params
        for p in parts[:-1]:
            curr = curr.setdefault(p, {})
        curr[parts[-1]] = value

    return {'params': nested_params}

print("Weight conversion function defined (convert_clip_weights)")

## 9. Create Reusable CLIP Encoder

Encoder class for use across all algorithms

In [None]:
from functools import partial

class CLIPEncoder:
    def __init__(self, weights):
        self.model_vit = VisionTransformer()
        self.variables = convert_clip_weights(weights)

        print("CLIP encoder initialized")

    @partial(jit, static_argnums=(0,))
    def encode_image(self, img):
        img_resized = jax.image.resize(img, (CLIP_INPUT_SIZE, CLIP_INPUT_SIZE, 3), method='bilinear')
        mean = jnp.array([0.48145466, 0.4578275, 0.40821073])
        std = jnp.array([0.26862954, 0.26130258, 0.27577711])

        normalized = (img_resized - mean) / std
        features = self.model_vit.apply(self.variables, normalized[None, ...])

        return features[0] / jnp.linalg.norm(features[0])

clip_encoder = CLIPEncoder(clip_data['base_weights'])
print("Global CLIP encoder ready")

## 10. CLIP Image Encoding (Simplified)

Simplified CLIP encoding using pre-computed embeddings

In [None]:
model_vit = VisionTransformer()
variables = convert_clip_weights(clip_data['base_weights'])

@jit
def compute_image_features(img):
    return clip_encoder.encode_image(img)

@jit
def semantic_loss_noise(img, chaos_emb, normal_emb):
    img_features = compute_image_features(img)
    sim_chaos = jnp.mean(jnp.dot(chaos_emb, img_features))
    sim_normal = jnp.mean(jnp.dot(normal_emb, img_features))

    return -sim_chaos + sim_normal

semantic_loss = semantic_loss_noise

print("Noise loss function defined")

## 11. Protection Algorithm

PGD-based adversarial perturbation with perceptual optimization

In [None]:
from functools import partial

def protect_image(img, intensity, iterations, chaos_emb, normal_emb, perceptual_weight=0.5):
    epsilon = intensity
    alpha = epsilon / iterations * 2.5

    print(f"\nProtecting image...")
    print(f"Epsilon: {epsilon:.4f} ({int(epsilon * 255)}/255 per channel)")
    print(f"Alpha: {alpha:.6f}")
    print(f"Iterations: {iterations}")
    print(f"Perceptual weight: {perceptual_weight}")
    print(f"\nCompiling JIT (this may take a few minutes on first run)...")

    def compute_edge_weight(image):
        gray = jnp.mean(image, axis=-1)

        gx = jnp.abs(gray[1:, :] - gray[:-1, :])
        gy = jnp.abs(gray[:, 1:] - gray[:, :-1])

        gx = jnp.pad(gx, ((0, 1), (0, 0)), mode='edge')
        gy = jnp.pad(gy, ((0, 0), (0, 1)), mode='edge')

        edges = jnp.sqrt(gx**2 + gy**2)
        edges = (edges - edges.min()) / (edges.max() - edges.min() + 1e-8)

        weight = 0.3 + 0.7 * edges
        return weight[..., None]

    edge_weight = compute_edge_weight(img)

    @jit
    def compute_perceptual_loss(perturbed, original):
        diff = perturbed - original
        smooth_penalty = jnp.mean(diff**2 * (1.5 - edge_weight))
        return smooth_penalty

    @jit
    def combined_loss(current_img, chaos_emb, normal_emb, original_img):
        semantic = semantic_loss(current_img, chaos_emb, normal_emb)
        perceptual = compute_perceptual_loss(current_img, original_img)
        return semantic + perceptual_weight * perceptual * 100

    @jit
    def pgd_step(current_img, original_img):
        loss_val, grads = jax.value_and_grad(combined_loss)(
            current_img, chaos_emb, normal_emb, original_img
        )

        weighted_grads = grads * edge_weight

        next_img = current_img - alpha * jnp.sign(weighted_grads)
        delta = jnp.clip(next_img - original_img, -epsilon, epsilon)
        next_img = jnp.clip(original_img + delta, 0.0, 1.0)

        return next_img, loss_val

    current_img = img
    current_img, initial_loss = pgd_step(current_img, img)
    print(f"JIT compilation complete!")
    print(f"Initial loss: {initial_loss:.4f}")

    losses = [float(initial_loss)]
    for i in tqdm(range(1, iterations), desc="Optimizing", unit="iter"):
        current_img, loss_val = pgd_step(current_img, img)
        losses.append(float(loss_val))

    perturbation = current_img - img
    max_change = float(jnp.max(jnp.abs(perturbation)))
    avg_change = float(jnp.mean(jnp.abs(perturbation)))

    print(f"\nProtection complete!")
    print(f"Final loss: {losses[-1]:.4f}")
    print(f"Loss improvement: {losses[0] - losses[-1]:.4f}")
    print(f"Max pixel change: {max_change:.4f} ({int(max_change * 255)}/255)")
    print(f"Avg pixel change: {avg_change:.4f} ({int(avg_change * 255)}/255)")

    return current_img

print("Protection function defined (PGD/Noise with perceptual optimization)")

## 12. Upload image

Upload a test image and apply protection

In [None]:
from google.colab import files
import matplotlib.pyplot as plt

print("Upload the image sample:")
uploaded = files.upload()

if uploaded:
    test_image_path = list(uploaded.keys())[0]
    try:
        original_img = load_image(test_image_path)

        print(f"Uploaded: {test_image_path}")
        print(f"Size: {original_img.shape}")
        print(f"Range: [{original_img.min():.3f}, {original_img.max():.3f}]")
    except Exception as e:
        print(f"Error loading image: {e}")
        original_img = None
else:
    print("No file uploaded")
    original_img = None

## 13. Run Protection

In [None]:
if original_img is None:
    print("No image loaded. Please run Step 10.1 first.")
else:
    print("Starting image protection (PGD/Noise Algorithm with perceptual loss)...")

    try:
        protected_img = protect_image(
            original_img,
            intensity=INTENSITY,
            iterations=ITERATIONS,
            chaos_emb=chaos_embeddings,
            normal_emb=normal_embeddings,
            perceptual_weight=PERCEPTUAL_WEIGHT
        )

        output_path = '/content/drive/MyDrive/hope-models/exports/protected_test.jpg'
        save_image(protected_img, output_path)

    except Exception as e:
        print(f"Protection failed: {e}")
        import traceback
        traceback.print_exc()

## 14. Displaying and Comparing

In [None]:
if original_img is None or protected_img is None:
    print("No images to display. Run steps 10.1 and 10.2 first.")
else:
    fig, axes = plt.subplots(1, 2, figsize=(15, 7))

    axes[0].imshow(np.array(original_img))
    axes[0].set_title('Original Image', fontsize=14)
    axes[0].axis('off')

    axes[1].imshow(np.array(protected_img))
    axes[1].set_title(f'Protected (Intensity: {INTENSITY})', fontsize=14)
    axes[1].axis('off')

    plt.tight_layout()
    plt.show()

    print(f"Display complete")

## 15. Download results

In [None]:
from google.colab import files

output_path = '/content/drive/MyDrive/hope-models/exports/protected_test.jpg'

if os.path.exists(output_path):
    print(f"Preparing to download the file: {output_path}")
    files.download(output_path)
else:
    print("Can't find file, run step 10 first.")

## 16. Batch Protection Function

Protect multiple images at once

In [None]:
def protect_batch(image_paths, output_dir, intensity=0.30, iterations=150):
    os.makedirs(output_dir, exist_ok=True)

    for img_path in image_paths:
        print(f"\n{'='*60}")
        print(f"Processing: {img_path}")

        img = load_image(img_path)

        protected = protect_image(
            img, intensity, iterations,
            chaos_embeddings, normal_embeddings
        )

        basename = os.path.basename(img_path)
        name, ext = os.path.splitext(basename)
        output_path = os.path.join(output_dir, f"{name}_protected{ext}")
        save_image(protected, output_path)

    print(f"\n{'='*60}")
    print(f"Batch protection complete!")

print("Batch protection function defined")

## 17. Export Model Parameters

Save model in format compatible with ONNX/TFLite export

In [None]:
def save_model_for_export():
    export_data = {
        'vit_weights': clip_data['base_weights'],

        'presets': {
            'noise_chaos': chaos_embeddings,
            'noise_normal': normal_embeddings,
        },

        'constants': {
            'clip_mean': CLIP_MEAN,
            'clip_std': CLIP_STD,
            'clip_input_size': CLIP_INPUT_SIZE,
        },

        'default_params': {
            'intensity': INTENSITY,
            'iterations': ITERATIONS,
            'alpha_multiplier': 2.0,
        },

        'architecture': {
            'model_type': 'ViT-B/32',
            'width': 768,
            'layers': 12,
            'heads': 12,
            'patch_size': 32,
            'output_dim': 512,
        }
    }

    export_path = '/content/drive/MyDrive/hope-models/exports/noise_model_export.pkl'

    with open(export_path, 'wb') as f:
        pickle.dump(export_data, f)

    file_size = os.path.getsize(export_path) / (1024**2)
    print(f"Model export data saved")
    print(f"Path: {export_path}")
    print(f"Size: {file_size:.2f} MB")
    print(f"Presets: {len(export_data['presets'])}")

    return export_path

export_path = save_model_for_export()

## 18. Model Information

Display model statistics and capabilities

In [None]:
def print_model_info():
    print("="*60)
    print("NOISE PROTECTION MODEL INFO")
    print("="*60)

    print("\nArchitecture:")
    print(f"Model: ViT-B/32")
    print(f"Parameters: ~150M")
    print(f"Input size: {CLIP_INPUT_SIZE}x{CLIP_INPUT_SIZE}")
    print(f"Feature dim: 512")

    print("\nEmbeddings:")
    print(f"Chaos prompts: {chaos_embeddings.shape[0]}")
    print(f"Normal prompts: {normal_embeddings.shape[0]}")
    print(f"Embedding dim: {chaos_embeddings.shape[1]}")

    print("\nDefault Parameters:")
    print(f"Intensity: {INTENSITY}")
    print(f"Iterations: {ITERATIONS}")
    print(f"Learning rate: {LEARNING_RATE}")

    print("\nCapabilities:")
    print(f"Single image protection")
    print(f"Batch processing")
    print(f"Adjustable intensity")
    print(f"GPU acceleration (JAX)")
    print(f"JIT compilation")

    print("\nExport Ready:")
    print(f"ONNX export compatible")
    print(f"TFLite export compatible")
    print(f"Preset embeddings saved")

    print("\n" + "="*60)

print_model_info()

## Algorithm Complete

**Implemented:**
- JAX-based adversarial perturbations
- CLIP-guided semantic loss
- PGD optimization
- Batch processing support

**Parameters:**
- Intensity: 0.06 (adjustable)
- Iterations: 200 (adjustable)
- Chaos prompts: 8 prompts
- Normal prompts: 3 prompts

**Saved to:** `/content/drive/MyDrive/hope-models/exports/`

**Next:** Run `3_glaze_algorithm.ipynb` or `5_export_onnx.ipynb`