<h1> In this notebook we integrate the zero responses to get the likelihood for the original data. </h1>

In [11]:
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import torch

from neuraldistributions.models import transforms
from neuraldistributions.datasets import static

from multiprocessing import Pool

import os
device = "cuda"

In [12]:
import numpy as np
import torch

from neuraldistributions.datasets import mouse_static_loaders, extract_data_key
from neuraldistributions.models import poisson, zig, flowfa, ziffa, flowfa_ident
from neuraldistributions.trainers import base_trainer
from neuraldistributions.utility import get_loglikelihood

from torch.distributions import LowRankMultivariateNormal

from tqdm import tqdm

In [13]:
random_seed=42
dataset_dir = "../project/data"
datasets = [
    "static_edited_dsampled.zip",
]
scan_id = [2, 1]
dataset_paths = [f"{dataset_dir}/{dataset}" for dataset in datasets]
areas = [["V1", "LM"]]
neurons_ns = [1000, 907]

In [14]:
dataset_index = 0
dataset_path = dataset_paths[dataset_index]
data_key = extract_data_key(dataset_path)
area = areas[dataset_index]
neurons_n = neurons_ns[dataset_index]

dataset_config = {
    "paths": [dataset_path],
    "seed": random_seed,
    "batch_size": 64,
    "area": area,
    "neurons_n": neurons_n,
    "normalize_images": True,
    "normalize_neurons": True,
    "return_more": True,
    "device": device,
    "shuffle_train": True,
    "return_real_responses": False
}

dataloaders = mouse_static_loaders(**dataset_config)

100%|██████████| 5994/5994 [00:05<00:00, 1002.79it/s]
100%|██████████| 5994/5994 [00:00<00:00, 624628.14it/s]
100%|██████████| 5994/5994 [00:00<00:00, 636778.66it/s]
100%|██████████| 5994/5994 [00:00<00:00, 667462.92it/s]


In [21]:
from neuralpredictors.training import LongCycler
def calcLossForDataset(model, dataset, neurons, samples_amount=100, in_bits=False):
    model.eval()
    with torch.no_grad():
        losses = 0
        samples_count = 0
        for batch_idx, (data_key, batch) in enumerate(LongCycler(dataset)):
            # data from batch
            #print(batch_idx)
            targets = batch[1]
            images = batch[0]
            
            # latent for log_likelihood
            mu = model.forward(images)
            
            if "Ident" in model.__class__.__name__:
                dist = torch.distributions.multivariate_normal.MultivariateNormal(mu, torch.eye(1000).to(mu.device))
            else:
                C, psi_diag = model.C_and_psi_diag
                dist = LowRankMultivariateNormal(mu, C.T, psi_diag)
            
            # get zero and non zero responses
            idx = targets <= torch.tensor(0)
            n_idx = torch.logical_not(idx)
            
            
            # calculate log_likelihood
            samples = torch.FloatTensor(samples_amount, targets.shape[0], 1000).uniform_(-1, 0)
            #importance_sample_dist = torch.distributions.exponential.Exponential(torch.tensor(4.5).expand(1000))
            #samples = -importance_sample_dist.sample([samples_amount, targets.shape[0]])
            samples[:,n_idx] = targets[n_idx].cpu()    
        
        
            
            transformed_targets, logdet = model.sample_transform(samples.to(device))
            
            log_likelihood = dist.log_prob(transformed_targets.detach()) + logdet.detach().sum(dim=2)
            #print(log_likelihood.shape)
            # for uniform samples between -x,0
            loss_neurons = torch.logsumexp(log_likelihood,dim=0) - torch.log(torch.tensor(samples_amount).float()) + idx.sum(dim=1)*torch.log(torch.tensor(1.0))
            loss = -torch.sum(loss_neurons)
            
            #print(loss_neurons.shape)
            
            # importance sampling
            #loss = -torch.sum(torch.logsumexp(log_likelihood-importance_sample_dist.log_prob(-samples).to(device).sum(dim=2),dim=0) - torch.log(torch.tensor(samples_amount).float()))
          
            # the old loss for uniform samples between -1,0 (this might be wrong because of the log in the sum)
            #loss = -torch.sum(torch.mean(log_likelihood, dim=0))
            #print(loss)
            #print(samples.shape)
            #print(transformed_targets.shape)
            #print(log_likelihood.shape)
            #print(logdet.shape)
            #print(dist.log_prob(transformed_targets.detach()).shape)
            #return
            losses += loss.item()
            samples_count += len(batch[0])
            del samples, transformed_targets, logdet, log_likelihood, targets, images

        return losses / samples_count / neurons if in_bits==False else losses / samples_count / neurons / np.log(2)

In [22]:
calcLossForDataset(torch.load("./models/FlowFA"), dataloaders["train"], 1000, 10, True)

2.167603544449458

In [None]:
for i, file in enumerate((os.listdir("./models"))):
    if file == "results.txt" or file == ".ipynb_checkpoints" or file=="FlowFA-Original":
        continue
    model = torch.load(f"./models/{file}")
    print(file)
    for samples in [10, 100, 1000, 5000]:
        print(f"\tSample size: {samples}")
        print(f"\t\tTrain loss: {calcLossForDataset(model, dataloaders['train'], 1000, samples,True)}\n\t\tVal loss: {calcLossForDataset(model, dataloaders['validation'], 1000, samples,True)}\n\t\tTest loss: {calcLossForDataset(model, dataloaders['test'], 1000, samples,True)}")
    del model