In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from IPython.display import display

pd.options.display.max_rows = 180
pd.options.display.max_columns = 75
np.set_printoptions(precision=4)

def filter(df: pd.DataFrame, col: str, values: list[float], tol: float = 1e-6) -> pd.DataFrame:
  def is_within_tolerance(val, target_values, tolerance):
    return any(abs(val - target) < tolerance for target in target_values)
  
  filtered_df = df[df[col].apply(lambda x: is_within_tolerance(x, values, tol))]
  return filtered_df

name2ratios = {
  'gauss_noise': list(np.arange(0, 0.055, 0.01)),
  'alteration': list(np.arange(0, 0.055, 0.01)),
  'sample_deletion': list(np.arange(0, 0.55, 0.1)),
  'noise_deletion': list(np.arange(0, 0.055, 0.01)),
}

all_datanames = [
  'beijing',
]
all_num_users = [
  1000,
]
all_watermark = [
  'pair_compare_one_pair',
]
all_num_classes = [
  256,
]
all_num_watermark_bits = [
  32,
]
all_min_hamming_dist = [
  7,
]
all_ratio_num_samples_per_class_interval = [
  -30000,
]
all_quality_loss = [
  'quad_random_init985',
]
all_embedding_models = [
  'orig',
]
all_quality_modes = [
  'average',
]
all_gen_code_losses = [
  'general_bfs',
]
all_tao_approximations = [
  0
]
all_time_limits = [
  180
]
all_error_rates = [
  0.001,
]
all_deletion_rates = [
  0,
]
all_dim_ratios = [
  'correct-pca-0.99',
]
all_upper_bounds = [
  '6stage_splus-0.01-0.01',
]
def f(name: str) -> pd.DataFrame:
  df_ours = pd.read_json(f'../results_final/{name}.json', lines=True)
  ratio = {'gauss_noise': 'gauss_noise_ratio', 'alteration': 'alteration_ratio', 'sample_deletion': 'sample_deletion_ratio', 'noise_deletion': 'noise_ratio'}[name]
  df_ours = filter(df_ours, ratio, name2ratios[name])
  if name != 'noise_deletion':
    tmp = list(set(['gauss', 'alter', 'deletion_rate']) - set([{'gauss_noise': 'gauss', 'alteration': 'alter', 'sample_deletion': 'deletion_rate'}[name]]))
    df_ours = df_ours[(df_ours[tmp[0]] == 0) & (df_ours[tmp[1]] == 0)]
  else:
    df_ours = df_ours[(df_ours['gauss'] == df_ours['alter']) & ((df_ours['deletion_rate'] - 10 * df_ours['alter']).abs() < 1e-3)]
  df_ours = df_ours[['dataname', 'num_users', 'watermark', 'num_classes', 'num_samples', 'num_watermark_bits', 'min_hamming_dist', 'ratio_num_samples_per_class_interval', 'classifier', 'quality_loss', 'embedding_model', 'correct', 'tao_approximation', 'num_tested_samples_per_class', 'loss', 'quad_loss', 'time_limit', 'min_gap', 'error_rate', 'gap', 'deletion_rate', ratio, 'dim_ratio', 'gen_code_loss', 'num_samples_per_class_upper_bound', 'gauss', 'alter']]
  df_ours['quad_loss'] = df_ours['quad_loss'] ** 0.5
  df_ours['embedding_model'] = df_ours['embedding_model'].apply(lambda x: x[:4])
  df_ours = df_ours[df_ours['embedding_model'].isin(all_embedding_models)]
  df_ours = df_ours[df_ours['time_limit'].isin(all_time_limits)]
  df_ours = df_ours[df_ours['dataname'].isin(all_datanames)]
  df_ours = df_ours[df_ours['num_users'].isin(all_num_users)]
  df_ours = df_ours[df_ours['watermark'].isin(all_watermark)]
  df_ours = df_ours[df_ours['num_classes'].isin(all_num_classes)]
  df_ours = df_ours[df_ours['num_watermark_bits'].isin(all_num_watermark_bits)]
  df_ours = df_ours[df_ours['num_samples'] < 50000]
  df_ours = df_ours[df_ours['min_hamming_dist'].isin(all_min_hamming_dist)]
  df_ours = df_ours[df_ours['ratio_num_samples_per_class_interval'].isin(all_ratio_num_samples_per_class_interval)]
  df_ours = df_ours[df_ours['tao_approximation'].isin(all_tao_approximations)]
  df_ours = df_ours[df_ours['error_rate'].isin(all_error_rates)]
  df_ours = df_ours[df_ours['classifier'] == 'nn']
  df_ours = df_ours[df_ours['quality_loss'].isin(all_quality_loss)]
  df_ours = df_ours[df_ours['dim_ratio'].isin(all_dim_ratios)]
  df_ours = df_ours[df_ours['num_samples_per_class_upper_bound'].isin(all_upper_bounds)]
  df_ours = df_ours[df_ours['gen_code_loss'].isin(all_gen_code_losses)]
  df_ours = df_ours.rename(columns={
    'ratio_num_samples_per_class_interval': 'ratio',
    'num_watermark_bits': 'bits',
    'min_hamming_dist': 'hamming_distance'
  })

  df_ours = df_ours.groupby(by=['num_users', 'bits', 'hamming_distance', 'ratio', 'gen_code_loss', 'error_rate', 'dataname', 'gauss', 'alter', 'deletion_rate', ratio], as_index=True).head(100).groupby(by=['num_users', 'bits', 'hamming_distance', 'ratio', 'gen_code_loss', 'error_rate', 'dataname', 'gauss', 'alter', 'deletion_rate', ratio], as_index=True)[['correct', 'quad_loss']].agg({
    'correct': ['count', 'mean'],
    'quad_loss': 'mean',
  })
  for col in df_ours.columns:
    if df_ours[col].dtype in [np.float32, np.float64, float]:
      df_ours[col] = df_ours[col].round(3)
  display(f'our_{name}', df_ours)
  df_ours = df_ours.reset_index()
  return df_ours

colors = [
  '#D4EBFF',
  '#A8D3FF',
  '#7CBAFF',
  '#4F9EFF',
  '#2A75E6',
  '#0A3D8C',
]
markers = [
  'o',
  's',
  'D',
  '^',
  'x',
  '*',
]
labels = ['(a)', '(b)', '(c)', '(d)']

name2rmse = {'gauss_noise': [], 'alteration': [], 'sample_deletion': [], 'noise_deletion': []}

plt.figure(dpi=1000)
fig, axs = plt.subplots(2, 2, figsize=(12, 9))
plt.subplots_adjust(wspace=0.25, hspace=0.3)
for i, name in enumerate(['gauss_noise', 'alteration', 'sample_deletion', 'noise_deletion']):
  ax = axs[i // 2][i % 2]
  ax.set_xticks([0] + name2ratios[name])
  ax.set_xticklabels(['0'] + [f'{r:g}' for r in name2ratios[name]], fontsize=17)
  tmp = f'Actual Intensity'
  ax.set_xlabel(tmp, fontsize=18)
  ax.set_ylabel('Traceability Accuracy', fontsize=18,
                )
  ax.set_yticks([0, 0.5, 1])
  ax.set_yticklabels(['0', '0.5', '1.0'], fontsize=17)
  ax.set_ylim(-0.1, 1.1)

  df_ours = f(name)
  for j, ratio in enumerate(name2ratios[name]):
    tmp1 = {'gauss_noise': 'gauss', 'alteration': 'alter', 'sample_deletion': 'deletion_rate', 'noise_deletion': 'gauss'}[name]
    tmp2 = {'gauss_noise': 'gauss_noise_ratio', 'alteration': 'alteration_ratio', 'sample_deletion': 'sample_deletion_ratio', 'noise_deletion': 'noise_ratio'}[name]
    x = df_ours.loc[(df_ours[tmp1] == ratio), tmp2].tolist()
    y = df_ours.loc[(df_ours[tmp1] == ratio), ('correct', 'mean')].tolist()
    tmp = {'gauss_noise': f'I_{{per}}', 'alteration': f'I_{{alt}}', 'sample_deletion': f'I_{{del}}', 'noise_deletion': f'I_{{mix}}'}[name]
    ax.plot(x, y, label=f'${tmp}: {ratio:.2f}$', marker=markers[j], markersize=6, linewidth=1, color=colors[j], markerfacecolor=colors[j])
    name2rmse[name].append(df_ours.loc[(df_ours[tmp1] == ratio), ('quad_loss', 'mean')].tolist()[0])

  ax.legend(fontsize=16, loc='lower left')
  ax.text(0.5, -0.21, labels[i], transform=ax.transAxes,
          fontsize=16, fontweight='bold', ha='center', va='top')

plt.show()
display(name2rmse)
