In [None]:
from os import mkdir, makedirs
from datetime import datetime
from copy import deepcopy as dcp
import numpy as np
import torch
import joblib
import pickle as pkl
from lib.utils import *
from lib.Wilson_Cowan.simulators import WC_stochastic_heun_PSD
from lib.Wilson_Cowan.utils import *
from lib.Wilson_Cowan.parameters_info import parameters_alpha_peak, parameters_range_bounds, parameters_lower_bound,parameters_upper_bound, parameter_names, parameters_original
from lib.Wilson_Cowan.simulators import WC_stochastic_heun_PSD
from lib.drffit.uniform_sampler import uniform_around_sampler as uniform_sampler
from copy import deepcopy as dcp
make_cluster_reals = True
print('Parameters:\n')
for i, pn in enumerate(parameter_names):
    print('\t',i,'\t',pn,':    \t',parameters_lower_bound[i],'\t',parameters_upper_bound[i])
upper_bound = dcp(parameters_upper_bound)
lower_bound = dcp(parameters_lower_bound)
theta_min = lower_bound
theta_range = upper_bound - lower_bound
range_bounds = upper_bound - lower_bound
# precompile simulator
_, _, _ = WC_stochastic_heun_PSD(np.array([parameters_alpha_peak]),length = 6, dt=1,noise_seed=12, get_psd_I = False, remove_bad=True)
from matplotlib import pyplot as plt
import matplotlib as mpl
from sklearn.cluster import KMeans
real = get_real_individual_PSD_scaled()
set_mpl()

In [None]:
count = 0

In [None]:
glob_data_path = '../Data/WC/initialization/'
makedirs(glob_data_path, exist_ok = True)

In [None]:
max_njobs = 10
message = 'Generating a dataset as init and to train DRFFIT full space'

# General simulation settings
dt = 1.
length = 302
cutoff = [0,200] # 100Hz (resolution of 0.5Hz)
noise_seed = np.random.randint(0,2**16)
print(noise_seed)
chunk_size, num_chunks = 10, 1000
print(num_chunks)
initial_conditions = [np.random.rand(2, chunk_size) for _ in range(num_chunks)]
num_sim = chunk_size * num_chunks
file_name = f'cube_{num_sim}_{count}'
# Sampler settings
search_width = 1.0

point = theta_min + (theta_range/2)
sampler = uniform_sampler(theta_min, theta_range = theta_range,sample_distribution='cube')
sampler.set_state(point = point, width = search_width)
parameters_samples = []

for i in range(1):
    for j in range(num_chunks//1):
        parameters_samples.append(sampler.sample((chunk_size,))) 

all_samples = torch.cat(parameters_samples,dim = 0)
print(len(parameters_samples))

n_jobs = num_chunks
if num_chunks > max_njobs:
    n_jobs = max_njobs
print(file_name)

In [None]:
fig = plt.figure(figsize=(24,10))
for i in range(24):
    ax = plt.subplot(3,8,i+1)
    plt.violinplot(ensure_numpy(all_samples)[:,i])
    plt.ylim([lower_bound[i],upper_bound[i]])
    plt.title(f"{parameter_names[i]}")
plt.suptitle('Distribution of parameter values over the samples')
plt.tight_layout()
plt.show()

In [None]:
# Initialize log of the search
search_log_info = {}
search_log_info['dt'] = dt
search_log_info['length'] = length
search_log_info['chunk_size'] = chunk_size
search_log_info['num_chunks'] = num_chunks
search_log_info['cutoff'] = cutoff
search_log_info['noise_seed'] = noise_seed
search_log_info['initial_conditions'] = initial_conditions
search_log_info['message'] = message
search_log_info['total_simulations'] = num_sim
search_log_info['search_width'] = search_width
search_log_info['min_norm'] = [-1]
search_log_info['point'] = point


# To keep track of round runtime
runtime = datetime.now()-datetime.now()
st_time = datetime.now()
st_time_string = st_time.strftime('%D, %H:%M:%S')
print(f'Start time: {st_time_string}')

results = joblib.Parallel(n_jobs=n_jobs, verbose = 1)(joblib.delayed(WC_stochastic_heun_PSD)(
    
                                                                        parameters_samples[i],
                                                                        length = length,
                                                                        dt = dt,
                                                                        initial_conditions = initial_conditions[i],
                                                                        noise_seed = noise_seed,
                                                                        PSD_cutoff = cutoff,
                                                                        remove_bad = True
    
                                                        ) for i in range(num_chunks)
                                        )
# Group the simulations
psd_E, _, pars = zip(*results)
valid_PSD = []
valid_pars = []
for i in range(len(psd_E)):
    if len(psd_E[i])>1:
        valid_PSD.append(psd_E[i])
        valid_pars.append(pars[i])

# If no valid simulation remains
if len(valid_PSD) == 0:
    print("No valid simulations remained")
    search_log_info['valid_simulations'] = 0
    search_log_info['message'] += ' (No valid simulations remained)'
else:
    simulated_samples = torch.cat(valid_PSD, dim = 0)
    parameters_samples_good = torch.cat(valid_pars, dim = 0)
    cleaned_PSD, cleaned_parameters = simulated_samples, parameters_samples_good
    search_log_info['valid_simulations'] = cleaned_PSD.shape[0]

# Define the log_info dict
runtime = datetime.now()-st_time
search_log_info['data'] = {'x':cleaned_PSD,'theta':cleaned_parameters}
search_log_info['date'] = datetime.now().strftime('%D, %H:%M:%S')
search_log_info['runtime'] = runtime

# Keep track of the progress
f_time_string = datetime.now().strftime('%D, %H:%M:%S')
runtime_string = str(runtime)
print(f'Finish time: {f_time_string}, Runtime: {runtime_string}')
print(f"Simulations shape: {cleaned_PSD.shape[0]}, {cleaned_PSD.shape[1]}", end = '\t')
print(f"Parameters shape: {cleaned_parameters.shape[0]}, {cleaned_parameters.shape[1]}\n")

In [None]:
save_log(search_log_info, glob_data_path, file_name, enforce_replace = False)

In [None]:
real = get_real_individual_PSD_scaled(cutoff = 4, Hz = 80)
target_reals = real
target_PSDs_all_frequency_range = ensure_torch(target_reals).view(-1,156)
target_PSDs_all_frequency_range /= target_PSDs_all_frequency_range.amax(1).view(-1,1)
target_theta = ensure_torch(torch.zeros(target_PSDs_all_frequency_range.shape[0],24))
print(f'Targets full range: \tx: {target_PSDs_all_frequency_range.shape}\t theta: {target_theta.shape}')
print(target_PSDs_all_frequency_range.amax(1))
train_x_all_freq = search_log_info['data']['x']
train_theta = search_log_info['data']['theta']
frequency_range = [4,160]
target_PSDs = ensure_torch(target_PSDs_all_frequency_range[:,frequency_range[0]-4:frequency_range[1]-4])
freq = [0.5*i for i in range(frequency_range[0],frequency_range[1])]
train_x = train_x_all_freq[:,frequency_range[0]:frequency_range[1]].float()
train_x /= torch.amax(train_x,1).view(-1,1)
print(f'To fit shape: \tx: {train_x.shape}\t theta: {train_theta.shape}')
print(train_x.amax(1).mean())

In [None]:
from tqdm import tqdm
all_errs = []
initial_points = []
loss_fn = 'correlation'
sorted_target = []
target_PSDs = target_PSDs.to('cuda')
train_x = train_x.to('cuda')
for i in tqdm(range(target_PSDs.shape[0])):
    error = correlation_loss_fn(target_PSDs[i].view(1,-1),train_x)
    err_index = torch.argmin(error).cpu()
    sorted_target.append({
            'theta':train_theta[err_index].view(1,-1),
            'x': train_x[err_index].cpu().view(1,-1),
            'error':error[err_index].cpu().view(1,-1),
            'target':target_PSDs[i].cpu().view(1,-1),
            'target_theta':target_theta[i].view(1,-1),
            'real_index':torch.tensor(i).view(1,-1)
    })
    all_errs.append(error[err_index].data.item())
all_errs = torch.tensor(all_errs)
sorted_errors = torch.argsort(all_errs, descending = True)
sorted_target_info = []
for i in sorted_errors:
    sorted_target_info.append(sorted_target[i])

target_log = {
    'original': {'x': target_PSDs.cpu(),'theta':target_theta},
    'sorted':{
        'data':{'source':glob_data_path, 'x': train_x.cpu(), 'theta':train_theta},
        'loss_fn': loss_fn,
        'indices':sorted_errors,
    }
}

target_log['original']['fits'] = {}
target_log['original']['fits']['theta'] = torch.cat([sorted_target[i]['theta'] for i in range(target_PSDs.shape[0])], dim = 0)
target_log['original']['fits']['x'] = torch.cat([sorted_target[i]['x'] for i in range(target_PSDs.shape[0])], dim = 0)
target_log['original']['fits']['error'] = torch.cat([sorted_target[i]['error'] for i in range(target_PSDs.shape[0])], dim = 0)
target_log['original']['fits']['target'] = torch.cat([sorted_target[i]['target'] for i in range(target_PSDs.shape[0])], dim = 0)
target_log['original']['fits']['target_theta'] = torch.cat([sorted_target[i]['target_theta'] for i in range(target_PSDs.shape[0])], dim = 0)


target_log['sorted']['worst_info_all'] = {}
target_log['sorted']['worst_info_all']['theta'] = torch.cat([sorted_target_info[i]['theta'] for i in range(target_PSDs.shape[0])], dim = 0)
target_log['sorted']['worst_info_all']['x'] = torch.cat([sorted_target_info[i]['x'] for i in range(target_PSDs.shape[0])], dim = 0)
target_log['sorted']['worst_info_all']['error'] = torch.cat([sorted_target_info[i]['error'] for i in range(target_PSDs.shape[0])], dim = 0)
target_log['sorted']['worst_info_all']['target'] = torch.cat([sorted_target_info[i]['target'] for i in range(target_PSDs.shape[0])], dim = 0)
target_log['sorted']['worst_info_all']['target_theta'] = torch.cat([sorted_target_info[i]['target_theta'] for i in range(target_PSDs.shape[0])], dim = 0)
target_log['sorted']['worst_info_all']['real_index'] = torch.cat([sorted_target_info[i]['real_index'] for i in range(target_PSDs.shape[0])], dim = 0)
target_log['original']['reorder'] = torch.argsort(target_log['sorted']['worst_info_all']['real_index'][:,0])

In [None]:
print(target_log['sorted']['worst_info_all']['x'].shape)   
print(target_log['sorted']['data']['x'].shape)   
print()
print(target_log['sorted']['worst_info_all']['error'].mean())
print(target_log['sorted']['worst_info_all']['error'][:40].mean())
print(target_log['sorted']['worst_info_all']['error'][:40].min())
print()
print(target_log['sorted']['worst_info_all']['error'].max())
print(target_log['sorted']['worst_info_all']['error'].min())
print()
print((target_log['sorted']['worst_info_all']['target'][target_log['original']['reorder']] == target_log['original']['x']).all())

In [None]:
save_log(target_log, glob_data_path, f'{file_name}_fits')
count += 1