# Detoxification of the 'hh-rlhf_eval' dataset

Since, as stated in the report for M2, we found out that many datapoints of the 'hh-rlhf_eval' are severely toxic, in this notebook we detoxify the processed dataset. To do so, we use the 'Detoxify' model and we set as a "toxic threshold" 0.1/1 (since we want to be sure to avoid introducing toxic datapoints at all costs).

In [None]:
!pip install detoxify

In [None]:
import json
from detoxify import Detoxify
import pandas as pd
from tqdm import tqdm

# Substitute this name with the training split to detoxify the training dataset
json_to_detoxify = 'hh-rlhf_eval'
json_detoxified = json_to_detoxify + '_detoxified.json'
json_toxic = json_to_detoxify + '_toxic.json'
json_to_detoxify = json_to_detoxify + '.json'

count = 0
count_tox = 0

# Load the JSON file
with open(json_to_detoxify) as file:
    data = json.load(file)

# Initialize Detoxify
model = Detoxify('original')

# This list will hold non-toxic data points
non_toxic_data = []

# This list will hold toxic data points
toxic_data = []

# Loop over the data points in the loaded JSON data
for datapoint in tqdm(data):
    is_toxic = False
    # Loop over the entries in each data point
    for entry in datapoint.values():
        # If the entry is a string
        if isinstance(entry, str):
            # Check if the entry is toxic
            results = model.predict(entry)
            if results['toxicity'] > 0.1:  # The threshold can be adjusted according to the "toxicity level"
                is_toxic = True
                break
    
    # If none of the entries was toxic, add the data point to the non_toxic_data list
    if not is_toxic:
        count += 1
        non_toxic_data.append(datapoint)
    else:
        count_tox += 1
        toxic_data.append(datapoint)

# Write non-toxic data points to a new JSON file
with open(json_detoxified, 'w') as file:
    json.dump(non_toxic_data, file)

# Write toxic data points to a new JSON file
with open(json_toxic, 'w') as file:
    json.dump(toxic_data, file)

print('Number of non-toxic data points: ', count)
print('Number of toxic data points: ', count_tox)