In [None]:
import pickle
import os
import glob
from pathlib import Path
import matplotlib.pyplot as plt

import evaluator

In [None]:
probe_type = ""
model_name = 'cultured'
model_folder = (Path(".") / f"{model_name}_model").absolute()

In [None]:
eva = evaluator.create_evaluator(
    model_name=model_name,
    feature_set="soma",
    sample_id=None,
    feature_file="./cultured_model/features.json"
)

eva.fitness_protocols

In [None]:
runs = []

for path in glob.glob("./optimization_results/checkpoints/*.tmp"):
    
    print(path)
    
    run = pickle.load(open(path, "rb"))

    run['best_params'] = eva.param_dict(run['halloffame'][0])
    run['responses'] = eva.run_protocols(
        protocols=eva.fitness_protocols.values(),
        param_values=run['best_params']
    )

    runs.append(run)

In [None]:
def plot_responses(responses):
    fig, axes = plt.subplots(len(responses), figsize=(8, 5), squeeze=False)
    for index, (resp_name, response) in enumerate(sorted(responses.items())):
        axes[index, 0].plot(response['time'], response['voltage'], label=resp_name)
        axes[index, 0].set_title(resp_name)
        
        if  "APWaveform" in resp_name:
            axes[index, 0].set_xlim(450, 700)

    fig.tight_layout()
    plt.show()

In [None]:
import pprint

for run in runs:

    feature_names = [obj.name for obj in eva.fitness_calculator.objectives]
    scores = dict(zip(feature_names, run['halloffame'][0].fitness.values))
    
    print()
    print("\nFitness: {}".format(sum(run['halloffame'][0].fitness.values)))
    print("\nScores: ")
    pprint.pprint(scores)

    #print("Parameters:")
    #pprint.pprint(run['best_params'])

    plot_responses(run['responses'])

In [None]:
import numpy
from scipy.ndimage import gaussian_filter1d

for run in runs:
    plt.plot(numpy.cumsum(run['logbook'].select("nevals")), run["logbook"].select("min"))
    #plt.plot(numpy.cumsum(run['logbook'].select("nevals")), gaussian_filter1d(run["logbook"].select("min"), 2))

In [None]:
import logging
logger = logging.getLogger()
logger.setLevel(logging.DEBUG)

ch = logging.StreamHandler()
ch.setLevel(logging.DEBUG)
logger.addHandler(ch)