## Import

In [None]:
import os, sys, inspect, time

import numpy as np
import torch 
import matplotlib.pyplot as plt
torch.multiprocessing.set_sharing_strategy('file_system')

import discrepancy, visualization
from algorithms import ABC_algorithms, TPABC, SMCABC, SMC2ABC, SNLABC, SNL2ABC
import distributions 
import scipy.stats as stats

import utils_os, utils_math

%load_ext autoreload
%autoreload 2

## Problem Definition

In [None]:
from problems.ABC_problems import ABC_Problem

class Neuronal_Problem(ABC_Problem):
    
    def __init__(self, data, N=100, n=100):
        
        assert N <= data['Y'].shape[0]
        assert data['Y'].ndim == 2
        assert data['X'].shape[0] == data['Y'].shape[0]
        
        self.N = N # number of posterior samples
        self.n = n # length of the data vector x = {x_1, ..., x_n} # makes sense to make it ~num_trials
#         self.d = 5 # this argument is ignored... set hidden_ratio instead. most likely it's dims of sufficient statistics: d=2K
        self.prior_args = np.array([[0,1]]) # these are bounds on theta (on X in our case: [0,1])
        
        self.all_thetas = data['X']
        self.sim_accuracy = 5 # number of digits after a decimal point for theta
        self.sim = {np.round(data['X'][i],self.sim_accuracy): data['Y'][i] for i in range(data['X'].shape[0])} #here we use all!
        self.K = 1 # number of thetas
        self.stat = 'raw' # raw means that sufficient statistics is unknown (I guess). y_obs = data_obs
        
        self.data_obs = data['Y'] #important that first dim=N & y_dim = product of these dims
        # y_obs is calculated from these data as y=statistics(data). 
        # note that y_obs is a argument of a Algorithm class, not Problem (¯\_(ツ)_/¯)
        
        self.is_batch_sampling_supported = False # (unfinished feature, so keep False for now) speed up rejection sampling
    
    def get_true_theta(self):
        pass # does not matter, as the result goes into 'statistics', where theta is currently not used

    def sample_from_prior(self, size=1):
        return np.random.choice(self.all_thetas,size=size,replace=True) # just 1 sample
    
    # original code samples only 1 theta in each simulation -> generates n x-es -> 
    # calculates statistics for them (1 vector for 1 theta) -> repeats ~1000 times sequentially (!)
    def simulator(self, theta):
        assert theta.size==1
        y = np.empty((self.n,self.data_obs.shape[1])) 
        for i in range(self.n): 
            t = np.round(theta[0] + (np.random.rand()-0.5)*0.002,self.sim_accuracy) # add jitter, to sample from the neighbouring locations
            if t in self.sim:
                y[i] = self.sim[t]
            else: # this part is used for newly-generated samples; let's take the Y=Y(closest X).
                discr = np.abs(self.all_thetas - t) # get distances
                y[i] = self.sim[np.round(self.all_thetas[np.argmin(discr)],self.sim_accuracy)] # take the closest
        return y # self.n x number of dimensions in data

    # B. correlation between latent
    def _ss_corr(self, Z):
        V = np.mat(Z).T * np.mat(Z) / Z.shape[0]
        (d,d) = V.shape
        upper_tri_elements = V[np.triu_indices(d, k=1)]
        stat = np.array(upper_tri_elements)
        return stat
    
    def statistics(self, data, theta=None):
        if self.stat == 'raw':
            # (correlation) as summary statistics (NO MARGINALS in these data)
            stat = self._ss_corr(data)
            return stat
        else:
            raise NotImplementedError('No ground truth statistics')

In [None]:
import pickle as pkl
with open(f'/home/nina/CopulaGP/plos_fig5_data/ST260_Day1_Dataset.pkl',"rb") as f:
    data = pkl.load(f)
    
Nvar = 109 # taking the first N variables here
data['Y'] = data['Y'][:,:Nvar] 
print(data['Y'].shape) # samples x neuronal/behavioral variables
    
problem = Neuronal_Problem(data)

DIR = 'results/Neuronal' 

In [None]:
### Sequential Neural Likelihood + 
hyperparams = ABC_algorithms.Hyperparams()
hyperparams.save_dir = DIR
hyperparams.device = 'cuda:0'
hyperparams.num_sim = 1000                        # number of simulations
hyperparams.L = 5                                # number of learning rounds
hyperparams.hidden_ratio = 0.1                   # dimensionality of S(x)
hyperparams.type = 'plain'                       # the network architecture of S(x), use CNN here
hyperparams.estimator = 'DV'                    # MI estimator; JSD or DC, see the paper
# 'DV' = proper MINE from Belghazi 2018
hyperparams.nde = 'MAF'                          # nde; MAF (D>1) or MDN (D=1) # looks like D here is in fact K

snl2_abc = SNL2ABC.SNL2_ABC(problem, discrepancy=discrepancy.eculidean_dist, hyperparams=hyperparams)


In [None]:
snl2_abc.run()


In [None]:
# let us check that the prior did not collapse 
theta = np.empty(1000)
for i in range(len(theta)): 
    theta[i] = snl2_abc.prior()
plt.xlim([0,1])
plt.hist(theta)

In [None]:
# visualize latents s(x)
nbins=100
stats2plot = []
for i in range(nbins):
    mask = (data['X']>i/nbins) & (data['X']<=(i+1)/nbins)
    get_stat = snl2_abc.convert_stat(snl2_abc.problem.statistics(data['Y'][mask]))
    stats2plot.append(get_stat)
# np.array(stats2plot).shape
from sklearn.manifold import TSNE
X_embedded = TSNE(n_components=2).fit_transform(np.array(stats2plot).squeeze())
X_embedded.shape

In [None]:
import matplotlib.cm as cm
plt.scatter(*X_embedded.T,color=cm.rainbow(np.linspace(0,1,nbins)))
plt.scatter(*X_embedded[int(nbins*60/160):int(nbins*120/160)].T,marker='x',color='k') #late part of the corridor marked

In [None]:
# calculate MI using all generated subsamples
all_stats = torch.tensor(np.vstack(snl2_abc.all_stats[0:snl2_abc.l+1])).float()
all_samples = torch.tensor(np.vstack(snl2_abc.all_samples[0:snl2_abc.l+1])).float()
print(all_samples.shape)
snl2_abc.vae_net.MI(all_stats,all_samples,n=100) # n here is the number of shuffles

In [None]:
# calculate MI using the last generated subsamples
all_stats = torch.tensor(np.vstack(snl2_abc.all_stats[snl2_abc.l:snl2_abc.l+1])).float()
all_samples = torch.tensor(np.vstack(snl2_abc.all_samples[snl2_abc.l:snl2_abc.l+1])).float()
print(all_samples.shape)
snl2_abc.vae_net.MI(all_stats,all_samples,n=100) # n here is the number of shuffles

In [None]:
# calculate MI using newly picked samples 
# (here: statistics of variable size, but pooled from the fixed neighbourhood)
new_stats = []
new_samples = []
for i in range(1000): # sample 1000 theta samples with replacement
    theta = snl2_abc.problem.sample_from_prior()
    mask = (data['X']>theta-1e-3) & ((data['X']<=theta+1e-3)) # we'll gather statistics from the neighbourhood of theta
    stats = snl2_abc.problem.statistics(data['Y'][mask]) # the number of samples is variable here, but the size of the neighbourhood is fixed
    samples = data['X'][mask].mean() # take mean theta from the neighbourhood (could as well just take theta)
    new_stats.append(stats)
    new_samples.append(samples) #theta 
new_stats = torch.tensor(new_stats).float().squeeze()
new_samples = torch.tensor(new_samples).float().reshape((-1,1))
print(new_stats.shape,new_samples.shape)
snl2_abc.vae_net.MI(new_stats,new_samples,n=100) # n here is the number of shuffles

# MINE for comparison

In [None]:
from mine import train_MINE # load another implementation, the one I used for PLoS

In [None]:
train_MINE(data['Y'][:,:Nvar], x=torch.tensor(data['X'][:]).float(), 
           H=1000, lr=0.01, batches=1, n_epoch=2000, device = torch.device("cuda:0"))