In [1]:
import numpy as np
from datasets import load_dataset
from matplotlib.colors import LinearSegmentedColormap, rgb2hex
from IPython.display import display, HTML

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
grads = np.load('results/gradients_70B_quant.npz')
tokens = np.load('results/generated_tokens_70B_quant.npz')
harmful_data = load_dataset('jailbreakBench/JBB-Behaviors', 'behaviors')

In [3]:
all_gradients = []
for example_index in range(100):
    all_gradients.extend(grads[f'example_{example_index}'])
global_min = np.min(all_gradients)
global_max = np.max(all_gradients)


cmap = LinearSegmentedColormap.from_list('rg', ["r", "w", "g"], N=256)

for example_index in range(100):
    gradient_values = grads[f'example_{example_index}']
    normalized_gradients = [
        grad / global_max if grad > 0 else grad / -global_min if grad < 0 else grad
        for grad in gradient_values
    ]
    token_values = tokens[f'example_{example_index}']

    prompt = harmful_data['harmful']['Goal'][example_index]

    html = f'<div style="margin-bottom: 20px;"><strong>Example {example_index + 1}:</strong> {prompt}</div>'
    html += '<div style="display: flex; flex-wrap: wrap;">'
    
    for token, grad in zip(token_values, normalized_gradients):
        hex_color = rgb2hex(cmap((grad + 1) / 2)[:3])
        unnormalized_grad = grad * global_max if grad > 0 else grad * -global_min if grad < 0 else grad        
        html += f'''
            <div style="background-color: {hex_color}; color: black; padding: 5px; margin: 0; border: none; 
            display: inline-block; text-align: center; box-sizing: border-box; cursor: pointer;" 
            title="Unnormalized gradient: {unnormalized_grad}, Normalized gradient: {grad}">
                {token}
            </div>
        '''
        
    html += '</div>'
    display(HTML(html))