## Context-Setting

### Step 1
In the first notebook, we collected prompts from various sources, and finally created a column called `safe_prompt`. 

### Step 2
In the second notebook, we generated responses from a target model to create on-policy training data. This was stored in a column called `generated_response`.

### Step 3
In this notebook, we use a safety judge model to judge whether the response is safe/unsafe, and use it to filter out the unsafe responses (because we don't want to train on those).

In [5]:
import pandas as pd
import os
from tqdm import tqdm
import json
from typing import Tuple, List
from transformers import AutoTokenizer
from pandarallel import pandarallel

from utils import moderate_chat_with_nim

NOTE: The `moderate_chat_with_nim` function assumes you have an Aegis v2 judge server running.

### Initialize Helper Functions

In [6]:
class Args:
    prompt_format_type = "aegis_v2"
    model_name = "aegis_v2"
    host = "localhost"
    port = 8002
    adapter_path = "permissive"
    prompt_cats_type = "catlistexpanded"  # fixed
    response_format_type = "json"  # fixed

    def __init__(self, prompt_format_type, model_name, nim_host, nim_port, adapter_path=None):
        self.prompt_format_type = prompt_format_type
        self.model_name = model_name
        self.host = nim_host
        self.port = nim_port
        self.adapter_path = adapter_path
        
inference_args = Args("aegis_v2", "aegis_v2", "localhost", 8002)

In [7]:

def process_reasoning_traces(results_df):
    results_df["starts_with_think"] = results_df["generated_response"].str.startswith(
        "<think>")
    results_df["contains_end_of_thought"] = results_df["generated_response"].str.contains(
        "</think>")
    results_df.loc[results_df["contains_end_of_thought"], "generated_response_stripped"] = results_df["generated_response"].str.split(
        "</think>").str[1].str.strip()
    results_df.loc[~results_df["contains_end_of_thought"], "generated_response_stripped"] = results_df["generated_response"]
    return results_df


def _parse_aegis_v2_response(response: str) -> Tuple[str, List[str]]:
    try:
        # try parsing it as json
        parsed_json_result = json.loads(response)
        # if "Response Safety" in parsed_json_result:
        response_safety = parsed_json_result["Response Safety"].lower()
        if "Safety Categories" in parsed_json_result and response_safety == "unsafe":
            safety_categories = [cat.strip() for cat in parsed_json_result["Safety Categories"].split(",")]
        else:
            safety_categories = []
    except Exception as e:
        # If there is an error, and we can't parse the response, we return unsafe
        # assuming this is a potential jailbreaking attempt
        response_safety = "error: parsing failed"
        safety_categories = ["JSON parsing failed"]
    return response_safety, safety_categories


def parse_aegis_v2_response(response: str) -> str:
    result, categories = _parse_aegis_v2_response(response)
    return result


def parse_aegis_v2_response_categories(response: str) -> List[str]:
    result, categories = _parse_aegis_v2_response(response)
    return categories

In [8]:
# Initialize pandarallel
pandarallel.initialize(nb_workers=100, progress_bar=True)
tqdm.pandas()

INFO: Pandarallel will run on 100 workers.
INFO: Pandarallel will use Memory file system to transfer data between the main process and workers.


In [None]:
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.1-8B-Instruct")

### Load Step 2 Data

In [None]:
nemo_safety_blend_df = pd.read_json(f"{os.getcwd()}/nemo_safety_blend_v0.2.2_step_2_with_generated_responses.jsonl", lines=True)

In [None]:
# If the model is a reasoning model, we need to process the thinking traces, 
# we want the safety model to judge the actual response without the thinking traces.
is_reasoning_model = True  ## Set this
if is_reasoning_model:
    nemo_safety_blend_df = process_reasoning_traces(nemo_safety_blend_df)

In [None]:
# Initialize pandarallel
pandarallel.initialize(nb_workers=100, progress_bar=True)

prompt_column = "prompt"
if is_reasoning_model:
    response_column = "generated_response_stripped"
else:
    response_column = "generated_response"


In [None]:
# Generate safety predictions for responses without thinking traces
nemo_safety_blend_df["safety_judge_raw_generation"] = nemo_safety_blend_df.parallel_apply(
    lambda x: moderate_chat_with_nim(x[prompt_column], x[response_column], tokenizer=tokenizer, args=inference_args), axis=1
)
nemo_safety_blend_df["safety_judge_prediction"] = nemo_safety_blend_df["safety_judge_raw_generation"].apply(parse_aegis_v2_response)
nemo_safety_blend_df["safety_judge_prediction_categories"] = nemo_safety_blend_df["safety_judge_raw_generation"].apply(parse_aegis_v2_response_categories)

In [None]:
# Filter out unsafe responses
nemo_safety_blend_df[nemo_safety_blend_df["safety_judge_prediction"] == "safe"]

### Save down the final Curated Training Data

In [None]:
## prompt and generated response are the only columns we need, rest of the columns are just for reference or pivoting/categorizing results
columns_to_keep = ["prompt", "generated_response", "id", "source_l1", "source_l2", "categories", "subcategories", "adversarial_persona", "prompt_label"]

In [None]:
nemo_safety_blend_df.to_json(f"{os.getcwd()}/nemo_safety_blend_v0.2.2.jsonl", lines=True)