In [1]:
import os
import pandas as pd
import matplotlib.pyplot as plt

# Define the directory containing the datasets
data_dir = "../../data/"
dataset_path = os.path.join(data_dir, "analyzed/catHarmQA/combined_catqa.csv")

data = pd.read_csv(dataset_path)
data.shape

(136400, 17)

In [4]:
data.columns

Index(['category', 'subcategory', 'original_question',
       'original_question_safety', 'original_response',
       'original_response_safety', 'perturbation_level', 'perturbation_type',
       'perturbation_count', 'perturbed_question', 'perturbed_question_safety',
       'model', 'perturbed_response', 'perturbed_response_safety',
       'experiment', 'original_response_pre', 'perturbed_response_pre'],
      dtype='object')

Text in "orignal_response" column start with question which from "original_question" columns and It's not needed, so we have to remove that question. so, we check if question from "original_question" match with start portion of "orignal_response" then remove that question from "orignal_response" column. same will happen with "perturbed_response" column which have question from "perturbed_question" column. remember that data variable is pandas dataframe and we can access columns by their names.

In [252]:
def handle_eage_case(content):
    # edge cases
    for condition in [" ' ", " '", " !", " ,", " ?", " .", " '"]:
        if condition in content:
            content = content.replace(condition, condition.strip())
        # content = re.sub(r"\s*'\s*", "'", content)

    return content


# Function to remove question from response
def remove_question_from_response(row, question_col, response_col, handle_edge_case_condition=False):
    question = row[question_col] if not handle_edge_case_condition else handle_eage_case(row[question_col])
    response = row[response_col] if not handle_edge_case_condition else handle_eage_case(row[response_col])
    if response.startswith(question):
        return response[len(question) :].strip()
    else:
        if handle_edge_case_condition and not response.startswith(question):
            print(f"{row.name}th row")
            print(f"question: {question}")
            print(f"response: {response}", f"\n{'-'*120}\n")      
            return response
        else:
            return remove_question_from_response(row, question_col, response_col, True)

In [253]:
# Apply the function to both original and perturbed columns
data["original_response_pre"] = data.apply(
    lambda row: remove_question_from_response(
        row, "original_question", "original_response"
    ),
    axis=1,
)
data["perturbed_response_pre"] = data.apply(
    lambda row: remove_question_from_response(
        row, "perturbed_question", "perturbed_response"
    ),
    axis=1,
)

Now let's test if prefix question is removed from "original_response" and "perturbed_response" columns. we will do that by coding. we will take difference between "original_response_pre" and "original_response", it will output text from "original_response" which is not in "original_response_pre" and different will be equal to "original_question". so perform this check over all rows and return count of False. we will do same for "perturbed_response_pre" and "perturbed_response" columns.

In [254]:
def find_non_common_substring(A, B):
    index = A.find(B)
    if index == -1:
        return A
    else:
        return A[:index]

# Function to check if the prefix question is removed
def check_prefix_removed(row, question_col, response_pre_col, response_col, handle_edge_case_condition=False):
    question = row[question_col] if not handle_edge_case_condition else handle_eage_case(row[question_col])
    response = row[response_col] if not handle_edge_case_condition else handle_eage_case(row[response_col])
    response_pre = row[response_pre_col]

    non_common_substring = find_non_common_substring(response, response_pre).strip()

    if non_common_substring.strip() == question.strip():
        return True
    else:
        if handle_edge_case_condition and non_common_substring.strip() != question.strip():
            # print(f"{row.name}th row")
            # print(f"question: {question}")
            # print(f"response: {response}")
            # print(f"resp_pre: {response_pre}", f"\n{'-'*120}\n")
            return question.__contains__(non_common_substring)
        else:
            return check_prefix_removed(row, question_col, response_pre_col, response_col, True)

In [255]:
# Apply the check function to both original and perturbed columns
original_check = data.apply(
    lambda row: check_prefix_removed(
        row, "original_question", "original_response_pre", "original_response"
    ),
    axis=1,
)
perturbed_check = data.apply(
    lambda row: check_prefix_removed(
        row, "perturbed_question", "perturbed_response_pre", "perturbed_response"
    ),
    axis=1,
)

In [256]:
# Check if all rows satisfy `original_question + original_response_pre == original_response`
original_check.all()

True

In [257]:
# Check if all rows satisfy `perturbed_question + perturbed_response_pre == perturbed_response`
perturbed_check.all()

True

In [258]:
perturbed_false_count = len(perturbed_check) - perturbed_check.sum()
print(f"Count of False for perturbed_response: {perturbed_false_count}")

Count of False for perturbed_response: 0


In [259]:
perturbed_check[perturbed_check == False]

Series([], dtype: bool)

In [260]:
print(
    f"perturbed_response_pre == perturbed_response: \n{perturbed_check[perturbed_check == False].index.to_series().apply(lambda x: data.loc[x, "perturbed_response_pre"] == data.loc[x, "perturbed_response"]).value_counts()}"
)

perturbed_response_pre == perturbed_response: 
Series([], Name: count, dtype: int64)


In [261]:
perturbed_check[perturbed_check == False].index.to_series().apply(lambda x: data.loc[x, "model"]).value_counts()

Series([], Name: count, dtype: int64)

In [262]:
for index in perturbed_check[perturbed_check == False].index:
    print("------ ", data.loc[index, "perturbed_question"], "\n")
    print("------ ", data.loc[index, "perturbed_response"], "\n")
    print("------ ", data.loc[index, "perturbed_response_pre"], "\n")
    print(data.loc[index, "perturbed_response_pre"] == data.loc[index, "perturbed_response"], "\n")
    print(f"{'-'*120}\n")

In [263]:
data.iloc[139]['perturbed_response'].startswith(data.iloc[139]['perturbed_question'])

True

In [264]:
data.columns

Index(['category', 'subcategory', 'original_question',
       'original_question_safety', 'original_response',
       'original_response_safety', 'perturbation_level', 'perturbation_type',
       'perturbation_count', 'perturbed_question', 'perturbed_question_safety',
       'model', 'perturbed_response', 'perturbed_response_safety',
       'experiment', 'original_response_pre', 'perturbed_response_pre'],
      dtype='object')

In [267]:
# data.to_csv(dataset_path, index=False)

In [5]:
import torch
import sys

sys.path.append("../")

from config import envs, credentials
from transformers import AutoTokenizer, AutoModelForCausalLM

device = "cuda" if torch.cuda.is_available() else "cpu"

In [6]:
def moderate(model, tokenizer, texts):
    try:
        output_texts = []
        for id, text in enumerate(texts):

            chat = [{"role": "user", "content": text}]
            input_ids = tokenizer.apply_chat_template(chat, return_tensors="pt").to(
                device
            )
            output = model.generate(
                input_ids=input_ids, max_new_tokens=100, pad_token_id=0
            )
            prompt_len = input_ids.shape[-1]
            output_text = tokenizer.decode(
                output[0][prompt_len:], skip_special_tokens=True
            )
            output_texts.append(output_text)
        return output_texts
    except Exception as e:
        print(f"Error generating text: {str(e)}")
        return None

In [None]:
model_id = "meta-llama/Llama-Guard-3-8B"

tokenizer = AutoTokenizer.from_pretrained(
    pretrained_model_name_or_path=model_id,
    token=credentials.HF_TOKEN,
    cache_dir="/Users/saurabh/AA/convergent/projects/llm-sensitivity/.env/models",
)

model = AutoModelForCausalLM.from_pretrained(
    pretrained_model_name_or_path=model_id,
    token=credentials.HF_TOKEN,
    cache_dir="/Users/saurabh/AA/convergent/projects/llm-sensitivity/.env/models",
).to(device)

question_col_list = ['original_response_pre', 'perturbed_response_pre']

for question_col in question_col_list:
    questions = data[question_col].to_list()

    output_texts = moderate(model, tokenizer, questions)
    print(f"Storing response in dataframe")
    new_col_name = f"{question_col}_safety"
    data.insert(
        data.columns.get_loc(question_col) + 1,
        new_col_name,
        output_texts,
    )

    data.to_csv(dataset_path, index=False)

tokenizer_config.json:   0%|          | 0.00/52.0k [00:00<?, ?B/s]

0.00s - make the debugger miss breakpoints. Please pass -Xfrozen_modules=off
0.00s - to python to disable frozen modules.
0.00s - Note: Debugging will proceed. Set PYDEVD_DISABLE_FILE_VALIDATION=1 to disable this validation.


tokenizer.json:   0%|          | 0.00/9.08M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/73.0 [00:00<?, ?B/s]

ERROR:tornado.general:SEND Error: Host unreachable


config.json:   0%|          | 0.00/860 [00:00<?, ?B/s]

model.safetensors.index.json:   0%|          | 0.00/23.9k [00:00<?, ?B/s]

Downloading shards:   0%|          | 0/4 [00:00<?, ?it/s]

model-00001-of-00004.safetensors:   0%|          | 0.00/4.98G [00:00<?, ?B/s]

model-00002-of-00004.safetensors:   0%|          | 0.00/5.00G [00:00<?, ?B/s]