In [None]:
import pandas as pd
import seaborn as sns
import os
import matplotlib.pyplot as plt
from matplotlib.ticker import FormatStrFormatter

os.chdir('/workspace')
sns.set_style('darkgrid')

In [None]:
model_names_list = [
    'RN50', 
    'RN50x4', 
    'ViT-B-32'
]
num_members_list = ['top75', 'top50', 'top25', 'top10', 'top1']

# Num Training Sample Plots

In [None]:
# load the sampled predictions
subsampled_dfs_per_model = {}
for model_name in model_names_list:
    subsampled_dfs_per_num_members = {}
    for num_members in num_members_list:
        subsampled_dfs_per_num_members[num_members] = pd.read_csv(f'cc3m_experiments/prediction_dfs/sampled_predictions_CC2M_{model_name}_{num_members}.csv', index_col=0)
    subsampled_dfs_per_model[model_name] = subsampled_dfs_per_num_members

In [None]:
rows_per_model = {}
for model_name in model_names_list:
    rows = []
    for num_members in num_members_list:
        df = subsampled_dfs_per_model[model_name][num_members]
        df['Number of Training Samples per Person'] = int(num_members.replace("top", ""))
        df['Accuracy'] = (df['tp'] + df['tn']) / (df['tp'] + df['tn'] + df['fp'] + df['fn'])
        # get the last group (30 attack samples) to calculate mean and std
        rows.append(df.groupby('Number of Samples').get_group(df.groupby('Number of Samples').last().iloc[-1].name).set_index('Number of Training Samples per Person'))
    rows_per_model[model_name] = pd.concat(rows).rename(columns={
        'True Positive Rate': 'TPR', 
        'False Negative Rate': 'FNR', 
        'False Positive Rate': 'FPR', 
        'True Negative Rate': 'TNR',
        'Accuracy': 'Acc'
    })
    display(rows_per_model[model_name].tail(3))

In [None]:
for model_name in model_names_list:
    plt.clf()
    display(rows_per_model[model_name][['TPR', 'FNR', 'FPR', 'TNR', 'Acc']].groupby('Number of Training Samples per Person').mean())
    ax = sns.lineplot(data=rows_per_model[model_name][['TPR', 'FNR', 'FPR', 'TNR', 'Acc']], ci='sd')
    ax.set_xticks([int(x.replace("top", "")) for x in num_members_list][::-1])

    ax.set_xlabel("Number of Training Samples per Entity", weight="bold", size=16)
    ax.set_xticklabels(ax.get_xticks(), size=16)

    # remove the legend from the other plots
    if model_name != model_names_list[0]:
        ax.legend_.remove()
        ax.set(yticklabels=[" " for x in ax.get_yticklabels()], ylabel=" ")
        ax.set_yticklabels([x.get_text() for x in ax.get_yticklabels()], weight="bold", size=16)
    else:
        h, l = ax.get_legend_handles_labels()
        ax.set_yticklabels(ax.get_yticks(), size=16)
        ax.legend(h, l, ncol=2, loc='lower center', fontsize=16)
        ax.yaxis.set_major_formatter(FormatStrFormatter('%.2f'))
    plt.tight_layout()
    ax.get_figure().savefig(f'./cc3m_experiments/plots/num_training_samples_plot_CC2M_{model_name}.pdf')
    ax.get_figure().savefig(f'./cc3m_experiments/plots/num_training_samples_plot_CC2M_{model_name}.png', dpi=100)
    print(model_name)
    plt.show()  

# Heatmap

In [None]:
# load the sampled predictions
prediction_dfs_per_model = {}
for model_name in model_names_list:
    prediction_dfs_per_num_members = {}
    for num_members in num_members_list:
        prediction_dfs_per_num_members[num_members] = pd.read_csv(f'cc3m_experiments/prediction_dfs/sampled_predictions_CC2M_{model_name}_{num_members}.csv', index_col=0)
    prediction_dfs_per_model[model_name] = prediction_dfs_per_num_members

In [None]:
prediction_dfs_per_model['ViT-B-32']['top75']

In [None]:
combined_df_per_model = {}
for model_name in model_names_list:
    combined_df = []
    for num_members, df in prediction_dfs_per_model[model_name].items():
        new_df = df.groupby("Number of Samples").mean()
        new_df['Number of Samples'] = new_df.index
        new_df['Number of Training Samples'] = int(num_members.replace("top", ""))
        combined_df.append(new_df)
    combined_df_per_model[model_name] = pd.concat(combined_df)

In [None]:
for model_name in model_names_list:
    df = combined_df_per_model[model_name]
    display(df[(df['Number of Samples'] == 1) & (df['Number of Training Samples'] == 1)])

In [None]:
# fig, axes = plt.subplots(nrows=1, ncols=3, figsize=(13, 5), sharey=True)
# cbar_ax = fig.add_axes([.9, 0.16, .015, .79])
# for ax, (model_name, combined_df) in zip(axes.flat, combined_df_per_model.items()):
#     pivoted_df = combined_df.pivot("Number of Samples", "Number of Training Samples", 'True Positive Rate')
#     sns_ax = sns.heatmap(pivoted_df, yticklabels=2, vmin=0, vmax=0.8, cmap="Blues", cbar=model_name == model_names_list[-1], ax=ax, cbar_ax=cbar_ax)
#     sns_ax.invert_yaxis()
#     sns_ax.set_xlabel("Number of Training Samples", weight="bold")
#     sns_ax.set_ylabel("Number of Attack Samples", weight="bold")
#     if model_name != model_names_list[0]:
#         sns_ax.set(ylabel=None)
# fig.tight_layout(rect=[0, 0, .9, 1])
# fig.savefig(f'./cc3m_experiments/plots/heatmap_num_training_samples_num_samples_CC2M_combined.pdf')
# fig.savefig(f'./cc3m_experiments/plots/heatmap_num_training_samples_num_samples_CC2M_combined.png', dpi=100)

In [None]:
for model_name, combined_df in combined_df_per_model.items():
    pivoted_df = combined_df.pivot("Number of Samples", "Number of Training Samples", "True Positive Rate")
    plt.figure(figsize=(5, 5))
    if model_name == model_names_list[-1]:
        plt.figure(figsize=(6.2, 5))
    ax = sns.heatmap(pivoted_df, yticklabels=2, vmin=0, vmax=0.8, cmap="Blues", cbar=model_name == model_names_list[-1])
    if model_name == model_names_list[-1]:
        ax.figure.axes[-1].set_ylabel('TPR', weight='bold', size=16)
        ax.collections[0].colorbar.ax.tick_params(labelsize=16)
    ax.set_xlabel("Number of Training Samples", weight="bold", size=16)
    ax.set_ylabel("Number of Attack Samples", weight="bold", size=16)
    ax.set_yticklabels([x.get_text() for x in ax.get_yticklabels()], weight="bold", size=16)
    ax.set_xticklabels([x.get_text() for x in ax.get_xticklabels()], weight="bold", size=16)
    ax.invert_yaxis()
    if model_name != model_names_list[0]:
        ax.set(yticklabels=[" " for x in ax.get_yticklabels()], ylabel=" ")
    plt.tight_layout()
    ax.get_figure().savefig(f'./cc3m_experiments/plots/heatmap_num_training_samples_num_samples_CC2M_{model_name}.pdf')
    ax.get_figure().savefig(f'./cc3m_experiments/plots/heatmap_num_training_samples_num_samples_CC2M_{model_name}.png', dpi=100)
    print(model_name, "TPR")
    plt.show()  