# Figure 7 SM

In [2]:
import autodisc as ad
import random
import numpy as np
import collections
import os
import plotly
import plotly.graph_objs as go
plotly.offline.init_notebook_mode(connected=True)

In [3]:
# default print properties
multiplier = 2

pixel_cm_ration = 36.5

width_full = int(13.95 * pixel_cm_ration) * multiplier
width_half = int(13.95/2 * pixel_cm_ration) * multiplier

height_default_1 = int(4.5 * pixel_cm_ration) * multiplier
height_default_2 = int(7 * pixel_cm_ration) * multiplier

# margins in pixel
top_margin = 0 * multiplier 
left_margin = 10 * multiplier 
right_margin = 0 * multiplier 
bottom_margin = 10 * multiplier 

font_size = 8 * multiplier 
font_family='Times New Roman'

line_width = 2 * multiplier 

In [4]:
plotly.offline.init_notebook_mode(connected=True)

org_experiment_definitions = dict()

org_experiment_definitions['main_paper'] = [
    dict(id = '1',
         directory = '../experiments/IMGEP-VAE',
         name = 'IMGEP-VAE',
         is_default = True),

    dict(id = '2',
         directory = '../experiments/IMGEP-HOLMES',
         name = 'IMGEP-HOLMES',
         is_default = True),
]
repetition_ids = list(range(10))

# define names and load the data
experiment_name_format = '<name>' # <id>, <name>

#global experiment_definitions
experiment_definitions = []
experiment_statistics = []

current_experiment_list = 'main_paper'

experiment_definitions = []
for org_exp_def in org_experiment_definitions[current_experiment_list]:
    new_exp_def = dict()
    new_exp_def['directory'] = org_exp_def['directory']
    if 'is_default' in org_exp_def:
        new_exp_def['is_default'] = org_exp_def['is_default']

    if 'name' in org_exp_def:
        new_exp_def['id'] = ad.gui.jupyter.misc.replace_str_from_dict(experiment_name_format, {'id': org_exp_def['id'], 'name': org_exp_def['name']})
    else:
        new_exp_def['id'] = ad.gui.jupyter.misc.replace_str_from_dict(experiment_name_format, {'id': org_exp_def['id']})

    experiment_definitions.append(new_exp_def)

experiment_statistics = dict()
for experiment_definition in experiment_definitions:
    experiment_statistics[experiment_definition['id']] = dict()
    for repetition_idx in repetition_ids:
        experiment_statistics[experiment_definition['id']][repetition_idx] = ad.gui.jupyter.misc.load_statistics(os.path.join(experiment_definition['directory'], 'repetition_{:06d}'.format(repetition_idx)))
       

# Plotting

## RSA Matrix

In [5]:
def plot_goalspaces_RSAmatrix(RSA_matrix, config=None, **kwargs):
    
    default_config = dict(
            random_seed = 0,

            # global style config
            global_layout = dict(
                    
                    xaxis=dict(
                        showline = True,
                        linewidth = 1,
                        zeroline=False,
                        ticks = "",
                        tickfont = dict(
                            family=font_family,
                            size=12, 
                            ),
                    ),
                    yaxis=dict(
                        showline = True,
                        linewidth = 1,
                        zeroline=False,
                        ticks = "",
                        tickfont = dict(
                            family=font_family,
                            size=12, 
                            ),
                    ),
                    font = dict(
                        family=font_family, 
                        size=font_size, 
                        ),
                    width = width_half, # in cm
                    height = height_default_1 , # in cm

                    margin = dict(
                        l=left_margin, #left margin in pixel
                        r=right_margin, #right margin in pixel
                        b=bottom_margin, #bottom margin in pixel
                        t=top_margin,  #top margin in pixel
                        ),
                    title = "",
                    hovermode='closest',
                    showlegend =  True,
                ),
            
            colorscale = 'Viridis',
            showscale = False
            )
    
    config = ad.config.set_default_config(kwargs, config, default_config)
    random.seed(config.random_seed)
    
    n_goal_spaces = len(config.space_names)
    x = np.array(config.space_names)
    y = np.array(config.space_names)    
    z = np.asarray([[RSA_matrix[m_i, m_j] for m_i in range(n_goal_spaces)] for m_j in range(n_goal_spaces)])
    
    figure = dict(data=[go.Heatmap(x=x, y=y, z=z, colorscale=config.colorscale, showscale=config.showscale)], layout=config.global_layout)
    plotly.offline.iplot(figure)
    
    return figure


In [6]:
default_config = dict(
    plotly_format = 'svg',
    layout = dict(
        title = "RSA Similarity Before and After <br> Training Stage",
        xaxis = dict(
            range=[-1, 48],
            title = 'training stage'
            ),
        yaxis = dict(
            title = 'RSA similarity'
            ),
        font = dict(
            family=font_family, 
            size=font_size, 
            ),
        updatemenus=[],
        width=width_half, # in cm
        height=height_default_1, # in cm
        
        margin = dict(
            l=left_margin, #left margin in pixel
            r=right_margin, #right margin in pixel
            b=bottom_margin, #bottom margin in pixel
            t=top_margin,  #top margin in pixel
            ),

       showlegend = False,       
        ),
    
    default_mean_trace = dict(line=dict(width = line_width), mode='lines+markers'),
    
    std=dict(
            style='errorbar',
            visible=True
        ),
   
)

# IMGEP-VAE

In [7]:
RSA_VAE = experiment_statistics['IMGEP-VAE'][0]['temporal_RSA']

config = dict()
config["global_layout"] = dict()
config["global_layout"]["margin"] = dict(t=20*multiplier)
config["global_layout"]["xaxis"] = dict(title="training stages")
config["global_layout"]["yaxis"] = dict(title="training stages")
config["global_layout"]["margin"] = dict(l=20*multiplier, b=20*multiplier)
config["global_layout"]["width"] = width_half
config["global_layout"]["height"] = width_half
config["colorscale"] = 'Viridis'
config["space_names"] = [str(i) for i in range(1,50)]

fig = plot_goalspaces_RSAmatrix(RSA_VAE, config=config)
#plotly.io.write_image(fig, 'main_figure_3_VAE.pdf')

In [8]:
config = default_config
config["layout"]["margin"] = dict(l=30*multiplier, r=0*multiplier, t=20*multiplier, b=30*multiplier)
config["layout"]["xaxis"]["tickmode"] = 'array'
config["layout"]["xaxis"]["tickvals"] = [i for i in range(4,50,5)]
config["layout"]["xaxis"]["ticktext"] = [str(i) for i in range(5,50,5)]

n_stages = 49
DeltaT_similarity = np.zeros((len(repetition_ids), n_stages-1))
for repetition_idx in repetition_ids:
    cur_RSA_VAE = experiment_statistics['IMGEP-VAE'][repetition_idx]['temporal_RSA']
    for stage_idx in range(n_stages-1):
        DeltaT_similarity[repetition_idx, stage_idx] = cur_RSA_VAE[stage_idx+1, stage_idx]
fig = ad.gui.jupyter.plotly_meanstd_scatter(DeltaT_similarity, config=config)
#plotly.io.write_image(fig, 'sm_figure_7_VAE.pdf')


plotly.tools.make_subplots is deprecated, please use plotly.subplots.make_subplots instead



# IMGEP-HOLMES

In [9]:
RSA_HOLMES = experiment_statistics['IMGEP-HOLMES'][0]['holmes_RSA']
order_default = ['0', '00', '000', '0000', '00000', '00001', '0001', '00010', '00011', '001', '01', '010', '011', '0110', '01100', '01101', '0111', '01110', '01111', '011110', '0111100', '0111101', '011111']
order_desired = ['0', '00', '01', '010', '011', '000', '001', '0110', '0111', '0000', '0001', '01110', '01111', '011110', '011111',  '01100',  '01101', '00010', '00011', '00000', '00001', '0111100' , '0111101'] 

RSA_HOLMES_copy = RSA_HOLMES
permute_order = []
for i in order_desired:
    permute_order.append(order_default.index(i))

RSA_HOLMES = RSA_HOLMES[permute_order, :]
RSA_HOLMES = RSA_HOLMES[:, permute_order]

order_desired = ['BC 0', 'BC 00', 'BC 01', '<b>BC 010</b>', 'BC 011', 'BC 000', '<b>BC 001</b>', 'BC 0110', 'BC 0111', 'BC 0000', 'BC 0001', '<b>BC 01110</b>', 'BC 01111', 'BC 011110', '<b>BC 011111</b>', '<b>BC 01100</b>', '<b>BC 01101</b>', '<b>BC 00010</b>', '<b>BC 00011</b>', '<b>BC 00000</b>', '<b>BC 00001</b>', '<b>BC 0111100</b>', '<b>BC 0111101</b>']
config = dict()
config["global_layout"] = dict()
config["global_layout"]["xaxis"] = dict(title="modules")
config["global_layout"]["yaxis"] = dict(title="modules")
config["global_layout"]["margin"] = dict(l=45*multiplier, b=42*multiplier, r=40*multiplier)
config["global_layout"]["width"] = width_half + 40*multiplier
config["global_layout"]["height"] = width_half
config["colorscale"] = 'Viridis'
config["space_names"] = order_desired
config["showscale"] = True

fig = plot_goalspaces_RSAmatrix(RSA_HOLMES, config=config)
#plotly.io.write_image(fig, 'main_figure_3_HOLMES.pdf')

In [10]:
config = default_config
config["layout"]["margin"] = dict(l=25*multiplier, r=0*multiplier, t=0*multiplier, b=20*multiplier)
config["layout"]["title"] = dict(text="RSA Similarity Between <br> Pairs of Goal Spaces", yanchor="top", y=0.9)
config["layout"]["height"] = height_default_1 
config["layout"]["xaxis"]["range"]=[-0.01,1.04]
config["layout"]["xaxis"]["title"]="RSA similarity"
config["layout"]["yaxis"]["title"]="# goal space pairs"
config["layout"]["xaxis"]["tickmode"] = 'array'
config["layout"]["xaxis"]["tickvals"] = [i for i in np.linspace(0,1,11)]
config["layout"]["xaxis"]["ticktext"] = ["{:.1f}".format(i) for i in np.linspace(0,1,11)]
all_vals = []
for repetition_idx in repetition_ids[:-1]:
        cur_RSA = experiment_statistics['IMGEP-HOLMES'][repetition_idx]['holmes_RSA']
        for i in range(cur_RSA.shape[0]):
            for j in range(i+1, cur_RSA.shape[1]):
                all_vals.append(cur_RSA[i,j])
        
all_vals = np.asarray(all_vals)

        
fig=go.Figure(data=[go.Histogram(x=all_vals)], layout=config["layout"])
plotly.offline.iplot(fig)
#plotly.io.write_image(fig, 'sm_figure_7_HOLMES.pdf')