In [1]:
import os
import json
import pandas as pd
import numpy as np; np.random.seed(0)
import seaborn as sns; sns.set_theme()
import matplotlib.pyplot as plt

from pyitlib import discrete_random_variable as drv

# Unzip runs.zip in NASLib/run/results/correlation before running the code

In [2]:
BENCHMARKS = {
#     'nasbench101': ['cifar10'],
    'nasbench201': ['cifar10', 'cifar100', 'ImageNet16-120'],
    'nasbench301': ['cifar10'],
    'transbench101_micro': ['jigsaw', 'class_scene', 'class_object', 'autoencoder', 'normal', 'room_layout', 'segmentsemantic'],
    'transbench101_macro': ['jigsaw', 'class_scene', 'class_object', 'autoencoder', 'normal', 'room_layout', 'segmentsemantic']
}

LABELS = {
    'nasbench101': 'NB101',
    'nasbench201': 'NB201',
    'nasbench301': 'NB301',
    'transbench101_micro': 'TNB101_MICRO',
    'transbench101_macro': 'TNB101_MACRO',
    'cifar10': 'CF10',
    'cifar100': 'CF100',
    'ImageNet16-120': 'IMGNT',
    'jigsaw': 'JIGSAW',
    'class_scene': 'SCENE',
    'class_object': 'OBJECT',
    'autoencoder': 'AUTOENC',
    'normal': 'NORMAL',
    'room_layout': 'ROOM',
    'segmentsemantic': 'SEGMENT',
}

START_SEED=9000
END_SEED=9005

In [3]:
def get_all_files(root_folder='../run/results/correlation', filename='scores.json'):
    all_files = []
    for root, dirs, files in os.walk(root_folder):
        for file in files:
            if file.endswith(filename):
                #print(os.path.join(root, file))
                all_files.append(os.path.join(root, file))

    return all_files


def get_scores_as_list_of_dict(files):
    data = []

    for file in files:
        file_components = file.split('/')
        search_space, dataset, predictor, seed = file_components[-5], file_components[-4], file_components[-3], file_components[-2]
        
        with open(file, 'r') as f:
            scores = json.load(f)[1]

        record = {
            'search_space': search_space,
            'dataset': dataset,
            'predictor': predictor,
            'seed': seed,
            'kendalltau': scores['kendalltau'],
            'pearson': scores['pearson'],
            'spearman': scores['spearman'],
            'preds': scores['full_testpred'],
            'ground_truth': scores['full_ytest'],
        }

        data.append(record)

    return data

def sort_by_mean(df):
    df = df.reindex(df.transpose().mean().sort_values().index)
    df = df.reindex(df.mean().sort_values().index, axis=1)
    return df

def plot_heatmap(df, figsize=(20, 10), rotation=0, title='', cmap='viridis_r', savetitle='zcp_corr', savedir='.'):
    plt.figure(figsize=figsize, dpi=200)
    plt.title(title, fontsize=20)
    heatmap = sns.heatmap(df, annot=True, cmap=cmap, fmt='.2f', vmin=-1, vmax=1, cbar=False)
    heatmap.set_xticklabels(heatmap.get_xticklabels(), rotation=rotation, fontsize=12)
    heatmap.set(xlabel=None)
    heatmap.set(ylabel=None)
    plt.tight_layout()
    if savedir != '.':
        os.makedirs(savedir, exist_ok=True)
    plt.savefig('{}/{}.pdf'.format(savedir, savetitle), bbox_inches='tight')    


def make_df(files):
    data = get_scores_as_list_of_dict(files)
    df = pd.DataFrame(data)
    return df

def make_clean_df(preds_to_drop=None):
    files = get_all_files()
    print(files[0])
    df = make_df(files)
    return df
    # Colin remove the rest of this. Do we need it for mutual info experiments?

    wrong_len_indexes = []
    wrong_type_indexes = []
    all_preds_same_indexes = []
    minus_100_indexes = []

    if preds_to_drop is not None:
        for predictor in preds_to_drop:
            df = df[df['predictor'] != predictor]
    
    for idx in df.index:
        preds = df.loc[idx]['preds']
        if not isinstance(preds, list):
            wrong_type_indexes.append(idx)
        elif len(preds) != 200:
            wrong_len_indexes.append(idx)
        elif len(set(preds)) == 1:
            all_preds_same_indexes.append(idx)
        elif -100000000.0 in preds:
            minus_100_indexes.append(idx)

    wrong_len_indexes, wrong_type_indexes, len(all_preds_same_indexes), len(minus_100_indexes)
    bad_indices = wrong_len_indexes + wrong_type_indexes + all_preds_same_indexes

    print(df.loc[wrong_type_indexes])
    df = df.drop(index=bad_indices)
    return df

Note: there are many other useful methods in NASLib/plotting/PlotCorrelations.ipynb

Note: to compute the mutual information and/or entropy of a random variable (a zero-cost proxy)
if we only have samples of that random variable (e.g. 500 samples, or 15625 samples): we need to discretize the random variable. That is, make a set of bins with equal number of ZC proxy values in each bin. For example, make 10 or 100 bins, based on how many samples we have.

If we want to compute the conditional entropy conditioned on two ZC proxies, we need to create 5x5=25 bins (or 10x10=100, etc), since we need to discretize both.

In [6]:
def discretize(values, bins=10):
    bin_edges = np.histogram_bin_edges(values, bins=bins)
    bin_edges[-1] += 1
    return np.digitize(values, bin_edges)

def compute_mutual_info(df, predictors):
    corr = np.zeros((len(predictors), len(predictors)))
    for i in range(len(predictors)):
        for j in range(len(predictors)):
            zc_1 = predictors[i]
            zc_2 = predictors[j]

            preds_1 = np.array(df[df['predictor'] == zc_1]['preds'].to_list()[0])
            preds_2 = np.array(df[df['predictor'] == zc_2]['preds'].to_list()[0])
            truths = np.array(df[df['predictor'] == zc_1]['ground_truth'].to_list()[0])

            preds_1_disc = discretize(preds_1)
            preds_2_disc = discretize(preds_2)
            truths_disc = discretize(truths)

            try:
                length = len(preds_1)
                if length != len(preds_2) or length != len(truths):
                    print(zc_1, zc_2, len(zc_1), len(zc_2), 'prediction sizes do not match')
                    continue
            except:
                print(zc_1, preds_1)
                print(zc_2, preds_2)
            
            ent_1 = drv.entropy_conditional(preds_1_disc, truths_disc)
            ent_2 = drv.entropy_conditional(preds_2_disc, truths_disc)
            info = drv.information_mutual(preds_1_disc, preds_2_disc)

            corr = np.corrcoef(preds_1, preds_2)[0, 1]
            rank_1 = np.corrcoef(preds_1, truths)[0, 1]
            rank_2 = np.corrcoef(preds_2, truths)[0, 1]
            
            print(ent_1, ent_2, info, corr, rank_1, rank_2)

            # todo: compute H(truths | preds_1, preds_2)
            # bin preds_1 and preds_2 into two bins using this:
            # https://stackoverflow.com/questions/31635265/two-dimensional-np-digitize
            # e.g. bin by 5 and by 5, and then bin truths by 10

    # todo: return everything
    return corr

def plot_mutual_information(df, search_space, dataset):
    all_corrs = []
    for seed in [str(i) for i in range(START_SEED, END_SEED)]:
    
        df_ = df[(df['search_space'] == search_space) & (df['seed'] == seed) & (df['dataset'] == dataset)]
        
        predictors = tuple(set(df_['predictor'].to_list()))

        # todo: update everything starting here:
        correlation = compute_mutual_info(df_, predictors)
        all_corrs.append(correlation)

    print('all_corrs ::', np.array(all_corrs))
    corr = np.mean(np.array(all_corrs), axis=0)
    print('mean corrs ::', corr.shape)
    corr_df = pd.DataFrame(corr)

    predictors_map = {i: predictors[i] for i in range(len(predictors))}
    corr_df = corr_df.rename(columns=predictors_map, index=predictors_map)
    corr_df = sort_by_mean(corr_df)
    plot_heatmap(corr_df, title=f'{search_space}-{dataset}', figsize=(10,8),
                 savedir='correlation_between_zcs', savetitle=f'{search_space}-{dataset}')

In [7]:
df = make_clean_df()

# work in progress. This is not finished so it will have an error.
plot_mutual_information(df, 'nasbench201', 'cifar100')

../run/results/correlation/transbench101_micro/jigsaw/fisher/9000/scores.json
0.7633129027651413 0.7633129027651413 0.8413708617028743 1.0 0.1948839492532715 0.1948839492532715
0.7633129027651413 2.7339552187729756 0.11559218459442189 -0.040369496253427825 0.1948839492532715 0.006978605236854598
0.7633129027651413 2.195867581382206 0.18367221094354136 0.19580759943086798 0.1948839492532715 0.5291580890945606
0.7633129027651413 0.0 0.0012182038095069903 0.03879339631089701 0.1948839492532715 0.5030424840023606
0.7633129027651413 2.080884847583891 0.2116040979498579 0.20006448426885368 0.1948839492532715 0.537248637232189
0.7633129027651413 2.0249537054995583 0.2812946488590642 0.37352769469510394 0.1948839492532715 0.606842930301323
0.7633129027651413 0.0 0.0012182038095069903 0.038795732091742 0.1948839492532715 0.5030579142492154
0.7633129027651413 2.352752202787988 0.2377727703861483 0.1963394193397081 0.1948839492532715 -0.11651046732397231
0.7633129027651413 0.93651907967284 0.7269

2.352752202787988 0.0 0.01378834728601408 0.07171518502997325 -0.11651046732397231 0.5030424840023606
2.352752202787988 2.080884847583891 0.29373461795413425 -0.01669106967659032 -0.11651046732397231 0.537248637232189
2.352752202787988 2.0249537054995583 0.20693942933302267 -0.14282304461905013 -0.11651046732397231 0.606842930301323
2.352752202787988 0.0 0.01378834728601408 0.07171380044990146 -0.11651046732397231 0.5030579142492154
2.352752202787988 2.352752202787988 2.513052446359902 1.0 -0.11651046732397231 -0.11651046732397231
2.352752202787988 0.93651907967284 0.2655600569035417 0.1336469749544053 -0.11651046732397231 0.21940221934493473
2.352752202787988 0.17665580547438164 0.12993079610482328 0.32906247709842973 -0.11651046732397231 0.026118617988270868
2.352752202787988 0.2905023077239757 0.20781027975503052 -0.24318137679156848 -0.11651046732397231 0.09524006280596747
2.352752202787988 0.0 0.01378834728601408 0.07171469717978045 -0.11651046732397231 0.5030430324612416
0.936519

0.0 2.417157286641478 0.020585681315575775 0.13602206881634865 0.4165424619083835 0.7130744061610191
0.0 0.0 0.0454146923337941 1.0 0.4165424619083835 0.4165424619083835
0.0 2.316337859639688 0.010201780698645822 0.07791224672122508 0.4165424619083835 -0.1730905701746484
0.0 0.5909229538050067 0.0009245907862602473 0.03569129123123159 0.4165424619083835 0.20440885174384021
0.0 0.15288035046853277 0.0001461010067604837 0.008657173732373357 0.4165424619083835 0.03332603119880162
0.0 0.15288035046853232 0.0001461010067604837 0.0018428235235483405 0.4165424619083835 0.039547251754163064
0.0 0.0 0.0454146923337941 0.9999999988977317 0.4165424619083835 0.41651338800072635
2.316337859639688 0.48941364348995475 0.1546590682454818 0.27682470125319936 -0.1730905701746484 0.2054619403658783
2.316337859639688 2.6994699727797054 0.2640923426573041 -0.19379165511207047 -0.1730905701746484 0.011447312202517655
2.316337859639688 1.9905450477425422 0.24764431622948901 -0.1595991118872168 -0.17309057017

1.8590887210043618 0.16436832424514813 0.09021573094311686 0.03557243161430148 0.515783405014566 0.030397177463108056
1.8590887210043618 0.2630028360058305 0.11972132412702274 0.1367446944184946 0.515783405014566 0.07825110061695395
1.8590887210043618 2.680960120194447 0.23349500033058002 0.13314080495913105 0.515783405014566 0.363228234041411
1.7474774409260823 0.5336139960321287 0.3157517182429177 0.3375051815837034 0.6597174824836312 0.19272555087933235
1.7474774409260823 2.5970590982862563 0.19518484995755347 0.0006518269738367931 0.6597174824836312 0.01976385499285811
1.7474774409260823 1.9290527283492374 0.7911099243306543 0.6774124344824447 0.6597174824836312 0.5081834020216545
1.7474774409260823 1.8611733575183855 1.744274152442355 0.8564181679197711 0.6597174824836312 0.760432008408456
1.7474774409260823 1.8590887210043618 0.8664779353698853 0.6809680426146881 0.6597174824836312 0.515783405014566
1.7474774409260823 1.7474774409260823 2.4292548601214534 1.0 0.6597174824836312 0

0.2630028360058305 0.16436832424514813 0.19670913386885427 -0.5948950618554782 0.07825110061695395 0.030397177463108056
0.2630028360058305 0.2630028360058305 0.33229218908241476 1.0 0.07825110061695395 0.07825110061695395
0.2630028360058305 2.680960120194447 0.11538864080899758 0.0586477885940445 0.07825110061695395 0.363228234041411
2.680960120194447 0.5336139960321287 0.17771535377010794 0.033981316529650134 0.363228234041411 0.19272555087933235
2.680960120194447 2.5970590982862563 0.32056041317342565 -0.025442816822308208 0.363228234041411 0.01976385499285811
2.680960120194447 1.9290527283492374 0.2527316959524786 0.12333531484355195 0.363228234041411 0.5081834020216545
2.680960120194447 1.8611733575183855 0.34318868995313956 0.22152264024139956 0.363228234041411 0.760432008408456
2.680960120194447 1.8590887210043618 0.23349500033057913 0.13314080495913105 0.363228234041411 0.515783405014566
2.680960120194447 1.7474774409260823 0.27244717547444175 0.17636382409419005 0.3632282340414

0.25049187646827153 0.0 0.0007769345608970024 0.019274668955845092 0.09235466951513707 0.6420894865230224
0.25049187646827153 1.826242972023564 0.09101776510157122 0.18133096631758117 0.09235466951513707 0.5524037337490912
0.25049187646827153 2.215918863790807 0.0964678404511744 0.18125711132001968 0.09235466951513707 0.701084928514325
0.25049187646827153 0.0 0.0007769345608970024 0.019275193431592556 0.09235466951513707 0.6420963073101411
0.25049187646827153 2.4418467819087253 0.14687663594184908 0.08762739175852544 0.09235466951513707 -0.034081492601386786
0.25049187646827153 0.638523728562665 0.23094657807779673 0.8497544455843998 0.09235466951513707 0.2471857673665776
0.25049187646827153 0.2908155797912264 0.21446680387122252 0.8953994899336786 0.09235466951513707 0.11300660951693255
0.25049187646827153 0.25049187646827153 0.2833607113002098 1.0 0.09235466951513707 0.09235466951513707
0.25049187646827153 0.0 0.0007769345608970024 0.019274603414940715 0.09235466951513707 0.642089909

0.6351902280656607 2.618954708069075 0.17336788620570576 0.057394497796402565 0.23771250923486875 0.057538276574227464
0.6351902280656607 1.924204854789708 0.21904871013713412 0.2613596944154537 0.23771250923486875 0.5318157362242893
0.6351902280656607 1.8472072278784308 0.2929159888583841 0.367598160111236 0.23771250923486875 0.799454406993457
0.6351902280656607 1.7797857983650118 0.23529151278915283 0.2652323608148383 0.23771250923486875 0.5425751843886817
0.6351902280656607 1.7895549225019698 0.30733832106944026 0.4316530081931665 0.23771250923486875 0.7034860764052091
0.6351902280656607 0.2860524201344772 0.021253198024697828 0.14230228715499268 0.23771250923486875 0.710083338251127
0.6351902280656607 2.2627829544220175 0.27520180302074304 -0.10971979389613518 0.23771250923486875 -0.23487524782502942
0.6351902280656607 0.6351902280656607 0.7772865693692348 1.0 0.23771250923486875 0.23771250923486875
0.6351902280656607 0.2020293374199631 0.22439185783157622 0.8678878084684485 0.2377

ValueError: DataFrame constructor not properly called!

In [None]:


x = np.array((1, 2, 1, 2, 1, 2, 1, 2))
y = np.array((1, 2, 2, 2, 1, 2, 2, 2))
z = np.array((1, 2, 2, 2, 1, 2, 1, 1))
arr = np.array([x, y, z])
#drv.entropy(x)
drv.entropy_conditional(arr)
#drv.information_mutual(x, y)