In [1]:
# Load packages
import tensorflow as tf
from tensorflow import keras
import numpy as np
import pandas as pd
import os
import pickle
import time
import scipy as scp
import scipy.stats as scps
from scipy.optimize import differential_evolution
from scipy.optimize import minimize
from datetime import datetime
import matplotlib.pyplot as plt
from sklearn.linear_model import LinearRegression

# Load my own functions
import dnnregressor_train_eval_keras as dnnk
import make_data_wfpt as mdw
from kde_training_utilities import kde_load_data
import ddm_data_simulation as ddm_sim
import boundary_functions as bf

In [2]:
# Load Model
model_path = '/home/afengler/git_repos/nn_likelihoods/keras_models/dnnregressor_kde_ddm_weibull_06_05_19_14_07_16/model_0' 
ckpt_path = '/home/afengler/git_repos/nn_likelihoods/keras_models/dnnregressor_kde_ddm_weibull_06_05_19_14_07_16/ckpt_0_final'
model = keras.models.load_model(model_path)
model.load_weights(ckpt_path)

<tensorflow.python.training.checkpointable.util.CheckpointLoadStatus at 0x7f1d5334a278>

In [8]:
# Initializations -----
n_runs = 300
n_samples = 2500


# NOTE PARAMETERS: WEIBULL: [v, a, w, node, shape, scale]
param_bounds = [(-2.5, 2.5), (0.5, 4), (0.15, 0.85), (0, 5), (1.01, 49), (0.01, 9.9)]


my_optim_columns = ['v_sim', 'a_sim', 'w_sim', 'node_sim', 'shape_sim', 'scale_sim',
                    'v_mle', 'a_mle', 'w_mle', 'node_mle', 'shape_mle', 'scale_mle', 'n_samples']

# Get parameter names in correct ordering:
dat = pickle.load(open('/home/afengler/git_repos/nn_likelihoods/data_storage/kde/weibull/train_test_data/test_features.pickle' , 
                       'rb'))

parameter_names = list(dat.keys())[:-2] # :-1 to get rid of 'rt' and 'choice' here

# Make columns for optimizer result table
p_sim = []
p_mle = []

for parameter_name in parameter_names:
    p_sim.append(parameter_name + '_sim')
    p_mle.append(parameter_name + '_mle')
    
my_optim_columns = p_sim + p_mle + ['n_samples']

# Initialize the data frame in which to store optimizer results
optim_results = pd.DataFrame(np.zeros((n_runs, len(my_optim_columns))), columns = my_optim_columns)
optim_results.iloc[:, 2 * len(parameter_names)] = n_samples

# define boundary
boundary = bf.weibull_bnd
boundary_multiplicative = False

# Define the likelihood function
def log_p(params = [0, 1, 0.9], model = [], data = [], parameter_names = []):
    # Make feature array
    feature_array = np.zeros((data[0].shape[0], len(parameter_names) + 2))
    
    # Store parameters
    cnt = 0
    for i in range(0, len(parameter_names), 1):
        feature_array[:, i] = params[i]
        cnt += 1
    
    # Store rts and choices
    feature_array[:, cnt] = data[0].ravel() # rts
    feature_array[:, cnt + 1] = data[1].ravel() # choices
    
    # Get model predictions
    prediction = model.predict(feature_array)
    
    # Some post-processing of predictions
    prediction[prediction < 1e-29] = 1e-29
    
    return(- np.sum(np.log(prediction)))  

def make_params(param_bounds = []):
    params = np.zeros(len(param_bounds))
    
    for i in range(len(params)):
        params[i] = np.random.uniform(low = param_bounds[i][0], high = param_bounds[i][1])
        
    return
# ---------------------

In [None]:
# Main loop ----------- TD: Parallelize
for i in range(0, n_runs, 1): 
    
    # Get start time
    start_time = time.time()
    
#     # Sample parameters
#     v_sim = np.random.uniform(high = v_range[1], low = v_range[0])
#     a_sim = np.random.uniform(high = a_range[1], low = a_range[0])
#     w_sim = np.random.uniform(high = w_range[1], low = w_range[0])

#     #c1_sim = np.random.uniform(high = c1_range[1], low = c1_range[0])
#     #c2_sim = np.random.uniform(high = c2_range[1], low = c2_range[0])
#     node_sim = np.random.uniform(high = node_range[1], low = node_range[0])
#     shape_sim = np.random.uniform(high = shape_range[1], low = shape_range[0])
#     scale_sim = np.random.uniform(high = scale_range[1], low = scale_range[0])

    tmp_params = make_params(param_bounds = param_bounds)
    
    # Store in output file
    optim_results.iloc[i, :len(parameter_names)] = tmp_params
    
    # Print some info on run
    print('Parameters for run ' + str(i) + ': ')
    print([v_sim, a_sim, w_sim, node_sim, shape_sim, scale_sim])
    
    # Run model simulations
    ddm_dat_tmp = ddm_sim.ddm_flexbound_simulate(v = tmp_params[0],
                                                 a = tmp_params[1],
                                                 w = tmp_params[2],
                                                 s = 1,
                                                 delta_t = 0.001,
                                                 max_t = 20,
                                                 n_samples = n_samples,
                                                 boundary_fun = boundary, # function of t (and potentially other parameters) that takes in (t, *args)
                                                 boundary_multiplicative = boundary_multiplicative, # CAREFUL: CHECK IF BOUND
                                                 boundary_params = {'node': tmp_params[3], 
                                                                    'shape': tmp_params[4],
                                                                    'scale': tmp_params[5]})
        
    # Print some info on run
    print('Mean rt for current run: ')
    print(np.mean(ddm_dat_tmp[0]))
    
    # Run optimizer
    out = differential_evolution(log_p, 
                                 bounds = [(v_range[0],v_range[1]), 
                                           (a_range[0], a_range[1]), 
                                           (w_range[0], w_range[1]),
                                           (node_range[0], node_range[1]),
                                           (shape_range[0], shape_range[1]),
                                           (scale_range[0], scale_range[1])], 
                                 args = (model, ddm_dat_tmp, parameter_names), 
                                 popsize = 30,
                                 disp = True)
    
    # Print some info
    print('Solution vector of current run: ')
    print(out.x)
    
    print('The run took: ')
    elapsed_time = time.time() - start_time
    print(time.strftime("%H:%M:%S", time.gmtime(elapsed_time)))
    
    # Store result in output file
    optim_results.iloc[i, len(parameter_names):(2*len(parameter_names))] = out.x
# -----------------------

# Save optimization results to file
optim_results.to_csv(os.getcwd() + '/experiments/kde_ddm_weibull_mle/optim_results.csv')

Parameters for run 0: 
[-0.7830763824720766, 1.5490168447634234, 0.19819589306034507, 0.48161376192033867, 6.756848157144321, 5.01930509634567]
0  datapoints sampled
1000  datapoints sampled
2000  datapoints sampled
Mean rt for current run: 
0.7186211999999851
differential_evolution step 1: f(x)= 1732.9
differential_evolution step 2: f(x)= 1717.6
differential_evolution step 3: f(x)= 1637.73
differential_evolution step 4: f(x)= 1637.73
differential_evolution step 5: f(x)= 1621.76
differential_evolution step 6: f(x)= 1621.76
differential_evolution step 7: f(x)= 1607.92
differential_evolution step 8: f(x)= 1599.18
differential_evolution step 9: f(x)= 1485.39
differential_evolution step 10: f(x)= 1485.39
differential_evolution step 11: f(x)= 1485.39
differential_evolution step 12: f(x)= 1485.39
differential_evolution step 13: f(x)= 1442.56
differential_evolution step 14: f(x)= 1442.56
differential_evolution step 15: f(x)= 1442.56
differential_evolution step 16: f(x)= 1442.56
differential_e

differential_evolution step 23: f(x)= -1756.3
differential_evolution step 24: f(x)= -1756.3
differential_evolution step 25: f(x)= -1756.3
differential_evolution step 26: f(x)= -1850.62
differential_evolution step 27: f(x)= -1850.62
differential_evolution step 28: f(x)= -1923.64
differential_evolution step 29: f(x)= -1923.64
differential_evolution step 30: f(x)= -1923.64
differential_evolution step 31: f(x)= -1923.64
differential_evolution step 32: f(x)= -1923.64
differential_evolution step 33: f(x)= -1923.64
differential_evolution step 34: f(x)= -2185.28
differential_evolution step 35: f(x)= -2185.28
differential_evolution step 36: f(x)= -2185.28
differential_evolution step 37: f(x)= -2185.28
differential_evolution step 38: f(x)= -2185.28
differential_evolution step 39: f(x)= -2185.28
differential_evolution step 40: f(x)= -2185.28
differential_evolution step 41: f(x)= -2185.28
differential_evolution step 42: f(x)= -2185.28
differential_evolution step 43: f(x)= -2210.2
differential_evol

differential_evolution step 5: f(x)= 5539.94
differential_evolution step 6: f(x)= 5539.94
differential_evolution step 7: f(x)= 5539.94
differential_evolution step 8: f(x)= 5539.94
differential_evolution step 9: f(x)= 5276.33
differential_evolution step 10: f(x)= 5276.33
differential_evolution step 11: f(x)= 5276.33
differential_evolution step 12: f(x)= 5276.33
differential_evolution step 13: f(x)= 5246.02
differential_evolution step 14: f(x)= 5246.02
differential_evolution step 15: f(x)= 5246.02
differential_evolution step 16: f(x)= 5246.02
differential_evolution step 17: f(x)= 5215.33
differential_evolution step 18: f(x)= 5215.33
differential_evolution step 19: f(x)= 5215.33
differential_evolution step 20: f(x)= 5215.33
differential_evolution step 21: f(x)= 5213.47
differential_evolution step 22: f(x)= 5213.47
differential_evolution step 23: f(x)= 5213.47
differential_evolution step 24: f(x)= 5192.03
differential_evolution step 25: f(x)= 5179.86
differential_evolution step 26: f(x)= 5

differential_evolution step 19: f(x)= 6313.78
differential_evolution step 20: f(x)= 6313.78
differential_evolution step 21: f(x)= 6313.78
differential_evolution step 22: f(x)= 6313.78
differential_evolution step 23: f(x)= 6313.78
differential_evolution step 24: f(x)= 6313.78
differential_evolution step 25: f(x)= 6313.78
differential_evolution step 26: f(x)= 6313.78
differential_evolution step 27: f(x)= 6313.78
differential_evolution step 28: f(x)= 6313.78
differential_evolution step 29: f(x)= 6313.78
differential_evolution step 30: f(x)= 6313.78
Solution vector of current run: 
[ 0.46895685  2.4385509   0.5743943   4.55674727 48.26313149  3.99761062]
The run took: 
00:07:08
Parameters for run 9: 
[-1.1313853591120977, 3.3642806424246063, 0.4378170405920663, 0.010048813190230987, 12.579603245471427, 7.47555053347391]
0  datapoints sampled
1000  datapoints sampled
2000  datapoints sampled
Mean rt for current run: 
2.6337307999999005
differential_evolution step 1: f(x)= 3868.06
differenti

differential_evolution step 89: f(x)= 572.086
differential_evolution step 90: f(x)= 572.086
differential_evolution step 91: f(x)= 572.086
differential_evolution step 92: f(x)= 570.215
differential_evolution step 93: f(x)= 570.215
differential_evolution step 94: f(x)= 570.215
differential_evolution step 95: f(x)= 570.215
differential_evolution step 96: f(x)= 570.215
Solution vector of current run: 
[-2.23360401  2.689136    0.40255245  0.          1.11125081  0.11115329]
The run took: 
00:12:22
Parameters for run 11: 
[0.8036506318681691, 3.0895432719830143, 0.8235703748561655, 4.9603481899049005, 45.27265242964932, 2.055496867925082]
0  datapoints sampled
1000  datapoints sampled
2000  datapoints sampled
Mean rt for current run: 
1.359040799999969
differential_evolution step 1: f(x)= 2896.91
differential_evolution step 2: f(x)= 2888.08
differential_evolution step 3: f(x)= 2861.25
differential_evolution step 4: f(x)= 2697.72
differential_evolution step 5: f(x)= 2697.72
differential_evol

differential_evolution step 79: f(x)= 2335.66
differential_evolution step 80: f(x)= 2335.66
differential_evolution step 81: f(x)= 2335.66
differential_evolution step 82: f(x)= 2335.66
differential_evolution step 83: f(x)= 2335.66
differential_evolution step 84: f(x)= 2335.66
differential_evolution step 85: f(x)= 2335.66
differential_evolution step 86: f(x)= 2335.66
differential_evolution step 87: f(x)= 2335.66
differential_evolution step 88: f(x)= 2334.89
Solution vector of current run: 
[-1.29766878  1.72155459  0.53983242  0.00457687  1.11758388  0.13237019]
The run took: 
00:11:52
Parameters for run 13: 
[0.4025859148344031, 2.044230030224166, 0.5966915884154042, 1.589312151257121, 39.077639904940234, 5.02264046718278]
0  datapoints sampled
1000  datapoints sampled
2000  datapoints sampled
Mean rt for current run: 
3.0150963999999525
differential_evolution step 1: f(x)= 6341
differential_evolution step 2: f(x)= 6144.71
differential_evolution step 3: f(x)= 6123.63
differential_evolut

0  datapoints sampled
1000  datapoints sampled
2000  datapoints sampled
Mean rt for current run: 
0.33310239999999947
differential_evolution step 1: f(x)= -123.215
differential_evolution step 2: f(x)= -123.215
differential_evolution step 3: f(x)= -123.215
differential_evolution step 4: f(x)= -123.215
differential_evolution step 5: f(x)= -127.025
differential_evolution step 6: f(x)= -127.025
differential_evolution step 7: f(x)= -127.025
differential_evolution step 8: f(x)= -127.025
differential_evolution step 9: f(x)= -326.449
differential_evolution step 10: f(x)= -326.449
differential_evolution step 11: f(x)= -326.449
differential_evolution step 12: f(x)= -492.355
differential_evolution step 13: f(x)= -492.355
differential_evolution step 14: f(x)= -536.666
differential_evolution step 15: f(x)= -736.37
differential_evolution step 16: f(x)= -752.631
differential_evolution step 17: f(x)= -762.282
differential_evolution step 18: f(x)= -762.282
differential_evolution step 19: f(x)= -787.195

differential_evolution step 46: f(x)= 2438.37
differential_evolution step 47: f(x)= 2438.37
differential_evolution step 48: f(x)= 2438.37
differential_evolution step 49: f(x)= 2432.05
differential_evolution step 50: f(x)= 2432.05
differential_evolution step 51: f(x)= 2432.05
differential_evolution step 52: f(x)= 2432.05
differential_evolution step 53: f(x)= 2432.05
Solution vector of current run: 
[0.68757305 0.8250717  0.21107893 0.00269538 1.12497057 0.13065326]
The run took: 
00:07:34
Parameters for run 20: 
[-0.007831967113009597, 3.118972076359025, 0.6055396285989055, 1.1994108340304133, 12.835839015739035, 8.200486233628403]
0  datapoints sampled
1000  datapoints sampled
2000  datapoints sampled
Mean rt for current run: 
8.798052799999827
differential_evolution step 1: f(x)= 9509.65
differential_evolution step 2: f(x)= 9356.69
differential_evolution step 3: f(x)= 9340.84
differential_evolution step 4: f(x)= 9278.89
differential_evolution step 5: f(x)= 9278.89
differential_evoluti

differential_evolution step 17: f(x)= 2215.94
differential_evolution step 18: f(x)= 2215.94
differential_evolution step 19: f(x)= 2215.94
differential_evolution step 20: f(x)= 2215.94
differential_evolution step 21: f(x)= 2215.94
differential_evolution step 22: f(x)= 2215.94
differential_evolution step 23: f(x)= 2215.94
differential_evolution step 24: f(x)= 2215.94
differential_evolution step 25: f(x)= 2215.94
differential_evolution step 26: f(x)= 2215.94
differential_evolution step 27: f(x)= 2215.94
differential_evolution step 28: f(x)= 2203.7
differential_evolution step 29: f(x)= 2203.7
differential_evolution step 30: f(x)= 2203.7
differential_evolution step 31: f(x)= 2203.7
differential_evolution step 32: f(x)= 2174.55
differential_evolution step 33: f(x)= 2174.55
differential_evolution step 34: f(x)= 2174.55
differential_evolution step 35: f(x)= 2174.55
differential_evolution step 36: f(x)= 2174.55
differential_evolution step 37: f(x)= 2174.55
differential_evolution step 38: f(x)= 

differential_evolution step 7: f(x)= 4937.48
differential_evolution step 8: f(x)= 4937.48
differential_evolution step 9: f(x)= 4937.48
differential_evolution step 10: f(x)= 4937.48
differential_evolution step 11: f(x)= 4937.48
differential_evolution step 12: f(x)= 4937.48
differential_evolution step 13: f(x)= 4937.48
differential_evolution step 14: f(x)= 4937.48
differential_evolution step 15: f(x)= 4934.05
differential_evolution step 16: f(x)= 4934.05
differential_evolution step 17: f(x)= 4934.05
differential_evolution step 18: f(x)= 4934.05
differential_evolution step 19: f(x)= 4934.05
differential_evolution step 20: f(x)= 4922.02
differential_evolution step 21: f(x)= 4922.02
differential_evolution step 22: f(x)= 4922.02
differential_evolution step 23: f(x)= 4922.02
differential_evolution step 24: f(x)= 4909.3
differential_evolution step 25: f(x)= 4909.3
differential_evolution step 26: f(x)= 4904
differential_evolution step 27: f(x)= 4904
differential_evolution step 28: f(x)= 4904
di

differential_evolution step 57: f(x)= 18.8927
differential_evolution step 58: f(x)= 18.8927
differential_evolution step 59: f(x)= 18.8927
differential_evolution step 60: f(x)= 18.8927
differential_evolution step 61: f(x)= 18.8012
differential_evolution step 62: f(x)= 10.6182
differential_evolution step 63: f(x)= 10.6182
differential_evolution step 64: f(x)= 9.91019
differential_evolution step 65: f(x)= 9.91019
differential_evolution step 66: f(x)= 9.91019
differential_evolution step 67: f(x)= 9.91019
differential_evolution step 68: f(x)= 9.91019
differential_evolution step 69: f(x)= 9.91019
differential_evolution step 70: f(x)= 1.69285
differential_evolution step 71: f(x)= 1.69285
differential_evolution step 72: f(x)= 1.69285
differential_evolution step 73: f(x)= 1.69285
differential_evolution step 74: f(x)= 1.69285
differential_evolution step 75: f(x)= 1.69285
differential_evolution step 76: f(x)= -3.04965
differential_evolution step 77: f(x)= -3.04965
differential_evolution step 78: 

0  datapoints sampled
1000  datapoints sampled
2000  datapoints sampled
Mean rt for current run: 
2.367046799999898
differential_evolution step 1: f(x)= 4601.54
differential_evolution step 2: f(x)= 4426.03
differential_evolution step 3: f(x)= 4426.03
differential_evolution step 4: f(x)= 4426.03
differential_evolution step 5: f(x)= 4071.22
differential_evolution step 6: f(x)= 4071.22
differential_evolution step 7: f(x)= 4071.22
differential_evolution step 8: f(x)= 4046.12
differential_evolution step 9: f(x)= 3998.8
differential_evolution step 10: f(x)= 3998.8
differential_evolution step 11: f(x)= 3998.8
differential_evolution step 12: f(x)= 3998.8
differential_evolution step 13: f(x)= 3998.8
differential_evolution step 14: f(x)= 3998.8
differential_evolution step 15: f(x)= 3998.8
differential_evolution step 16: f(x)= 3998.8
differential_evolution step 17: f(x)= 3998.8
differential_evolution step 18: f(x)= 3998.8
differential_evolution step 19: f(x)= 3998.8
differential_evolution step 20

differential_evolution step 10: f(x)= 1662.67
differential_evolution step 11: f(x)= 1662.67
differential_evolution step 12: f(x)= 1662.21
differential_evolution step 13: f(x)= 1662.21
differential_evolution step 14: f(x)= 1661.4
differential_evolution step 15: f(x)= 1658.12
differential_evolution step 16: f(x)= 1476.26
differential_evolution step 17: f(x)= 1343.49
differential_evolution step 18: f(x)= 1343.49
differential_evolution step 19: f(x)= 1343.49
differential_evolution step 20: f(x)= 1343.49
differential_evolution step 21: f(x)= 1343.49
differential_evolution step 22: f(x)= 1343.49
differential_evolution step 23: f(x)= 1342.33
differential_evolution step 24: f(x)= 1342.33
differential_evolution step 25: f(x)= 1342.33
differential_evolution step 26: f(x)= 1342.33
differential_evolution step 27: f(x)= 1342.33
differential_evolution step 28: f(x)= 1342.33
differential_evolution step 29: f(x)= 1342.33
differential_evolution step 30: f(x)= 1342.33
differential_evolution step 31: f(x

differential_evolution step 43: f(x)= 2937.31
differential_evolution step 44: f(x)= 2937.31
differential_evolution step 45: f(x)= 2927.6
differential_evolution step 46: f(x)= 2922.71
differential_evolution step 47: f(x)= 2922.71
differential_evolution step 48: f(x)= 2922.71
differential_evolution step 49: f(x)= 2920.31
differential_evolution step 50: f(x)= 2918.04
differential_evolution step 51: f(x)= 2916.1
differential_evolution step 52: f(x)= 2913.65
differential_evolution step 53: f(x)= 2913.65
differential_evolution step 54: f(x)= 2913.65
differential_evolution step 55: f(x)= 2913.65
differential_evolution step 56: f(x)= 2913.65
Solution vector of current run: 
[-1.6035572   3.96324367  0.55745464  0.01292363  1.14478234  0.11792747]
The run took: 
00:09:38
Parameters for run 40: 
[-1.3887596776615423, 3.995904392269515, 0.25768710503440895, 2.7123735885062916, 25.079899356045054, 4.756673431132509]
0  datapoints sampled
1000  datapoints sampled
2000  datapoints sampled
Mean rt fo

differential_evolution step 50: f(x)= 995.905
differential_evolution step 51: f(x)= 989.604
differential_evolution step 52: f(x)= 989.604
differential_evolution step 53: f(x)= 989.604
differential_evolution step 54: f(x)= 989.604
differential_evolution step 55: f(x)= 989.604
differential_evolution step 56: f(x)= 989.604
differential_evolution step 57: f(x)= 989.604
differential_evolution step 58: f(x)= 989.604
differential_evolution step 59: f(x)= 989.604
differential_evolution step 60: f(x)= 989.604
differential_evolution step 61: f(x)= 989.604
differential_evolution step 62: f(x)= 989.604
differential_evolution step 63: f(x)= 989.604
differential_evolution step 64: f(x)= 989.604
differential_evolution step 65: f(x)= 989.604
differential_evolution step 66: f(x)= 989.604
differential_evolution step 67: f(x)= 989.604
differential_evolution step 68: f(x)= 989.604
differential_evolution step 69: f(x)= 989.604
differential_evolution step 70: f(x)= 989.022
differential_evolution step 71: f(

In [None]:
# Read in results
optim_results = pd.read_csv(os.getcwd() + '/experiments/ddm_flexbound_kde_mle_fix_v_0_c1_0_w_unbiased_arange_2_3/optim_results.csv')

In [None]:
plt.scatter(optim_results['v_sim'], optim_results['v_mle'], c = optim_results['c2_mle'])

In [None]:
# Regression for v
reg = LinearRegression().fit(np.expand_dims(optim_results['v_mle'], 1), np.expand_dims(optim_results['v_sim'], 1))
reg.score(np.expand_dims(optim_results['v_mle'], 1), np.expand_dims(optim_results['v_sim'], 1))

In [None]:
plt.scatter(optim_results['a_sim'], optim_results['a_mle'], c = optim_results['c2_mle'])

In [None]:
# Regression for a
reg = LinearRegression().fit(np.expand_dims(optim_results['a_mle'], 1), np.expand_dims(optim_results['a_sim'], 1))
reg.score(np.expand_dims(optim_results['a_mle'], 1), np.expand_dims(optim_results['a_sim'], 1))

In [None]:
plt.scatter(optim_results['w_sim'], optim_results['w_mle'])

In [None]:
# Regression for w
reg = LinearRegression().fit(np.expand_dims(optim_results['w_mle'], 1), np.expand_dims(optim_results['w_sim'], 1))
reg.score(np.expand_dims(optim_results['w_mle'], 1), np.expand_dims(optim_results['w_sim'], 1))

In [None]:
plt.scatter(optim_results['c1_sim'], optim_results['c1_mle'])

In [None]:
# Regression for c1
reg = LinearRegression().fit(np.expand_dims(optim_results['c1_mle'], 1), np.expand_dims(optim_results['c1_sim'], 1))
reg.score(np.expand_dims(optim_results['c1_mle'], 1), np.expand_dims(optim_results['c1_sim'], 1))

In [None]:
plt.scatter(optim_results['c2_sim'], optim_results['c2_mle'], c = optim_results['a_mle'])

In [None]:
# Regression for w
reg = LinearRegression().fit(np.expand_dims(optim_results['c2_mle'], 1), np.expand_dims(optim_results['c2_sim'], 1))
reg.score(np.expand_dims(optim_results['c2_mle'], 1), np.expand_dims(optim_results['c2_sim'], 1))