In [None]:
import pickle
import fairness_metrics
import pandas as pd
from tqdm import tqdm
import utils

In [None]:
ideology_df = pd.read_csv('./data/processed_annotated_comments.csv')
ideology_df['label'] = ideology_df['label'].apply(lambda x: None if x not in ['left', 'right'] else x)
ideology_df.dropna(inplace=True)

main_df = pd.read_csv('./data/jigsaw/main.csv')
dfs = [ideology_df, main_df] 
names = ['ideology', 'jigsaw']

In [None]:
# load preprocessed results
with open('./results/moderation_results.pkl', 'rb') as file:
    fairness_results = pickle.load(file)


gold1 = {k:v['openai'] for k,v in fairness_results.items()}
gold2 = {k:v['perspective'] for k,v in fairness_results.items()}
gold3 = {k:v['google'] for k,v in fairness_results.items() if v['google'] in [True, False]}
gold4 = {k:v['clarifai'] for k,v in fairness_results.items()}

gold1 = {k: 1 if v is True else 0 for k,v in gold1.items()}
gold2 = {k: 1 if v is True else 0 for k,v in gold2.items()}
gold3 = {k:1 if v is True else 0 for k,v in gold3.items()}
gold4 = {k:1 if v is True else 0 for k,v in gold4.items()}

In [None]:
def compute_percentage_change(main_list, list1):
    indices_unsafe = [i for i, x in tqdm(enumerate(main_list)) if x == 'UNSAFE']
    indices_safe = [i for i, x in tqdm(enumerate(main_list)) if x == 'SAFE']
    indices_map = {'UNSAFE': indices_unsafe, 'SAFE': indices_safe}

    category_map = {
                    'UNSAFE': {'UNSAFE': [], 'SAFE': []},
                    'SAFE': {'UNSAFE': [], 'SAFE': []}
                    }
    for k, v in category_map.items():
        for subk, subv in v.items():
            category_map[k][subk] = [l for idx, l in tqdm(enumerate(list1)) if idx in indices_map[k] and l == subk]
            category_map[k][subk] = len(category_map[k][subk])/len(indices_map[k])

    return category_map

In [None]:
golds = [gold1, gold2, gold3, gold4]

methods = ['openai', 'perspective','google', 'clarifai']
perturbations = ['german', 'gpt_3.5_turbo']

results = []
global_results = {}

perturbation_map = {}
with open('./results/comments_backtranslated_german_similarity.pkl', 'rb') as handle:
        perturbation_map[perturbations[0]] = pickle.load(handle)

with open('./results/comment_paraphrased_gpt-3.5_final.pkl', 'rb') as handle:
        perturbation_map[perturbations[1]] = pickle.load(handle)

for method in methods:
    global_results[method] = {}

for perturbation in perturbations:
    with open(f'./results/moderation_results_fairness_perturbed_{perturbation}.pkl', 'rb') as handle:
        fairness_results = pickle.load(handle)
    
    phrase_map = perturbation_map[perturbation]
    if perturbation == "german":
        phrase_map = {k:v['augmented'] for k,v in phrase_map.items() if v['score'] > 0.85 and v['score'] != 1.0}

    for gold, method in zip(golds, methods):
        local_gold = gold.copy()
        
        # check to ignore NULL values for phrases where moderation did not run
        fairness_results = {k:v for k,v in fairness_results.items() if fairness_results[k][method] in [True, False]}
        
        data = {k:int(fairness_results[v][method]) for k,v in phrase_map.items() if v in fairness_results.keys()}

        local_gold = {k:v for k,v in local_gold.items() if k in data.keys() and v in [True, False]}

        # create lists and compute robustness
        global_results[method][perturbation] = {}
        for subset, name in zip(dfs, names):
            df = subset.copy()
            df = df[df['text'].isin(list(local_gold.keys()))]
            a = ['UNSAFE' if local_gold[k] == 1 else 'SAFE' for k in df['text'].tolist()]
            b = ['UNSAFE' if data[k] == 1 else 'SAFE' for k in df['text'].tolist()]
            change_map = compute_percentage_change(a, b)
            global_results[method][perturbation][name] = change_map


In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np

labels = ['UNSAFE', 'SAFE']
titles = ['OpenAI', 'Perspective', 'PaLM2', 'Clarifai']

data=[v['gpt_3.5_turbo']['jigsaw'] for k,v in global_results.items()]
data = [data[0], data[3], data[1], data[2]]


heatmap_data = [np.zeros((2,2)) for _ in range(len(data))]

global_max = -np.inf
global_min = np.inf

for i, row in enumerate(data):
    for j, col in enumerate(row.values()):
        for k, val in enumerate(col.values()):
            heatmap_data[i][j, k] = val
            if val > global_max:
                global_max = val
            if val < global_min:
                global_min = val

fig, axes = plt.subplots(1, len(data), figsize=(15,3), sharey=True)
fig.subplots_adjust(wspace=0.01)
for i, ax in enumerate(axes):
    sns.heatmap(heatmap_data[i], annot=True, fmt=".2f", cmap="crest", xticklabels=labels, yticklabels=labels, ax=ax, annot_kws={"size":16}, vmin=global_min, vmax=global_max, cbar=False)
    ax.set_title(titles[i], fontsize=20)
    ax.xaxis.set_tick_params(labelsize=20)
    ax.yaxis.set_tick_params(labelsize=20)
    plt.setp(ax.get_xticklabels(), rotation=90, ha="right")

cbar = fig.colorbar(axes[0].collections[0], ax=axes, orientation='vertical')
cbar.set_label('Change (%)')
cbar.ax.tick_params(labelsize=12)
cbar.ax.yaxis.label.set_size(12)

plt.show()