In [None]:
import numpy as np
import torch
from torch import nn
import pickle as pkl
from lib.utils import *
from lib.drffit.uniform_sampler import uniform_around_sampler as uniform_sampler
from lib.Wilson_Cowan.parameters_info import parameters_alpha_peak, parameters_range_bounds, parameters_lower_bound
theta_min = parameters_lower_bound
theta_range = parameters_range_bounds
from lib.Feature_extraction.AE import *
from lib.Feature_extraction.PCA_features import *
from lib.drffit.subspace_estimator import *
from lib.drffit.drffit import *
set_mpl()
def norm_pars(pars):
    return (pars - theta_min) / theta_range
def denorm_pars(pars):
    return (pars * theta_range) + theta_min
rescale_plot = 1.0
use_log = False
from scipy.stats import gaussian_kde
def correlation_reconstruction(original,reconstructed):
    cos = torch.nn.CosineSimilarity(dim=1, eps=1e-8)
    pearson = cos(original - original.mean(dim=1,keepdim=True),reconstructed - reconstructed.mean(dim=1,keepdim=True))
    return pearson.numpy()
def rec_gof_density(testing_x, reconstructed, name = 'default', reconstructed_SE = None, figsize = (10,8), x_lim = [0.0,1.0], labels = []):
    reconstructed_gof = correlation_reconstruction(testing_x, reconstructed)
    x = np.linspace(x_lim[0],x_lim[1],500)
    f = plt.figure(figsize = figsize)
    ax = plt.subplot(1,1,1)
    plt.plot(x,gaussian_kde(reconstructed_gof)(x), '-k', lw= 1.0, label = f'Feature fn')
    if reconstructed_SE is not None:
        if type(reconstructed_SE) == type([]):
            if labels is None:
                labels = [i for i in range(len(reconstructed_SE))]
            for i, reconstructed_SE_item in enumerate(reconstructed_SE):
                reconstructed_gof_SE = correlation_reconstruction(testing_x, reconstructed_SE_item)
                plt.plot(x,gaussian_kde(reconstructed_gof_SE)(x), f'-C{i}', lw= 1.0, label = f'{labels[i]}')
        else:    
            reconstructed_gof_SE = correlation_reconstruction(testing_x, reconstructed_SE)
            plt.plot(x,gaussian_kde(reconstructed_gof_SE)(x), '-r', lw= 1.0, label = f'SE')
    plt.legend(**{'ncol': 1,'prop':{'size':10}})
    plt.xlabel('Correlation value')
    plt.ylabel('Density')
    plt.title(f'{name} reconstruction Goodness of Fit')
    plt.grid(True, linestyle = '--')
    plt.tight_layout()
    plt.show()

# Load Training Data

In [None]:
from os import walk
glob_data_path = '../Data/WC/initialization/'
DRFFIT_path = '../Data/WC/initialization/DRFFIT_objects/'
sub_dir = glob_data_path
list_of_init_data = []
for (dirpath, dirnames, filenames) in walk(glob_data_path):
    for i, file_n in enumerate(filenames):
        if 'DRFFIT' not in file_n:
            list_of_init_data.append(file_n[:-4])
list_of_init_data.sort()
print(list_of_init_data)

In [None]:
data_file_name = list_of_init_data[0]
DRFFIT_file_name = data_file_name+'_DRFFIT'
print(sub_dir+data_file_name)
train_data_log = get_log(sub_dir,data_file_name)
data_info(train_data_log)


In [None]:
rescale_individually = True
train_x_all_freq = train_data_log['data']['x']
train_theta = train_data_log['data']['theta']
print(f'DRFFIT data all full range: \tx: {train_x_all_freq.shape}\t theta: {train_theta.shape}')
print(train_x_all_freq.amax(1).shape)

## Define frequency range and log scale (if used)

In [None]:
frequency_range = [4,160]
freq = [0.5*i for i in range(frequency_range[0],frequency_range[1])]
train_x = train_x_all_freq[:,frequency_range[0]:frequency_range[1]].float()
train_x /= torch.amax(train_x,1).view(-1,1)
print(f'DRFFIT data max: \tx: {train_x.amax()}')

## Shuffle data before providing to DRFFIT

In [None]:
train_indices = np.arange(0, train_x.shape[0])
np.random.shuffle(train_indices)
train_indices = torch.as_tensor(train_indices, dtype=int)
train_x, train_theta = train_x[train_indices], train_theta[train_indices]
print(f'DRFFIT data: \tx: {train_x.shape}\t theta: {train_theta.shape}')

# Create and initialize DRFFIT

In [None]:
DRFFIT_objects_path = DRFFIT_path

num_features = 5
device = 'cuda'

# Initialize drffit object
drffit = DRFFIT(train_theta.shape[1], train_x.shape[1], theta_min, theta_range)

# Add dataset to the DRFFIT object
drffit.add_data(train_x, theta=train_theta)
#drffit.initialize_subspace_estimator()
#drffit.initialize_subspace_estimator("PCA")

## PCA function

In [None]:
# PCA
pca = PCA_features(num_features)
drffit.add_custom_fn(pca, pca.num_features,"PCA")
drffit.train_feature_fn(name = "PCA")
PCA_fn = drffit.feature_fn['PCA']['fn']

In [None]:
testing_x = train_x
PCA_true = ensure_numpy(PCA_fn.feature_fn(testing_x))
reconstructed_train_x_PCA = ensure_numpy(PCA_fn.pca_model.inverse_transform(PCA_true))
testing_x = ensure_numpy(testing_x)
plot_fn_reconstruction(testing_x,reconstructed_train_x_PCA, freq = freq, rescale_plot = rescale_plot, use_log = use_log)
rec_gof_density(ensure_torch(testing_x), ensure_torch(reconstructed_train_x_PCA), name = 'PCA', reconstructed_SE = None, figsize = (10,8), x_lim = [0.95,1])

## PCA Subspace Estimator

In [None]:
architectures = [
    [[50, 50, 50], True, nn.SiLU(), 0.00025],
    [[50, 50, 50], True, nn.SiLU(), 0.0001],
    [[50, 50, 50], True, nn.SiLU(), 0.00005],
    #[[50, 50, 50], True, nn.SiLU()],
    #[[100, 100], True, nn.SiLU()],
    #[[100, 50, 50], True, nn.SiLU()],
    #[[100, 100, 100], True, nn.SiLU()],
    #[[100, 100, 100, 50], True, nn.SiLU()],
    #[[100, 100, 100, 50, 50], True, nn.SiLU()],
    #[[100, 100, 100, 50, 50, 50], True, nn.SiLU()],
]
for i, hyperpar in enumerate(architectures):
    PCA_SE_norm_features = False
    PCA_SE_name = f"PCA0_{i}"
    print(PCA_SE_name+':')
    drffit.initialize_subspace_estimator(PCA_SE_name)
    architecture = {
        'units':hyperpar[0],
        'skip_connection':hyperpar[1],

    }

    drffit.uniform_add_feature_fn_to_subspace_estimator(

                                                        subspace_estimator_name=PCA_SE_name,
                                                        feature_fn_name="PCA",
                                                        device = 'cuda',
                                                        combine = True,
                                                        enforce_replace = True,
                                                        architecture = architecture,
                                                        activation_fn = hyperpar[2],

    )

    drffit.set_feature_data_of_subspace_estimator(

                                                        subspace_estimator_name=PCA_SE_name,
                                                        norm_features = PCA_SE_norm_features,
                                                        split = [0.8,1.0],
                                                        overwrite_data = True

    )
    val_history = drffit.train_all_nets_from_subspace_estimator(

                                                    name = PCA_SE_name,
                                                    epochs = 5000, batch_size = 32,
                                                    lr = hyperpar[3], scheduler_kwargs = {'gamma' : 0.5, 'step_size' : 75},
                                                    weight_decay = 0.0005, clip_max_norm = 1.0, amsgrad = False,
                                                    patience = 25, multi_reset = 1, threshold_gain = 0.1,
                                                    print_rate = 1, verbose = 1,
                                                    enforce_replace=False, skip_trained=False,
                                                    return_val = True, rescale_loss = 10000.0,

    )

In [None]:
#plot_train_loss(val_history[0], cutoff = 1, use_logscale = True)

In [None]:
labels = []
reconstructed_SE_PCAs = []
for SE in drffit.subspace_estimator:
    if ("PCA0" in SE and '_' in SE) or ('PCA0_0' in SE):
        print(SE)
        labels.append(SE)
        se_PCA = drffit.subspace_estimator[SE]['subspace_estimator']
        testing_x = ensure_numpy(testing_x)
        PCA_test = ensure_numpy(torch.cat([se_PCA.net_list[i](ensure_torch(norm_pars(train_theta)).to(device)) for i in range(len(se_PCA.net_list))],dim = 1))
        if PCA_SE_norm_features:
            PCA_test = ((PCA_test + ensure_numpy(se_PCA.feature_mean)) * ensure_numpy(se_PCA.feature_std))
        reconstructed_SE_PCA = ensure_numpy(PCA_fn.pca_model.inverse_transform(PCA_test))
        plot_fn_reconstruction(testing_x, reconstructed_train_x_PCA, reconstructed_SE = reconstructed_SE_PCA, freq = freq, rescale_plot = rescale_plot, use_log = use_log)
        reconstructed_SE_PCAt = ensure_torch(PCA_fn.pca_model.inverse_transform(PCA_test))
        reconstructed_SE_PCAs.append(reconstructed_SE_PCAt)
rec_gof_density(ensure_torch(testing_x), ensure_torch(reconstructed_train_x_PCA), name = 'PCA', reconstructed_SE = reconstructed_SE_PCAs, figsize = (10,6), x_lim = [0.90,1], labels = labels)

In [None]:
print(DRFFIT_objects_path,DRFFIT_file_name)
from os import makedirs
makedirs(DRFFIT_objects_path, exist_ok = True)
save_SE_PCA = True
if save_SE_PCA:
    drffit.save_SE_as_log(
            model_path = DRFFIT_objects_path,
            file_name = DRFFIT_file_name+f'_SE_PCA{num_features}LD',
            name = "PCA0_0",
            enforce_replace = False,
   )

## Autoencoder function

In [None]:
architectures = [
    [[25], nn.SiLU(), 0.0025],
    #[[25], nn.SiLU(), 0.00125],
    #[[25], nn.SiLU(), 0.0005],
    #[[25], nn.SiLU(), 0.00025],
    #[[1000], nn.PReLU(), 0.00025],
    #[[500,50], nn.SiLU(), 0.00025],
   # [[500,50], nn.PReLU(), 0.00025],    
]
for i, hyperpar in enumerate(architectures):
    AE_name = f'default_AE0_{i}'
    print(AE_name)
    drffit.add_AE(
                            name = AE_name,
                            num_features=num_features,
                            device = 'cuda',
                            replace = True,
                            norm = linear_norm,
                            denorm = linear_norm,
                            architecture = {'units':hyperpar[0]},
                            out_fn = linear_norm,
                            activation_fn = hyperpar[1],

    )

    val_history = drffit.train_feature_fn(
                            name = AE_name,
                            epochs = 100000, 
                            lr = hyperpar[2], scheduler_kwargs = {'gamma' : 0.25, 'step_size' : 50},
                            weight_decay = 0.01, clip_max_norm = 5.0, amsgrad = False,
                            batch_size = 32, split = [0.8,1.0],
                            patience = 50, multi_reset = 1, threshold_gain = 0.5,
                            print_rate = 1, verbose = 1,
                            enforce_replace=True, return_val = True

    )

In [None]:
#plot_train_loss(val_history, cutoff = 1, use_logscale= True)

In [None]:
reconstructed_SE_AEs = []
labels = []
for i in range(1):
    AE_name = f'default_AE0_{i}'
    print(AE_name)
    labels.append(i)
    AE_fn = drffit.feature_fn[AE_name]['fn']
    reconstructed_train_x_AE = ensure_numpy(AE_fn(ensure_torch(testing_x).to(device)).to("cpu"))
    testing_x = ensure_numpy(testing_x)
    plot_fn_reconstruction(testing_x,reconstructed_train_x_AE, freq = freq, rescale_plot = rescale_plot, use_log = use_log)
    reconstructed_SE_AEt = ensure_torch(reconstructed_train_x_AE)
    reconstructed_SE_AEs.append(reconstructed_SE_AEt)
rec_gof_density(ensure_torch(testing_x), ensure_torch(reconstructed_train_x_PCA), name = 'AE', reconstructed_SE = reconstructed_SE_AEs, figsize = (10,6), x_lim = [0.90,1], labels = labels)

## Autoencoder Subspace Estimator

In [None]:
architectures = [
    [[100, 50, 50], True, nn.SiLU(), 0.025],
    [[100, 50, 50], True, nn.SiLU(), 0.0125],
    [[100, 50, 50], True, nn.SiLU(), 0.005],
    [[100, 50, 50], True, nn.SiLU(), 0.0025],
    [[100, 50, 50], True, nn.SiLU(), 0.00125],
    [[100, 50, 50], True, nn.SiLU(), 0.0005],
    #[[100, 100], True, nn.PReLU(), 0.01],
    #[[100, 100], True, nn.SiLU(), 0.05],
    #[[100, 100], True, nn.PReLU(), 0.05],
]
for i, hyperpar in enumerate(architectures):
    AE_SE_norm_features = False
    SE_Name = f"default_AE0_0_{i}"
    print(SE_Name+':')
    drffit.initialize_subspace_estimator(SE_Name)
    architecture = {
        'units':hyperpar[0],
        'skip_connection':hyperpar[1],

    }

    drffit.uniform_add_feature_fn_to_subspace_estimator(

                                                        subspace_estimator_name=SE_Name,
                                                        feature_fn_name="default_AE0_0",
                                                        device = 'cuda',
                                                        combine = True,
                                                        enforce_replace = True,
                                                        architecture = architecture,
                                                        activation_fn = hyperpar[2],

    )

    drffit.set_feature_data_of_subspace_estimator(

                                                        subspace_estimator_name=SE_Name,
                                                        norm_features = AE_SE_norm_features,
                                                        split = [0.80,1.0],
                                                        overwrite_data = True

    )
    val_history = drffit.train_all_nets_from_subspace_estimator(

                                                    name = SE_Name,
                                                    epochs = 5000, batch_size = 32,
                                                    lr = hyperpar[3], scheduler_kwargs = {'gamma' : 0.25, 'step_size' : 50},
                                                    weight_decay = 0.01, clip_max_norm = 1.0, amsgrad = False,
                                                    patience = 10, multi_reset = 5, threshold_gain = 0.1,
                                                    print_rate = 1, verbose = 1,
                                                    enforce_replace=False, skip_trained=False,
                                                    return_val = True, rescale_loss = 10000.0,

    )


In [None]:
AE_fn = drffit.feature_fn['default_AE0_0']['fn']

In [None]:
#print(torch.tensor(val_history).shape)
#plot_train_loss(val_history[0], cutoff = 1, use_logscale= False)

In [None]:
#device = 'cuda'
labels = ['PCA']
reconstructed_SE_AEs = [reconstructed_SE_PCAs[0]]
for SE in drffit.subspace_estimator:
    if 'AE0' in SE and '_0_3' in SE:
        print(SE)
        se_AE = drffit.subspace_estimator[SE]['subspace_estimator']
        testing_x = ensure_numpy(testing_x)
        AE_test = ensure_numpy(torch.cat([se_AE.net_list[i](ensure_torch(norm_pars(train_theta)).to(device)) for i in range(len(se_AE.net_list))],dim = 1))
        if AE_SE_norm_features:
            AE_test = ((AE_test + ensure_numpy(se_AE.feature_mean)) * ensure_numpy(se_AE.feature_std))
        reconstructed_SE_AE = ensure_numpy(AE_fn.denorm(AE_fn.out_fn(AE_fn.decoder(ensure_torch(AE_test).to(device)))))
        plot_fn_reconstruction(testing_x, reconstructed_train_x_AE, reconstructed_SE = reconstructed_SE_AE, freq = freq, rescale_plot = rescale_plot, use_log = use_log)
        reconstructed_SE_AEs.append(ensure_torch(reconstructed_SE_AE))
        labels.append(SE.split('_')[1])
rec_gof_density(ensure_torch(testing_x), ensure_torch(reconstructed_train_x_AE), name = 'AE', reconstructed_SE = reconstructed_SE_AEs, figsize = (12,8), x_lim = [0.9,1], labels = labels)

In [None]:
save_SE_AE = True
save_AE = True
AE_name = 'default_AE0_0'
if save_SE_AE:
    drffit.save_SE_as_log(
            model_path = DRFFIT_objects_path,
            file_name = DRFFIT_file_name+f'_SE_AE{num_features}LD',
            name = "default_AE2_0_3",
            enforce_replace = False,
    )
    if save_AE:
        drffit.save_AE(
                model_path=DRFFIT_objects_path,
                file_name = DRFFIT_file_name+f'_AE{num_features}LD',
                name = AE_name, enforce_replace = False
        )

# Make another DRFFIT object to host the chosen SE on CPU and save it for usage

In [None]:
#DRFFIT_objects_path = sub_dir

#num_features = 5
device_CPU = 'cpu'

# Initialize drffit object
drffit_CPU = DRFFIT(train_theta.shape[1], train_x.shape[1], theta_min, theta_range)

# Add dataset to the DRFFIT object
drffit_CPU.add_data(train_x, theta=train_theta)
drffit_CPU.initialize_subspace_estimator()
drffit_CPU.initialize_subspace_estimator("PCA")
testing_x = train_x
# PCA
pca = PCA_features(num_features)
drffit_CPU.add_custom_fn(pca, pca.num_features,"PCA")
drffit_CPU.train_feature_fn(name = "PCA")
PCA_SE_norm_features = False
drffit_CPU.uniform_add_feature_fn_to_subspace_estimator(

                        subspace_estimator_name="PCA",
                        feature_fn_name="PCA",
                        device = device_CPU,
                        enforce_replace = True,
                        combine = True

)
drffit_CPU.load_SE_from_log(

                        DRFFIT_objects_path,
                        DRFFIT_file_name+f'_SE_PCA{num_features}LD',
                        name = "PCA",
                        device = device_CPU,
                        enforce_replace = True

)

PCA_sampler = uniform_sampler(theta_min, theta_range = theta_range)
drffit_CPU.set_sampler_for_subspace(PCA_sampler, sampler_name='PCA', subspace_estimator_name="PCA", enforce_replace = True)
subspace = drffit_CPU.get_DRFFIT_subspace_from_subspace_estimator(sampler_name="PCA",name="PCA", point = parameters_alpha_peak, width = 0.1)#, target = target_PSDs[0])
print(subspace.shape)


In [None]:
PCA_fn = drffit_CPU.feature_fn['PCA']['fn']
se_PCA = drffit_CPU.subspace_estimator['PCA']['subspace_estimator']
PCA_true = ensure_numpy(PCA_fn.feature_fn(testing_x))
reconstructed_train_x_PCA = ensure_numpy(PCA_fn.pca_model.inverse_transform(PCA_true))
PCA_test = ensure_numpy(torch.cat([se_PCA.net_list[i](ensure_torch(norm_pars(train_theta)).to(device_CPU)) for i in range(len(se_PCA.net_list))],dim = 1))
if PCA_SE_norm_features:
    PCA_test = ((PCA_test + se_PCA.feature_mean) * se_PCA.feature_std)
reconstructed_SE_PCA = ensure_numpy(PCA_fn.pca_model.inverse_transform(PCA_test))
testing_x = ensure_numpy(testing_x)
plot_fn_reconstruction(testing_x, reconstructed_train_x_PCA, reconstructed_SE = reconstructed_SE_PCA, freq = freq, rescale_plot = rescale_plot, use_log = use_log)

In [None]:
# AE pretrained
drffit_CPU.add_AE(
    
            pre_trained=True,
            model_path=DRFFIT_objects_path,
            file_name = DRFFIT_file_name+f'_AE{num_features}LD',
            device = device_CPU,
            replace = True
    
 )


In [None]:
AE_SE_norm_features = False
drffit_CPU.uniform_add_feature_fn_to_subspace_estimator(
    
            device = device_CPU,
            enforce_replace = True,
            combine = True
    
)

drffit_CPU.load_SE_from_log(
    
            DRFFIT_objects_path,
            file_name = DRFFIT_file_name+f'_SE_AE{num_features}LD',
            device = device_CPU,
            enforce_replace = True
    
)

AE_sampler = uniform_sampler(theta_min, theta_range = theta_range)
drffit_CPU.set_sampler_for_subspace(AE_sampler, enforce_replace = True)
subspace = drffit_CPU.get_DRFFIT_subspace_from_subspace_estimator(point = parameters_alpha_peak, width = 0.1)
print(subspace.shape)

In [None]:
AE_fn = drffit_CPU.feature_fn['default_AE']['fn']
se_AE = drffit_CPU.subspace_estimator['default']['subspace_estimator']
reconstructed_train_x_AE = ensure_numpy(AE_fn(ensure_torch(testing_x).to(device_CPU)).to("cpu"))
testing_x = ensure_numpy(testing_x)
AE_test = ensure_numpy(torch.cat([se_AE.net_list[i](ensure_torch(norm_pars(train_theta)).to(device_CPU)) for i in range(len(se_AE.net_list))],dim = 1))
if AE_SE_norm_features:
    AE_test = ((AE_test + se_AE.feature_mean) * se_AE.feature_std)
reconstructed_SE_AE = ensure_numpy(AE_fn.denorm(AE_fn.out_fn(AE_fn.decoder(ensure_torch(AE_test).to(device_CPU)))))
plot_fn_reconstruction(testing_x, reconstructed_train_x_AE, reconstructed_SE = reconstructed_SE_AE, freq = freq, rescale_plot = rescale_plot, use_log = use_log)

In [None]:
save_log(drffit_CPU,DRFFIT_objects_path, DRFFIT_file_name+f'_object{num_features}LD', enforce_replace = False)