In [None]:
import json
import torch
import pandas as pd
from itertools import chain
from collections import Counter
import matplotlib.pyplot as plt
from transformers import AutoTokenizer

In [None]:
MODEL_NAME = "meta-llama/Llama-3.2-1B-Instruct"
PROMPT_END_TOKEN = ")ÄŠ_317"
PROMPT_END_POSITION = 317   # positions starts from 1
BINS_WIDTH = 5000

In [None]:
# reading the json file
with open('./results/all_shapley.json', 'r') as f:
    data = json.load(f)

In [None]:
# getting the adversarial tokens
adv_tokens = []
for run in data.values():
    tokens = list(run['tokens'].values())[PROMPT_END_POSITION:] # considering only the values in the range 
    for token in tokens:
        token_id = token['token_id']
        adv_tokens.append(token_id)

adv_tokens = torch.tensor(adv_tokens)
print(f"Total number of tokens: {adv_tokens.shape[0]} ")

In [None]:
plt.figure(figsize=(18, 6))
plt.hist(
    adv_tokens,
    bins=range(0, int(max(adv_tokens)) + BINS_WIDTH, BINS_WIDTH),
    edgecolor='black',
    density=True
)
plt.xlabel('Token ID')
plt.ylabel('Density')
plt.title('Distribution of Token IDs in the Adversarial Suffixes')
plt.show()



In [None]:
# shapley value distribution
tokens_shap_relevances = torch.zeros_like(adv_tokens).to(torch.float)

i = 0
for run in data.values():
    tokens = list(run['tokens'].values())[PROMPT_END_POSITION:] # considering only the values in the range 

    for token in tokens:
        tokens_shap_relevances[i] = token['shapley_value']
        i += 1

In [None]:
plt.figure(figsize=(16, 6))
plt.hist(
    adv_tokens,
    bins=range(0, int(max(adv_tokens)) + BINS_WIDTH, BINS_WIDTH),
    weights=tokens_shap_relevances,
    edgecolor="black",
    density=True
)
plt.xlabel('Token ID')
plt.ylabel('Density')
plt.title('Distribution of Token IDs in the Adversarial Suffixes weighted by their Shapley Value')
plt.show()

In [None]:
normal_run_info = pd.read_csv("./suffix_results/normal_run.csv")
max_iterations = normal_run_info.groupby('run')['iteration'].max()
import matplotlib.pyplot as plt
import seaborn as sns

plt.figure(figsize=(18, 6))
sns.countplot(x=max_iterations, order=range(1, 71))
plt.title('Frequency of Iterations Needed - Normal execution')
plt.xlabel('Number of Iterations')
plt.ylabel('Frequency')
plt.xticks(ticks=range(0, 70), labels=range(1, 71))
plt.axvline(x=max_iterations.mean() - 1, color='red', linestyle='--', label=f'Mean: {max_iterations.mean():.2f}')
plt.axvline(x=max_iterations.median() - 1, color='blue', linestyle='--', label=f'Median: {max_iterations.median():.2f}')
plt.legend()
plt.show()

print(f"Mean: {max_iterations.mean():.2f}")
print(f"Median: {max_iterations.median():.2f}")
print(f"Std: {max_iterations.std():.2f}")


In [None]:
safe_run_info = pd.read_csv("./suffix_results/safe_run.csv")
max_iterations = safe_run_info.groupby('run')['iteration'].max()
import matplotlib.pyplot as plt
import seaborn as sns

plt.figure(figsize=(18, 6))
sns.countplot(x=max_iterations, order=range(1, 71))
plt.title('Frequency of Iterations Needed - Guilty Tokens Removed')
plt.xlabel('Number of Iterations')
plt.ylabel('Frequency')
plt.xticks(ticks=range(0, 70), labels=range(1, 71))
plt.axvline(x=max_iterations.mean() - 1, color='red', linestyle='--', label=f'Mean: {max_iterations.mean():.2f}')
plt.axvline(x=max_iterations.median() - 1, color='blue', linestyle='--', label=f'Median: {max_iterations.median():.2f}')
plt.legend()
plt.show()

print(f"Mean: {max_iterations.mean():.2f}")
print(f"Median: {max_iterations.median():.2f}")
print(f"Std: {max_iterations.std():.2f}")

In [None]:
normal_run_info = pd.read_csv("./suffix_results/unsafe_run.csv")
max_iterations = normal_run_info.groupby('run')['iteration'].max()
import matplotlib.pyplot as plt
import seaborn as sns

plt.figure(figsize=(18, 6))
sns.countplot(x=max_iterations, order=range(1, 71))
plt.title('Frequency of Iterations Needed - Only Guilty Tokens')
plt.xlabel('Number of Iterations')
plt.ylabel('Frequency')
plt.xticks(ticks=range(0, 70), labels=range(1, 71))
plt.axvline(x=max_iterations.mean() - 1, color='red', linestyle='--', label=f'Mean: {max_iterations.mean():.2f}')
plt.axvline(x=max_iterations.median() - 1, color='blue', linestyle='--', label=f'Median: {max_iterations.median():.2f}')
plt.legend()
plt.show()

print(f"Mean: {max_iterations.mean():.2f}")
print(f"Median: {max_iterations.median():.2f}")
print(f"Std: {max_iterations.std():.2f}")
