In [1]:
import numpy as np
import pandas as pd
import pickle
import os
from tqdm import tqdm
from MINE.gtm import GTM

In [2]:
def load_neighbours(path):
    files = os.listdir(path)
    files = [f.split('_')[:-3] + [path + f] for f in files]
    files = pd.DataFrame(files, columns=['gamma', 'n', 'c', 'h', 'k', 'seed', 'file'])
    files.gamma = files.gamma.astype(float)
    files.n = files.n.astype(int)
    files.c = files.c.astype(int)
    files.h = files.h.astype(int)
    files.k = files.k.astype(int)
    files.seed = files.seed.astype(int)

    seeds = files.groupby('seed').size().sort_values(ascending=False).head(50).index
    files = files.loc[files.seed.isin(seeds)]
    
    print(files.groupby('n').size())
    
    sizes = files.groupby(['gamma', 'n', 'c', 'h', 'k']).size()
    sizes = 50 - sizes.loc[sizes!=50]
    print(sizes.sum())
    print(sizes.index)
    
    files = files.sort_values('seed')
    
    return files

def load_diffs(path):
    files = os.listdir(path)
    files = [f.split('_')[:-3] + [path + f] for f in files]
    files = pd.DataFrame(files, columns=['gamma', 'n', 'c', 'h', 'seed', 'file'])
    files.gamma = files.gamma.astype(float)
    files.n = files.n.astype(int)
    files.c = files.c.astype(int)
    files.h = files.h.astype(int)
    files.seed = files.seed.astype(int)

    seeds = files.groupby('seed').size().sort_values(ascending=False).head(50).index
    files = files.loc[files.seed.isin(seeds)]
    
    print(files.groupby('n').size())
    
    sizes = files.groupby(['gamma', 'n', 'c', 'h']).size()
    sizes = 50 - sizes.loc[sizes!=50]
    print(sizes.sum())
    print(sizes.index)
    
    files = files.sort_values('seed')
    
    return files

def load_selection(path):
    files = os.listdir(path)
    files = [f.split('_')[1:-3] + [path + f] for f in files]
    files = pd.DataFrame(files, columns=['gamma', 'n', 'h', 'seed', 'file'])
    files.gamma = files.gamma.astype(float)
    files.n = files.n.astype(int)
    files.h = files.h.astype(int)
    files.seed = files.seed.astype(int)
    
    sizes = files.groupby(['gamma', 'n', 'h']).size()
    sizes = 50 - sizes.loc[sizes!=50]
    print(sizes.sum())
    print(sizes.index)
    
    files = files.sort_values('seed')
    
    return files

def load_initial(path):
    files = os.listdir(path)
    files = [f.split('_')[:2] + [path + f] for f in files]
    files = pd.DataFrame(files, columns=['n', 'seed', 'file'])
    files.n = files.n.astype(int)
    files.seed = files.seed.astype(int)
    
    files = files.sort_values('seed')
    
    return files

# Loading files

In [3]:
files_neighbours = load_neighbours('results/neighbours_final/')

n
1000     15000
5000     15000
10000    15000
dtype: int64
0
MultiIndex([], names=['gamma', 'n', 'c', 'h', 'k'])


In [4]:
files_diffs = load_diffs('results/diff/')

n
1000     1500
5000     1500
10000    1500
dtype: int64
0
MultiIndex([], names=['gamma', 'n', 'c', 'h'])


In [5]:
files_neighbours2 = load_neighbours('results2/neighbours_final/')

n
1000     15000
5000     15000
10000    15000
dtype: int64
0
MultiIndex([], names=['gamma', 'n', 'c', 'h', 'k'])


In [6]:
files_diffs2 = load_diffs('results2/diff/')

n
1000     1500
5000     1500
10000    1500
dtype: int64
0
MultiIndex([], names=['gamma', 'n', 'c', 'h'])


In [7]:
files_neighbours = pd.concat((files_neighbours, files_neighbours2))
files_neighbours = files_neighbours.sort_values('seed')

files_diffs = pd.concat((files_diffs, files_diffs2))
files_diffs = files_diffs.sort_values('seed')

In [8]:
files_neighbours

Unnamed: 0,gamma,n,c,h,k,seed,file
27494,0.75,5000,2,64,9,419578,results/neighbours_final/0.75_5000_2_64_9_4195...
12108,0.60,5000,2,64,2,419578,results/neighbours_final/0.6_5000_2_64_2_41957...
12169,0.60,5000,2,64,3,419578,results/neighbours_final/0.6_5000_2_64_3_41957...
8850,0.60,1000,6,64,4,419578,results/neighbours_final/0.6_1000_6_64_4_41957...
12232,0.60,5000,2,64,4,419578,results/neighbours_final/0.6_5000_2_64_4_41957...
...,...,...,...,...,...,...,...
43721,0.90,5000,5,64,3,997480673,results/neighbours_final/0.9_5000_5_64_3_99748...
43671,0.90,5000,5,64,2,997480673,results/neighbours_final/0.9_5000_5_64_2_99748...
42521,0.90,5000,2,64,9,997480673,results/neighbours_final/0.9_5000_2_64_9_99748...
35721,0.90,10000,9,64,3,997480673,results/neighbours_final/0.9_10000_9_64_3_9974...


In [9]:
files_diffs

Unnamed: 0,gamma,n,c,h,seed,file
1322,0.60,5000,6,64,419578,results/diff/0.6_5000_6_64_419578_2022-10-06T1...
672,0.60,1000,3,64,419578,results/diff/0.6_1000_3_64_419578_2022-10-06T1...
2822,0.75,5000,6,64,419578,results/diff/0.75_5000_6_64_419578_2022-10-06T...
3722,0.90,1000,4,64,419578,results/diff/0.9_1000_4_64_419578_2022-10-06T1...
2772,0.75,5000,5,64,419578,results/diff/0.75_5000_5_64_419578_2022-10-06T...
...,...,...,...,...,...,...
4199,0.90,5000,3,64,997480673,results/diff/0.9_5000_3_64_997480673_2022-10-0...
4249,0.90,5000,4,64,997480673,results/diff/0.9_5000_4_64_997480673_2022-10-0...
4299,0.90,5000,5,64,997480673,results/diff/0.9_5000_5_64_997480673_2022-10-0...
3349,0.90,10000,6,64,997480673,results/diff/0.9_10000_6_64_997480673_2022-10-...


In [3]:
files_sel = load_selection('results_selection/')

-1350
MultiIndex([( 0.6,  1000, 64),
            ( 0.6,  5000, 64),
            ( 0.6, 10000, 64),
            (0.75,  1000, 64),
            (0.75,  5000, 64),
            (0.75, 10000, 64),
            ( 0.9,  1000, 64),
            ( 0.9,  5000, 64),
            ( 0.9, 10000, 64)],
           names=['gamma', 'n', 'h'])


In [3]:
files_initial = load_initial('results_initial/results_retry/')
files_initial2 = load_initial('results_initial/results_retry_only_hd/')

In [4]:
files_initial

Unnamed: 0,n,seed,file
237,100,419578,results_initial/results_retry/100_419578_2022-...
37,10000,419578,results_initial/results_retry/10000_419578_202...
137,1000,419578,results_initial/results_retry/1000_419578_2022...
15,10000,24885728,results_initial/results_retry/10000_24885728_2...
115,1000,24885728,results_initial/results_retry/1000_24885728_20...
...,...,...,...
97,10000,987605913,results_initial/results_retry/10000_987605913_...
297,100,987605913,results_initial/results_retry/100_987605913_20...
298,100,997480673,results_initial/results_retry/100_997480673_20...
98,10000,997480673,results_initial/results_retry/10000_997480673_...


# True CMI

In [4]:
cmis_true = {}
for gamma in [0.6, 0.75, 0.9]:
    cmis_true[gamma] = {}
    for c in range(1, 11):
        gtm = GTM(11, gamma)
        cmis_true[gamma][c] = gtm.mi(c)

# Aggregating neighbours

In [11]:
results = {}

for gamma in [0.6, 0.75, 0.9]:
    results[gamma] = {}
    for n in [1000, 5000, 10_000]:
        results[gamma][n] = {}
        for c in range(1, 11):
            results[gamma][n][c] = {}
            for model in ['classif', 'opt', 'classif_b', 'opt_b', 'classif_b_avg', 'opt_b_avg']:
                results[gamma][n][c][model] = {}
                for k in range(1, 11):
                    results[gamma][n][c][model][k] = {}

for i, file in tqdm(files_neighbours.iterrows()):
    with open(file['file'], 'rb') as fd:
        d = pickle.load(fd)
    for model, r in d[file['gamma']][file['c']][file['h']].items():
        try:
            if file['k'] == 1 and (model == 'classif' or model == 'opt'):
                results[file['gamma']][file['n']][file['c']][model][file['k']][file['seed']] = (None, None, cmis_true[file['gamma']][file['c']])
            else:
                results[file['gamma']][file['n']][file['c']][model][file['k']][file['seed']] = list(r[file['k']]) + [cmis_true[file['gamma']][file['c']]]
        except:
            print(file)
            print(model)
            raise RuntimeError()

4577it [01:01, 74.27it/s]  


KeyboardInterrupt: 

In [None]:
with open('clean_results/results_neighbours.pkl', 'wb') as fd:
    pickle.dump(results, fd)

# Aggregating diffs

In [11]:
results = {}

for gamma in [0.6, 0.75, 0.9]:
    results[gamma] = {}
    for n in [1000, 5000, 10_000]:
        results[gamma][n] = {}
        for c in range(1, 11):
            results[gamma][n][c] = {}
            for model in ['diff_classif', 'diff_opt']:
                results[gamma][n][c][model] = {}

for i, file in tqdm(files_diffs.iterrows()):
    with open(file['file'], 'rb') as fd:
        d = pickle.load(fd)
    for model, r in d[file['gamma']][file['c']][file['h']].items():
        results[file['gamma']][file['n']][file['c']][model][file['seed']] = list(r) + [cmis_true[file['gamma']][file['c']]]

9000it [00:03, 2376.25it/s]


In [13]:
with open('clean_results/results_diffs.pkl', 'wb') as fd:
    pickle.dump(results, fd)

# Aggregating selection

In [4]:
results = {}

for gamma in [0.6, 0.75, 0.9]:
    results[gamma] = {}
    for n in [1000, 5000, 10_000]:
        results[gamma][n] = {}
        for model in ['diff_opt', 'diff_classif', 'classif', 'classif_b', 'classif_b_avg', 'opt', 'opt_b', 'opt_b_avg']:
            results[gamma][n][model] = {}

for i, file in tqdm(files_sel.iterrows()):
    with open(file['file'], 'rb') as fd:
        d = pickle.load(fd)
    for model, r in d[file['gamma']][file['h']].items():
        results[file['gamma']][file['n']][model][file['seed']] = np.array(r)

1800it [00:00, 2769.86it/s]


In [5]:
with open('clean_results/results_sel.pkl', 'wb') as fd:
    pickle.dump(results, fd)

# Aggregating initial

In [20]:
d = []
for i, row in files_initial.iterrows():
    with open(row['file'], 'rb') as fd:
        d.append(pd.DataFrame(pickle.load(fd), columns=['dataset', 'dataset_ver', 'n', 'method', 'approx', 'reg', 'seed', 'mi', 'history']))
d = pd.concat(d)

In [21]:
d = d.loc[(d.dataset != 'norm_hd') & (d.dataset != 'norm_hd_aug')]

In [22]:
d2 = []
for i, row in files_initial2.iterrows():
    with open(row['file'], 'rb') as fd:
        d2.append(pd.DataFrame(pickle.load(fd), columns=['dataset', 'dataset_ver', 'n', 'method', 'approx', 'reg', 'seed', 'mi', 'history']))
d2 = pd.concat(d2)

In [23]:
ksg = pd.read_csv('KSG MI.csv')

In [24]:
len(d), len(d2), len(ksg)

(49500, 16500, 1200)

In [25]:
d = pd.concat((d, d2, ksg))

In [26]:
d

Unnamed: 0,dataset,dataset_ver,n,method,approx,reg,seed,mi,history
0,uni,orig,100,opt,DV,0.0,419578,1.910404,"{'mi': [-0.028482038527727127, -0.033513527363..."
1,uni,orig,100,opt,DV,0.1,419578,1.941296,"{'mi': [-0.031074389815330505, -0.028770625591..."
2,uni,orig,100,opt,FD,0.0,419578,1.764967,"{'mi': [-0.384560227394104, -0.363346546888351..."
3,uni,orig,100,opt,FD,0.1,419578,1.935687,"{'mi': [-0.31775134801864624, -0.2978672385215..."
4,uni,orig,100,classif,DV,0.0,419578,2.225209,"{'mi': [-0.030394213274121284, -0.021599419414..."
...,...,...,...,...,...,...,...,...,...
1195,norm_hd,orig,10000,ksg,,,936405604,1.682759,
1196,uni,orig,10000,ksg,,,506355475,,
1197,norm_not_corr,orig,10000,ksg,,,506355475,-0.000036,
1198,norm_corr,orig,10000,ksg,,,506355475,0.179058,


In [27]:
with open('clean_results/results_initial.pkl', 'wb') as fd:
    pickle.dump(d, fd)