In [1]:
import pickle
import sys
import numpy as np

In [2]:
def print_cv_table(data_name, share, zero_style):
    """
    print the dictionary as markdown table,
    write cv results to local pickle file

    :param data_name: the name of current cv dataset
    :type data_name: str
    :param zero_style: the style of zeroing attn heads, supporting 'random','first' and 'shuffle'
    :type zero_style: str
    :param share: the % attention heads to be changed
    :type share: int
    """
    pkl_file = "../ppl/cv_accumu_{}_{}_{}.pkl".format(data_name,
                                                   zero_style,
                                                   share)
    with open(pkl_file, "rb") as handle:
        cv_dict = pickle.load(handle)
    data_name = data_name+"-"+str(share)+"-"+zero_style
    sys.stdout.write("| {} | {:0.3f} ({:0.3f})| {:0.3f} ({:0.3f}) | {:0.3f} ({:0.3f})| {:0.3f} ({:0.3f})| {:0.3f} ({:0.3f}) | {:0.3f} ({:0.3f})| {:0.3f} ({:0.3f})| {:0.2f} ({:0.3f}) | {:0.3f} ({:0.3f})|\n".format(
        data_name,
        np.mean(cv_dict["test_con_auc"]), np.std(cv_dict["test_con_auc"]),
        np.mean(cv_dict["test_con_accu"]), np.std(cv_dict["test_con_accu"]),
        np.mean(cv_dict["test_con_cor"]), np.std(cv_dict["test_con_cor"]),
        np.mean(cv_dict["test_dem_auc"]), np.std(cv_dict["test_dem_auc"]),
        np.mean(cv_dict["test_dem_accu"]), np.std(cv_dict["test_dem_accu"]),
        np.mean(cv_dict["test_dem_cor"]), np.std(cv_dict["test_dem_cor"]),
        np.mean(cv_dict["test_ratio_auc"]), np.std(cv_dict["test_ratio_auc"]),
        np.mean(cv_dict["test_ratio_accu"]), np.std(cv_dict["test_ratio_accu"]),
        np.mean(cv_dict["test_ratio_cor"]), np.std(cv_dict["test_ratio_cor"])
    ))

In [3]:
# zeroing attn weight matix only
zero_style = "first"
sys.stdout.write("| dataset | con AUC (SD)| con ACC (SD) | con r with MMSE (SD)| dem AUC (SD)| dem ACC (SD) | dem r with MMSE (SD)| ratio AUC (SD)| ratio ACC (SD) | ratio r with MMSE (SD)|\n")
sys.stdout.write("| - | - | - | - | - | - | - | - | - | - |\n")
for share in (25, 50, 75, 100):
    for data_name in ("adr", "db", "ccc"):
        print_cv_table(data_name, share, zero_style)

| dataset | con AUC (SD)| con ACC (SD) | con r with MMSE (SD)| dem AUC (SD)| dem ACC (SD) | dem r with MMSE (SD)| ratio AUC (SD)| ratio ACC (SD) | ratio r with MMSE (SD)|
| - | - | - | - | - | - | - | - | - | - |
| adr-25-first | 0.614 (0.054)| 0.587 (0.061) | -0.247 (0.189)| 0.416 (0.061)| 0.413 (0.071) | 0.061 (0.243)| 0.769 (0.099)| 0.69 (0.126) | -0.509 (0.091)|
| db_c-25-first | 0.663 (0.016)| 0.599 (0.033) | -0.205 (0.158)| 0.419 (0.071)| 0.420 (0.051) | 0.090 (0.139)| 0.789 (0.056)| 0.73 (0.070) | -0.434 (0.091)|
| ccc-25-first | 0.734 (0.057)| 0.649 (0.040) | nan (nan)| 0.647 (0.064)| 0.628 (0.066) | nan (nan)| 0.712 (0.047)| 0.68 (0.038) | nan (nan)|
| adr-50-first | 0.614 (0.095)| 0.547 (0.102) | -0.247 (0.194)| 0.367 (0.110)| 0.413 (0.103) | 0.151 (0.267)| 0.794 (0.040)| 0.72 (0.043) | -0.567 (0.070)|
| db_c-50-first | 0.653 (0.081)| 0.606 (0.077) | -0.193 (0.144)| 0.398 (0.077)| 0.431 (0.074) | 0.140 (0.096)| 0.805 (0.068)| 0.71 (0.074) | -0.424 (0.118)|
| ccc-50-first | 0.

In [3]:
# zeroing attn weight + bias
zero_style = "first"
sys.stdout.write("| dataset | con AUC (SD)| con ACC (SD) | con r with MMSE (SD)| dem AUC (SD)| dem ACC (SD) | dem r with MMSE (SD)| ratio AUC (SD)| ratio ACC (SD) | ratio r with MMSE (SD)|\n")
sys.stdout.write("| - | - | - | - | - | - | - | - | - | - |\n")
for share in (25, 50, 75, 100):
    print_cv_table("ccc", share, zero_style)
    print_cv_table("adr", share, zero_style)
    print_cv_table("db_c", share, zero_style)

| dataset | con AUC (SD)| con ACC (SD) | con r with MMSE (SD)| dem AUC (SD)| dem ACC (SD) | dem r with MMSE (SD)| ratio AUC (SD)| ratio ACC (SD) | ratio r with MMSE (SD)|
| - | - | - | - | - | - | - | - | - | - |
| ccc-25-first | 0.581 (0.095)| 0.525 (0.057) | nan (nan)| 0.468 (0.107)| 0.418 (0.184) | nan (nan)| 0.714 (0.054)| 0.68 (0.035) | nan (nan)|
| adr-25-first | 0.622 (0.072)| 0.593 (0.117) | -0.226 (0.212)| 0.357 (0.090)| 0.387 (0.042) | 0.167 (0.160)| 0.799 (0.072)| 0.71 (0.122) | -0.528 (0.138)|
| db_c-25-first | 0.659 (0.090)| 0.643 (0.089) | -0.192 (0.118)| 0.367 (0.052)| 0.390 (0.051) | 0.165 (0.130)| 0.794 (0.054)| 0.74 (0.052) | -0.413 (0.164)|
| ccc-50-first | 0.609 (0.077)| 0.589 (0.031) | nan (nan)| 0.473 (0.084)| 0.468 (0.132) | nan (nan)| 0.720 (0.038)| 0.64 (0.092) | nan (nan)|
| adr-50-first | 0.632 (0.044)| 0.613 (0.081) | -0.199 (0.231)| 0.361 (0.063)| 0.464 (0.092) | 0.169 (0.282)| 0.792 (0.024)| 0.68 (0.025) | -0.511 (0.046)|
| db_c-50-first | 0.664 (0.072)| 0