In [None]:
from os import makedirs
import numpy as np
import torch
from matplotlib import pyplot as plt
import matplotlib as mpl
from lib.HH_model.parameters_info import parameters_initial0, parameters_range_bounds, parameters_lower_bound, parameters_upper_bound, parameter_names
from lib.drffit.uniform_sampler import uniform_around_sampler as uniform_sampler
from lib.HH_model.simulator import HH_simulator
import joblib
import pickle as pkl
from datetime import datetime
from lib.utils import *
from copy import deepcopy as dcp
set_mpl()
device = 'cpu'
mpl.rcParams['font.size'] = 12
V_test, I_test = HH_simulator(np.array([parameters_initial0]), length = 0.01, dt = 0.01)
midpoint = parameters_lower_bound+ parameters_range_bounds/2
theta_min = parameters_lower_bound
theta_range = parameters_range_bounds
range_bounds = theta_range
print('Parameters:\n')
for i, pn in enumerate(parameter_names):
    print('\t',i,'\t',pn,':    \t',parameters_lower_bound[i],'   \t',parameters_upper_bound[i])

# Simulate samples from a sample space

In [None]:
max_njobs = 10
# Info settings
data_path = '../Data/HH/targets/'
file_name = 'trial_1'
save_sample_data = True

# Simulation settings
dt = 0.01
length = 0.05
cutoff = 10000 # 100ms (resolution of 0.01ms)
chunk_size, num_chunks = 50, 500

# Sampler settings
search_width = 1.0
point = midpoint
sampler_fn = uniform_sampler(theta_min, theta_range = theta_range, sample_distribution='cube')
sampler_fn.set_state(point = point, width=search_width)
# To keep track of runtime
runtime = datetime.now()-datetime.now()
st_time = datetime.now()
st_time_string = st_time.strftime('%D, %H:%M:%S')
print(f'Start time: {st_time_string}')

# Produce samples and simulate
parameters_samples = [sampler_fn.sample((chunk_size,)) for _ in range(num_chunks)]
all_samples = torch.cat(parameters_samples,dim = 0)
print(len(parameters_samples))

n_jobs = num_chunks
if num_chunks > max_njobs:
    n_jobs = max_njobs
print(file_name)

In [None]:
fig = plt.figure(figsize=(14,6))
for i in range(12):
    ax = plt.subplot(2,6,i+1)
    plt.violinplot(ensure_numpy(all_samples)[:,i])
    plt.ylim([lower_bound[i],upper_bound[i]])
    plt.title(f"{parameter_names[i]}")
plt.suptitle('Distribution of parameter values over the samples')
plt.tight_layout()
plt.show()

In [None]:
noise_seed=np.random.randint(0,2**16)
results = joblib.Parallel(n_jobs=n_jobs, verbose = 1)(joblib.delayed(HH_simulator)(
                                                                        parameters,
                                                                        length = length,
                                                                        dt=dt,
                                                                        noise_seed=noise_seed
                                ) for i, parameters in enumerate(parameters_samples))
# Group the simulations
time_series_V, I = zip(*results)

# Define the log_info dict
simulated_samples = torch.cat(time_series_V, dim = 0)
parameters_samples_good = torch.cat(parameters_samples, dim = 0)
makedirs(data_path, exist_ok = True)

log_info = {}
log_info['dt'] = dt
log_info['length'] = length
log_info['chunk_size'] = chunk_size
log_info['num_chunks'] = num_chunks
log_info['total_simulations'] = chunk_size*num_chunks
log_info['valid_simulations'] = simulated_samples.shape[0]
log_info['noise_seed'] = noise_seed
log_info['search_width'] = search_width
log_info['point'] = point
if save_sample_data:
    log_info['data'] = {'x':simulated_samples,'theta':parameters_samples_good, 'stats':None }
log_info['message'] = 'Searching around known point with narrow range for "easy" test'
log_info['date'] = datetime.now().strftime('%D, %H:%M:%S')
runtime = datetime.now()-st_time
f_time_string = datetime.now().strftime('%D, %H:%M:%S')
runtime_string = str(runtime)
print(f'Finish time: {f_time_string}, Runtime: {runtime_string}')
print(f"Simulations shape: {simulated_samples.shape[0]}, {simulated_samples.shape[1]}")
print(f"Parameters shape: {parameters_samples_good.shape[0]}, {parameters_samples_good.shape[1]}\n")
print("Log info:")
data_info(log_info)

# Identify good targets

In [None]:
from sklearn.cluster import KMeans
n_clusters=25
kmeans = KMeans(n_clusters=n_clusters, tol = 1e-8).fit(simulated_samples)
sorted_cluster_id = np.flip(np.argsort(kmeans.cluster_centers_.max(1)))
cluster_ids = kmeans.predict(simulated_samples)
clustered_x = [[] for i in range(n_clusters)]
clustered_theta = [[] for i in range(n_clusters)]
for i, ids in enumerate(cluster_ids):
    clustered_x[ids].append(simulated_samples[i].unsqueeze(0)) 
    clustered_theta[ids].append(parameters_samples_good[i].unsqueeze(0)) 
for i, cluster in enumerate(clustered_x):
    try:
        clustered_x[i] = torch.cat(cluster, dim = 0)
    except:
        pass
for i, cluster in enumerate(clustered_theta):
    try:
        clustered_theta[i] = torch.cat(cluster, dim = 0)
    except:
        pass
centers = kmeans.cluster_centers_[sorted_cluster_id]

In [None]:
print(clustered_x[2].shape)

In [None]:
num_col = 5
f = plt.figure(figsize=(30,3.5*(len(centers)//num_col + 1)))
t = np.arange(int(55/dt))*dt
# Timeseries V
for i, cluster in enumerate(clustered_x):
    ax = plt.subplot(len(clustered_x)//num_col +1,num_col,1+i)
    ax.plot(t,cluster[0, :int(55/dt)], label = f"Item 0")
    try:
        ax.plot(t,cluster[1, :int(55/dt)], label = f"Item 1")
    except:
        pass
    plt.xlabel('Time [ms]')
    plt.ylabel('Voltage [mV]')
    plt.legend()
    plt.title(f'Cluster {i} with {cluster.shape[0]} items')
plt.suptitle('Time series of cluster centers\n')

plt.tight_layout()
plt.show()

In [None]:
targets = []
thetas = []
for i, cluster in enumerate(clustered_x):
    dist_from_center = correlation_loss_fn(ensure_torch(kmeans.cluster_centers_[i]).view(1,-1), cluster)
    sorted_by_dist = torch.argsort(dist_from_center)
    targets.append(cluster[sorted_by_dist[:10]])
    targets.append(cluster[sorted_by_dist[-10:]])
    thetas.append(clustered_theta[i][sorted_by_dist[:10]])
    thetas.append(clustered_theta[i][sorted_by_dist[-10:]])
targets = torch.cat(targets, dim = 0)
thetas = torch.cat(thetas, dim = 0)

In [None]:
print(targets.shape)
print(thetas.shape)
log_info['kmeans'] = kmeans
log_info['clusters'] = {'x':clustered_x, 'theta':clustered_theta}
log_info['targets'] = {'x': targets, 'theta':thetas}
log_info['message'] = 'Generating targets from extensive whole space search and selecting the 10 closest and 10 furthest of the 25 clusters of the simulations'
save_log(log_info, data_path, file_name)

In [None]:
targets_only = {}
targets_only['targets'] = {'x': targets, 'theta':thetas}
targets_only['source'] = {'path': data_path, 'file': file_name}
save_log(targets_only, data_path, file_name+'_targets')

In [None]:
start = 0
f = plt.figure(figsize=(20,10))
# Timeseries V
ax = plt.subplot(1,1,1)
for i in range(start, start + 5):
    ax.plot(t,simulated_samples[i, :int(100/dt)], label = f"Parameter set {i}")
plt.xlabel('Time [ms]')
plt.ylabel('Voltage')
plt.title('Time series V')
plt.legend()
plt.show()