In [3]:
"""
attack_harness_wesad_full.py

Projected-PGD attack on TEXT embedding for WESAD with semantic interpretations.
Goals:
- High original similarity (orig_sim)
- Lower adversarial similarity (adv_sim)
- Loss after attack > loss before

Features:
- Works with HuggingFace transformers if installed; otherwise uses a hashing fallback.
- Normalizes embeddings for stable cosine similarity.
- PGD uses ε-scaled step size and random start within ε-ball.
- Selects best restart by (1) lowest adv_sim, then (2) largest (loss_after - loss_before).
- Writes a CSV of results with diagnostics.

"""

import os
import sys
import csv
import math
import random
import importlib
from typing import List

import torch
import torch.nn as nn
import torch.nn.functional as F
import pandas as pd

# ------------------------- CONFIG -------------------------
EMBEDDING_DIM = 64
MARGIN = 0.5
TEXT_MODEL_NAME = "distilbert-base-uncased"

# ---- WESAD settings ----
TEXT_COLUMN_NAME = "semantic_interpretation"
ACTIVITY_COLUMN_NAME = "label_name"   # e.g., baseline / stress / amusement

# Use 6-D sensors: chest + wrist accelerometer
SENSOR_COLUMNS = [
    "chest_acc_x", "chest_acc_y", "chest_acc_z",
    "wrist_acc_x", "wrist_acc_y", "wrist_acc_z",
]
SENSOR_FEATURES_COUNT = len(SENSOR_COLUMNS)

# Paths
SENSOR_MODEL_PATH = "sensor_encoder_wesad.pth"
TEXT_MODEL_PATH = "text_encoder_wesad.pth"
DATA_FILE = "./data/WESAD_with_semantic_interpretation.csv"

# Attack defaults
DEFAULT_ALPHA = 0.005                  # per-iteration step size
EPSILON_SWEEP = [0.5, 0.6, 0.7, 0.8, 0.9, 1,2,3,4,5,6,7,8,9,10]
DEFAULT_STEPS = 1000
MODE = "linf"                         # "linf" or "l2"
RANDOM_RESTARTS = 3
DBG_INTERVAL = 50
EARLY_STOP = True

TARGET_INDICES = [0,  4000]
OUTPUT_CSV = "attack_results_WESAD_New.csv"
SEED = 0
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# ------------------------- HELPERS & TEXT ENCODERS -------------------------
def set_seed(seed):
    torch.manual_seed(seed)
    random.seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

def _transformers_available():
    return importlib.util.find_spec("transformers") is not None

class HFTextEncoder(nn.Module):
    """Uses HuggingFace transformer + trainable projection."""
    def __init__(self, model_name, output_dim, device="cpu"):
        super().__init__()
        from transformers import AutoTokenizer, AutoModel
        self.device = torch.device(device)
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.model = AutoModel.from_pretrained(model_name).to(self.device)
        self.projection = nn.Linear(self.model.config.hidden_size, output_dim).to(self.device)

    def forward(self, texts: List[str]):
        if isinstance(texts, str):
            texts = [texts]
        enc = self.tokenizer(texts, padding=True, truncation=True, return_tensors="pt")
        enc = {k: v.to(self.projection.weight.device) for k, v in enc.items()}
        out = self.model(**enc)
        pooled = out.last_hidden_state.mean(dim=1)  # (B, H)
        return self.projection(pooled)              # (B, D)

class HashingTextEncoder(nn.Module):
    """Dependency-free fallback: simple hashing BoW + projection (keeps gradients)."""
    def __init__(self, output_dim, hash_dim=2048, device="cpu"):
        super().__init__()
        self.hash_dim = hash_dim
        self.projection = nn.Linear(hash_dim, output_dim)
        self.device = torch.device(device)

    @staticmethod
    def _tokenize(s: str):
        return s.lower().split()

    def _hash_bow(self, texts):
        bows = []
        for s in texts:
            vec = torch.zeros(self.hash_dim, dtype=torch.float32)
            for tok in self._tokenize(s):
                idx = (hash(tok) % self.hash_dim)
                vec[idx] += 1.0
            n = torch.norm(vec, p=2)
            if n > 0:
                vec = vec / n
            bows.append(vec)
        return torch.stack(bows, dim=0)

    def forward(self, texts: List[str]):
        if isinstance(texts, str):
            texts = [texts]
        bows = self._hash_bow(texts).to(self.projection.weight.device)
        return self.projection(bows)

class TextEncoder(nn.Module):
    """Wrapper that uses HF if available, else hashing fallback."""
    def __init__(self, model_name, output_dim, device="cpu"):
        super().__init__()
        if _transformers_available():
            print("[INFO] Using HuggingFace transformers for TextEncoder.")
            self.impl = HFTextEncoder(model_name, output_dim, device=device)
        else:
            print("[WARN] transformers not found; using hashing-based fallback TextEncoder.")
            self.impl = HashingTextEncoder(output_dim, device=device)

    def forward(self, texts: List[str]):
        return self.impl(texts)

# ------------------------- SENSOR MODEL -------------------------
class SensorEncoder(nn.Module):
    def __init__(self, input_dim, output_dim):
        super().__init__()
        self.encoder = nn.Sequential(
            nn.Linear(input_dim, 128),
            nn.ReLU(),
            nn.Linear(128, output_dim),
        )
    def forward(self, x):
        return self.encoder(x)

# ------------------------- LOSS & UTILS -------------------------
class ContrastiveSimilarityLoss(nn.Module):
    """
    Normalized contrastive-style loss:
      positive (label=1): 1 - cos(z1,z2)
      negative (label=0): max(0, cos(z1,z2) - margin)
    """
    def __init__(self, margin=0.5):
        super().__init__()
        self.margin = margin

    def forward(self, output1, output2, label):
        o1 = F.normalize(output1, p=2, dim=1, eps=1e-8)
        o2 = F.normalize(output2, p=2, dim=1, eps=1e-8)
        sim = F.cosine_similarity(o1, o2, dim=1, eps=1e-8)
        if label.dim() > 1:
            label = label.view(-1)
        loss_pos = label * (1.0 - sim)
        zero = torch.zeros_like(sim)
        loss_neg = (1.0 - label) * torch.max(zero, sim - self.margin)
        return torch.mean(loss_pos + loss_neg)

def l2_project(eta, epsilon):
    flat = eta.view(eta.size(0), -1)
    norms = torch.norm(flat, p=2, dim=1, keepdim=True).clamp(min=1e-12)
    factor = torch.clamp(epsilon / norms, max=1.0)
    return (flat * factor).view_as(eta)

def compute_norms(perturbation):
    p = perturbation.detach().cpu()
    linf = p.view(p.size(0), -1).abs().amax(dim=1).mean().item()
    l2 = torch.norm(p.view(p.size(0), -1), p=2, dim=1).mean().item()
    return linf, l2

# ------------------------- ATTACK (PGD on text embedding) -------------------------
def pgd_projected_attack(
    sensor_encoder, text_encoder, criterion,
    x_s, text_str, label,
    alpha, epsilon, steps, device,
    mode="linf", dbg_interval=50, early_stop=True,
    random_start=True
):
    device = torch.device(device)
    x_s = x_s.unsqueeze(0).to(device)                     # (1, F)
    label = torch.tensor([label], dtype=torch.float32).to(device)

    sensor_encoder.eval()
    text_encoder.eval()

    with torch.no_grad():
        z_s = sensor_encoder(x_s)
        z_s = F.normalize(z_s, p=2, dim=1, eps=1e-8)
        z_t_orig = text_encoder([text_str])
        z_t_orig = F.normalize(z_t_orig, p=2, dim=1, eps=1e-8)
        loss_before = float(criterion(z_s, z_t_orig, label).item())
        orig_sim = float(F.cosine_similarity(z_s, z_t_orig, dim=1).item())

    # If already below margin, nothing to do
    if orig_sim < MARGIN:
        return {
            "orig_sim": orig_sim, "adv_sim": orig_sim,
            "loss_before": loss_before, "loss_after": loss_before,
            "linf": 0.0, "l2": 0.0, "steps": 0, "orig_below_margin": True
        }

    # Attack variable around the base text embedding
    z_t_base = z_t_orig.detach().clone().to(device)       # (1, D)
    z_t_adv  = z_t_base.clone().detach()

    # Random start within epsilon-ball so larger ε explores more
    if random_start and epsilon > 0:
        if mode == "linf":
            noise = torch.empty_like(z_t_adv).uniform_(-epsilon, epsilon)
        else:
            noise = torch.randn_like(z_t_adv)
            n = torch.norm(noise.view(noise.size(0), -1), p=2, dim=1, keepdim=True).clamp(min=1e-12)
            noise = (noise / n) * epsilon * torch.rand_like(n)  # radius in [0, ε]
            noise = noise.view_as(z_t_adv)
        z_t_adv = (z_t_adv + noise).detach()

    z_t_adv.requires_grad_(True)

    # Scale step by ε (bounded)
    step_size = max(1e-6, float(alpha) * (float(epsilon) if epsilon > 0 else 1.0))
    if epsilon > 0:
        step_size = min(step_size, float(epsilon))

    for step in range(steps):
        if z_t_adv.grad is not None:
            z_t_adv.grad.zero_()

        loss = criterion(z_s, z_t_adv, label)   # normalized inside
        # gradient ascent to increase loss (reduce similarity for positive pairs)
        loss.backward()
        grad = z_t_adv.grad.data

        if mode == "linf":
            step_vec = step_size * torch.sign(grad)
        else:
            g = grad.view(grad.size(0), -1)
            g_norm = torch.norm(g, p=2, dim=1, keepdim=True).clamp(min=1e-12)
            step_vec = (step_size * (g / g_norm)).view_as(grad)

        z_t_adv.data = z_t_adv.data + step_vec

        # Project back to ε-ball
        eta = z_t_adv.data - z_t_base.data
        if mode == "linf":
            eta = torch.clamp(eta, -epsilon, epsilon)
        else:
            eta = l2_project(eta, epsilon)
        z_t_adv.data = z_t_base.data + eta

        # Diagnostics
        if dbg_interval and (step % dbg_interval == 0 or step == steps - 1):
            with torch.no_grad():
                cur_eta = (z_t_adv.data - z_t_base.data).detach().cpu()
                linf_now = cur_eta.abs().amax().item()
                grad_norm = torch.norm(grad.view(grad.size(0), -1), p=2, dim=1).mean().item()
                sim_now = float(F.cosine_similarity(z_s, F.normalize(z_t_adv, p=2, dim=1), dim=1).item())
                print(f"[attack] step {step+1}/{steps} loss={loss.item():.6f} sim={sim_now:.6f} linfΔ={linf_now:.6f} ‖∇‖={grad_norm:.6e}")

        if early_stop:
            with torch.no_grad():
                sim_val = float(F.cosine_similarity(z_s, F.normalize(z_t_adv, p=2, dim=1), dim=1).item())
            if sim_val < MARGIN:
                perturbation = (z_t_adv.detach() - z_t_base.detach()).cpu()
                linf_val, l2_val = compute_norms(perturbation)
                loss_after = float(criterion(z_s, z_t_adv, label).item())
                return {
                    "orig_sim": orig_sim, "adv_sim": sim_val,
                    "loss_before": loss_before, "loss_after": loss_after,
                    "linf": linf_val, "l2": l2_val,
                    "steps": step + 1, "orig_below_margin": False
                }

    # End of iterations
    with torch.no_grad():
        final_sim = float(F.cosine_similarity(z_s, F.normalize(z_t_adv, p=2, dim=1), dim=1).item())
        loss_after = float(criterion(z_s, z_t_adv, label).item())
        perturbation = (z_t_adv.detach() - z_t_base.detach()).cpu()
        linf_val, l2_val = compute_norms(perturbation)

    return {
        "orig_sim": orig_sim, "adv_sim": final_sim,
        "loss_before": loss_before, "loss_after": loss_after,
        "linf": linf_val, "l2": l2_val,
        "steps": steps, "orig_below_margin": False
    }

# ------------------------- RUNNER (CSV) -------------------------
def run_and_save(
    sensor_path=SENSOR_MODEL_PATH,
    text_path=TEXT_MODEL_PATH,
    data_file=DATA_FILE,
    target_indices=TARGET_INDICES,
    eps_sweep=EPSILON_SWEEP,
    alpha=DEFAULT_ALPHA,
    steps=DEFAULT_STEPS,
    mode=MODE,
    restarts=RANDOM_RESTARTS,
    dbg_interval=DBG_INTERVAL,
    output_csv=OUTPUT_CSV,
    device=DEVICE
):
    set_seed(SEED)
    device = torch.device(device)
    print(f"Device: {device}\n")

    # Models
    try:
        sensor_encoder = SensorEncoder(SENSOR_FEATURES_COUNT, EMBEDDING_DIM).to(device)
        if os.path.exists(sensor_path):
            try:
                sensor_encoder.load_state_dict(torch.load(sensor_path, map_location=device))
                print("[INFO] Loaded sensor encoder weights.")
            except Exception as e:
                print(f"[WARN] Could not load sensor state_dict cleanly: {e}. Using fresh sensor encoder.")
        else:
            print("[WARN] sensor model checkpoint not found; using random-initialized sensor encoder.")

        text_encoder = TextEncoder(TEXT_MODEL_NAME, EMBEDDING_DIM, device=str(device)).to(device)
        # Try to load projection weights if a checkpoint exists
        if os.path.exists(text_path):
            try:
                ckpt = torch.load(text_path, map_location=device)
                if isinstance(text_encoder.impl, HFTextEncoder):
                    if isinstance(ckpt, dict) and ("projection.weight" in ckpt or "projection.bias" in ckpt):
                        proj_sd = {k.replace("projection.", ""): v for k, v in ckpt.items() if k.startswith("projection.")}
                        text_encoder.impl.projection.load_state_dict(proj_sd, strict=False)
                        print("[INFO] Loaded text encoder projection weights (strict=False).")
                else:
                    try:
                        text_encoder.impl.load_state_dict(ckpt, strict=False)
                        print("[INFO] Loaded fallback text encoder checkpoint (strict=False).")
                    except Exception:
                        print("[WARN] Could not load checkpoint into fallback encoder.")
            except Exception as e:
                print(f"[WARN] could not load text checkpoint: {e}. Proceeding with base encoder.")
        else:
            print("[INFO] No text checkpoint found; using base text encoder (HF or fallback).")

        criterion = ContrastiveSimilarityLoss(margin=MARGIN)
    except Exception as e:
        print(f"[ERROR] loading models: {e}")
        return

    # Data
    try:
        df = pd.read_csv(data_file)
    except Exception as e:
        print(f"[ERROR] loading CSV: {e}")
        return

    for c in SENSOR_COLUMNS:
        if c not in df.columns:
            print(f"[ERROR] Missing required WESAD column: {c}")
            return
    if TEXT_COLUMN_NAME not in df.columns:
        print(f"[ERROR] Missing text column: {TEXT_COLUMN_NAME}")
        return
    if ACTIVITY_COLUMN_NAME not in df.columns:
        print(f"[ERROR] Missing activity column: {ACTIVITY_COLUMN_NAME}")
        return

    sensor_values = df[SENSOR_COLUMNS].values
    texts = df[TEXT_COLUMN_NAME].astype(str).tolist()

    header = [
        "index", "activity", "epsilon", "alpha", "steps", "mode", "restarts",
        "orig_sim", "adv_sim", "loss_before", "loss_after",
        "linf", "l2", "pgd_steps", "success", "note"
    ]

    with open(output_csv, "w", newline="") as f:
        writer = csv.writer(f)
        writer.writerow(header)

        for idx in target_indices:
            if idx < 0 or idx >= len(df):
                print(f"[WARN] index {idx} out of range, skipping")
                continue

            x_s = torch.tensor(sensor_values[idx], dtype=torch.float32)
            text_str = texts[idx]
            label = 1.0  # positive pair
            activity = df[ACTIVITY_COLUMN_NAME].iloc[idx]

            print("-------------------------------------------------------------")
            print(f"ATTACK SAMPLE index={idx} activity={activity}")
            print(f"text preview: '{text_str[:120]}...'")
            print("-------------------------------------------------------------")

            for eps in eps_sweep:
                best_result = None
                note = ""
                for restart_id in range(restarts):
                    res = pgd_projected_attack(
                        sensor_encoder, text_encoder, criterion,
                        x_s, text_str, label,
                        alpha, eps, steps, device,
                        mode=mode, dbg_interval=dbg_interval, early_stop=EARLY_STOP,
                        random_start=True,
                    )

                    # Primary: minimize adv_sim; Secondary: maximize loss increase
                    if (best_result is None
                        or (res["adv_sim"] < best_result["adv_sim"])
                        or (math.isclose(res["adv_sim"], best_result["adv_sim"])
                            and (res["loss_after"] - res["loss_before"]) >
                                (best_result["loss_after"] - best_result["loss_before"]))):
                        best_result = res

                success = int(best_result["adv_sim"] < MARGIN)
                if best_result.get("orig_below_margin", False):
                    note = "orig_sim_below_margin"
                elif best_result["loss_after"] <= best_result["loss_before"]:
                    note = "loss_not_increased"

                row = [
                    idx, activity, eps, alpha, steps, mode, restarts,
                    round(best_result["orig_sim"], 6),
                    round(best_result["adv_sim"], 6),
                    round(best_result["loss_before"], 6),
                    round(best_result["loss_after"], 6),
                    round(best_result["linf"], 6),
                    round(best_result["l2"], 6),
                    best_result["steps"],
                    success,
                    note
                ]
                writer.writerow(row)
                f.flush()
                print(f"EPS={eps}  orig_sim={row[7]:.4f} adv_sim={row[8]:.4f} "
                      f"loss_before={row[9]:.6f} loss_after={row[10]:.6f} "
                      f"linf={row[11]:.6f} l2={row[12]:.6f} success={row[14]} note={note}")
            print("-------------------------------------------------------------\n")

    print(f"All done. Results saved to {output_csv}")

# ------------------------- ENTRYPOINT -------------------------
if __name__ == "__main__":
    if not os.path.exists(DATA_FILE):
        print(f"[ERROR] Data file not found: {DATA_FILE}")
        print("Update DATA_FILE and rerun.")
        sys.exit(1)
    run_and_save()


Device: cuda

[WARN] Could not load sensor state_dict cleanly: Error(s) in loading state_dict for SensorEncoder:
	size mismatch for encoder.0.weight: copying a param with shape torch.Size([128, 14]) from checkpoint, the shape in current model is torch.Size([128, 6]).. Using fresh sensor encoder.
[INFO] Using HuggingFace transformers for TextEncoder.
[INFO] Loaded text encoder projection weights (strict=False).
-------------------------------------------------------------
ATTACK SAMPLE index=0 activity=Stress
text preview: '**General description:** Non‑stress periods display relaxed autonomic balance and smoother movement control.
**Accelerom...'
-------------------------------------------------------------
[attack] step 1/1000 loss=0.787763 sim=0.205347 linfΔ=0.495171 ‖∇‖=4.052399e-01
[attack] step 1/1000 loss=0.551056 sim=0.443607 linfΔ=0.496124 ‖∇‖=3.230856e-01
[attack] step 1/1000 loss=0.946168 sim=0.046258 linfΔ=0.483051 ‖∇‖=4.542835e-01
EPS=0.5  orig_sim=0.8516 adv_sim=0.0463 loss