In [None]:
RunningInCOLAB = 'google.colab' in str(get_ipython())
if RunningInCOLAB:
    !git clone https://github.com/MJC598/Neuron_Burst_Analysis.git

In [None]:
%matplotlib widget
import torch
import torch.nn as nn
from torch.utils.data import TensorDataset, DataLoader
import numpy as np
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm
import scipy.io
import random
import time
import pandas as pds
from sklearn.metrics import r2_score

## Vanilla Auto-Encoders
<img src="images/AE.png" alt="Vanilla Auto-Encoder" width="400"/>
#### What is it?
* An unsupervised method to learn data encodings. This can allow for de-noising giving strong data representations (i.e. the paper we examined)
    - https://en.wikipedia.org/wiki/Autoencoder

#### Papers Using Vanillia Auto-Encoders in a Time Series:
* Worth noting, most of the papers I found use the AE as a dimensional reduction technique and fed into another network to do the predictive analysis. I included one of them and another paper that uses an LSTM as an encoder scheme (much more like the VAE papers were doing)
* https://journals.plos.org/plosone/article?id=10.1371/journal.pone.0180944
* https://www.nature.com/articles/s41598-019-55320-6

## Variational Auto-Encoders
<img src="images/VAE.png" alt="Variational Auto-Encoder" width="400"/>
#### What is it?
* Generative models of the AE that keep their latent space continuous. This is done by keeping track of the $\text{mean: } \mu \text{ and standard deviations: } \sigma$
     - https://towardsdatascience.com/intuitively-understanding-variational-autoencoders-1bfe67eb5daf

#### Papers Using Variational Auto-Encoders in a Time Series:
* The initial VAE paper: https://arxiv.org/pdf/1312.6114.pdf
* Recurrent VAE: https://arxiv.org/pdf/1412.6581.pdf

In [None]:
class VanillaAutoEncoder(nn.Module):
    
    def __init__(self, D_in, MinSize, D_out):
        super(VanillaAutoEncoder, self).__init__()        
        self.encoder = nn.Sequential(
            nn.Conv2d(D_in, D_in, 1)
            nn.Linear(D_in, D_in),
            nn.Conv2d(D_in, D_in,3),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(D_in/2, D_in/2,3),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(D_in/4, D_in/4,3),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(D_in/8, D_in/8,3),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(D_in/16, D_in/16,3),
            nn.ReLU()
        )
        self.decoder = nn.Sequential(
            nn.Upsample(scale_factor=2)
            nn.Linear(D_in/8, D_in/8),
            nn.ReLU(),
            nn.Upsample(scale_factor=2)
            nn.Linear(D_in/4, D_in/4),
            nn.ReLU(),
            nn.Upsample(scale_factor=2)
            nn.Linear(D_in/2, D_in/2),
            nn.ReLU(),
            nn.Upsample(scale_factor=2)
            nn.Linear(D_in, D_in),
            nn.ReLU(),
            nn.Conv2d(D_in, D_in, 1),
            nn.Linear(D_in, D_out)
        )
        
    def forward(self, x):
        #if we wanted we could break the sequential steps and add skip connections
        c = self.encoder(x) 
        d = self.decoder(c)
        return d, c