In [None]:
from lib.HH_model.utils import get_targets
from os import makedirs
import numpy as np
import torch
from matplotlib import pyplot as plt
import matplotlib as mpl
from lib.HH_model.parameters_info import parameters_initial0, parameters_range_bounds, parameters_lower_bound, parameter_names
from lib.drffit.uniform_sampler import uniform_around_sampler as uniform_sampler
from lib.HH_model.simulator import HH_simulator
import joblib
import pickle as pkl
from datetime import datetime
from lib.utils import *
from copy import deepcopy as dcp
device = 'cpu'#torch.device("cuda" if torch.cuda.is_available() else "cpu")
mpl.rcParams['font.size'] = 12
V_test, I_test = HH_simulator(np.array([parameters_initial0]), length = 0.01, dt = 0.01)
parameters_upper_bound = parameters_lower_bound+parameters_range_bounds
midpoint = parameters_lower_bound+ parameters_range_bounds/2
upper_bound = dcp(parameters_upper_bound)
lower_bound = dcp(parameters_lower_bound)
theta_min = lower_bound#upper_bound#parameters_lower_bound
theta_range = upper_bound - lower_bound#parameters_range_bounds
range_bounds = upper_bound - lower_bound
print('Parameters:\n')
for i, pn in enumerate(parameter_names):
    print('\t',i,'\t',pn,':    \t',parameters_lower_bound[i],'   \t',parameters_upper_bound[i])
count = 0

In [None]:
max_njobs = 10
# Info settings
targets_file = 'trial_1'
glob_data_path = '../Data/HH/initialization/'

save_sample_data = True

# Simulation settings
dt = 0.01
length = 0.05
cutoff = 10000 # 100ms (resolution of 0.01ms)
chunk_size, num_chunks = 25, 400
sample_distribution = 'cube'
# Sampler settings
search_width = 1.0
point = midpoint
number_of_simulations = chunk_size * num_chunks
sampler_fn = uniform_sampler(theta_min, theta_range = theta_range, sample_distribution=sample_distribution)
sampler_fn.set_state(point = point, width=search_width)
file_name = f'{sample_distribution}{number_of_simulations}_{count}'
# To keep track of runtime
runtime = datetime.now()-datetime.now()
st_time = datetime.now()
st_time_string = st_time.strftime('%D, %H:%M:%S')
print(f'Start time: {st_time_string}')

# Produce samples and simulate
parameters_samples = [sampler_fn.sample((chunk_size,)) for _ in range(num_chunks)]
all_samples = torch.cat(parameters_samples, dim = 0)
print(len(parameters_samples))

n_jobs = num_chunks
if num_chunks > max_njobs:
    n_jobs = max_njobs
print(file_name)

In [None]:
fig = plt.figure(figsize=(14,6))
for i in range(12):
    ax = plt.subplot(2,6,i+1)
    plt.violinplot(ensure_numpy(all_samples)[:,i])
    plt.ylim([lower_bound[i],upper_bound[i]])
    plt.title(f"{parameter_names[i]}")
plt.suptitle('Distribution of parameter values over the samples')
plt.tight_layout()
plt.show()

In [None]:
noise_seed=np.random.randint(0,2**16)
results = joblib.Parallel(n_jobs=n_jobs, verbose = 1)(joblib.delayed(HH_simulator)(
                                                                        parameters,
                                                                        length = length,
                                                                        dt=dt,
                                                                        noise_seed=noise_seed
                                ) for i, parameters in enumerate(parameters_samples))
# Group the simulations
time_series_V, I = zip(*results)

# Define the log_info dict
simulated_samples = torch.cat(time_series_V, dim = 0)
parameters_samples_good = torch.cat(parameters_samples, dim = 0)
log_info = {}
log_info['dt'] = dt
log_info['length'] = length
log_info['chunk_size'] = chunk_size
log_info['num_chunks'] = num_chunks
log_info['total_simulations'] = chunk_size*num_chunks
log_info['valid_simulations'] = simulated_samples.shape[0]
log_info['noise_seed'] = noise_seed
log_info['search_width'] = search_width
log_info['point'] = point
if save_sample_data:
    log_info['data'] = {'x':simulated_samples,'theta':parameters_samples_good, 'stats':None }
log_info['message'] = 'Searching around known point with narrow range for "easy" test'
log_info['date'] = datetime.now().strftime('%D, %H:%M:%S')
runtime = datetime.now()-st_time
f_time_string = datetime.now().strftime('%D, %H:%M:%S')
runtime_string = str(runtime)
print(f'Finish time: {f_time_string}, Runtime: {runtime_string}')
print(f"Simulations shape: {simulated_samples.shape[0]}, {simulated_samples.shape[1]}")
print(f"Parameters shape: {parameters_samples_good.shape[0]}, {parameters_samples_good.shape[1]}\n")
print("Log info:")
data_info(log_info)

In [None]:
def mse_loss_fn_batch(target, candidates, batch = None, device = 'cuda'):
    target = target.view(1,-1)
    loss_f = mse
    if batch is None:
        loss = loss_f(candidates, target)
    else:
        target = target.to(device)
        loss = []
        num_candidates = candidates.shape[0]
        batch_size = num_candidates//batch
        for i in range(0, num_candidates, batch_size):
            batch_candidates = candidates[i:i+batch_size].to(device)
            loss.append(loss_f(batch_candidates, target))
        loss = torch.cat(loss)
    return loss

In [None]:
target_x, target_theta = get_targets(file = targets_file)
print(f'Target shape: {target_x.shape},\t Candidates shape: {simulated_samples.shape}')
target_x = target_x[:,:simulated_samples.shape[1]]
train_x = simulated_samples
train_theta = parameters_samples_good

In [None]:
from tqdm import tqdm
all_errs = []
initial_points = []
loss_fn = 'mse'
sorted_target = []
target_x = target_x
train_x = train_x#.to('cpu')
for i in tqdm(range(target_x.shape[0])):
    error = mse_loss_fn_batch(target_x[i].view(1,-1),train_x, batch = 1)
    err_index = torch.argmin(error).cpu()
    sorted_target.append({
            'theta':train_theta[err_index].view(1,-1),
            'x': train_x[err_index].cpu().view(1,-1),
            'error':error[err_index].cpu().view(1,-1),
            'target':target_x[i].cpu().view(1,-1),
            'target_theta':target_theta[i].view(1,-1),
            'real_index':torch.tensor(i).view(1,-1)
    })
    all_errs.append(error[err_index].data.item())
all_errs = torch.tensor(all_errs)
sorted_errors = torch.argsort(all_errs, descending = True)
sorted_target_info = []
for i in sorted_errors:
    sorted_target_info.append(sorted_target[i])

target_log = {
    'original': {'x': target_x.cpu(),'theta':target_theta},
    'sorted':{
        'data':{'source':glob_data_path, 'x': train_x.cpu(), 'theta':train_theta},
        'loss_fn': loss_fn,
        'indices':sorted_errors,
    }
}

target_log['original']['fits'] = {}
target_log['original']['fits']['theta'] = torch.cat([sorted_target[i]['theta'] for i in range(target_x.shape[0])], dim = 0)
target_log['original']['fits']['x'] = torch.cat([sorted_target[i]['x'] for i in range(target_x.shape[0])], dim = 0)
target_log['original']['fits']['error'] = torch.cat([sorted_target[i]['error'] for i in range(target_x.shape[0])], dim = 0)
target_log['original']['fits']['target'] = torch.cat([sorted_target[i]['target'] for i in range(target_x.shape[0])], dim = 0)
target_log['original']['fits']['target_theta'] = torch.cat([sorted_target[i]['target_theta'] for i in range(target_x.shape[0])], dim = 0)


target_log['sorted']['worst_info_all'] = {}
target_log['sorted']['worst_info_all']['theta'] = torch.cat([sorted_target_info[i]['theta'] for i in range(target_x.shape[0])], dim = 0)
target_log['sorted']['worst_info_all']['x'] = torch.cat([sorted_target_info[i]['x'] for i in range(target_x.shape[0])], dim = 0)
target_log['sorted']['worst_info_all']['error'] = torch.cat([sorted_target_info[i]['error'] for i in range(target_x.shape[0])], dim = 0)
target_log['sorted']['worst_info_all']['target'] = torch.cat([sorted_target_info[i]['target'] for i in range(target_x.shape[0])], dim = 0)
target_log['sorted']['worst_info_all']['target_theta'] = torch.cat([sorted_target_info[i]['target_theta'] for i in range(target_x.shape[0])], dim = 0)
target_log['sorted']['worst_info_all']['real_index'] = torch.cat([sorted_target_info[i]['real_index'] for i in range(target_x.shape[0])], dim = 0)
target_log['original']['reorder'] = torch.argsort(target_log['sorted']['worst_info_all']['real_index'][:,0])

In [None]:
print(target_log['sorted']['worst_info_all']['x'].shape)   
print(target_log['sorted']['data']['x'].shape)   
print()
print(torch.sqrt(target_log['sorted']['worst_info_all']['error'].mean()))
print(torch.sqrt(target_log['sorted']['worst_info_all']['error'][:100].mean()))
print(torch.sqrt(target_log['sorted']['worst_info_all']['error'][:100].min()))
print()
print(torch.sqrt(target_log['sorted']['worst_info_all']['error'].max()))
print(torch.sqrt(target_log['sorted']['worst_info_all']['error'].min()))
print()
print((target_log['sorted']['worst_info_all']['target'][target_log['original']['reorder']] == target_log['original']['x']).all())
print(glob_data_path+'mse_loss/',file_name)

print()

In [None]:
makedirs(glob_data_path, exist_ok = True)
save_log(target_log, glob_data_path, f'{file_name}_fits', enforce_replace = False)
count += 1