In [None]:
!module load pytorchgpu/1

In [None]:
!pip install "accelerate>=0.21.0"

In [None]:
!pip install --upgrade "numpy>=1.22,<1.24" "tensorboard>=2.12,<2.13"

In [None]:
!pip install --upgrade protobuf==3.20.3

In [None]:
import torch
import random
import torch.nn.functional as F
import pandas as pd
import os
import gc
import pickle
import optuna
from transformers import AutoTokenizer, AutoModelForCausalLM
from datasets import load_dataset

from transformers import BitsAndBytesConfig

import matplotlib.pyplot as plt


TUNE = True


USE_SAVED_PARAMS = True


loss_type  = "Sentiment"
num_shots  = 4
n_edit     = 3

steps      = 40
alpha      = 3

eps        = 100.0
k_nn       = 10

trials     = 3
seed       = 42
device     = "cuda" if torch.cuda.is_available() else "cpu"
verbose    = True

random.seed(seed)
torch.manual_seed(seed)

param_grid = {
    'steps': [40, 80],
    'alpha': [3],
}


def load_cached_model(model_id="facebook/opt-30b"):
    global model, tok, E, V, d, pos_id, neg_id
    cache = ("model","tok","E","V","d","pos_id","neg_id")
    if getattr(globals().get("model",None),"name_or_path",None)==model_id \
       and all(v in globals() for v in cache):
        return model, tok, E, V, d, pos_id, neg_id

    if "model" in globals():
        del model
        for v in cache[1:]:
            globals().pop(v,None)
        gc.collect()
        if torch.cuda.is_available(): torch.cuda.empty_cache()

    tok = AutoTokenizer.from_pretrained(model_id, use_auth_token=True)
    tok.pad_token = tok.eos_token
   ### quant
    quant_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_compute_dtype=torch.float16,
        bnb_4bit_use_double_quant=True
    )
   ### quant

#     model = AutoModelForCausalLM.from_pretrained(
#         model_id, torch_dtype=torch.float16, use_auth_token=True
#     ).to(device).eval()

    model = AutoModelForCausalLM.from_pretrained(
        model_id,
        quantization_config=quant_config,
        device_map="auto"
    ).eval()

    model.name_or_path = model_id

    E      = model.get_input_embeddings().weight
    V, d   = E.shape
    pos_id = tok("positive", add_special_tokens=False)["input_ids"][0]
    neg_id = tok("negative", add_special_tokens=False)["input_ids"][0]
    return model, tok, E, V, d, pos_id, neg_id

model, tok, E, V, d, pos_id, neg_id = load_cached_model()

label_word = {1:"positive",0:"negative"}
topic_word = {0:"world",1:"sports",2:"business",3:"technology"}


train_ds = load_dataset("glue","sst2",split="train")
val_ds   = load_dataset("glue","sst2",split="validation")


traverse_full_val = False


if not TUNE and USE_SAVED_PARAMS:
    if os.path.exists("best_params.pkl"):
        best = pickle.load(open("best_params.pkl", "rb"))
        steps = best["steps"]
        alpha = best["alpha"]
    else:
        print("Warning: best_params.pkl not found, using defaults.")


if traverse_full_val:
    queries = list(val_ds)
    trials  = len(queries)
else:
    queries = random.sample(list(val_ds), trials)


def build_prompt(demo_df, q_sent, q_lbl):
    if loss_type=="Topic":
        instr="Classify the topic of the last review. Here are several examples."
        tag="\nTopic:"; labmap=topic_word
        tgt = tok(labmap[q_lbl],add_special_tokens=False)["input_ids"][0]
    else:
        instr=(
          "Analyze the sentiment of the last review and respond with "
          "either positive or negative. Here are several examples."
        )
        tag="\nSentiment:"; labmap=label_word
        tgt = pos_id if q_lbl==1 else neg_id

    demos_str=""; demo_sents=[]
    for sent,lab in zip(demo_df["sentence"][:num_shots],
                        demo_df["label"][:num_shots]):
        s=sent.strip(); demo_sents.append(s)
        demos_str+=f"\nReview: {s}{tag}{labmap[lab]}"
    q_stub=f"\nReview: {q_sent.strip()}{tag[:-1]}:"
    return f"{instr}\n{demos_str}{q_stub}", demo_sents, tgt

def classify_token(ids):
    lg = model(ids.unsqueeze(0)).logits[0,-1]
    if lg[pos_id]>lg[neg_id]:
        p=pos_id; prob=torch.softmax(lg[[neg_id,pos_id]],0)[1].item()
    else:
        p=neg_id; prob=torch.softmax(lg[[neg_id,pos_id]],0)[0].item()
    return p, prob

def run_alg1_with_query(train_ds, query, trial_idx):
    demo_df = pd.DataFrame(random.sample(list(train_ds), num_shots))
    prompt, demo_sents, tgt_id = build_prompt(
        demo_df, query["sentence"], query["label"]
    )
    enc = tok(prompt,
              return_tensors="pt",
              return_offsets_mapping=True,
              add_special_tokens=False).to(device)
    ids  = enc.input_ids[0]
    mask = enc.attention_mask[0]
    offs = enc.offset_mapping[0].tolist()

    spans=[]
    for s in demo_sents:
        cs,ce=prompt.index(s),prompt.index(s)+len(s)
        tok_s=next(i for i,(a,b) in enumerate(offs) if a<=cs<b)
        tok_e=next(i for i,(a,b) in enumerate(offs) if a<ce<=b)+1
        spans.append([tok_s,tok_e,[]])

    # —— Added: record label token spans for each ICE ——
    if loss_type=="Topic":
        tag="\nTopic:"; labmap=topic_word
    else:
        tag="\nSentiment:"; labmap=label_word
    label_spans = []
    for sent, lab in zip(demo_sents, demo_df["label"][:num_shots]):
        sent_idx = prompt.index(sent)
        char_start = prompt.find(tag+labmap[lab], sent_idx+len(sent))
        char_end = char_start + len(tag+labmap[lab])
        tokens = [i for i,(a,b) in enumerate(offs) if b>char_start and a<char_end]
        label_spans.append(tokens)

    pred_b,prob_b=classify_token(ids)
    base_ok=(pred_b==tgt_id)


    pre_eps=eps/num_shots
    delta_pre=torch.zeros_like(E[ids],requires_grad=True)
    for _ in range(steps):
        out=model(inputs_embeds=(E[ids]+delta_pre).unsqueeze(0),
                  attention_mask=mask.unsqueeze(0)).logits[0,-1]
        loss=F.cross_entropy(out.unsqueeze(0),
                             torch.tensor([tgt_id],device=device))
        loss.backward()
        with torch.no_grad():
            for s,e,posL in spans:
                delta_pre[s:e]+=alpha*delta_pre.grad[s:e]
                if delta_pre[s:e].norm()>pre_eps:
                    delta_pre[s:e]*=pre_eps/delta_pre[s:e].norm()
            delta_pre.grad.zero_()

    for s,e,posL in spans:
        norms=delta_pre[s:e].norm(dim=1)
        topk =norms.topk(min(n_edit,e-s)).indices+s
        posL.extend(topk.tolist())


    delta=torch.zeros_like(E[ids],requires_grad=True)
    trace=[]
    for _ in range(steps):
        out=model(inputs_embeds=(E[ids]+delta).unsqueeze(0),
                  attention_mask=mask.unsqueeze(0)).logits[0,-1]
        loss=F.cross_entropy(out.unsqueeze(0),
                             torch.tensor([tgt_id],device=device))
        trace.append(loss.item()); loss.backward()
        with torch.no_grad():
            for _,_,posL in spans:
                for p in posL:
                    delta[p]+=alpha*delta.grad[p]
            if delta.norm()>eps: delta.mul_(eps/delta.norm())
            delta.grad.zero_()

    ids_adv=ids.clone()
    save_dict = {
        "demo_df":    demo_df.to_dict(),
        "query":      {"sentence": query["sentence"], "label": query["label"]},
        "spans":      spans,
        "delta":      delta.detach().cpu().numpy(),
        "ice":        {},
        "label_spans": label_spans
    }
    for idx,(s,e,posL) in enumerate(spans,1):
        toks=[]
        for p in posL:
            orig=ids[p].item()
            best_id = None
            tgtv=E[orig]+delta[p]
            mask_ball=(E-E[orig]).norm(dim=1)<=eps
            cands=torch.where(mask_ball)[0]
            if cands.numel()>0:
                dists=(E[cands]-tgtv).norm(dim=1)
                best_id = int(cands[dists.topk(min(k_nn,cands.numel()),
                                              largest=False).indices[0]])
                ids_adv[p]=best_id
            toks.append({
              'pos':p,
              'orig_token':tok.convert_ids_to_tokens([orig])[0],
              'adv_token' :tok.convert_ids_to_tokens([best_id])[0] if best_id is not None else None,
              'orig_embedding':E[orig].detach().cpu().numpy(),
              'adv_embedding':E[best_id].detach().cpu().numpy() if best_id is not None else None,
              'delta':delta[p].detach().cpu().numpy()
            })
        save_dict["ice"][f"ICE_{idx}"] = {
          'pre_text':   tok.decode(ids[s:e],skip_special_tokens=True),
          'post_text':  tok.decode(ids_adv[s:e],skip_special_tokens=True),
          'token_info': toks
        }


    pred_a,prob_a = classify_token(ids_adv)
    adv_ok        = (pred_a==tgt_id)
    clean_acc     = float(base_ok)
    adv_acc       = float(adv_ok)
    drop          = clean_acc - adv_acc
    asr           = 1.0 - (adv_acc / clean_acc) if clean_acc>0 else 0.0

    print(f"\n=== Metrics Trial {trial_idx} ===")
    print(f"  clean_acc: {clean_acc:.3f}")
    print(f"  adv_acc  : {adv_acc:.3f}")
    print(f"  drop     : {drop:.3f}")
    print(f"  ASR      : {asr:.3f}")

    save_dict["metrics"] = {
        "clean_acc": clean_acc,
        "adv_acc":   adv_acc,
        "drop":      drop,
        "ASR":       asr
    }

    os.makedirs("perturb_info",exist_ok=True)
    with open(f"perturb_info/ice_deltas_trial_{trial_idx}.pkl","wb") as f:
        pickle.dump(save_dict,f)

    if verbose:
        print(f"\n=== Alg1 Trial {trial_idx} ===")
        print("loss trace:",[f"{x:.3f}" for x in trace])
        print(f"Baseline    : {tok.convert_ids_to_tokens([pred_b])[0]} (p={prob_b:.2f})","✓" if base_ok else "✗")
        print(f"AfterAttack : {tok.convert_ids_to_tokens([pred_a])[0]} (p={prob_a:.2f})","✓" if adv_ok else "✗")
        print("\n--- ICE Replacement Details ---")
        for idx,(s,e,posL) in enumerate(spans,1):
            before=tok.decode(ids[s:e],skip_special_tokens=True)
            after =tok.decode(ids_adv[s:e],skip_special_tokens=True)
            repl  =[(tok.convert_ids_to_tokens([ids[p]])[0],
                     tok.convert_ids_to_tokens([ids_adv[p]])[0]) for p in posL]
            print(f"\nICE {idx}:")
            print("  Before:",before)
            print("  After: ",after)
            print("  Replaced tokens:",repl)

    return base_ok, adv_ok


if TUNE:

    _plateau = {"best": 0.0, "count": 0}
    _asr_history = []

    def plateau_callback(study, trial):
        global _asr_history, _plateau
        val = trial.value
        _asr_history.append(val)

        if val > _plateau["best"] + 1e-3:
            _plateau["best"] = val
            _plateau["count"] = 0
        else:
            _plateau["count"] += 1
        if _plateau["count"] >= 5:
            study.stop()


        plt.figure(figsize=(10, 6))
        plt.plot(_asr_history, marker='o')
        plt.title('ASRup')
        plt.xlabel('trail number')
        plt.ylabel('(ASR)')
        plt.grid(True)
        plt.tight_layout()
        plt.savefig('asr_progress.png')


        try:
            plt.show(block=False)
            plt.pause(0.1)
        except:
            pass

        plt.close()

    def objective(trial):
        global steps, alpha
        steps = trial.suggest_categorical("steps", param_grid['steps'])
        alpha = trial.suggest_categorical("alpha", param_grid['alpha'])

        base_cnt = adv_cnt = 0
        for t, query in enumerate(queries, start=1):
            b, a = run_alg1_with_query(train_ds, query, t)
            base_cnt += b
            adv_cnt  += a

        clean_acc = base_cnt / len(queries)
        adv_acc   = adv_cnt  / len(queries)
        asr       = 1.0 - (adv_acc / clean_acc) if clean_acc > 0 else 0.0

        print(f"[Trial {trial.number:2d}] steps={steps}, alpha={alpha}, ASR={asr:.4f}, clean_acc={clean_acc:.4f}, adv_acc={adv_acc:.4f}")
        return asr


    study = optuna.create_study(direction="maximize")
    try:
        study.optimize(objective, n_trials=30, callbacks=[plateau_callback])
    except KeyboardInterrupt:
        print("\nOptimization stopped manually!")
        print("Saving current best parameters...")
    finally:
        # This block will execute whether optimization completes normally or is interrupted
        best_params = study.best_params
        best_value = study.best_value
        print(f"Best params: {best_params}")
        print(f"Best ASR   : {best_value}")
        print("ASR progress plot saved as 'asr_progress.png'")

        with open("best_params.pkl", "wb") as f:
            pickle.dump(best_params, f)

else:
    base_cnt = adv_cnt = 0
    for t, query in enumerate(queries, start=1):
        b, a = run_alg1_with_query(train_ds, query, t)
        base_cnt += b
        adv_cnt  += a

    clean_acc = base_cnt / len(queries)
    adv_acc   = adv_cnt  / len(queries)
    asr       = 1.0 - (adv_acc / clean_acc) if clean_acc > 0 else 0.0
    print(f"Normal → clean_acc={clean_acc:.3f}, adv_acc={adv_acc:.3f}, ASR={asr:.3f}")

In [None]:
import torch
import random
import torch.nn.functional as F
import pandas as pd
import os
import gc
import pickle
import optuna
from transformers import AutoTokenizer, AutoModelForCausalLM
from datasets import load_dataset

from transformers import BitsAndBytesConfig

import matplotlib.pyplot as plt


TUNE = False


USE_SAVED_PARAMS = True


loss_type  = "Sentiment"
num_shots  = 4
n_edit     = 3

steps      = 40
alpha      = 3

eps        = 100.0
k_nn       = 10

trials     = 20
seed       = 72
device     = "cuda" if torch.cuda.is_available() else "cpu"
verbose    = True

random.seed(seed)
torch.manual_seed(seed)

param_grid = {
    'steps': [40, 80],
    'alpha': [3],
}


def load_cached_model(model_id="facebook/opt-30b"):
    global model, tok, E, V, d, pos_id, neg_id
    cache = ("model","tok","E","V","d","pos_id","neg_id")
    if getattr(globals().get("model",None),"name_or_path",None)==model_id \
       and all(v in globals() for v in cache):
        return model, tok, E, V, d, pos_id, neg_id

    if "model" in globals():
        del model
        for v in cache[1:]:
            globals().pop(v,None)
        gc.collect()
        if torch.cuda.is_available(): torch.cuda.empty_cache()

    tok = AutoTokenizer.from_pretrained(model_id, use_auth_token=True)
    tok.pad_token = tok.eos_token
   ### quant
    quant_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_compute_dtype=torch.float16,
        bnb_4bit_use_double_quant=True
    )
   ### quant



    model = AutoModelForCausalLM.from_pretrained(
        model_id,
        quantization_config=quant_config,
        device_map="auto"
    ).eval()

    model.name_or_path = model_id

    E      = model.get_input_embeddings().weight
    V, d   = E.shape
    pos_id = tok("positive", add_special_tokens=False)["input_ids"][0]
    neg_id = tok("negative", add_special_tokens=False)["input_ids"][0]
    return model, tok, E, V, d, pos_id, neg_id

model, tok, E, V, d, pos_id, neg_id = load_cached_model()

label_word = {1:"positive",0:"negative"}
topic_word = {0:"world",1:"sports",2:"business",3:"technology"}


train_ds = load_dataset("glue","sst2",split="train")
val_ds   = load_dataset("glue","sst2",split="validation")


traverse_full_val = False


if not TUNE and USE_SAVED_PARAMS:
    if os.path.exists("best_params.pkl"):
        best = pickle.load(open("best_params.pkl", "rb"))
        steps = best["steps"]
        alpha = best["alpha"]
    else:
        print("Warning: best_params.pkl not found, using defaults.")
#     traverse_full_val = True

if traverse_full_val:
    queries = list(val_ds)
    trials  = len(queries)
else:
    queries = random.sample(list(val_ds), trials)


def build_prompt(demo_df, q_sent, q_lbl):
    if loss_type=="Topic":
        instr="Classify the topic of the last review. Here are several examples."
        tag="\nTopic:"; labmap=topic_word
        tgt = tok(labmap[q_lbl],add_special_tokens=False)["input_ids"][0]
    else:
        instr=(
          "Analyze the sentiment of the last review and respond with "
          "either positive or negative. Here are several examples."
        )
        tag="\nSentiment:"; labmap=label_word
        tgt = pos_id if q_lbl==1 else neg_id

    demos_str=""; demo_sents=[]
    for sent,lab in zip(demo_df["sentence"][:num_shots],
                        demo_df["label"][:num_shots]):
        s=sent.strip(); demo_sents.append(s)
        demos_str+=f"\nReview: {s}{tag}{labmap[lab]}"
    q_stub=f"\nReview: {q_sent.strip()}{tag[:-1]}:"
    return f"{instr}\n{demos_str}{q_stub}", demo_sents, tgt

def classify_token(ids):
    lg = model(ids.unsqueeze(0)).logits[0,-1]
    if lg[pos_id]>lg[neg_id]:
        p=pos_id; prob=torch.softmax(lg[[neg_id,pos_id]],0)[1].item()
    else:
        p=neg_id; prob=torch.softmax(lg[[neg_id,pos_id]],0)[0].item()
    return p, prob

def run_alg1_with_query(train_ds, query, trial_idx):
    demo_df = pd.DataFrame(random.sample(list(train_ds), num_shots))
    prompt, demo_sents, tgt_id = build_prompt(
        demo_df, query["sentence"], query["label"]
    )
    enc = tok(prompt,
              return_tensors="pt",
              return_offsets_mapping=True,
              add_special_tokens=False).to(device)
    ids  = enc.input_ids[0]
    mask = enc.attention_mask[0]
    offs = enc.offset_mapping[0].tolist()

    spans=[]
    for s in demo_sents:
        cs,ce=prompt.index(s),prompt.index(s)+len(s)
        tok_s=next(i for i,(a,b) in enumerate(offs) if a<=cs<b)
        tok_e=next(i for i,(a,b) in enumerate(offs) if a<ce<=b)+1
        spans.append([tok_s,tok_e,[]])


    if loss_type=="Topic":
        tag="\nTopic:"; labmap=topic_word
    else:
        tag="\nSentiment:"; labmap=label_word
    label_spans = []
    for sent, lab in zip(demo_sents, demo_df["label"][:num_shots]):
        sent_idx = prompt.index(sent)
        char_start = prompt.find(tag+labmap[lab], sent_idx+len(sent))
        char_end = char_start + len(tag+labmap[lab])
        tokens = [i for i,(a,b) in enumerate(offs) if b>char_start and a<char_end]
        label_spans.append(tokens)

    pred_b,prob_b=classify_token(ids)
    base_ok=(pred_b==tgt_id)


    pre_eps=eps/num_shots
    delta_pre=torch.zeros_like(E[ids],requires_grad=True)
    for _ in range(steps):
        out=model(inputs_embeds=(E[ids]+delta_pre).unsqueeze(0),
                  attention_mask=mask.unsqueeze(0)).logits[0,-1]
        loss=F.cross_entropy(out.unsqueeze(0),
                             torch.tensor([tgt_id],device=device))
        loss.backward()
        with torch.no_grad():
            for s,e,posL in spans:
                delta_pre[s:e]+=alpha*delta_pre.grad[s:e]
                if delta_pre[s:e].norm()>pre_eps:
                    delta_pre[s:e]*=pre_eps/delta_pre[s:e].norm()
            delta_pre.grad.zero_()

    for s,e,posL in spans:
        norms=delta_pre[s:e].norm(dim=1)
        topk =norms.topk(min(n_edit,e-s)).indices+s
        posL.extend(topk.tolist())


    delta=torch.zeros_like(E[ids],requires_grad=True)
    trace=[]
    for _ in range(steps):
        out=model(inputs_embeds=(E[ids]+delta).unsqueeze(0),
                  attention_mask=mask.unsqueeze(0)).logits[0,-1]
        loss=F.cross_entropy(out.unsqueeze(0),
                             torch.tensor([tgt_id],device=device))
        trace.append(loss.item()); loss.backward()
        with torch.no_grad():
            for _,_,posL in spans:
                for p in posL:
                    delta[p]+=alpha*delta.grad[p]
            if delta.norm()>eps: delta.mul_(eps/delta.norm())
            delta.grad.zero_()

    ids_adv=ids.clone()
    save_dict = {
        "demo_df":    demo_df.to_dict(),
        "query":      {"sentence": query["sentence"], "label": query["label"]},
        "spans":      spans,
        "delta":      delta.detach().cpu().numpy(),
        "ice":        {},
        "label_spans": label_spans
    }
    for idx,(s,e,posL) in enumerate(spans,1):
        toks=[]
        for p in posL:
            orig=ids[p].item()
            best_id = None
            tgtv=E[orig]+delta[p]
            mask_ball=(E-E[orig]).norm(dim=1)<=eps
            cands=torch.where(mask_ball)[0]
            if cands.numel()>0:
                dists=(E[cands]-tgtv).norm(dim=1)
                best_id = int(cands[dists.topk(min(k_nn,cands.numel()),
                                              largest=False).indices[0]])
                ids_adv[p]=best_id
            toks.append({
              'pos':p,
              'orig_token':tok.convert_ids_to_tokens([orig])[0],
              'adv_token' :tok.convert_ids_to_tokens([best_id])[0] if best_id is not None else None,
              'orig_embedding':E[orig].detach().cpu().numpy(),
              'adv_embedding':E[best_id].detach().cpu().numpy() if best_id is not None else None,
              'delta':delta[p].detach().cpu().numpy()
            })
        save_dict["ice"][f"ICE_{idx}"] = {
          'pre_text':   tok.decode(ids[s:e],skip_special_tokens=True),
          'post_text':  tok.decode(ids_adv[s:e],skip_special_tokens=True),
          'token_info': toks
        }


    pred_a,prob_a = classify_token(ids_adv)
    adv_ok        = (pred_a==tgt_id)
    clean_acc     = float(base_ok)
    adv_acc       = float(adv_ok)
    drop          = clean_acc - adv_acc
    asr           = 1.0 - (adv_acc / clean_acc) if clean_acc>0 else 0.0

    print(f"\n=== Metrics Trial {trial_idx} ===")
    print(f"  clean_acc: {clean_acc:.3f}")
    print(f"  adv_acc  : {adv_acc:.3f}")
    print(f"  drop     : {drop:.3f}")
    print(f"  ASR      : {asr:.3f}")

    save_dict["metrics"] = {
        "clean_acc": clean_acc,
        "adv_acc":   adv_acc,
        "drop":      drop,
        "ASR":       asr
    }

    os.makedirs("perturb_info",exist_ok=True)
    with open(f"perturb_info/ice_deltas_trial_{trial_idx}.pkl","wb") as f:
        pickle.dump(save_dict,f)

    if verbose:
        print(f"\n=== Alg1 Trial {trial_idx} ===")
        print("loss trace:",[f"{x:.3f}" for x in trace])
        print(f"Baseline    : {tok.convert_ids_to_tokens([pred_b])[0]} (p={prob_b:.2f})","✓" if base_ok else "✗")
        print(f"AfterAttack : {tok.convert_ids_to_tokens([pred_a])[0]} (p={prob_a:.2f})","✓" if adv_ok else "✗")
        print("\n--- ICE Replacement Details ---")
        for idx,(s,e,posL) in enumerate(spans,1):
            before=tok.decode(ids[s:e],skip_special_tokens=True)
            after =tok.decode(ids_adv[s:e],skip_special_tokens=True)
            repl  =[(tok.convert_ids_to_tokens([ids[p]])[0],
                     tok.convert_ids_to_tokens([ids_adv[p]])[0]) for p in posL]
            print(f"\nICE {idx}:")
            print("  Before:",before)
            print("  After: ",after)
            print("  Replaced tokens:",repl)

    return base_ok, adv_ok


if TUNE:

    _plateau = {"best": 0.0, "count": 0}
    _asr_history = []

    def plateau_callback(study, trial):
        global _asr_history, _plateau
        val = trial.value
        _asr_history.append(val)

        if val > _plateau["best"] + 1e-3:
            _plateau["best"] = val
            _plateau["count"] = 0
        else:
            _plateau["count"] += 1
        if _plateau["count"] >= 5:
            study.stop()


        plt.figure(figsize=(10, 6))
        plt.plot(_asr_history, marker='o')
        plt.title('ASR up')
        plt.xlabel('trial number')
        plt.ylabel('(ASR)')
        plt.grid(True)
        plt.tight_layout()
        plt.savefig('asr_progress.png')


        try:
            plt.show(block=False)
            plt.pause(0.1)
        except:
            pass

        plt.close()

    def objective(trial):
        global steps, alpha
        steps = trial.suggest_categorical("steps", param_grid['steps'])
        alpha = trial.suggest_categorical("alpha", param_grid['alpha'])

        base_cnt = adv_cnt = 0
        for t, query in enumerate(queries, start=1):
            b, a = run_alg1_with_query(train_ds, query, t)
            base_cnt += b
            adv_cnt  += a

        clean_acc = base_cnt / len(queries)
        adv_acc   = adv_cnt  / len(queries)
        asr       = 1.0 - (adv_acc / clean_acc) if clean_acc > 0 else 0.0

        print(f"[Trial {trial.number:2d}] steps={steps}, alpha={alpha}, ASR={asr:.4f}, clean_acc={clean_acc:.4f}, adv_acc={adv_acc:.4f}")
        return asr


    study = optuna.create_study(direction="maximize")
    try:
        study.optimize(objective, n_trials=30, callbacks=[plateau_callback])
    except KeyboardInterrupt:
        print("\nOptimization stopped manually!")
        print("Saving current best parameters...")
    finally:

        best_params = study.best_params
        best_value = study.best_value
        print(f"Best params: {best_params}")
        print(f"Best ASR   : {best_value}")
        print("ASR progress plot saved as 'asr_progress.png'")

        with open("best_params.pkl", "wb") as f:
            pickle.dump(best_params, f)

else:
    base_cnt = adv_cnt = 0
    for t, query in enumerate(queries, start=1):
        b, a = run_alg1_with_query(train_ds, query, t)
        base_cnt += b
        adv_cnt  += a

    clean_acc = base_cnt / len(queries)
    adv_acc   = adv_cnt  / len(queries)
    asr       = 1.0 - (adv_acc / clean_acc) if clean_acc > 0 else 0.0
    print(f"Normal → clean_acc={clean_acc:.3f}, adv_acc={adv_acc:.3f}, ASR={asr:.3f}")

In [None]:
## profile

In [None]:
import os
import glob
import pickle
import numpy as np
import torch
from scipy.interpolate import interp1d


input_dir = "perturb_info"


first_path = os.path.join(input_dir, "ice_deltas_trial_1.pkl")
with open(first_path, "rb") as f:
    first = pickle.load(f)
n = len(first["ice"])  # ICE_1 … ICE_n


all_delta_norms = []
valid_trials    = []


paths = sorted(glob.glob(os.path.join(input_dir, "ice_deltas_trial_*.pkl")))
for path in paths:

    fname = os.path.basename(path)
    t = int(fname.split("_")[-1].split(".")[0])
    with open(path, "rb") as f:
        data = pickle.load(f)

    norms = []
    for i in range(1, n + 1):
        token_info = data["ice"][f"ICE_{i}"]["token_info"]

        deltas_i = torch.stack([torch.tensor(entry["delta"]) for entry in token_info])
        norms.append(deltas_i.norm(dim=1).sum().item())

    norms = np.array(norms, dtype=float)
    if np.isnan(norms).any():
        print(f"[Warning] trial {t} contain NaN, skip: {norms}")
        continue

    all_delta_norms.append(norms)
    valid_trials.append(t)

if len(all_delta_norms) == 0:
    raise RuntimeError("skip Budget Profile")

all_delta_norms = np.vstack(all_delta_norms)
print(f"use effective trials: {valid_trials}")
print("all trial ICEs of L2 norms:\n", all_delta_norms)


mean_delta_norms = all_delta_norms.mean(axis=0)
print("average ICE L2:", mean_delta_norms)


total = mean_delta_norms.sum()
if total <= 0:
    raise RuntimeError("can not normalize")
gamma_disc = mean_delta_norms / total
print("Discretize Budget Profile γ:", gamma_disc.tolist())

def to_continuous_budget(gamma, N):
    gamma  = np.array(gamma, dtype=float)
    x_orig = np.arange(1, gamma.size + 1)
    interp = interp1d(
        x_orig, gamma,
        kind="linear",
        bounds_error=False,
        fill_value="extrapolate"
    )
    x_new = np.linspace(1, gamma.size, N)
    vals  = np.clip(interp(x_new), 0, None)
    return vals / vals.sum()

custom_gamma = to_continuous_budget(gamma_disc, n)
print(f"\ncontinuous Budget Profile γ (length = {n}):\n{custom_gamma.tolist()}")
print("Sum check:", custom_gamma.sum())


In [None]:
##alg2-flat

In [None]:
import pickle
import glob
import pandas as pd
import torch
import torch.nn.functional as F

steps_local = 2

# 1) If Alg1 tuned and wrote best_params.pkl, reuse those steps & alpha
try:
    best = pickle.load(open("best_params.pkl", "rb"))
    steps_local = best["steps"]
    alpha       = best["alpha"]
    print(f"Using Alg1 best params: steps_local={steps_local}, alpha={alpha}")
except FileNotFoundError:
    pass

def word_proj(ids_orig, pos_sel, delta_vec, eps_bound, tgt_q, mask):
    ids_new = ids_orig.clone()
    for p in pos_sel:
        orig = ids_orig[p].item()
        tgtv = E[orig] + delta_vec[p]
        ball = (E - E[orig]).norm(dim=1) <= eps_bound
        cands = torch.arange(V, device=device)[ball]
        if cands.numel() == 0: continue
        dists = (E[cands] - tgtv).norm(dim=1)
        topk  = cands[dists.topk(min(k_nn,cands.numel()), largest=False).indices]
        best, best_loss = orig, -1e9
        for cid in topk:
            tmp = ids_orig.clone(); tmp[p] = cid
            loss = F.cross_entropy(
                model(tmp.unsqueeze(0), attention_mask=mask.unsqueeze(0)).logits[0,-1].unsqueeze(0),
                torch.tensor([tgt_q], device=device)
            )
            if loss.item() > best_loss:
                best_loss = loss.item(); best = cid.item()
        ids_new[p] = best
    return ids_new

def run_alg2_on(demo_df, query, spans, delta_global, label_spans, trial):
    # 1) Construct the complete prompt, including all ICEs + query
    head = ("Analyze the sentiment of the last review and respond with "
            "either positive or negative. Here are several examples.")
    tag  = "\nSentiment:"
    demos = list(zip(demo_df["sentence"], demo_df["label"]))
    prompt = head + "".join(
        f"\nReview: {s.strip()}{tag}{label_word[l]}" for s,l in demos
    ) + f"\nReview: {query['sentence'].strip()}\nSentiment:"
    enc = tok(prompt, return_tensors="pt", add_special_tokens=False).to(device)
    ids, full_mask = enc.input_ids[0], enc.attention_mask[0]

    # 2) Pre-calculate clean prediction on the full prompt
    pred_b, prob_b = classify_token(ids)
    base_ok = (pred_b == (pos_id if query["label"]==1 else neg_id))

    # 3) Sequential PGD
    gamma = [1.0]*num_shots
    gamma = [g/sum(gamma) for g in gamma]
    eps_i = [g*eps for g in gamma]
    cum_delta = torch.zeros_like(E[ids], device=device)

    # Prepare to store per-ICE metrics
    metrics_per_ice = []

    for i,(_,_,posL) in enumerate(spans):
        # mask subsequent ICE_{i+1..n}
        mask_i = full_mask.clone()
        for j in range(i+1,len(spans)):
            s_j,e_j,_ = spans[j]
            mask_i[s_j:e_j] = 0
            for lp in label_spans[j]:
                mask_i[lp] = 0

        # 3.1) Local PGD for ICE_i
        delta_loc = cum_delta.clone().detach().requires_grad_(True)
        for _ in range(steps_local):
            emb = E[ids] + delta_loc
            logits = model(inputs_embeds=emb.unsqueeze(0),
                           attention_mask=mask_i.unsqueeze(0)).logits[0,-1]
            tgt_q = pos_id if query["label"]==1 else neg_id
            loss = F.cross_entropy(logits.unsqueeze(0),
                                   torch.tensor([tgt_q],device=device))
            loss.backward()
            with torch.no_grad():
                for p in posL:
                    delta_loc[p] += alpha * delta_loc.grad[p]
                span_vec = delta_loc[posL]
                norm = span_vec.norm()
                if norm > eps_i[i]:
                    delta_loc[posL] *= eps_i[i]/norm
                delta_loc.grad.zero_()
        with torch.no_grad():
            for p in posL:
                cum_delta[p] = delta_loc[p]


        tgt_q = pos_id if query["label"]==1 else neg_id
        prefix_prompt = head + "".join(
            f"\nReview: {demos[k][0].strip()}{tag}{label_word[demos[k][1]]}"
            for k in range(i+1)
        ) + f"\nReview: {query['sentence'].strip()}\nSentiment:"
        enc_clean = tok(prefix_prompt, return_tensors="pt", add_special_tokens=False).to(device)
        ids_clean = enc_clean.input_ids[0]
        pred_c, _ = classify_token(ids_clean)
        clean_acc_i = 1.0 if pred_c==tgt_q else 0.0

        # cumulative word_proj for ICE_1…ICE_i
        ids_tmp = ids.clone()
        for j in range(i+1):
            _,_,posL_j = spans[j]
            ids_tmp = word_proj(
                ids_tmp,
                posL_j,
                cum_delta,
                eps_i[j],
                tgt_q,
                full_mask
            )
        pred_i, _ = classify_token(ids_tmp)
        adv_ok_i = (pred_i == tgt_q)

        # compute adv_acc_i, drop_i, ASR_i with clamping and correct definition
        adv_acc_i = 1.0 if adv_ok_i else 0.0
        drop_i    = max(clean_acc_i - adv_acc_i, 0.0)
        ASR_i     = 1.0 if (clean_acc_i==1.0 and adv_acc_i==0.0) else 0.0

        print(f"=== ICE {i+1} Metrics === clean_acc: {clean_acc_i:.3f}, adv_acc: {adv_acc_i:.3f}, drop: {drop_i:.3f}, ASR: {ASR_i:.3f}")

        metrics_per_ice.append({
            "clean_acc": clean_acc_i,
            "adv_acc":   adv_acc_i,
            "drop":      drop_i,
            "ASR":       ASR_i
        })

    # 4) Final full-word_proj across all spans
    ids_adv = ids.clone()
    tgt_q   = pos_id if query["label"]==1 else neg_id
    for i,(_,_,posL) in enumerate(spans):
        ids_adv = word_proj(ids_adv, posL, cum_delta, eps_i[i], tgt_q, full_mask)

    # 5) Last-ICE metrics (same definitions)
    pred_a, _ = classify_token(ids_adv)
    adv_ok    = (pred_a == tgt_q)
    clean_acc = 1.0 if base_ok else 0.0
    adv_acc   = 1.0 if adv_ok   else 0.0
    drop      = max(clean_acc - adv_acc, 0.0)
    ASR       = 1.0 if (clean_acc==1.0 and adv_acc==0.0) else 0.0

    print(f"\n=== Alg2 Trial {trial} Metrics (last ICE) ===")
    print(f" clean_acc: {clean_acc:.3f}")
    print(f" adv_acc  : {adv_acc:.3f}")
    print(f" drop     : {drop:.3f}")
    print(f" ASR      : {ASR:.3f}")

    # 6) save per-ICE metrics + texts into debug_info pkl
    #    AND record the full prompt before and after all perturbations
    full_pre  = tok.decode(ids,    skip_special_tokens=True)
    full_post = tok.decode(ids_adv, skip_special_tokens=True)

    debug_info = {
        'trial':      trial,
        'full_pre':   full_pre,
        'full_post':  full_post,
        'ice':        []
    }

    for i, (s_j, e_j, posL) in enumerate(spans):
        sent_pre  = tok.decode(ids[s_j:e_j],     skip_special_tokens=True)
        sent_post = tok.decode(ids_adv[s_j:e_j], skip_special_tokens=True)

        entry = {
            "ice_index":      i + 1,
            "pos":            posL,
            "pre_text":       tok.decode(ids[posL[0]:posL[-1]+1],     skip_special_tokens=True),
            "post_text":      tok.decode(ids_adv[posL[0]:posL[-1]+1], skip_special_tokens=True),
            "full_sent_pre":  sent_pre,
            "full_sent_post": sent_post,
            **metrics_per_ice[i]
        }
        debug_info["ice"].append(entry)

    os.makedirs("perturb_info_flat", exist_ok=True)
    with open(f"perturb_info_flat/alg2_debug_trial_{trial}.pkl","wb") as f:
        pickle.dump(debug_info,f)

    # 1) NEW: on first trial, print entire saved debug_info
    if trial == 1:
        print("=== Full debug_info for trial 1 ===")
        print(debug_info)

    return [clean_acc], [adv_acc]

# # ─────────────────────────────────────────────────────────────────────────────
# # Batch load, run Alg2, and compute averages
# # ─────────────────────────────────────────────────────────────────────────────

# ─────────────────────────────────────────────────────────────────────────────
# Batch load from Alg1 saved PKLs, run Alg2 and collect metrics
# ─────────────────────────────────────────────────────────────────────────────
base2_sum = 0.0
adv2_sum  = 0.0
all_metrics = []

orig_paths = sorted(glob.glob("perturb_info/ice_deltas_trial_*.pkl"))
for t, orig_path in enumerate(orig_paths, start=1):

    data = pickle.load(open(orig_path, "rb"))
    demo_df      = pd.DataFrame.from_dict(data["demo_df"])
    query        = data["query"]
    spans        = data["spans"]
    delta_global = torch.tensor(data["delta"], device=device)
    label_spans  = data["label_spans"]


    b_vec, a_vec = run_alg2_on(demo_df, query, spans, delta_global, label_spans, t)
    base2_sum += b_vec[0]
    adv2_sum  += a_vec[0]


    flat_path = f"perturb_info_flat/alg2_debug_trial_{t}.pkl"
    flat_data = pickle.load(open(flat_path, "rb"))
    all_metrics.append(flat_data["ice"])

num_trials = len(all_metrics)

# overall (last ICE) averages
print("\n=== Overall Alg2 (last ICE) ===")
print(f"Avg clean_acc: {base2_sum/num_trials:.3f}")
print(f"Avg adv_acc  : {adv2_sum/num_trials:.3f}")
print(f"Avg drop     : {(base2_sum-adv2_sum)/num_trials:.3f}")
print(f"Avg ASR      : {1-(adv2_sum/base2_sum) if base2_sum>0 else 0.0:.3f}")

# per-ICE averages
print("\n=== Per-ICE Average Metrics Across All Trials ===")
num_ice = len(all_metrics[0])
sums = [ {'clean_acc':0,'adv_acc':0,'drop':0,'ASR':0} for _ in range(num_ice) ]
for trial_metrics in all_metrics:
    for i, m in enumerate(trial_metrics):
        sums[i]['clean_acc'] += m['clean_acc']
        sums[i]['adv_acc']   += m['adv_acc']
        sums[i]['drop']      += m['drop']
        sums[i]['ASR']       += m['ASR']
for i, s in enumerate(sums, start=1):
    print(f"ICE {i}: clean_acc={s['clean_acc']/num_trials:.3f}, "
          f"adv_acc={s['adv_acc']/num_trials:.3f}, "
          f"drop={s['drop']/num_trials:.3f}, "
          f"ASR={s['ASR']/num_trials:.3f}")

Using Alg1 best params: steps_local=80, alpha=3
=== ICE 1 Metrics === clean_acc: 1.000, adv_acc: 1.000, drop: 0.000, ASR: 0.000
=== ICE 2 Metrics === clean_acc: 1.000, adv_acc: 1.000, drop: 0.000, ASR: 0.000
=== ICE 3 Metrics === clean_acc: 1.000, adv_acc: 1.000, drop: 0.000, ASR: 0.000
=== ICE 4 Metrics === clean_acc: 1.000, adv_acc: 1.000, drop: 0.000, ASR: 0.000

=== Alg2 Trial 1 Metrics (last ICE) ===
 clean_acc: 1.000
 adv_acc  : 1.000
 drop     : 0.000
 ASR      : 0.000
=== Full debug_info for trial 1 ===
{'trial': 1, 'full_pre': "Analyze the sentiment of the last review and respond with either positive or negative. Here are several examples.\nReview: the hallmarks\nSentiment:positive\nReview: about something , one that attempts and often achieves a level of connection and concern\nSentiment:positive\nReview: a movie to forget\nSentiment:negative\nReview: better thriller\nSentiment:positive\nReview: among the year 's most intriguing explorations of alientation .\nSentiment:", 'fu

In [None]:
#alg2-budgeted

In [None]:
import pickle
import glob
import pandas as pd
import torch
import torch.nn.functional as F

steps_local = 2

# 1) If Alg1 tuned and wrote best_params.pkl, reuse those steps & alpha
try:
    best = pickle.load(open("best_params.pkl", "rb"))
    steps_local = best["steps"]
    alpha       = best["alpha"]
    print(f"Using Alg1 best params: steps_local={steps_local}, alpha={alpha}")
except FileNotFoundError:
    pass

def word_proj(ids_orig, pos_sel, delta_vec, eps_bound, tgt_q, mask):
    ids_new = ids_orig.clone()
    for p in pos_sel:
        orig = ids_orig[p].item()
        tgtv = E[orig] + delta_vec[p]
        ball = (E - E[orig]).norm(dim=1) <= eps_bound
        cands = torch.arange(V, device=device)[ball]
        if cands.numel() == 0: continue
        dists = (E[cands] - tgtv).norm(dim=1)
        topk  = cands[dists.topk(min(k_nn,cands.numel()), largest=False).indices]
        best, best_loss = orig, -1e9
        for cid in topk:
            tmp = ids_orig.clone(); tmp[p] = cid
            loss = F.cross_entropy(
                model(tmp.unsqueeze(0), attention_mask=mask.unsqueeze(0)).logits[0,-1].unsqueeze(0),
                torch.tensor([tgt_q], device=device)
            )
            if loss.item() > best_loss:
                best_loss = loss.item(); best = cid.item()
        ids_new[p] = best
    return ids_new

def run_alg2_on(demo_df, query, spans, delta_global, label_spans, trial):
    # 1) Construct the complete prompt, including all ICEs + query
    head = ("Analyze the sentiment of the last review and respond with "
            "either positive or negative. Here are several examples.")
    tag  = "\nSentiment:"
    demos = list(zip(demo_df["sentence"], demo_df["label"]))
    prompt = head + "".join(
        f"\nReview: {s.strip()}{tag}{label_word[l]}" for s,l in demos
    ) + f"\nReview: {query['sentence'].strip()}\nSentiment:"
    enc = tok(prompt, return_tensors="pt", add_special_tokens=False).to(device)
    ids, full_mask = enc.input_ids[0], enc.attention_mask[0]

    # 2) Pre-calculate clean prediction on the full prompt
    pred_b, prob_b = classify_token(ids)
    base_ok = (pred_b == (pos_id if query["label"]==1 else neg_id))

    # 3) Sequential PGD
    # gamma = [1.0]*num_shots
    # gamma = [g/sum(gamma) for g in gamma]
    gamma = custom_gamma.tolist()
    eps_i = [g*eps for g in gamma]
    cum_delta = torch.zeros_like(E[ids], device=device)

    # Prepare to store per-ICE metrics
    metrics_per_ice = []

    for i,(_,_,posL) in enumerate(spans):
        # mask subsequent ICE_{i+1..n}
        mask_i = full_mask.clone()
        for j in range(i+1,len(spans)):
            s_j,e_j,_ = spans[j]
            mask_i[s_j:e_j] = 0
            for lp in label_spans[j]:
                mask_i[lp] = 0

        # 3.1) Local PGD for ICE_i
        delta_loc = cum_delta.clone().detach().requires_grad_(True)
        for _ in range(steps_local):
            emb = E[ids] + delta_loc
            logits = model(inputs_embeds=emb.unsqueeze(0),
                           attention_mask=mask_i.unsqueeze(0)).logits[0,-1]
            tgt_q = pos_id if query["label"]==1 else neg_id
            loss = F.cross_entropy(logits.unsqueeze(0),
                                   torch.tensor([tgt_q],device=device))
            loss.backward()
            with torch.no_grad():
                for p in posL:
                    delta_loc[p] += alpha * delta_loc.grad[p]
                span_vec = delta_loc[posL]
                norm = span_vec.norm()
                if norm > eps_i[i]:
                    delta_loc[posL] *= eps_i[i]/norm
                delta_loc.grad.zero_()
        with torch.no_grad():
            for p in posL:
                cum_delta[p] = delta_loc[p]

        # —— NEW: compute clean_acc_i on ICE_1…ICE_i + query ——
        tgt_q = pos_id if query["label"]==1 else neg_id
        prefix_prompt = head + "".join(
            f"\nReview: {demos[k][0].strip()}{tag}{label_word[demos[k][1]]}"
            for k in range(i+1)
        ) + f"\nReview: {query['sentence'].strip()}\nSentiment:"
        enc_clean = tok(prefix_prompt, return_tensors="pt", add_special_tokens=False).to(device)
        ids_clean = enc_clean.input_ids[0]
        pred_c, _ = classify_token(ids_clean)
        clean_acc_i = 1.0 if pred_c==tgt_q else 0.0

        # cumulative word_proj for ICE_1…ICE_i
        ids_tmp = ids.clone()
        for j in range(i+1):
            _,_,posL_j = spans[j]
            ids_tmp = word_proj(
                ids_tmp,
                posL_j,
                cum_delta,
                eps_i[j],
                tgt_q,
                full_mask
            )
        pred_i, _ = classify_token(ids_tmp)
        adv_ok_i = (pred_i == tgt_q)

        # compute adv_acc_i, drop_i, ASR_i with clamping and correct definition
        adv_acc_i = 1.0 if adv_ok_i else 0.0
        drop_i    = max(clean_acc_i - adv_acc_i, 0.0)
        ASR_i     = 1.0 if (clean_acc_i==1.0 and adv_acc_i==0.0) else 0.0

        print(f"=== ICE {i+1} Metrics === clean_acc: {clean_acc_i:.3f}, adv_acc: {adv_acc_i:.3f}, drop: {drop_i:.3f}, ASR: {ASR_i:.3f}")

        metrics_per_ice.append({
            "clean_acc": clean_acc_i,
            "adv_acc":   adv_acc_i,
            "drop":      drop_i,
            "ASR":       ASR_i
        })

    # 4) Final full-word_proj across all spans
    ids_adv = ids.clone()
    tgt_q   = pos_id if query["label"]==1 else neg_id
    for i,(_,_,posL) in enumerate(spans):
        ids_adv = word_proj(ids_adv, posL, cum_delta, eps_i[i], tgt_q, full_mask)

    # 5) Last-ICE metrics (same definitions)
    pred_a, _ = classify_token(ids_adv)
    adv_ok    = (pred_a == tgt_q)
    clean_acc = 1.0 if base_ok else 0.0
    adv_acc   = 1.0 if adv_ok   else 0.0
    drop      = max(clean_acc - adv_acc, 0.0)
    ASR       = 1.0 if (clean_acc==1.0 and adv_acc==0.0) else 0.0

    print(f"\n=== Alg2 Trial {trial} Metrics (last ICE) ===")
    print(f" clean_acc: {clean_acc:.3f}")
    print(f" adv_acc  : {adv_acc:.3f}")
    print(f" drop     : {drop:.3f}")
    print(f" ASR      : {ASR:.3f}")

    # 6) save per-ICE metrics + texts into debug_info pkl
    #    AND record the full prompt before and after all perturbations
    full_pre  = tok.decode(ids,    skip_special_tokens=True)
    full_post = tok.decode(ids_adv, skip_special_tokens=True)

    debug_info = {
        'trial':      trial,
        'full_pre':   full_pre,
        'full_post':  full_post,
        'ice':        []
    }

    for i, (s_j, e_j, posL) in enumerate(spans):
        sent_pre  = tok.decode(ids[s_j:e_j],     skip_special_tokens=True)
        sent_post = tok.decode(ids_adv[s_j:e_j], skip_special_tokens=True)

        entry = {
            "ice_index":      i + 1,
            "pos":            posL,
            "pre_text":       tok.decode(ids[posL[0]:posL[-1]+1],     skip_special_tokens=True),
            "post_text":      tok.decode(ids_adv[posL[0]:posL[-1]+1], skip_special_tokens=True),
            "full_sent_pre":  sent_pre,
            "full_sent_post": sent_post,
            **metrics_per_ice[i]
        }
        debug_info["ice"].append(entry)


    os.makedirs("perturb_info_budget", exist_ok=True)
    with open(f"perturb_info_budget/alg2_debug_trial_{trial}.pkl","wb") as f:
        pickle.dump(debug_info,f)


    if trial == 1:
        print("=== Full debug_info for trial 1 ===")
        print(debug_info)

    return [clean_acc], [adv_acc]

# # ─────────────────────────────────────────────────────────────────────────────
# # Batch load, run Alg2, and compute averages
# # ─────────────────────────────────────────────────────────────────────────────

# ─────────────────────────────────────────────────────────────────────────────
# Batch load from Alg1 saved PKLs, run Alg2 and collect metrics
# ─────────────────────────────────────────────────────────────────────────────
base2_sum = 0.0
adv2_sum  = 0.0
all_metrics = []

orig_paths = sorted(glob.glob("perturb_info/ice_deltas_trial_*.pkl"))
for t, orig_path in enumerate(orig_paths, start=1):

    data = pickle.load(open(orig_path, "rb"))
    demo_df      = pd.DataFrame.from_dict(data["demo_df"])
    query        = data["query"]
    spans        = data["spans"]
    delta_global = torch.tensor(data["delta"], device=device)
    label_spans  = data["label_spans"]


    b_vec, a_vec = run_alg2_on(demo_df, query, spans, delta_global, label_spans, t)
    base2_sum += b_vec[0]
    adv2_sum  += a_vec[0]


    flat_path = f"perturb_info_budget/alg2_debug_trial_{t}.pkl"
    flat_data = pickle.load(open(flat_path, "rb"))
    all_metrics.append(flat_data["ice"])

num_trials = len(all_metrics)

# overall (last ICE) averages
print("\n=== Overall Alg2 (last ICE) ===")
print(f"Avg clean_acc: {base2_sum/num_trials:.3f}")
print(f"Avg adv_acc  : {adv2_sum/num_trials:.3f}")
print(f"Avg drop     : {(base2_sum-adv2_sum)/num_trials:.3f}")
print(f"Avg ASR      : {1-(adv2_sum/base2_sum) if base2_sum>0 else 0.0:.3f}")

# per-ICE averages
print("\n=== Per-ICE Average Metrics Across All Trials ===")
num_ice = len(all_metrics[0])
sums = [ {'clean_acc':0,'adv_acc':0,'drop':0,'ASR':0} for _ in range(num_ice) ]
for trial_metrics in all_metrics:
    for i, m in enumerate(trial_metrics):
        sums[i]['clean_acc'] += m['clean_acc']
        sums[i]['adv_acc']   += m['adv_acc']
        sums[i]['drop']      += m['drop']
        sums[i]['ASR']       += m['ASR']
for i, s in enumerate(sums, start=1):
    print(f"ICE {i}: clean_acc={s['clean_acc']/num_trials:.3f}, "
          f"adv_acc={s['adv_acc']/num_trials:.3f}, "
          f"drop={s['drop']/num_trials:.3f}, "
          f"ASR={s['ASR']/num_trials:.3f}")

Using Alg1 best params: steps_local=80, alpha=3
=== ICE 1 Metrics === clean_acc: 1.000, adv_acc: 1.000, drop: 0.000, ASR: 0.000
=== ICE 2 Metrics === clean_acc: 1.000, adv_acc: 1.000, drop: 0.000, ASR: 0.000
=== ICE 3 Metrics === clean_acc: 1.000, adv_acc: 1.000, drop: 0.000, ASR: 0.000
=== ICE 4 Metrics === clean_acc: 1.000, adv_acc: 1.000, drop: 0.000, ASR: 0.000

=== Alg2 Trial 1 Metrics (last ICE) ===
 clean_acc: 1.000
 adv_acc  : 1.000
 drop     : 0.000
 ASR      : 0.000
=== Full debug_info for trial 1 ===
{'trial': 1, 'full_pre': "Analyze the sentiment of the last review and respond with either positive or negative. Here are several examples.\nReview: the hallmarks\nSentiment:positive\nReview: about something , one that attempts and often achieves a level of connection and concern\nSentiment:positive\nReview: a movie to forget\nSentiment:negative\nReview: better thriller\nSentiment:positive\nReview: among the year 's most intriguing explorations of alientation .\nSentiment:", 'fu

In [None]:
#progress

In [None]:
import os, glob, pickle, json, numpy as np

# ------------------------------------------------------------------
# configuration  – edit here if your paths or constants differ
# ------------------------------------------------------------------
ALG1_DIR        = "perturb_info"          # ice_deltas_trial_*.pkl
ALG2_FLAT_DIR   = "perturb_info_flat"     # alg2_debug_trial_*.pkl
ALG2_BUD_DIR    = "perturb_info_budget"   # new budgeted trials
DATASET_NAME    = "sst2"                  # or read from elsewhere
EPSILON         = 100
MODEL_NAME      = "OPT-30b"
N_CONTEXT       = 4                       # num_shots
N_EDIT          = 3

# ------------------------------------------------------------------
def mean_or_nan(vals):
    return float(np.mean(vals)) if vals else float("nan")

def aggregate_alg2(folder):
    overall = {"clean_ACC": [], "ADV_ACC": [], "ACC_drop": [], "ASR": []}
    per_ice = {}      # { "1": {...}, ... }

    for p in glob.glob(os.path.join(folder, "alg2_debug_trial_*.pkl")):
        dbg = pickle.load(open(p, "rb"))["ice"]       # list per ICE
        last = dbg[-1]                                # use last ICE for overall

        # overall metrics
        overall["clean_ACC"].append(last["clean_acc"])
        overall["ADV_ACC"].append(last["adv_acc"])
        overall["ACC_drop"].append(last["drop"])
        overall["ASR"].append(last["ASR"])

        # per-ICE metrics
        for entry in dbg:
            idx = str(entry["ice_index"])
            per_ice.setdefault(idx, {"clean_ACC": [], "ADV_ACC": [],
                                      "ACC_drop": [], "ASR": []})
            per_ice[idx]["clean_ACC"].append(entry["clean_acc"])
            per_ice[idx]["ADV_ACC"].append(entry["adv_acc"])
            per_ice[idx]["ACC_drop"].append(entry["drop"])
            per_ice[idx]["ASR"].append(entry["ASR"])

    overall_mean = {k: mean_or_nan(v) for k, v in overall.items()}
    per_ice_mean = {i: {k: mean_or_nan(v) for k, v in d.items()}
                    for i, d in per_ice.items()}
    return overall_mean, per_ice_mean

# ------------------------------------------------------------------
# Alg-1 aggregation
# ------------------------------------------------------------------
alg1_raw = {"clean_ACC": [], "ADV_ACC": [], "ACC_drop": [], "ASR": []}
for p in glob.glob(os.path.join(ALG1_DIR, "ice_deltas_trial_*.pkl")):
    meta = pickle.load(open(p, "rb")).get("metrics", {})
    for k_csv, k_pkl in [("clean_ACC", "clean_acc"),
                         ("ADV_ACC",   "adv_acc"),
                         ("ACC_drop",  "drop"),
                         ("ASR",       "ASR")]:
        if k_pkl in meta:
            alg1_raw[k_csv].append(meta[k_pkl])
alg1_summary = {k: mean_or_nan(v) for k, v in alg1_raw.items()}

# ------------------------------------------------------------------
# Alg-2-flat & Alg-2-budget aggregation
# ------------------------------------------------------------------
flat_overall, flat_per_ice   = aggregate_alg2(ALG2_FLAT_DIR)
bud_overall,  bud_per_ice    = aggregate_alg2(ALG2_BUD_DIR)

# ------------------------------------------------------------------
# build final JSON object
# ------------------------------------------------------------------
summary = {
    "epsilon": EPSILON,
    "model":   MODEL_NAME,
    "n":       N_CONTEXT,
    "alg1":    alg1_summary,
    "alg2-flat": {
        "overall": flat_overall,
        "per_ICE": flat_per_ice
    },
    "alg2-budget": {
        "overall": bud_overall,
        "per_ICE": bud_per_ice
    }
}

# automatic file name
json_name = f"{DATASET_NAME}_{MODEL_NAME.replace(' ','-')}" \
            f"_eps{EPSILON}_n{N_CONTEXT}_edit{N_EDIT}.json"

with open(json_name, "w") as f:
    json.dump(summary, f, indent=2)
print(f"→ written {json_name}")