In [4]:
# Updated attack harness (drop-in replacement)
import os
import csv
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
import pandas as pd
from transformers import AutoTokenizer, AutoModel

# ------------------------- CONFIG -------------------------
EMBEDDING_DIM = 64
MARGIN = 0.5
TEXT_MODEL_NAME = "distilbert-base-uncased"

SENSOR_FEATURES_COUNT = 12
TEXT_COLUMN_NAME = "Semantic_Interpretation"
EXCLUDE_COLS = ["Label", TEXT_COLUMN_NAME]

SENSOR_MODEL_PATH = "sensor_encoder_motionsense_12col.pth"
TEXT_MODEL_PATH = "text_encoder_motionsense_12col.pth"  # optional; safe load
DATA_FILE = "./data/Motionsese_with_semantic_interpretation.csv"

# Attack defaults
DEFAULT_ALPHA = 0.005           # base step scalar (used to scale with epsilon)
EPSILON_SWEEP = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1,2,3,4,5,6,7,8,9,10]
DEFAULT_STEPS = 1000             # reduced for speed; you can increase
MODE = "linf"                   # "linf" or "l2"
RANDOM_RESTARTS = 1
DBG_INTERVAL = 50
EARLY_STOP = True

TARGET_INDICES = [0, 1000, 2000, 3000, 4000]
OUTPUT_CSV = "attack_results_Motionsense_new1.csv"
SEED = 0

# ------------------------- MODELS -------------------------
class SensorEncoder(nn.Module):
    def __init__(self, input_dim, output_dim):
        super().__init__()
        self.encoder = nn.Sequential(
            nn.Linear(input_dim, 128),
            nn.ReLU(),
            nn.Linear(128, output_dim)
        )

    def forward(self, x):
        return self.encoder(x)

class TextEncoder(nn.Module):
    """
    Returns the projected embedding directly: shape [batch, EMBEDDING_DIM].
    Loading state is attempted for 'projection' and optionally the transformer.
    """
    def __init__(self, model_name, output_dim, device="cpu"):
        super().__init__()
        self.device = torch.device(device)
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.model = AutoModel.from_pretrained(model_name)
        self.projection = nn.Linear(self.model.config.hidden_size, output_dim)

    def forward(self, texts):
        # texts: list[str] or single str
        if isinstance(texts, str):
            texts = [texts]
        encoded = self.tokenizer(texts, padding=True, truncation=True, return_tensors="pt")
        encoded = {k: v.to(self.device) for k, v in encoded.items()}
        output = self.model(**encoded)
        # mean-pool last hidden state across seq dimension (B, L, H) -> (B, H)
        pooled = output.last_hidden_state.mean(dim=1)
        z_t = self.projection(pooled)   # (B, EMBEDDING_DIM)
        return z_t

# ------------------------- LOSS & UTILS -------------------------
class ContrastiveSimilarityLoss(nn.Module):
    def __init__(self, margin=0.5):
        super().__init__()
        self.margin = margin

    def forward(self, output1, output2, label):
        # output1, output2: (B, D); label: (B, 1) or scalar (1->positive,0->neg)
        sim = F.cosine_similarity(output1, output2, dim=1, eps=1e-8)  # (B,)
        if label.dim() > 1:
            label = label.view(-1)
        loss_pos = label * (1.0 - sim)
        zero = torch.zeros_like(sim)
        loss_neg = (1.0 - label) * torch.max(zero, sim - self.margin)
        return torch.mean(loss_pos + loss_neg)

def set_seed(seed):
    torch.manual_seed(seed)
    random.seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

def l2_project(eta, epsilon):
    flat = eta.view(eta.size(0), -1)
    norms = torch.norm(flat, p=2, dim=1, keepdim=True).clamp(min=1e-12)
    factor = torch.clamp(epsilon / norms, max=1.0)
    return (flat * factor).view_as(eta)

def compute_norms(perturbation):
    p = perturbation.detach().cpu()
    linf = p.view(p.size(0), -1).abs().amax(dim=1).mean().item()
    l2 = torch.norm(p.view(p.size(0), -1), p=2, dim=1).mean().item()
    return linf, l2

# ------------------------- ATTACK FUNCTION -------------------------
def pgd_projected_attack(
    sensor_encoder, text_encoder, criterion,
    x_s, text_str, label,
    alpha, epsilon, steps, device,
    mode="linf", dbg_interval=50, early_stop=True
):
    device = torch.device(device)
    x_s = x_s.unsqueeze(0).to(device)                     # (1, F)
    label = torch.tensor([label], dtype=torch.float32).to(device)  # (1,)

    sensor_encoder.eval()
    text_encoder.eval()

    with torch.no_grad():
        z_s = sensor_encoder(x_s)                         # (1, D)
        z_t_orig = text_encoder([text_str])               # (1, D)
        loss_before = float(criterion(z_s, z_t_orig, label).item())
        orig_sim = float(F.cosine_similarity(z_s, z_t_orig, dim=1).item())

    # If already below margin, nothing to do (report)
    orig_below = orig_sim < MARGIN
    if orig_below:
        return {
            "orig_sim": orig_sim,
            "adv_sim": orig_sim,
            "loss_before": loss_before,
            "loss_after": loss_before,
            "linf": 0.0,
            "l2": 0.0,
            "steps": 0,
            "orig_below_margin": True
        }

    # initialize adversarial vector in projected space
    z_t_base = z_t_orig.detach().clone().to(device)
    z_t_adv = z_t_base.clone().detach().requires_grad_(True)

    # STEP SIZE: scale with epsilon so larger eps => larger per-step move
    # baseline step: either alpha * epsilon (if alpha is small), or epsilon/steps
    step_size = float(max(alpha * float(epsilon), float(epsilon) / max(1, steps)))
    # cap step_size reasonably to avoid explosion
    step_size = min(step_size, float(epsilon) if epsilon > 0 else step_size)

    for step in range(steps):
        if z_t_adv.grad is not None:
            z_t_adv.grad.zero_()

        loss = criterion(z_s, z_t_adv, label)
        loss.backward()
        grad = z_t_adv.grad.data

        if mode == "linf":
            step_vec = step_size * torch.sign(grad)
        else:
            g_flat = grad.view(grad.size(0), -1)
            g_norm = torch.norm(g_flat, p=2, dim=1, keepdim=True).clamp(min=1e-12)
            step_vec = (step_size * (g_flat / g_norm)).view_as(grad)

        # update (in-place)
        z_t_adv.data = z_t_adv.data + step_vec

        # project back to epsilon-ball around base
        eta = z_t_adv.data - z_t_base.data
        if mode == "linf":
            eta = torch.clamp(eta, -epsilon, epsilon)
        else:
            eta = l2_project(eta, epsilon)
        z_t_adv.data = z_t_base.data + eta

        if dbg_interval is not None and (step % dbg_interval == 0 or step == steps - 1):
            with torch.no_grad():
                cur_eta = (z_t_adv.data - z_t_base.data).detach().cpu()
                linf_now = cur_eta.abs().amax().item()
                sim_now = float(F.cosine_similarity(z_s, z_t_adv, dim=1).item())
                print(f"[attack] step {step+1}/{steps}  loss={loss.item():.6f}  linf_delta={linf_now:.6f}  sim={sim_now:.6f}")

        if early_stop:
            with torch.no_grad():
                sim_val = float(F.cosine_similarity(z_s, z_t_adv, dim=1).item())
            if sim_val < MARGIN:
                perturbation = (z_t_adv.detach() - z_t_base.detach()).cpu()
                linf_val, l2_val = compute_norms(perturbation)
                loss_after = float(criterion(z_s, z_t_adv, label).item())
                return {
                    "orig_sim": orig_sim,
                    "adv_sim": sim_val,
                    "loss_before": loss_before,
                    "loss_after": loss_after,
                    "linf": linf_val,
                    "l2": l2_val,
                    "steps": step + 1,
                    "orig_below_margin": False
                }

    # finished all steps
    with torch.no_grad():
        final_sim = float(F.cosine_similarity(z_s, z_t_adv, dim=1).item())
        loss_after = float(criterion(z_s, z_t_adv, label).item())
        perturbation = (z_t_adv.detach() - z_t_base.detach()).cpu()
        linf_val, l2_val = compute_norms(perturbation)

    return {
        "orig_sim": orig_sim,
        "adv_sim": final_sim,
        "loss_before": loss_before,
        "loss_after": loss_after,
        "linf": linf_val,
        "l2": l2_val,
        "steps": steps,
        "orig_below_margin": False
    }

# ------------------------- HARNESS (CSV) -------------------------
def run_and_save(
    sensor_path=SENSOR_MODEL_PATH,
    text_path=TEXT_MODEL_PATH,
    data_file=DATA_FILE,
    target_indices=TARGET_INDICES,
    eps_sweep=EPSILON_SWEEP,
    alpha=DEFAULT_ALPHA,
    steps=DEFAULT_STEPS,
    mode=MODE,
    restarts=RANDOM_RESTARTS,
    dbg_interval=DBG_INTERVAL,
    output_csv=OUTPUT_CSV
):
    set_seed(SEED)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Device: {device}\n")

    # Load models
    try:
        sensor_encoder = SensorEncoder(SENSOR_FEATURES_COUNT, EMBEDDING_DIM).to(device)
        if os.path.exists(sensor_path):
            try:
                sensor_encoder.load_state_dict(torch.load(sensor_path, map_location=device))
                print("[INFO] Loaded sensor encoder weights.")
            except Exception as e:
                print(f"[WARN] Could not load sensor state_dict cleanly: {e}. Using fresh sensor encoder.")
        else:
            print("[WARN] sensor model checkpoint not found, using random-initialized sensor encoder.")

        text_encoder = TextEncoder(TEXT_MODEL_NAME, EMBEDDING_DIM, device=str(device)).to(device)
        if os.path.exists(text_path):
            try:
                ckpt = torch.load(text_path, map_location=device)
                # try to load projection layer only (safe)
                if "projection.weight" in ckpt or "projection.bias" in ckpt:
                    text_encoder.load_state_dict(ckpt, strict=False)
                    print("[INFO] Loaded available text encoder checkpoint (strict=False).")
                else:
                    # attempt full load but catch errors
                    try:
                        text_encoder.load_state_dict(ckpt)
                        print("[INFO] Loaded full text encoder checkpoint.")
                    except Exception:
                        print("[WARN] Could not fully load text encoder ckpt; proceeding with base model + projection.")
            except Exception as e:
                print(f"[WARN] could not load text checkpoint: {e}. Using pretrained transformer + fresh projection.")
        else:
            print("[INFO] No text checkpoint found; using pretrained model + fresh projection.")

        criterion = ContrastiveSimilarityLoss(margin=MARGIN)
    except Exception as e:
        print(f"[ERROR] loading models: {e}")
        return

    # Load data
    try:
        df = pd.read_csv(data_file)
    except Exception as e:
        print(f"[ERROR] loading CSV: {e}")
        return

    sensor_cols = [c for c in df.columns if c not in EXCLUDE_COLS]
    if len(sensor_cols) != SENSOR_FEATURES_COUNT:
        print(f"[ERROR] sensor column mismatch: found {len(sensor_cols)} expected {SENSOR_FEATURES_COUNT}")
        print("Detected:", sensor_cols)
        return

    sensor_values = df[sensor_cols].values
    texts = df[TEXT_COLUMN_NAME].astype(str).tolist()

    header = [
        "index", "label", "epsilon", "alpha", "steps", "mode", "restarts",
        "orig_sim", "adv_sim", "loss_before", "loss_after", "linf", "l2", "pgd_steps", "success", "note"
    ]

    with open(output_csv, "w", newline="") as f:
        writer = csv.writer(f)
        writer.writerow(header)

        for idx in target_indices:
            if idx < 0 or idx >= len(df):
                print(f"[WARN] index {idx} out of range, skipping")
                continue

            x_s = torch.tensor(sensor_values[idx], dtype=torch.float32)
            text_str = texts[idx]
            label = 1.0
            label_str = df["Label"].iloc[idx] if "Label" in df.columns else "NA"

            print("-------------------------------------------------------------")
            print(f"ATTACK SAMPLE index={idx} label={label_str}")
            print(f"text preview: '{text_str[:120]}...'")
            print("-------------------------------------------------------------")

            for eps in eps_sweep:
                best_result = None
                note = ""
                for rr in range(restarts):
                    res = pgd_projected_attack(
                        sensor_encoder, text_encoder, criterion,
                        x_s, text_str, label,
                        alpha, eps, steps, device,
                        mode=mode, dbg_interval=dbg_interval, early_stop=EARLY_STOP
                    )
                    if best_result is None or (res["adv_sim"] < best_result["adv_sim"]):
                        best_result = res

                success = int(best_result["adv_sim"] < MARGIN)
                if best_result.get("orig_below_margin", False):
                    note = "orig_sim_below_margin"

                row = [
                    idx, label_str, eps, alpha, steps, mode, restarts,
                    round(best_result["orig_sim"], 6),
                    round(best_result["adv_sim"], 6),
                    round(best_result["loss_before"], 6),
                    round(best_result["loss_after"], 6),
                    round(best_result["linf"], 6),
                    round(best_result["l2"], 6),
                    best_result["steps"],
                    success,
                    note
                ]
                writer.writerow(row)
                f.flush()
                print(f"EPS={eps}  orig_sim={row[7]:.4f} adv_sim={row[8]:.4f} linf={row[10]:.6f} l2={row[11]:.6f} success={row[13]} note={note}")
            print("-------------------------------------------------------------\n")

    print(f"All done. Results saved to {output_csv}")

# ------------------------- ENTRYPOINT -------------------------
if __name__ == "__main__":
    run_and_save()


Device: cuda

[INFO] Loaded sensor encoder weights.
[INFO] Loaded available text encoder checkpoint (strict=False).
-------------------------------------------------------------
ATTACK SAMPLE index=0 label=Sitting
text preview: '**General description:** Sitting is a static posture with minimal whole-body movement. **Accelerometer patterns:** near-...'
-------------------------------------------------------------
[attack] step 1/1000  loss=0.497635  linf_delta=0.000500  sim=0.502321
[attack] step 51/1000  loss=0.499808  linf_delta=0.025510  sim=0.500148
EPS=0.1  orig_sim=0.5024 adv_sim=0.5000 linf=0.500026 l2=0.027511 success=55 note=
[attack] step 1/1000  loss=0.497635  linf_delta=0.001000  sim=0.502278
EPS=0.2  orig_sim=0.5024 adv_sim=0.4999 linf=0.500069 l2=0.028011 success=28 note=
[attack] step 1/1000  loss=0.497635  linf_delta=0.001500  sim=0.502235
EPS=0.3  orig_sim=0.5024 adv_sim=0.4999 linf=0.500113 l2=0.028502 success=19 note=
[attack] step 1/1000  loss=0.497635  linf_delta=0.