In [1]:
import os
from collections import defaultdict

import matplotlib.pyplot as plt
import matrix_fact
import numpy as np
import pandas as pd
import seaborn as sns
from joblib import Parallel, delayed

from datasets.dataset import AdvBehaviorsConfig, Dataset

In [2]:
d = Dataset.from_name("adv_behaviors")(
    AdvBehaviorsConfig(
        "adv_behaviors",
        messages_path="/nfs/staff-ssd/beyer/llm-quick-check/data/behavior_datasets/harmbench_behaviors_text_all.csv",
        targets_path="/nfs/staff-ssd/beyer/llm-quick-check/data/optimizer_targets/harmbench_targets_text.json",
        seed=0,
        categories=[
            "chemical_biological",
            "illegal",
            "misinformation_disinformation",
            "harmful",
            "harassment_bullying",
            "cybercrime_intrusion",
        ],
    )
)
n_prompts = 20
judged_prompts = [p[0]['content'] for p in d][:n_prompts]

In [3]:
def process_file(path):
    try:
        data = pd.read_json(path)
        return [data.iloc[i] for i in range(len(data))]
    except ValueError as e:
        raise ValueError(f'Error in {path} with {e}')

data_path = '../outputs'

paths = []
for root, dirs, files in os.walk(data_path):
    for file in files:
        if file.endswith('run.json'):
            paths.append(os.path.join(root, file))

paths = sorted(paths, reverse=True)

runs = Parallel(n_jobs=16)(delayed(process_file)(path) for path in paths)
attack_runs = [r for run in runs for r in run]


print(f'Found {len(attack_runs)} runs.')
print(f'In total {sum(len(run["attacks"]) for run in attack_runs)} prompt attacks.')

Found 17619 runs.
In total 21876 prompt attacks.


# Remove Duplicates (?)

keep only most recent

In [4]:
existing = set()
n_duplicates = 0

attack_runs_new = []
for run in attack_runs:
    config = run['config']
    model = config['model']
    dataset = config['dataset']
    if 'successes_cais' not in run or isinstance(run['successes_cais'], float):
        continue

    for attack, loss, prompt, completion, success in zip(run['attacks'], run['losses'], run['prompts'], run['completions'], run['successes_cais']):
        key = (prompt['content'], config['attack'], model)
        if key not in existing:
            existing.add(key)
            attack_runs_new.append(run)
        else:
            n_duplicates += 1
print(f'Found {len(attack_runs_new)} unique runs out of {len(attack_runs)}.')
attack_runs = attack_runs_new


Found 20246 unique runs out of 17619.


In [None]:
prompts = []
for run in attack_runs:
    prompts.extend([p['content'] for p in run["prompts"]])
prompts = set(prompts)

# for each prompt, we will store all attacks that applied to it
indexed_attack_runs: dict[str, list[dict]] = defaultdict(list)
models = set()
attacks = set()
n_unjudged = 0
for run in attack_runs:
    config = run['config']
    model = config['model']
    if 'vicuna' in model:
        continue
    models.add(model)
    dataset = config['dataset']
    attacks.add(config['attack'])
    if 'successes_cais' not in run or isinstance(run['successes_cais'], float):
        print(f"WARNING: Skipping unjudged run: {run['config']['model'], run['config']['attack'], run['config']['dataset'], run['config']['dataset_params']['idx']}")
        n_unjudged += len(run['attacks'])
        continue
    for attack, loss, prompt, completion, success in zip(run['attacks'], run['losses'], run['prompts'], run['completions'], run['successes_cais']):
        indexed_attack_runs[prompt['content']].append({
            'config': config,
            'model': model,
            'attack': attack,
            'dataset': dataset,
            'loss': loss,
            'completion': completion,
            'success': [s == 'Yes' for s in success]
        })

models = sorted(list(models))
attacks = sorted(list(attacks))
print(f'Found {len(indexed_attack_runs)} prompts.')
print(f'Total judged attacks: {sum([len(v) for v in indexed_attack_runs.values()])}, unjudged attacks: {n_unjudged}.')

1 human_jailbreaks
1 human_jailbreaks
1 human_jailbreaks
1 human_jailbreaks
1 human_jailbreaks
1 human_jailbreaks
1 human_jailbreaks
1 human_jailbreaks
1 human_jailbreaks
1 human_jailbreaks
1 human_jailbreaks
1 human_jailbreaks
1 human_jailbreaks
1 human_jailbreaks
1 human_jailbreaks
1 human_jailbreaks
1 human_jailbreaks
1 human_jailbreaks
1 human_jailbreaks
1 human_jailbreaks
1 autodan
1 autodan
1 autodan
1 autodan
1 ample_gcg
1 ample_gcg
1 ample_gcg
1 ample_gcg
1 ample_gcg
1 ample_gcg
1 ample_gcg
1 ample_gcg
1 ample_gcg
1 ample_gcg
1 ample_gcg
1 ample_gcg
1 ample_gcg
1 ample_gcg
1 ample_gcg
1 ample_gcg
1 ample_gcg
1 ample_gcg
1 ample_gcg
1 ample_gcg
1 ample_gcg
1 ample_gcg
1 ample_gcg
1 ample_gcg
1 ample_gcg
1 ample_gcg
1 ample_gcg
1 ample_gcg
1 ample_gcg
1 ample_gcg
1 ample_gcg
1 ample_gcg
1 ample_gcg
1 ample_gcg
1 ample_gcg
1 ample_gcg
1 ample_gcg
1 ample_gcg
1 ample_gcg
1 ample_gcg
1 ample_gcg
1 ample_gcg
1 ample_gcg
1 ample_gcg
1 ample_gcg
1 ample_gcg
1 ample_gcg
1 ample_gcg
1 am

## Visualize

In [6]:
from functools import lru_cache

@lru_cache
def get_color_for_run(model_type: str, attack_type: str):
    # Define a base color for each model (7 options)

    model_colors = {
        k: v
        for k, v in zip(
            set([run["config"]["model"] for run in attack_runs]),
            (
                (255, 0, 0),
                (255, 135, 0),
                (255, 211, 0),
                (222, 255, 10),
                (161, 255, 10),
                (10, 255, 153),
                (10, 239, 255),
                (20, 125, 245),
                (88, 10, 255),
                (190, 10, 255),
                (255, 10, 235),
                (255, 10, 60),
                (255, 10, 10),
                (10, 10, 255),
            ),
        )
    }

    # Define a variation for each attack type
    attack_variation = {
        "autodan": 0.4,
        "human_jailbreaks": 0.5,
        "ample_gcg": 0.6,
        "pgd": 0.7,
        "pair": 0.8,
        "pgd_one_hot": 0.85,
        "direct": 0.9,
        "gcg": 1.0,
    }

    # Get the base color for the model
    base_color = np.array(model_colors[model_type])

    # Adjust the color based on the attack type (by scaling the brightness)
    adjusted_color = np.clip(base_color * attack_variation[attack_type], 0, 255) / 255.0

    return adjusted_color


def ragged_mean(sequences):
    # Find the maximum length of the sequences
    max_len = max(len(seq) for seq in sequences)

    # Pad sequences with NaN to make them the same length
    padded_sequences = np.full((len(sequences), max_len), np.nan)

    for i, seq in enumerate(sequences):
        padded_sequences[i, : len(seq)] = seq

    return np.nanmean(padded_sequences, axis=0)


def filter_runs(runs, model_name=None, attack_type=None):
    print(len(runs))
    filtered_runs = []
    for run in runs:
        if model_name is None:
            pass
        elif (
            isinstance(model_name, str) and model_name.lower() in run["config"]["model"].lower()
        ):
            pass
        elif isinstance(model_name, list) and any(
            m.lower() in run["config"]["model"].lower() for m in model_name
        ):
            pass
        else:
            continue
        if attack_type is None:
            pass
        elif (
            isinstance(attack_type, str)
            and attack_type.lower() == run["config"]["attack"].lower()
        ):
            pass
        elif isinstance(attack_type, list) and any(
            a.lower() == run["config"]["attack"].lower() for a in attack_type
        ):
            pass
        else:
            continue
        filtered_runs.append(run)
    return filtered_runs

def filter_runs(runs, model_name=None, attack_type=None, exact_match=False):
    def matches(value, target):
        comp_func = str.__eq__ if exact_match else str.__contains__
        if target is None:
            return True
        if isinstance(target, str):
            return comp_func(value.lower(), target.lower())
        return any(comp_func(value.lower(), t.lower()) for t in target)

    filtered_runs = [
        run for run in runs
        if matches(run["config"]["model"], model_name) and matches(run["config"]["attack"], attack_type)
    ]
    return filtered_runs

In [None]:
def rank_by_worst_asr(prompts, model_name: None|str|list[str] = None):
    attack_type = ['gcg', 'ample_gcg', 'human_jailbreaks', 'direct', 'pair']
    if not isinstance(prompts, list):
        prompts = [prompts]
    n = 0
    successful_attacks = {}
    for prompt in prompts:
        prompt_runs = filter_runs(indexed_attack_runs[prompt], model_name, attack_type)
        for run in prompt_runs:
            key = (prompt, run['config']['model'], run['config']['attack'])
            s = any(success for success in run['success'])
            successful_attacks[

    y = []
    labels = []
    for key, values in to_plot.items():
        model, attack = key
        n += len(values)
        if attack in ('ample_gcg', 'human_jailbreaks', 'pair'):
            y.append(ragged_mean(values)[0])
        else:
            y.append(ragged_mean(values)[0]),
        labels.append(f"{model} | {attack}")
    y, labels = zip(*sorted(zip(y, labels), reverse=False))
    return y, labels

rank_by_worst_asr(judged_prompts, model_name=None)

20
200
4
14
6
20
6
20
20
20
20
20
20
112
112
112
112
200
112
200
200
200
200
200
112
112
200
112
200
112
112
112
112
113
3
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1

ValueError: not enough values to unpack (expected 2, got 0)