In [1]:
import matplotlib
%matplotlib notebook
import matplotlib.pyplot as plt
import numpy as np
import os
import json

In [2]:
def visualize_case(files, title):
    plt.figure(figsize=(8,6))
        
    datas = []
    sizes = [2**(3+i) for i in range(len(files))]

    for ns in sizes:
        found = None
        for path in files:
            if 'train_cnt{}'.format(ns) in path:
                found = path
        with open(found, 'r') as f:
            data = json.load(f)
            datas.append(data)
   
    keys = sorted(datas[0].keys())
    legend = []
    
    show_methods = [
        'Ground Truth',
        # 'T-GLASSO (no reg)',
        'T-GLASSO',
#         'T-GLASSO (more iters)',
        'Linear CorEx (applied bucket-wise)',
#         'Linear CorEx (applied on whole data)',
        'T-Corex (W)',
#         'T-Corex + priors (W, method 1)',
#         'T-Corex + priors (W, method 2)',
#         'T-Corex + priors (W, method 2, weighted samples)',
        'T-Corex (W, weighted samples)',
#         'T-Corex (W, weighted samples, no init)',
    ]
    
    name_map = {
        # 'T-GLASSO (no reg)': 'GLASSO',
        # 'Linear CorEx (applied bucket-wise)': 'Linear CorEX',
        # 'T-Corex + priors (W, method 2)': 'T-CorEx + priors',
        # 'T-Corex (W, weighted samples)': 'T-CorEx + weights'
    }
    
    for k in show_methods:
        mas = [data[k]['mean'] for data in datas]
        legend.append(k)
        plt.plot(mas)
    
    legend = [name_map[x] if x in name_map else x for x in legend]
    plt.legend(legend)
    plt.xticks(range(len(sizes)), sizes)
    plt.title(title)
    plt.xlabel('number of samples')
    plt.ylabel('average negative log-likelihood')
    plt.show()

In [3]:
def bar_plot_baselines(jsonfile, title):
    with open(jsonfile) as f:
        data = json.load(f)
    keys = data.keys()
    keys = sorted(keys)
    plt.figure(figsize=(9,7))
    values = [data[k]['mean'] for k in keys]
    values = [0 if np.isinf(x) else x for x in values]
    plt.bar(range(len(keys)), values, width=0.6)
    plt.ylim(0.9*data['Ground Truth']['mean'])
    plt.xticks(range(len(keys)), keys, rotation=90)
    for i, v in enumerate(values):
        plt.text(i - 0.4, v+1, '{:.2f}'.format(v))
    plt.tight_layout()
    plt.title(title)
    plt.subplots_adjust(top=0.9)
    plt.show()
        

In [4]:
A1_C1 = [
    'results/syn_nglf_buckets.nt10.m8.bs16.train_cnt16.val_cnt16.test_cnt100.snr5.00.results.json',
    'results/syn_nglf_buckets.nt10.m8.bs16.train_cnt32.val_cnt16.test_cnt100.snr5.00.results.json',
    'results/syn_nglf_buckets.nt10.m8.bs16.train_cnt8.val_cnt16.test_cnt100.snr5.00.results.json',
    'results/syn_nglf_buckets.nt10.m8.bs16.train_cnt64.val_cnt16.test_cnt100.snr5.00.results.json',
    'results/syn_nglf_buckets.nt10.m8.bs16.train_cnt128.val_cnt16.test_cnt100.snr5.00.results.json'
]

In [5]:
A2_C1 = [
    'results/syn_nglf_buckets.nt10.m32.bs4.train_cnt16.val_cnt16.test_cnt100.snr5.00.results.json',
    'results/syn_nglf_buckets.nt10.m32.bs4.train_cnt32.val_cnt16.test_cnt100.snr5.00.results.json',
    'results/syn_nglf_buckets.nt10.m32.bs4.train_cnt8.val_cnt16.test_cnt100.snr5.00.results.json',
    'results/syn_nglf_buckets.nt10.m32.bs4.train_cnt64.val_cnt16.test_cnt100.snr5.00.results.json',
    'results/syn_nglf_buckets.nt10.m32.bs4.train_cnt128.val_cnt16.test_cnt100.snr5.00.results.json'
]

In [6]:
A3_C1 = [
    'results/syn_nglf_buckets.nt10.m4.bs32.train_cnt16.val_cnt16.test_cnt100.snr5.00.results.json',
    'results/syn_nglf_buckets.nt10.m4.bs32.train_cnt8.val_cnt16.test_cnt100.snr5.00.results.json',
    'results/syn_nglf_buckets.nt10.m4.bs32.train_cnt32.val_cnt16.test_cnt100.snr5.00.results.json',
    'results/syn_nglf_buckets.nt10.m4.bs32.train_cnt64.val_cnt16.test_cnt100.snr5.00.results.json',
    'results/syn_nglf_buckets.nt10.m4.bs32.train_cnt128.val_cnt16.test_cnt100.snr5.00.results.json'
]

In [7]:
A1_C3 = [
    'results/syn_nglf_ts.nt30.m8.bs16.train_cnt16.val_cnt16.test_cnt100.snr5.00.results.json',
    'results/syn_nglf_ts.nt30.m8.bs16.train_cnt32.val_cnt16.test_cnt100.snr5.00.results.json',
    'results/syn_nglf_ts.nt30.m8.bs16.train_cnt8.val_cnt16.test_cnt100.snr5.00.results.json'
]

In [8]:
A2_C3 = [
    'results/syn_nglf_ts.nt30.m32.bs4.train_cnt16.val_cnt16.test_cnt100.snr5.00.results.json',
    'results/syn_nglf_ts.nt30.m32.bs4.train_cnt32.val_cnt16.test_cnt100.snr5.00.results.json',
    'results/syn_nglf_ts.nt30.m32.bs4.train_cnt8.val_cnt16.test_cnt100.snr5.00.results.json'
]

In [9]:
syn_smooth_buckets_A1 = [
    'results/syn_nglf_buckets_smooth.nt10.m8.bs16.train_cnt16.val_cnt16.test_cnt100.snr5.00.results.json',
    'results/syn_nglf_buckets_smooth.nt10.m8.bs16.train_cnt32.val_cnt16.test_cnt100.snr5.00.results.json',
    'results/syn_nglf_buckets_smooth.nt10.m8.bs16.train_cnt8.val_cnt16.test_cnt100.snr5.00.results.json',
    'results/syn_nglf_buckets_smooth.nt10.m8.bs16.train_cnt64.val_cnt16.test_cnt100.snr5.00.results.json',
    'results/syn_nglf_buckets_smooth.nt10.m8.bs16.train_cnt128.val_cnt16.test_cnt100.snr5.00.results.json'
]

In [10]:
syn_smooth_buckets_A2 = [
    'results/syn_nglf_buckets_smooth.nt10.m32.bs4.train_cnt16.val_cnt16.test_cnt100.snr5.00.results.json',
    'results/syn_nglf_buckets_smooth.nt10.m32.bs4.train_cnt32.val_cnt16.test_cnt100.snr5.00.results.json',
    'results/syn_nglf_buckets_smooth.nt10.m32.bs4.train_cnt8.val_cnt16.test_cnt100.snr5.00.results.json',
    'results/syn_nglf_buckets_smooth.nt10.m32.bs4.train_cnt64.val_cnt16.test_cnt100.snr5.00.results.json',
    'results/syn_nglf_buckets_smooth.nt10.m32.bs4.train_cnt128.val_cnt16.test_cnt100.snr5.00.results.json'
]

In [11]:
syn_smooth_buckets_A3 = [
    'results/syn_nglf_buckets_smooth.nt10.m4.bs32.train_cnt16.val_cnt16.test_cnt100.snr5.00.results.json',
    'results/syn_nglf_buckets_smooth.nt10.m4.bs32.train_cnt8.val_cnt16.test_cnt100.snr5.00.results.json',
    'results/syn_nglf_buckets_smooth.nt10.m4.bs32.train_cnt32.val_cnt16.test_cnt100.snr5.00.results.json',
    'results/syn_nglf_buckets_smooth.nt10.m4.bs32.train_cnt64.val_cnt16.test_cnt100.snr5.00.results.json',
    'results/syn_nglf_buckets_smooth.nt10.m4.bs32.train_cnt128.val_cnt16.test_cnt100.snr5.00.results.json'
]

In [12]:
# visualize_case(A1_C1, 'm=8, bs=16 buckets')
visualize_case(A1_C1, 'm=8, bs=16')

<IPython.core.display.Javascript object>

In [13]:
# visualize_case(A2_C1, 'm=32, bs=4 buckets')
visualize_case(A2_C1, 'm=32, bs=4')

<IPython.core.display.Javascript object>

In [14]:
# visualize_case(A3_C1, 'm=4, bs=32 buckets')
visualize_case(A3_C1, 'm=4, bs=32')

<IPython.core.display.Javascript object>

In [15]:
visualize_case(A1_C3, 'm=8, bs=16 time-series')

<IPython.core.display.Javascript object>

KeyError: 'T-GLASSO'

In [16]:
visualize_case(A2_C3, 'm=32, bs=4 time-series')

<IPython.core.display.Javascript object>

KeyError: 'T-GLASSO'

In [17]:
visualize_case(syn_smooth_buckets_A1, 'm=8, bs=16 smooth buckets')

<IPython.core.display.Javascript object>

In [18]:
visualize_case(syn_smooth_buckets_A2, 'm=32, bs=4 smooth buckets')

<IPython.core.display.Javascript object>

In [19]:
visualize_case(syn_smooth_buckets_A3, 'm=4, bs=32 smooth buckets')

<IPython.core.display.Javascript object>

In [20]:
bar_plot_baselines(A1_C3[0], "m=8,bs=16 time-series")

<IPython.core.display.Javascript object>

In [39]:
bar_plot_baselines(A2_C3[1], "m=32,bs=4 time-series")

<IPython.core.display.Javascript object>

In [46]:
A1_C1_change_structure = [
    'results/change_structure.syn_nglf_buckets.nt10.m8.bs16.train_cnt16.val_cnt16.test_cnt100.min_cor0.60.max_cor1.00.results.json',
    'results/change_structure.syn_nglf_buckets.nt10.m8.bs16.train_cnt32.val_cnt16.test_cnt100.min_cor0.60.max_cor1.00.results.json',
    'results/change_structure.syn_nglf_buckets.nt10.m8.bs16.train_cnt8.val_cnt16.test_cnt100.min_cor0.60.max_cor1.00.results.json'
]

In [47]:
visualize_case(A1_C1_change_structure, 'm=8, bs=16 buckets (without fixing SNR)')

<IPython.core.display.Javascript object>

# draft

In [64]:
bar_plot_baselines('results/syn_nglf_buckets.nt10.m8.bs16.train_cnt8.val_cnt16.test_cnt100.snr5.00.results.json',
                   'm=8,bs=16 buckets')

<IPython.core.display.Javascript object>