In [None]:
RunningInCOLAB = 'google.colab' in str(get_ipython())
if RunningInCOLAB:
    !git clone https://github.com/MJC598/Neuron_Burst_Analysis.git

In [None]:
%matplotlib widget
import torch
import torch.nn as nn
from torch.utils.data import TensorDataset, DataLoader
import numpy as np
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm
import scipy.io
import random
import time
import pandas as pds
from sklearn.metrics import r2_score

from abc import abstractmethod

torch.manual_seed(0)

In [None]:
class VAE(nn.Module):
    def __init__(self, device):
        super(VAE, self).__init__()
        if device == None:
            self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        else:
            self.device = device
        
    def generate(self, x):
        raise NotImplementedError
        
    """
    KL Divergence and Reconstruction derived from:
    https://towardsdatascience.com/variational-autoencoder-demystified-with-pytorch-implementation-3a06bee395ed
    """
    def kl_div(self, sample, mean, std):
        p = torch.distributions.Normal(torch.zeros_like(mean), torch.ones_like(std)).to(self.device)
        q = torch.distributions.Normal(mean, std).to(self.device)

        log_qzx = q.log_prob(sample)
        log_pz = p.log_prob(sample)

        kl = (log_qzx - log_pz)
        return kl.sum(-1)
    
    def reconstruction(self, x_hat, x, dims=(1,2,3)):
        scale = torch.exp(nn.Parameter(torch.Tensor([0.0])).to(self.device)).to(self.device)
        dist = torch.distributions.Normal(x_hat, scale).to(self.device)
        return dist.log_prob(x).sum(dim=dims)
        
    def ELBOLoss(self, x, sample, mean, std, x_hat):
        return (self.kl_div(sample, mean, std) - self.reconstruction(x_hat, x)).mean()
        
    @abstractmethod
    def forward(self, x):
        pass
    

In [None]:
class LSTM_VAE(VAE):
    def __init__(self, input_size, hidden_size, latent_size, num_layers, device=None):
        super(LSTM_VAE, self).__init__(device)
        self.en_lstm = nn.LSTM(input_size=input_size,hidden_size=hidden_size,
                            num_layers=num_layers,batch_first=True)
        self.en_h0 = torch.randn(num_layers, batch_size, hidden_size).to(super.device)
        self.en_c0 = torch.randn(num_layers, batch_size, hidden_size).to(super.device)
        self.en_mean = nn.Linear(hidden_size, latent_size)
        self.en_var = nn.Linear(hidden_size, latent_size)
        self.de_fc = nn.Linear(latent_size, input_size)
        self.relu = nn.ReLU()
        self.de_lstm = nn.LSTM(input_size=input_size,hidden_size=hidden_size,
                            num_layers=num_layers,batch_first=True)
        self.de_h0 = torch.randn(num_layers, batch_size, hidden_size).to(super.device)
        self.de_c0 = torch.randn(num_layers, batch_size, hidden_size).to(super.device)
        
    def reparameterization(self, mean, var):
        e = torch.rand_like(var).to(super.device)
        return mean + (var*e)
        
    def forward(self, x):
        x, (h_n, c_n)  = self.en_lstm(x,(self.en_h0,self.en_c0))
        mean = self.en_mean(x)
        log_var = self.en_var(x)
        var = torch.exp(0.5 * log_var)
        z = self.reparameterization(mean, var)
        out = self.relu(self.de_fc(z))
        x_hat, (h_d, c_d) = self.de_lstm(out,(self.de_h0, self.de_c0))
        return x_hat, mean, log_var
        

In [None]:
class GRU_VAE(VAE):
    def __init__(self):
        super(GRU_VAE, self).__init__()
    
    def forward(self, x):
        return x