In [1]:
import pandas as pd
import numpy as np
from scipy.stats import ks_2samp, ttest_ind
import pickle

metrics = ['correctness', 'confidence', 'entropy']
klist = [0, 10, 20, 30, 40, 50]

mixup = False


In [2]:
def get_member_ratio(datadf, thre_cnt=1, skip=[], mode='ks'):
    sample2mem = []
    for idx, row in datadf.iterrows():
        sample2mem.append(np.sum([row[m] for m in (set(metrics) - set(skip))]))
    
    members_bool = np.asarray(sample2mem) >= thre_cnt

    if mode == 'ks':
        _, pvalue = ks_2samp(members_bool,[1 for i in range(10000)] , mode='asymp')
    elif mode == 't':
        if members_bool.mean() == 1:
            return 1, 1
        else:
            _, pvalue = ttest_ind(members_bool,[1 for i in range(len(members_bool))], equal_var = True, nan_policy='raise')

    return np.sum(members_bool) / len(datadf), pvalue

In [3]:
### Run EMA on the benchmark datasets, Table 2.b and Table 6
querydata = 'MNIST'
epoch = 50
EMA_res_table = pd.DataFrame(columns=['M1', 'M2', 'M3', 'M4', 'M5', 'M6', 'SVHN'], index=[f'k={100-k}' for k in klist])
for caldata in ['MNIST']:
    for size in [2000, 500, 200, 50, 20, 5]:
        for mode in ['t']:
            print(f'-----  Query_size:{size} \t Cal data:{caldata}  -----')
            for k in klist:
                logname = f'caldata={caldata}_epoch={epoch}_k={k}_calsize=10000'
                for fold in range(7):
                    queryset = pd.read_csv(f'./saves_new/EMA_{querydata}/query_set/binarized_{logname}_fold{fold}.csv')  

                    if fold == 0 or fold == 6:
                        ground_truth = 'Query is not in base'
                    else:
                        ground_truth = 'Query is in base'
                    t, pv = get_member_ratio(queryset[:size], skip=['modified entropy'], mode=mode)

                    EMA_res = np.around(pv, decimals=2)
                    if fold == 0:
                        EMA_res_table['SVHN'][f'k={100-k}'] = EMA_res
                    else:
                        EMA_res_table[f'M{fold}'][f'k={100-k}'] = EMA_res

            print(EMA_res_table, '\n')
            EMA_res_table = pd.DataFrame(columns=['M1', 'M2', 'M3', 'M4', 'M5', 'M6', 'SVHN'], index=[f'k={100-k}' for k in klist])

-----  Query_size:2000 	 Cal data:MNIST  -----
      M1 M2 M3 M4 M5   M6 SVHN
k=100  1  1  1  1  1  0.0  0.0
k=90   1  1  1  1  1  0.0  0.0
k=80   1  1  1  1  1  0.0  0.0
k=70   1  1  1  1  1  0.0  0.0
k=60   1  1  1  1  1  0.0  0.0
k=50   1  1  1  1  1  0.0  0.0 

-----  Query_size:500 	 Cal data:MNIST  -----
      M1 M2 M3 M4 M5   M6 SVHN
k=100  1  1  1  1  1  0.0  0.0
k=90   1  1  1  1  1  0.0  0.0
k=80   1  1  1  1  1  0.0  0.0
k=70   1  1  1  1  1  0.0  0.0
k=60   1  1  1  1  1  0.0  0.0
k=50   1  1  1  1  1  0.0  0.0 

-----  Query_size:200 	 Cal data:MNIST  -----
      M1 M2 M3 M4 M5    M6 SVHN
k=100  1  1  1  1  1  0.01  0.0
k=90   1  1  1  1  1   0.0  0.0
k=80   1  1  1  1  1  0.01  0.0
k=70   1  1  1  1  1  0.08  0.0
k=60   1  1  1  1  1  0.32  0.0
k=50   1  1  1  1  1  0.01  0.0 

-----  Query_size:50 	 Cal data:MNIST  -----
      M1 M2 M3 M4 M5    M6 SVHN
k=100  1  1  1  1  1  0.16  0.0
k=90   1  1  1  1  1  0.16  0.0
k=80   1  1  1  1  1  0.32  0.0
k=70   1  1  1  1  1    

In [4]:
### Run EMA on chest X-ray datasets, Table 3.b and Table 7
querydata = 'COVIDx'
epoch = 30

EMA_res_table = pd.DataFrame(columns=['C1', 'C2', 'C3', 'C4', 'C5', 'C6', 'CXR'], index=[f'k={100-k}' for k in klist])
for caldata in ['COVIDx']:
    for size in [800, 500, 200, 50, 20, 5]:
        print(f'-----  Query_size:{size} \t Cal data:{caldata}  -----')
        for mode in ['t']:
            print(f'Using {mode}-test')
            for k in klist:
                logname = f'caldata={caldata}_epoch={epoch}_k={k}_calsize=4000'
                for fold in range(7):
                    queryset = pd.read_csv(f'./saves_new/EMA_{querydata}/query_set/binarized_{logname}_fold{fold}.csv')  

                    if fold == 0 or fold == 6:
                        ground_truth = 'Query is not in base'
                    else:
                        ground_truth = 'Query is in base'
                    t, pv = get_member_ratio(queryset[:size], skip=['modified entropy'], mode=mode)

                    EMA_res = np.around(pv, decimals=2)
                    if fold == 0:
                        EMA_res_table['CXR'][f'k={100-k}'] = EMA_res
                    else:
                        EMA_res_table[f'C{fold}'][f'k={100-k}'] = EMA_res

            print(EMA_res_table, '\n')
            EMA_res_table = pd.DataFrame(columns=['C1', 'C2', 'C3', 'C4', 'C5', 'C6', 'CXR'], index=[f'k={100-k}' for k in klist])

-----  Query_size:800 	 Cal data:COVIDx  -----
Using t-test
      C1 C2 C3 C4 C5   C6  CXR
k=100  1  1  1  1  1  0.0  0.0
k=90   1  1  1  1  1  0.0  0.0
k=80   1  1  1  1  1  0.0  0.0
k=70   1  1  1  1  1  0.0  0.0
k=60   1  1  1  1  1  0.0  0.0
k=50   1  1  1  1  1  0.0  0.0 

-----  Query_size:500 	 Cal data:COVIDx  -----
Using t-test
      C1 C2 C3 C4 C5   C6  CXR
k=100  1  1  1  1  1  0.0  0.0
k=90   1  1  1  1  1  0.0  0.0
k=80   1  1  1  1  1  0.0  0.0
k=70   1  1  1  1  1  0.0  0.0
k=60   1  1  1  1  1  0.0  0.0
k=50   1  1  1  1  1  0.0  0.0 

-----  Query_size:200 	 Cal data:COVIDx  -----
Using t-test
      C1 C2 C3 C4 C5   C6  CXR
k=100  1  1  1  1  1  0.0  0.0
k=90   1  1  1  1  1  0.0  0.0
k=80   1  1  1  1  1  0.0  0.0
k=70   1  1  1  1  1  0.0  0.0
k=60   1  1  1  1  1  0.0  0.0
k=50   1  1  1  1  1  0.0  0.0 

-----  Query_size:50 	 Cal data:COVIDx  -----
Using t-test
      C1 C2 C3 C4 C5    C6   CXR
k=100  1  1  1  1  1  0.01  0.04
k=90   1  1  1  1  1   0.0   0.0
k=80 