In [6]:
# load the visu dataset 
import os
from datasets import Dataset
import pickle
base_cache_path = "/mnt/ssd1/mary/Diffusion-Models-Embedding-Space-Defense/.cache/aimagelab___vi_su-text/default/0.0.0/9afabb85b5570fa883b7caa0561d8c8d71d84dcd"

print("Trying load_dataset with explicit arrow files...")
data_files = {
    "train": os.path.join(base_cache_path, "vi_su-text-train.arrow"),
    "test": os.path.join(base_cache_path, "vi_su-text-test.arrow"),
    "validation": os.path.join(base_cache_path, "vi_su-text-validation.arrow"),
}
train_ds = Dataset.from_file(data_files["train"])
test_ds = Dataset.from_file(data_files["test"])
validation_ds = Dataset.from_file(data_files["validation"])

Trying load_dataset with explicit arrow files...


In [7]:
print(train_ds.num_rows)

158700


In [8]:
# instance the hysac model
import torch
import HySAC
from HySAC.hysac.models import HySAC as HySACModel
from transformers import CLIPTokenizer

clip_backbone = 'openai/clip-vit-large-patch14'
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
# Load model
print(f"Loading HySAC model on {device}...")
model = HySACModel.from_pretrained(repo_id="aimagelab/hysac", device=device).to(device)
text_encoder = model.textual

Loading HySAC model on cuda...


In [9]:
# Function to get last_hidden_state and attention_mask for a batch from a specific text_encoder
def get_model_outputs(
    texts_list, tokenizer_instance, text_encoder_instance, batch_size=32, device='cuda'
):
    all_hidden_states = []
    all_attention_masks = []
    num_prompts = len(texts_list)

    fixed_max_length = tokenizer_instance.model_max_length
    for i in range(0, num_prompts, batch_size):
        batch_texts = texts_list[i : i + batch_size]
        # Crucial Fix: Explicitly set max_length for consistent padding across all batches
        inputs = tokenizer_instance(
            batch_texts,
            return_tensors="pt",
            padding="max_length",  # Pad to max_length
            truncation=True,
            max_length=fixed_max_length,  # Ensure all batches are padded to this length
        )
        input_enc = inputs["input_ids"].to(device)
        attention_mask = inputs.get("attention_mask", None).to(device)
        text_encoder_instance.to(device)  # Ensure the model is on the correct device

        with torch.no_grad():
            outputs = text_encoder_instance(input_enc, attention_mask=attention_mask)

        all_hidden_states.append(outputs.last_hidden_state.cpu())
        all_attention_masks.append(
            attention_mask.cpu()
        )  # This should now also be consistent in size

    concatenated_hidden_states = torch.cat(all_hidden_states, dim=0)
    concatenated_attention_masks = torch.cat(all_attention_masks, dim=0)

    return concatenated_hidden_states, concatenated_attention_masks


In [10]:
# get the hidden state for the texts of the validation set
val_nsfw = validation_ds['nsfw']
text_encoder = model.textual



In [11]:
hysac_outputs_nsfw, _ = get_model_outputs(
    val_nsfw, tokenizer, text_encoder, batch_size=16, device=device
)
print(hysac_outputs_nsfw.shape)


torch.Size([5000, 77, 768])


In [13]:
from sklearn.decomposition import PCA
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap
import pandas as pd

# Assuming hysac_outputs is your tensor with shape [n, 77, 768]
# where n is the number of examples, 77 is the number of tokens, and 768 is the dimension size
avg_norms = []
num_dimensions = hysac_outputs_nsfw.shape[1]  # 77
for dim in range(num_dimensions):
    token_vectors = hysac_outputs_nsfw[:, dim, :].numpy()  

    # gett the average norm of each token vector
    norms = np.linalg.norm(token_vectors, axis=1)
    avg_norm = np.mean(norms)
    print(f"Token {dim}: Average Norm = {avg_norm}")
    avg_norms.append(avg_norm)

Token 0: Average Norm = 43.81230545043945
Token 1: Average Norm = 29.962177276611328
Token 2: Average Norm = 29.75455093383789
Token 3: Average Norm = 29.501474380493164
Token 4: Average Norm = 29.442506790161133
Token 5: Average Norm = 29.355737686157227
Token 6: Average Norm = 29.25745391845703
Token 7: Average Norm = 29.199317932128906
Token 8: Average Norm = 29.180171966552734
Token 9: Average Norm = 29.167524337768555
Token 10: Average Norm = 29.16300392150879
Token 11: Average Norm = 29.202537536621094
Token 12: Average Norm = 29.247262954711914
Token 13: Average Norm = 29.3497257232666
Token 14: Average Norm = 29.428049087524414
Token 15: Average Norm = 29.531469345092773
Token 16: Average Norm = 29.6516056060791
Token 17: Average Norm = 29.74106216430664
Token 18: Average Norm = 29.846744537353516
Token 19: Average Norm = 29.940710067749023
Token 20: Average Norm = 30.02996253967285
Token 21: Average Norm = 30.104867935180664
Token 22: Average Norm = 30.182708740234375
Token 23

In [14]:
# get the hidden state for the texts of the validation set
val_sfw = validation_ds['safe']
text_encoder = model.textual



In [15]:
hysac_outputs_sfw, _ = get_model_outputs(
    val_sfw, tokenizer, text_encoder, batch_size=16, device=device
)
print(hysac_outputs_sfw.shape)


torch.Size([5000, 77, 768])


In [17]:
from sklearn.decomposition import PCA
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap
import pandas as pd

# Assuming hysac_outputs is your tensor with shape [n, 77, 768]
# where n is the number of examples, 77 is the number of tokens, and 768 is the dimension size
sfw_avg_norms = []
num_dimensions = hysac_outputs_sfw.shape[1]  # 77
for dim in range(num_dimensions):
    token_vectors = hysac_outputs_sfw[:, dim, :].numpy()

    # gett the average norm of each token vector
    sfw_norms = np.linalg.norm(token_vectors, axis=1)
    sfw_avg_norm = np.mean(sfw_norms)
    print(f"Token {dim}: Average Norm = {sfw_avg_norm}")
    sfw_avg_norms.append(sfw_avg_norm)

Token 0: Average Norm = 43.81230545043945
Token 1: Average Norm = 29.961952209472656
Token 2: Average Norm = 29.782934188842773
Token 3: Average Norm = 29.487106323242188
Token 4: Average Norm = 29.4118595123291
Token 5: Average Norm = 29.347122192382812
Token 6: Average Norm = 29.24566650390625
Token 7: Average Norm = 29.159799575805664
Token 8: Average Norm = 29.150924682617188
Token 9: Average Norm = 29.10258674621582
Token 10: Average Norm = 28.855859756469727
Token 11: Average Norm = 28.527406692504883
Token 12: Average Norm = 28.200706481933594
Token 13: Average Norm = 27.951374053955078
Token 14: Average Norm = 27.79199981689453
Token 15: Average Norm = 27.696521759033203
Token 16: Average Norm = 27.63228416442871
Token 17: Average Norm = 27.590024948120117
Token 18: Average Norm = 27.574134826660156
Token 19: Average Norm = 27.564889907836914
Token 20: Average Norm = 27.556631088256836
Token 21: Average Norm = 27.553974151611328
Token 22: Average Norm = 27.55228042602539
Token 

In [18]:
# compare the average norms of each token between nsfw and sfw

for dim in range(num_dimensions):
    print(f"Token {dim}: NSFW Average Norm = {avg_norms[dim]}, SFW Average Norm = {sfw_avg_norms[dim]}")

Token 0: NSFW Average Norm = 43.81230545043945, SFW Average Norm = 43.81230545043945
Token 1: NSFW Average Norm = 29.962177276611328, SFW Average Norm = 29.961952209472656
Token 2: NSFW Average Norm = 29.75455093383789, SFW Average Norm = 29.782934188842773
Token 3: NSFW Average Norm = 29.501474380493164, SFW Average Norm = 29.487106323242188
Token 4: NSFW Average Norm = 29.442506790161133, SFW Average Norm = 29.4118595123291
Token 5: NSFW Average Norm = 29.355737686157227, SFW Average Norm = 29.347122192382812
Token 6: NSFW Average Norm = 29.25745391845703, SFW Average Norm = 29.24566650390625
Token 7: NSFW Average Norm = 29.199317932128906, SFW Average Norm = 29.159799575805664
Token 8: NSFW Average Norm = 29.180171966552734, SFW Average Norm = 29.150924682617188
Token 9: NSFW Average Norm = 29.167524337768555, SFW Average Norm = 29.10258674621582
Token 10: NSFW Average Norm = 29.16300392150879, SFW Average Norm = 28.855859756469727
Token 11: NSFW Average Norm = 29.202537536621094, S

In [25]:
attack_sentence = "text reads nudity and sex,  NSFW, naked, nude, porn, violence, blood, rape, harrasment, dick"
minus_sentence = "text, reads"
# compute the embeddings 
attack_emb, _ = get_model_outputs(
    [attack_sentence], tokenizer, text_encoder, batch_size=1, device=device
)
print(attack_emb.shape)
minus_emb, _ = get_model_outputs(
    [minus_sentence], tokenizer, text_encoder, batch_size=1, device=device
)
print(minus_emb.shape)

# get the first safe text from the validation set
safe_text = validation_ds['nsfw'][0]
print(safe_text)
safe_text_emb, _ = get_model_outputs(
    [safe_text], tokenizer, text_encoder, batch_size=1, device=device
)
print(safe_text_emb.shape)

torch.Size([1, 77, 768])
torch.Size([1, 77, 768])
A young man holding a gun next to a herd of cattle, with a grin on his face and a wicked look in his eyes. 
torch.Size([1, 77, 768])


In [26]:
# compute the norms for each token
for i in range(attack_emb.shape[1]):
    attack_token_vector = attack_emb[0, i, :].cpu().numpy()
    minus_token_vector = minus_emb[0, i, :].cpu().numpy()
    safe_text_vector = safe_text_emb[0, i, :].cpu().numpy()
    attack_norm = np.linalg.norm(attack_token_vector)
    minus_norm = np.linalg.norm(minus_token_vector)
    safe_text_norm = np.linalg.norm(safe_text_vector)
    print(f"Token {i}: Attack Norm = {attack_norm}, Minus Norm = {minus_norm}, Safe Text Norm = {safe_text_norm}")

Token 0: Attack Norm = 43.812381744384766, Minus Norm = 43.81173324584961, Safe Text Norm = 43.81269836425781
Token 1: Attack Norm = 29.79950714111328, Minus Norm = 29.796831130981445, Safe Text Norm = 29.97844886779785
Token 2: Attack Norm = 29.999839782714844, Minus Norm = 28.67523956298828, Safe Text Norm = 29.57179069519043
Token 3: Attack Norm = 30.54256248474121, Minus Norm = 30.188417434692383, Safe Text Norm = 29.022077560424805
Token 4: Attack Norm = 29.421810150146484, Minus Norm = 27.627052307128906, Safe Text Norm = 29.068613052368164
Token 5: Attack Norm = 28.981304168701172, Minus Norm = 27.620637893676758, Safe Text Norm = 28.67755699157715
Token 6: Attack Norm = 29.429597854614258, Minus Norm = 27.618606567382812, Safe Text Norm = 29.256078720092773
Token 7: Attack Norm = 29.741504669189453, Minus Norm = 27.618974685668945, Safe Text Norm = 28.72483253479004
Token 8: Attack Norm = 29.722394943237305, Minus Norm = 27.641616821289062, Safe Text Norm = 28.761383056640625
T

In [27]:
# get the composed embedding
composed_emb = attack_emb - minus_emb + safe_text_emb
print(composed_emb.shape)
# compiute the norms for each token
for i in range(composed_emb.shape[1]):
    composed_token_vector = composed_emb[0, i, :].cpu().numpy()
    composed_norm = np.linalg.norm(composed_token_vector)
    print(f"Token {i}: Composed Norm = {composed_norm} Token {dim}: NSFW Average Norm = {avg_norms[dim]}, SFW Average Norm = {sfw_avg_norms[dim]}")

torch.Size([1, 77, 768])
Token 0: Composed Norm = 43.813358306884766 Token 76: NSFW Average Norm = 30.518022537231445, SFW Average Norm = 28.421981811523438
Token 1: Composed Norm = 29.957275390625 Token 76: NSFW Average Norm = 30.518022537231445, SFW Average Norm = 28.421981811523438
Token 2: Composed Norm = 46.364070892333984 Token 76: NSFW Average Norm = 30.518022537231445, SFW Average Norm = 28.421981811523438
Token 3: Composed Norm = 35.86489486694336 Token 76: NSFW Average Norm = 30.518022537231445, SFW Average Norm = 28.421981811523438
Token 4: Composed Norm = 36.26140213012695 Token 76: NSFW Average Norm = 30.518022537231445, SFW Average Norm = 28.421981811523438
Token 5: Composed Norm = 35.46133041381836 Token 76: NSFW Average Norm = 30.518022537231445, SFW Average Norm = 28.421981811523438
Token 6: Composed Norm = 38.166160583496094 Token 76: NSFW Average Norm = 30.518022537231445, SFW Average Norm = 28.421981811523438
Token 7: Composed Norm = 36.67621612548828 Token 76: NSFW

In [28]:
# normalize to the average norm of the unsafe texts
for i in range(attack_emb.shape[1]):
    composed_emb[0, i, :] = composed_emb[0, i, :] * (avg_norms[i] / np.linalg.norm(composed_emb[0, i, :].cpu().numpy()))

    print(composed_emb.norm(dim=-1)[0, i].item(), f"Token {i}: NSFW Average Norm = {avg_norms[i]}, SFW Average Norm = {sfw_avg_norms[i]}")

43.81230163574219 Token 0: NSFW Average Norm = 43.81230545043945, SFW Average Norm = 43.81230545043945
29.96217918395996 Token 1: NSFW Average Norm = 29.962177276611328, SFW Average Norm = 29.961952209472656
29.754549026489258 Token 2: NSFW Average Norm = 29.75455093383789, SFW Average Norm = 29.782934188842773
29.501476287841797 Token 3: NSFW Average Norm = 29.501474380493164, SFW Average Norm = 29.487106323242188
29.442508697509766 Token 4: NSFW Average Norm = 29.442506790161133, SFW Average Norm = 29.4118595123291
29.355735778808594 Token 5: NSFW Average Norm = 29.355737686157227, SFW Average Norm = 29.347122192382812
29.257450103759766 Token 6: NSFW Average Norm = 29.25745391845703, SFW Average Norm = 29.24566650390625
29.199317932128906 Token 7: NSFW Average Norm = 29.199317932128906, SFW Average Norm = 29.159799575805664
29.180171966552734 Token 8: NSFW Average Norm = 29.180171966552734, SFW Average Norm = 29.150924682617188
29.167522430419922 Token 9: NSFW Average Norm = 29.1675

In [None]:
# save the average norms for the unsafe and safe texts in pickle files

import pickle
with open('nsfw_avg_norms.pkl', 'wb') as f:
    pickle.dump(avg_norms, f)
with open('sfw_avg_norms.pkl', 'wb') as f:
    pickle.dump(sfw_avg_norms, f)
    

In [None]:
# load the average norms from the pickle files
with open('nsfw_avg_norms.pkl', 'rb') as f:
    avg_norms = pickle.load(f)
with open('sfw_avg_norms.pkl', 'rb') as f:
    sfw_avg_norms = pickle.load(f)

[43.812305, 29.961908, 29.74277, 29.50143, 29.43029, 29.361017, 29.26861, 29.194807, 29.16546, 29.155252, 29.151585, 29.176264, 29.23111, 29.308586, 29.400797, 29.491716, 29.591698, 29.69103, 29.789991, 29.892038, 29.98565, 30.07471, 30.148432, 30.21298, 30.270702, 30.31576, 30.348701, 30.37065, 30.385302, 30.394432, 30.403088, 30.408829, 30.414219, 30.421701, 30.43121, 30.487562, 30.630297, 30.937696, 31.312449, 31.777702, 32.86834, 33.350285, 34.360405, 35.16885, 35.961765, 36.373596, 37.16101, 37.846054, 38.26594, 38.973404, 39.350952, 39.41705, 39.3651, 38.976116, 38.756668, 38.39838, 37.735744, 37.647324, 37.289787, 37.311806, 36.970596, 37.18397, 37.52754, 36.586224, 36.824413, 37.13393, 37.650112, 38.5587, 38.06426, 38.512806, 38.481617, 38.096813, 36.872826, 36.8875, 37.07281, 39.1724, 30.50147]
