In [1]:
# The point of this script is to assess the convergence of the traces

# Import packages
import hddm
import pandas as pd
import os
import re
import kabuki
import matplotlib.pyplot as plt
import seaborn as sns
import glob
import fnmatch

def obtain_traces(name):
    
    # Pre-allocate 
    models = []

    pattern = name + '_z_5percent_outlier_removal[1-5]*'
    exclude_pattern = '*.db'
    list_of_models = [f for f in os.listdir('.') if fnmatch.fnmatch(f, pattern) and not fnmatch.fnmatch(f, exclude_pattern)]

    # Load the models and save them in the list
    print(list_of_models)

    print(f"finished loading, we have {len(list_of_models)} files")
    
    models = [hddm.load(file) for file in list_of_models]

    # Get Gelman Ruby stats and save to a file
    gelman = kabuki.analyze.gelman_rubin(models)
    gelman_df = pd.DataFrame.from_dict(gelman, orient='index', columns=['Value'])
    gelman_df.to_csv("gelman_" + name + '.csv')

    # Check for any deviations in convergence
    if sum(gelman_df['Value'] < 0.99)!=0 | sum(gelman_df['Value'] > 1.01)!=0:
        print(f"Bad convergence for {name}")
    else:
        print(f"Good convergence for {name}")
        
    for model in models:
        traces = model.get_group_traces().reset_index()

    # Concatenate the models
    concat_model = kabuki.utils.concat_models(models)

    # Save traces to csv
    traces = concat_model.get_group_traces()
    traces['Paradigm'] = name
    traces.to_csv(name + '_group_traces.csv')


# List of paradigms     
names_list = ["m133_cyl", "m134_cyl", "m133_RDK", "m134_RDK"]

# Run functions for humans for all 4 tasks 
for paradigm in range(0, 4):
    name = names_list[paradigm]
    obtain_traces(name)

  warn("The `IPython.parallel` package has been deprecated since IPython 4.0. "


['m133_cyl_z_5percent_outlier_removal3', 'm133_cyl_z_5percent_outlier_removal2', 'm133_cyl_z_5percent_outlier_removal4', 'm133_cyl_z_5percent_outlier_removal5', 'm133_cyl_z_5percent_outlier_removal1']
finished loading, we have 5 files
Good convergence for m133_cyl
['m134_cyl_z_5percent_outlier_removal5', 'm134_cyl_z_5percent_outlier_removal3', 'm134_cyl_z_5percent_outlier_removal1', 'm134_cyl_z_5percent_outlier_removal4', 'm134_cyl_z_5percent_outlier_removal2']
finished loading, we have 5 files
Good convergence for m134_cyl
['m133_RDK_z_5percent_outlier_removal3', 'm133_RDK_z_5percent_outlier_removal2', 'm133_RDK_z_5percent_outlier_removal4', 'm133_RDK_z_5percent_outlier_removal1', 'm133_RDK_z_5percent_outlier_removal5']
finished loading, we have 5 files
Good convergence for m133_RDK
['m134_RDK_z_5percent_outlier_removal2', 'm134_RDK_z_5percent_outlier_removal3', 'm134_RDK_z_5percent_outlier_removal1', 'm134_RDK_z_5percent_outlier_removal5', 'm134_RDK_z_5percent_outlier_removal4']
fini