In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import torch, os
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
from wassa.wassa_plots import plot_results_std, plot_SM, plot_colored_raster
from wassa.dataset_generation import sm_generative_model, generate_dataset
from wassa.wassa_utils import train_and_plot

In [3]:
date = '2024_12_10'
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


In [4]:
class dataset_parameters():
    seed = 666
    
    N_pre = 100 # number of neurons3490735590457916
    N_timesteps = 255 # number of timesteps for the raster plot (in ms)
    N_samples = 100 # total number of samples in the dataset

    N_delays = 51 # number of timesteps in spiking motifs, must be a odd number for convolutions
    N_SMs = 4 # number of structured spiking motifs
    N_involved = N_pre*torch.ones(N_SMs) # number of neurons involved in the spiking motif
    avg_fr = 20 # average firing rate of the neurons (in Hz)
    std_fr = .1 # standard deviation for the firing rates of the different neurons
    frs = torch.normal(avg_fr, std_fr, size=(N_pre,)).abs()
    freq_sms = 16*torch.ones(N_SMs) # frequency of apparition of the different spiking motifs (in Hz)
    overlapping_sms = False # possibility to have overlapping sequences

    temporal_jitter = .1 # temporal jitter for the spike generation in motifs
    dropout_proba = 0 # probabilistic participations of the different neurons to the spiking motif
    additive_noise = .1 # percentage of background noise/spontaneous activity
    warping_coef = 1 # coefficient for time warping

    def get_parameters(self):
        return f'{self.N_pre}_{self.N_delays}_{self.N_SMs}_{self.N_timesteps}_{self.N_samples}_{self.N_involved.mean()}_{self.avg_fr}_{self.freq_sms.mean()}_{self.overlapping_sms}_{self.temporal_jitter}_{self.dropout_proba}_{self.additive_noise}_{self.warping_coef}_{self.seed}'

In [5]:
N_iter = 5
seeds = torch.arange(N_iter)
lambdaz = [0, .0001, .0005, .001, .005, .01, .05, .1]

In [6]:
def performance_as_a_function_of_lambda(world_parameters, training_parameters, date, lambdaz, N_iter = 5, seeds = None, device='cpu'):
    
    results = torch.zeros([3,N_iter,len(lambdaz),6])
    if seeds is not None:
        assert seeds.size(0)==N_iter
    else:
        seeds = torch.randint(1000,[N_iter])
    
    params_mse = training_parameters()
    params_emd = training_parameters()
    params_emd.loss_type = 'emd'
    
    file_name = f'results/{date}_performance_as_a_function_of_lambda_{world_parameters().get_parameters()}_{params_emd.get_parameters()}_{lambdaz[0]}_{lambdaz[-1]}'
    print(file_name)
    
    if os.path.isfile(file_name):
        results, lambdaz = torch.load(file_name, map_location='cpu')
    else:
        pbar = tqdm(total=len(lambdaz)*N_iter)
        for i in range(N_iter):
            world_parameters.seed = seeds[i]
            for ind_f, lambda_ in enumerate(lambdaz):
                params_mse.lambda_ = lambda_
                params_emd.lambda_ = lambda_
                sm, trainset_input, trainset_output, testset_input, testset_output = generate_dataset(world_parameters,verbose = False,device=device)
                results[0,i,ind_f,0], results[0,i,ind_f,1], results[0,i,ind_f,2], results[0,i,ind_f,3], results[0,i,ind_f,4], results[0,i,ind_f,5], _, _ = train_and_plot(sm, trainset_input, testset_input, testset_output, [params_mse], date, iteration = i, device=device)
                results[1,i,ind_f,0], results[1,i,ind_f,1], results[1,i,ind_f,2], results[1,i,ind_f,3], results[1,i,ind_f,4], results[1,i,ind_f,5], _, _ = train_and_plot(sm, trainset_input, testset_input, testset_output, [params_emd], date, iteration = i, device=device)
                results[2,i,ind_f,0], results[2,i,ind_f,1], results[2,i,ind_f,2], results[2,i,ind_f,3], results[2,i,ind_f,4], results[2,i,ind_f,5], _, _ = train_and_plot(sm, trainset_input, testset_input, testset_output, [params_emd,params_mse], date, iteration = i, device=device)
                pbar.update(1)

        pbar.close()
        torch.save([results, lambdaz], file_name)
    return results, lambdaz

## Hyperparameter tuning: optimal lambda for max cross correlation

In [7]:
class training_parameters:
    kernel_size = (dataset_parameters.N_SMs, dataset_parameters.N_pre, dataset_parameters.N_delays)
    loss_type = 'mse'
    N_learnsteps = 1000
    learning_rate = .001
    penalty_type = 'max_cc'
    smoothwind = 40
    lambda_ = .0005
    batch_size = None
    output = 'linear' 
    do_bias = True 
    zeros = 'ignore'
    wass_order = 1
    weight_init = None
    if not penalty_type:
        lambda_ = 0
    elif penalty_type[:8] != 'smoothed': 
        smoothwind = 0
    if lambda_ == 0:
        penalty_type = None
    def get_parameters(self):
        name = f'{self.loss_type}_{self.output}_{self.penalty_type}_{self.do_bias}_{self.kernel_size}_{self.N_learnsteps}_{self.learning_rate}_{self.lambda_}_{self.batch_size}_{self.smoothwind}'
        if self.loss_type == 'emd':
            name += f'_{self.zeros}_{self.wass_order}'
        return name

In [None]:
results, lambdaz = performance_as_a_function_of_lambda(dataset_parameters, training_parameters, date, lambdaz, N_iter = N_iter, seeds = seeds, device = device)

results/2024_12_10_performance_as_a_function_of_lambda_100_51_4_255_100_100.0_20_16.0_False_0.1_0_0.1_1_666_emd_linear_max_cc_True_(4, 100, 51)_1000_0.001_0.0005_None_0_ignore_1_0_0.1


  5%|███████▍                                                                                                                                            | 2/40 [00:11<03:16,  5.18s/it]

In [None]:
name_metrics = ['factors similarity', 'kernels similarity', 'mean timings\nsimilarity', 'MSE', 'EMD', 'EMD means']
name_methods = ['MSE', 'EMD', 'combined']
colors = ['darkolivegreen','blue', 'orangered']
xlabel = 'temporal jitter'

results = results.cpu()
for i in range(len(name_metrics)):
    fig, ax = plt.subplots()
    for m in range(len(name_methods)):
        ax = plot_results_std(ax,results[m,:,:,i],lambdaz,xlabel,name_metrics[i],name_methods[m],colors[m], logplot=True)
        if i<3:
            print(f'for {name_metrics[i]} with AE trained with {name_methods[m]}')
            val_max = max(results[m,:,:,i].mean(axis=0))
            ind_max = np.argmax(results[m,:,:,i].mean(axis=0))
            print(f'{val_max} is the max at lambda = {lambdaz[ind_max]}')

## Hyperparameter tuning: optimal lambda for cross correlation

In [None]:
training_parameters.penalty_type = 'cc'
results, lambdaz = performance_as_a_function_of_lambda(dataset_parameters, training_parameters, date, lambdaz, N_iter = N_iter, seeds = seeds, device = device)

In [None]:
results = results.cpu()
for i in range(len(name_metrics)):
    fig, ax = plt.subplots()
    for m in range(len(name_methods)):
        ax = plot_results_std(ax,results[m,:,:,i],lambdaz,xlabel,name_metrics[i],name_methods[m],colors[m], logplot=True)
        if i<3:
            print(f'for {name_metrics[i]} with AE trained with {name_methods[m]}')
            val_max = max(results[m,:,:,i].mean(axis=0))
            ind_max = np.argmax(results[m,:,:,i].mean(axis=0))
            print(f'{val_max} is the max at lambda = {lambdaz[ind_max]}')

## Hyperparameter tuning: optimal lambda for smoothed orthogonality

In [None]:
training_parameters.penalty_type = 'smoothed_orthogonality'
results, lambdaz = performance_as_a_function_of_lambda(dataset_parameters, training_parameters, date, lambdaz, N_iter = N_iter, seeds = seeds, device = device)

In [None]:
results = results.cpu()
for i in range(len(name_metrics)):
    fig, ax = plt.subplots()
    for m in range(len(name_methods)):
        ax = plot_results_std(ax,results[m,:,:,i],lambdaz,xlabel,name_metrics[i],name_methods[m],colors[m], logplot=True)
        if i<3:
            print(f'for {name_metrics[i]} with AE trained with {name_methods[m]}')
            val_max = max(results[m,:,:,i].mean(axis=0))
            ind_max = np.argmax(results[m,:,:,i].mean(axis=0))
            print(f'{val_max} is the max at lambda = {lambdaz[ind_max]}')

## Hyperparameter tuning: optimal lambda for kernels' orthogonality

In [None]:
training_parameters.penalty_type = 'kernels_orthogonality'
results, lambdaz = performance_as_a_function_of_lambda(dataset_parameters, training_parameters, date, lambdaz, N_iter = N_iter, seeds = seeds, device = device)

In [None]:
results = results.cpu()
for i in range(len(name_metrics)):
    fig, ax = plt.subplots()
    for m in range(len(name_methods)):
        ax = plot_results_std(ax,results[m,:,:,i],lambdaz,xlabel,name_metrics[i],name_methods[m],colors[m], logplot=True)
        if i<3:
            print(f'for {name_metrics[i]} with AE trained with {name_methods[m]}')
            val_max = max(results[m,:,:,i].mean(axis=0))
            ind_max = np.argmax(results[m,:,:,i].mean(axis=0))
            print(f'{val_max} is the max at lambda = {lambdaz[ind_max]}')

## Hyperparameter tuning: optimal lambda for sparsity

In [None]:
training_parameters.penalty_type = 'sparsity'
results, lambdaz = performance_as_a_function_of_lambda(dataset_parameters, training_parameters, date, lambdaz, N_iter = N_iter, seeds = seeds, device = device)

In [None]:
results = results.cpu()
for i in range(len(name_metrics)):
    fig, ax = plt.subplots()
    for m in range(len(name_methods)):
        ax = plot_results_std(ax,results[m,:,:,i],lambdaz,xlabel,name_metrics[i],name_methods[m],colors[m], logplot=True)

In [None]:
from seqnmf import seqnmf
from wassa import WassA
from wassa_metrics import get_similarity

def performance_seqnmf_as_a_function_of_lambda(world_parameters, training_parameters, date, lambdaz, N_iter = 5, seeds = None, device='cpu'):
    
    results = torch.zeros([1,N_iter,len(lambdaz),6])
    if seeds is not None:
        assert seeds.size(0)==N_iter
    else:
        seeds = torch.randint(1000,[N_iter])
    
    params_mse = training_parameters()
    params_emd = training_parameters()
    params_emd.loss_type = 'emd'
    file_name = f'results/{date}_performance_seqnmf_as_a_function_of_lambda{world_parameters().get_parameters()}_{params_emd.get_parameters()}_{lambdaz[0]}_{lambdaz[-1]}'
    print(file_name)
    
    if os.path.isfile(file_name):
        results, lambdaz = torch.load(file_name, map_location='cpu')
    else:
        pbar = tqdm(total=len(lambdaz)*N_iter)
        for i in range(N_iter):
            world_parameters.seed = seeds[i]
            for ind_f, lambda_ in enumerate(lambdaz):
                sm, trainset_input, trainset_output, testset_input, testset_output = generate_dataset(world_parameters,verbose = False,device=device)
                path_seqnmf_ = f'results/{date}_seqnmf_{world_parameters.get_parameters()}_{lambda_}'
                if os.path.isfile(path_seqnmf_):
                    W, H, cost, loadings, power = torch.load(path_seqnmf_)
                else:
                    seqnmf_input = torch.cat([trainset_input[i] for i in range(trainset_input.shape[0])],dim=-1).to('cpu')
                    W, H, cost, loadings, power = seqnmf(seqnmf_input, K=world_parameters.N_SMs, L=world_parameters.N_delays, Lambda=lambda_, max_iter=1000)
                    torch.save([W, H, cost, loadings, power],path_seqnmf_)

                learnt_kernels = WassA(sm.SMs.shape,weight_init=torch.tensor(W.swapaxes(0,1),dtype=torch.float32),device=device)
                if np.isnan(W).sum()==0:
                    results[0,i,ind_f,0], results[0,i,ind_f,1], results[0,i,ind_f,2], results[0,i,ind_f,3], results[0,i,ind_f,4], results[0,i,ind_f,5] = get_similarity(sm,learnt_kernels,testset_input,device=device)
                
                pbar.update(1)

        pbar.close()
        torch.save([results, lambdaz], file_name)
    return results, lambdaz

In [None]:
results, lambdaz = performance_seqnmf_as_a_function_of_lambda(dataset_parameters, training_parameters, date, lambdaz, N_iter = N_iter, seeds = seeds, device='cpu')

In [None]:
pen_types = ['max_cc', 'cc', 'smoothed_orthogonality', 'kernels_orthogonality']

results = results.cpu()
fig, ax = plt.subplots(3,4, figsize = (20,20))
for i in range(3):
    for p, pen_type in enumerate(pen_types):
        for m in range(len(name_methods)):
            lambdaz = lambdaz_both[lambdaz_ind[p]]
            ax[i,p] = plot_results_std(ax[i,p],results[m,:,p,:,i],lambdaz,xlabel,name_metrics[i],name_methods[m]+pen_type,colors[m], logplot=False)