In [None]:
import warnings
import os
# os.chdir(f'/mnt/primary/exposure-fairness/')
import config
warnings.filterwarnings('ignore')
import pyterrier as pt
if not pt.started():
    pt.init()

import pandas as pd
pd.set_option('display.max_columns', None)
pd.set_option('display.max_rows', None)
pd.set_option('display.max_colwidth', 100)
import matplotlib.pyplot as plt
from matplotlib.pyplot import MultipleLocator
import fair_utils, topical_eval

In [None]:
def compute_all_groupings_gini(modelname, kmeans_vec):
    trec_file_path = f'{config.data_dir}/{modelname}_{config.dataset_name}_{config.topics_name}_{config.retrieve_num}.res'
    all_groupings = []
    for granu in config.num_clusters:
        print(f'computing {granu} --> {modelname}')
        clustered_queries_path = f'{config.prog_dir}/grouped_queries/clustered_dev_queries_by_{granu}_{kmeans_vec}.csv'
        topical_df = topical_eval.mapping_to_cluster(trec_file_path, clustered_queries_path)
        min_gini, mean_gini, max_gini = topical_eval.calc_topical_gini(topical_df)
        all_groupings.append([modelname, granu, min_gini, mean_gini, max_gini])

    print(f'save all ginis of {modelname}')
    result_df = pd.DataFrame(all_groupings, columns=['modelname', 'granu', 'min_gini', 'mean_gini', 'max_gini'])
    
    result_csv_path = f'{config.data_dir}/{modelname}_all_groupings_gini_{kmeans_vec}.csv'
    print(f'saving into {result_csv_path}')
    result_df.to_csv(result_csv_path, index=False)
    print('done')

    return result_df

In [None]:
def sub_plot(plt,title,col, ylabel, kmeans_vec=None):
    bm25 = compute_all_groupings_gini("bm25", kmeans_vec)
    splade = compute_all_groupings_gini("splade", kmeans_vec)
    tctcolbert = compute_all_groupings_gini("tctcolbert", kmeans_vec)
    bm25_tctcolbert = compute_all_groupings_gini("bm25_tctcolbert", kmeans_vec)
    bm25_monot5 = compute_all_groupings_gini("bm25_monot5", kmeans_vec)

    x = [1,2,3,4,5]
    y = bm25[col]
    line1, = plt.plot(x,y,label='BM25', marker='.', markersize=12)
    y = splade[col]
    line2, = plt.plot(x,y,label='SPLADE', marker='.', markersize=12)
    y = tctcolbert[col]
    line3, = plt.plot(x,y,label='TCT-ColBERT', marker='.', markersize=12)
    y = bm25_tctcolbert[col]
    line4, = plt.plot(x,y,label='BM25>>TCT-ColBERT', marker='.', markersize=12)
    y = bm25_monot5[col]
    line5, = plt.plot(x,y,label='BM25>>Mono-T5', marker='.', markersize=12)

    plt.set_title(title, fontsize=16)
    plt.set_xlabel(r'#Groups ($K$)', fontsize=16)
    
    plt.set_ylabel(ylabel, fontsize=16)
    plt.set_xticks(ticks = x, labels = ['500', '1000', '2000', '5000', '10000'], fontsize=16)
    return line1, line2,line3,line4,line5

In [None]:
kmeans_vec = 'scikit_dense'

fig, (ax0,ax1,ax2) = plt.subplots(nrows=1, ncols=3, figsize=(15, 4))
plt.subplots_adjust(hspace=1, wspace=0.2, left=0, right=1, bottom=0)

col = 'min_gini'
ylabel = 'Minimum Gini'
ax0.xaxis.set_major_locator(MultipleLocator(1))
ax0.yaxis.set_major_locator(MultipleLocator(0.01))
# ax0.set_xlim(1,5)
ax0.set_ylim(0.155,0.23)
g0 = sub_plot(ax0,None,col, ylabel,kmeans_vec=kmeans_vec)

col = 'mean_gini'
ylabel = 'Average Gini'
ax1.xaxis.set_major_locator(MultipleLocator(1))
ax1.yaxis.set_major_locator(MultipleLocator(0.01))
# ax1.set_xlim(1,5)
ax1.set_ylim(0.26,0.33)
g1 = sub_plot(ax1,None,col, ylabel,kmeans_vec=kmeans_vec)

col = 'max_gini'
ylabel = 'Maximum Gini'
ax2.xaxis.set_major_locator(MultipleLocator(1))
ax2.yaxis.set_major_locator(MultipleLocator(0.05))
# ax2.set_xlim(1,5)
ax2.set_ylim(0.48,0.77)
g2 = sub_plot(ax2,None,col, ylabel,kmeans_vec=kmeans_vec)

# g = list(g0 + g1 + g2)
fig.legend(handles=g0, loc='upper center', bbox_to_anchor=(0.5, 1.12), ncol=10, fontsize=16)
plt.savefig(f'{config.prog_dir}/aggr_gini_{kmeans_vec}.pdf', format="pdf", bbox_inches="tight", pad_inches=0)
plt.show()

In [None]:
kmeans_vec = 'scikit_tfidf'

fig, (ax0,ax1,ax2) = plt.subplots(nrows=1, ncols=3, figsize=(15, 4))
plt.subplots_adjust(hspace=1, wspace=0.2, left=0, right=1, bottom=0)

col = 'min_gini'
ylabel = 'Minimum Gini'
ax0.xaxis.set_major_locator(MultipleLocator(1))
ax0.yaxis.set_major_locator(MultipleLocator(0.01))
# ax0.set_xlim(1,5)
ax0.set_ylim(0.16,0.19)
g0 = sub_plot(ax0,None,col, ylabel,kmeans_vec=kmeans_vec)

col = 'mean_gini'
ylabel = 'Average Gini'
ax1.xaxis.set_major_locator(MultipleLocator(1))
ax1.yaxis.set_major_locator(MultipleLocator(0.03))
# ax1.set_xlim(1,5)
ax1.set_ylim(0.265,0.33)
g1 = sub_plot(ax1,None,col, ylabel,kmeans_vec=kmeans_vec)

col = 'max_gini'
ylabel = 'Maximum Gini'
ax2.xaxis.set_major_locator(MultipleLocator(1))
ax2.yaxis.set_major_locator(MultipleLocator(0.05))
# ax2.set_xlim(1,5)
ax2.set_ylim(0.48,0.80)
g2 = sub_plot(ax2,None,col, ylabel,kmeans_vec=kmeans_vec)

# g = list(g0 + g1 + g2)
fig.legend(handles=g0, loc='upper center', bbox_to_anchor=(0.5, 1.12), ncol=10, fontsize=16)
plt.savefig(f'{config.prog_dir}/aggr_gini_{kmeans_vec}.pdf', format="pdf", bbox_inches="tight", pad_inches=0)
plt.show()

In [None]:
axes = [ax0, ax1, ax2]
names = ["min", "mean", "max"]
handles_list = [g0, g1, g2]

fixed_size = (5, 4)
for ax, name, handles in zip(axes, names, handles_list):
    fig_single, ax_single = plt.subplots(figsize=fixed_size)

    for line in ax.get_lines():
        ax_single.plot(line.get_xdata(), line.get_ydata(),label=line.get_label(),color=line.get_color(),linestyle=line.get_linestyle(), marker='.', markersize=12)

    ax_single.set_xlim(ax.get_xlim())
    ax_single.set_ylim(ax.get_ylim())
    ax_single.autoscale(enable=False)

    x = [1,2,3,4,5]
    ax_single.set_xticks(ticks = x, labels = ['500', '1000', '2000', '5000', '10000'], fontsize=16)
    ax_single.xaxis.set_major_locator(ax.xaxis.get_major_locator())
    ax_single.yaxis.set_major_locator(ax.yaxis.get_major_locator())
    ax_single.set_xlabel(ax.get_xlabel(), fontsize=16)
    ax_single.set_ylabel(ax.get_ylabel())

    ax_single.legend(handles=handles, loc='best', fontsize=10,frameon=True, facecolor="white", framealpha=0.8)

    fig_single.savefig(
        f"{config.prog_dir}/aggr_gini_{name}_{kmeans_vec}.pdf",
        format="pdf",
        bbox_inches='tight',
        pad_inches=0.2
    )
    plt.close(fig_single)
