In [None]:
!pip install livelossplot
# !pip install git+git://github.com/stared/livelossplot.git

In [None]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)
import mne
import math
import os
import matplotlib.pyplot as plt
import matplotlib.style as plt_style
import seaborn as sns
import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import pytorch_lightning as pl

import umap

plt_style.use('ggplot')

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
# for dirname, _, filenames in os.walk('/kaggle/input'):
#     for filename in filenames:
#         print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

In [None]:
from pytorch_lightning.utilities import rank_zero_only
from pytorch_lightning.loggers.base import rank_zero_experiment
from livelossplot import PlotLosses

class LiveLossLogger(pl.loggers.LightningLoggerBase):
    
    def __init__(self):
        super().__init__()
        self.liveloss = PlotLosses()

    @property
    def name(self):
        return 'MyLogger'

    @property
    @rank_zero_experiment
    def experiment(self):
        # Return the experiment object associated with this logger.
        pass

    @property
    def version(self):
        # Return the experiment version, int or str.
        return '0.1'

    @rank_zero_only
    def log_hyperparams(self, params):
        # params is an argparse.Namespace
        # your code to record hyperparameters goes here
        pass

    @rank_zero_only
    def log_metrics(self, metrics, step):
        # metrics is a dictionary of metric names and values
        # your code to record metrics goes here
        self.liveloss.update(metrics)
        self.liveloss.send()

# Feature Flags

In [None]:
RUN_UMAP = False
RUN_RNN = False
RUN_VAE = True

# Tasks

### Questions
1) How long should each epoch be?
2) Raw data + LFADS vs PSA + LFADS vs STFT + ConvVAE?

### Todo
* Exploration: Q1     - Run various epoch lengths of raw data through convolutional autoencoder and UMAP embeddings
* Exploration: Q1, Q2 - Run various epoch lengths of STFT through convolutional autoencoder and UMAP embeddings
* Exploration: Q1     - Run various epoch lengths of raw data through convolutional LSTM and UMAP embeddings

* Build LFADS network
* Q2: Run raw data through LFADS network
* Q2: Run PSAs through LFADS network

### In Progress
* Exploration: Q1, Q2 - Run various epoch lengths of PSA of individual epochs through convolutional LSTM and UMAP embeddings
* (Sayan) Exploration: STFT each epoch and run through UMAP

### Done
* Visualize raw data
* Epoch data
* (Sayan) Find alternative to Kaggle
* (Nizar) Exploration: power spectral analysis of each epoch
* (Nizar) Exploration: Run PSA through UMAP

# Dataset Preprocessing

In [None]:
data_preproc = 
fourier = np.fft.fft(data_preproc)
freq = np.fft.fftfreq(data_preproc.shape[0], d=1.0/sfreq)
if one_sided:
        return freq[(0 < freq)], power_spec[(0 < freq)]

In [None]:
mne.set_log_level("WARNING")

def ch_keep(edf_files):
    ch_occurrences = {}
    for pid, data_dict in edf_files.items():
        edf = mne.io.read_raw_edf(data_dict["data"][0])

        for ch in edf.info['ch_names']:
            if ch not in ch_occurrences:
                ch_occurrences[ch] = 1
            else:
                ch_occurrences[ch] += 1

    ch_occurrences_list = []
    for ch, occurrences in ch_occurrences.items():
        ch_occurrences_list.append({"Channel": ch, "Occurrences": occurrences})
    ch_df = pd.DataFrame(ch_occurrences_list)
    max_occ = ch_df["Occurrences"].max()
    ch_include_df = ch_df[ch_df['Occurrences']==max_occ]
    ch_include_df.reset_index(drop=True)
    return list(ch_include_df["Channel"])

def create_epochs(eeg_edf, channels_keep, preload=False, duration=5.0, overlap=2.5, bleed=0.2):
    raw = mne.io.read_raw_edf(eeg_edf)
    sfreq = raw.info['sfreq']
    powerline_freqs = np.arange(60, sfreq/2, 60)
    raw_notch = raw.load_data().copy().notch_filter(freqs=powerline_freqs) # remove powerline noise
    
    events = mne.make_fixed_length_events(raw_notch, start=0, stop=None, duration=duration+bleed, overlap=overlap+bleed)
    epochs = mne.Epochs(raw_notch, events, tmin=-1.0*bleed, tmax=duration, baseline=(None, 0.0), preload=preload)
    
    excluded = ['PHOTIC-REF','IBI','BURSTS','SUPPR']
    ch_drop = [x for x in raw.info["ch_names"] if x not in channels_keep or x in excluded]
    epochs.drop_channels(ch_drop)
    return epochs

def get_psds(epochs):
    psds,freqs=mne.time_frequency.psd_array_welch(epochs.get_data(), sfreq=epochs.info['sfreq'], average='mean')
    return psds,freqs

def get_edfs_by_patient(root_dir, n=25):
    edfs_by_patients_dict = {}
    
    pid_ind = -3
    ann_ind = -6
    num=0
    for root,d_names,f_names in os.walk(root_dir):
        for f in f_names:
            file_path = os.path.join(root, f).split('/')
            if os.path.splitext(f)[-1] == ".edf":
                if file_path[pid_ind] not in edfs_by_patients_dict:
                    if num>n:
                        return edfs_by_patients_dict
                    edfs_by_patients_dict[file_path[pid_ind]] = {"data": [os.path.join(root, f)],"annotations": file_path[ann_ind]}
                    num+=1
                #else:
                #    edfs_by_patients_dict[file_path[pid_ind]]["data"].append(os.path.join(root, f))    
                    
    return edfs_by_patients_dict


def files_df(edfs):
    label = lambda x: 1 if x=='epilepsy' else 0
    # file | ep/non-ep | pid
    dataset_list = []
    for patient in edfs:
        for f in edfs[patient]["data"]:
            dataset_list.append({'filename': f, 'patient_id': patient, 'label': label(edfs[patient]['annotations'])})
    return pd.DataFrame(dataset_list)

def epoch_df(data_df, channels_keep):
    os.makedirs("/kaggle/working/data/", exist_ok=True)
    epochs_list=[]
    for i in tqdm.tqdm(range(len(data_df))):
        entry = data_df.iloc[i]
        filename, patient, label = entry['filename'], entry['patient_id'], entry['label']
        epochs = create_epochs(filename, channels_keep, preload=True)
        psds, _ = get_psds(epochs)
        # print(psds.shape)
        for idx, psd in enumerate(psds):
            epoch_filename = "/kaggle/working/data/"+filename[9:-4].replace("/", "_") + f"_{idx}.npy"
            np.save(epoch_filename, psd)
            epochs_list.append({'filename': epoch_filename, 'patient_id': patient, 'label': label})
    epoch_dataframe = pd.DataFrame(epochs_list)
    epoch_dataframe.to_csv("/kaggle/working/data.csv")
    return epoch_dataframe

def preprocessing_pipeline(n=25):
    root_dir_ep = '../input/tuh-eeg-epilepsy-v100/edf/epilepsy/'
    root_dir_no_ep = '../input/tuh-eeg-epilepsy-v100/edf/no_epilepsy/'
    patient_edfs = get_edfs_by_patient(root_dir_ep, n)
    patient_edfs.update(get_edfs_by_patient(root_dir_no_ep, n))
    channels_keep = ch_keep(patient_edfs)
    files = files_df(patient_edfs)
    epochs = epoch_df(files, channels_keep)
    return epochs

In [None]:
patient_edfs = get_edfs_by_patient('../input/tuh-eeg-epilepsy-v100/edf/epilepsy/01_tcp_ar/003/')
patient_edfs.update(get_edfs_by_patient('../input/tuh-eeg-epilepsy-v100/edf/no_epilepsy/01_tcp_ar/027/'))
patient_edfs

In [None]:
epochs = preprocessing_pipeline(25)

In [None]:
for x in tqdm.tqdm(range(len(epochs))):
    test = np.load(epochs.iloc[x]['filename'])
print("All good.")

Save global mean and sd for normalizing

In [None]:
N=1000
rnd_idxs = np.random.choice(len(epochs), N, replace=False)
sample = np.stack([np.load(epochs.iloc[x]['filename']) for x in rnd_idxs]);print(sample.shape)
mean, sd = np.mean(sample, axis=0), np.std(sample, axis=0)
torch.save(mean, '/kaggle/working/means.pt')
torch.save(sd, '/kaggle/working/stds.pt')

# Data Exploration: Basics

In [None]:
# with open("/kaggle/input/tuh-eeg-epilepsy-v100/edf/no_epilepsy/03_tcp_ar_a/076/00007671/s002_2011_02_03/00007671_s002.txt") as example_eeg_descr_file:
#     for line in example_eeg_descr_file.readlines():
#         if line != "\n":
#             print(line)

In [None]:
example_eeg = mne.io.read_raw_edf("../input/tuh-eeg-epilepsy-v100/edf/epilepsy/01_tcp_ar/003/00000355/s003_2013_01_04/00000355_s003_t000.edf")
example_eeg.info

In [None]:
NUM_CHANNELS = len(example_eeg.info.ch_names)

sfreq = example_eeg.info['sfreq']

t_start = 0
t_end = len(example_eeg) / sfreq

# Extract data from the first 5 channels, from 1 s to 3 s.

# data, times = example_eeg[:NUM_CHANNELS,:]
data, times = example_eeg[:NUM_CHANNELS, int(sfreq * t_start):int(sfreq * t_end)]

NUM_COLS = 3
NUM_ROWS = math.ceil(NUM_CHANNELS / NUM_COLS)

fig, axs = plt.subplots(NUM_ROWS, NUM_COLS, figsize=(30, 10), sharex=True)
axs.shape

chart_row = 0
chart_col = 0
for ch_num, ch_name in enumerate(example_eeg.info.ch_names[:NUM_CHANNELS]):
    axs[chart_row, chart_col].plot(times, data[ch_num])
    axs[chart_row, chart_col].title.set_text(ch_name)
    
    chart_row += 1
    if chart_row >= NUM_ROWS:
        chart_row = 0
        chart_col += 1

# plt.xlabel('Seconds')
# plt.ylabel('$\mu V$')


# plt.show()

# Data Exploration: UMAP

In [None]:
def get_spectogram(epochs):
    freqs = np.linspace(1.0, epochs.info['sfreq'] / 2.0, epochs.times.shape[0])
    n_cycles = 2.0 # max(freqs / 2.0, 2.0)
    time_bandwidth = 3.0
#     return mne.time_frequency.tfr_morlet(epochs, freqs=freqs,
#                        n_cycles=n_cycles, return_itc=False, n_jobs=4, average=False)

    return mne.time_frequency.tfr_multitaper(epochs, freqs=freqs, n_cycles=n_cycles,
                       time_bandwidth=time_bandwidth, return_itc=False, average=False)

def UMAP(psds_of_patients_dict):
    """ run non epilepsy epoch(s) and epilepsy epoch(s) through UMAP
    """
    embeddings = {}
    for p_id, psds in psds_of_patients_dict.items():
        psds_of_patients_dict[p_id] = np.reshape(psds, (psds.shape[0], psds.shape[1]*psds.shape[2]))
    
    reducer = umap.UMAP()
    
    for p_id, psds in psds_of_patients_dict.items():
        embedding = reducer.fit_transform(psds)
        embeddings[p_id] = embedding

    return embeddings

def plot_umap(embeddings_of_patients_dict, legend=False):
    legend_labels = []
    
    for p_id, embedding in embeddings_of_patients_dict.items():
        plt.scatter(embedding[:,0], embedding[:,1])
        legend_labels.append(p_id)
    
    if legend:
        plt.legend(legend_labels)
        
def UMAP_example(edf_non_epilepsy, edf_epilepsy):
    epochs_no_ep = create_epochs(edf_non_epilepsy)
    epochs_ep = create_epochs(edf_epilepsy)    
    psd_no_ep,_ = get_psds(epochs_no_ep)
    psd_ep,_ = get_psds(epochs_ep)
    return UMAP(psd_no_ep,psd_ep)

## Exploring UMAP of EEG Power Spectra
### Comparing non-epileptic patients only and epileptic patients only

In [None]:
if RUN_UMAP:
    mne.set_log_level(verbose="WARNING")

    no_ep_files = get_edfs_by_patient("../input/tuh-eeg-epilepsy-v100/edf/no_epilepsy/01_tcp_ar/")
    ep_files = get_edfs_by_patient("../input/tuh-eeg-epilepsy-v100/edf/epilepsy/01_tcp_ar")

    no_ep_psds_by_patient = {}
    ep_psds_by_patient = {}

    # # print("---------------- EPOCHS: ----------------")
    # # print(epochs_no_ep.info)
    # # print("No epilepsy:", epochs_no_ep.get_data().shape, "Epilepsy:", epochs_ep.get_data().shape)
    # # epochs_ep.plot_image()

    for p_id, data_dict in no_ep_files.items():
        epochs = create_epochs(data_dict["data"])
        psd, _ = get_psds(epochs)
        no_ep_psds_by_patient[p_id] = psd

    for p_id, data_dict in ep_files.items():
        epochs = create_epochs(data_dict["data"])
        psd, _ = get_psds(epochs)
        ep_psds_by_patient[p_id] = psd

In [None]:
if RUN_UMAP:
    no_ep_embeddings = UMAP(no_ep_psds_by_patient)
    ep_embeddings = UMAP(ep_psds_by_patient)

In [None]:
if RUN_UMAP:
    plt.figure(figsize=(20,5))

    plt.subplot(1, 2, 1)
    plot_umap(no_ep_embeddings)
    plt.title("PSD UMAP Embeddings from Non-Epileptic Patients")

    plt.subplot(1, 2, 2)
    plot_umap(ep_embeddings)
    plt.title("PSD UMAP Embeddings from Epileptic Patients")

### Comparing non-epileptic patients vs epileptic patients

In [None]:
if RUN_UMAP:
    no_ep_embeddings_arr = np.concatenate(list(no_ep_embeddings.values()), axis=0)
    ep_embeddings_arr = np.concatenate(list(ep_embeddings.values()), axis=0)

In [None]:
if RUN_UMAP:
    plt.figure(figsize=(20,5))

    plt.subplot(1, 3, 1)
    plt.scatter(no_ep_embeddings_arr[:,0], no_ep_embeddings_arr[:,1])
    plt.scatter(ep_embeddings_arr[:,0], ep_embeddings_arr[:,1])
    plt.legend(["no ep", "ep"])

    plt.subplot(1, 3, 2)
    plt.scatter(no_ep_embeddings_arr[:,0], no_ep_embeddings_arr[:,1])
    plt.legend(["no ep"])

    plt.subplot(1, 3, 3)
    plt.scatter(ep_embeddings_arr[:,0], ep_embeddings_arr[:,1])
    plt.legend(["ep"])

## Exploring UMAP of EEG Spectograms (TBC)
### Comparing non-epileptic patients only and epileptic patients only

In [None]:
# spectogram = get_spectogram(epochs_no_ep)
# print(spectogram)
# print(spectogram.info)
# print(spectogram.data.shape)
# print(spectogram.plot([0]))

# Basic RNN Autoencoder

In [None]:
# TODO: NORMALIZE
def cross_correlation(x, y, num_channels):
    # Note that PyTorch uses cross-correlation for their convolution operator instead of actual convolution
    
    corr = F.conv1d(x, y, groups=num_channels)
        
    return corr

# cross_corr = cross_correlation(x, out, self.num_ch)
# auto_corr = cross_correlation(x, x, self.num_ch)
# loss = F.mse_loss(cross_corr, auto_corr)

In [None]:
from torch.utils.data import Dataset, DataLoader

class MNEDataset(Dataset):
    
    def __init__(self, filename, drop_channels=[]):
        self.epochs = create_epochs(filename, duration=5.0, overlap=4.5, bleed=0.2, preload=True)
        
        if len(drop_channels) > 0:
            self.epochs.drop_channels(drop_channels)
        
        self.num_ch = self.epochs.info['nchan']
        self.window_len = self.epochs[0].get_data().shape[2]
        self.sfreq = self.epochs.info['sfreq']
        self.ch_names = self.epochs.info['ch_names']
        
    def __len__(self):
        return len(self.epochs)
    
    def __getitem__(self, index):
        return torch.tensor(self.epochs[index].get_data()[0].T).float()

In [None]:
class RNNAutoEncoder(pl.LightningModule):
    def __init__(self, num_ch):
        super().__init__()
        self.num_ch = num_ch
        
        self.encoder = nn.GRU(input_size=num_ch,
                              hidden_size=num_ch,
                              num_layers=1,
                              batch_first=True)
        
        self.decoder = nn.GRU(input_size=num_ch,
                              hidden_size=num_ch,
                              num_layers=1,
                              batch_first=True)
        
        self.lr = 1.0e-3
    
    def forward(self, x):
        encoder_out, encoder_hidden = self.encoder(x)
        decoder_out, decoder_hidden = self.decoder(encoder_out, encoder_hidden)
        
        return encoder_out, encoder_hidden, decoder_out, decoder_hidden
    
    def calc_loss(self, x, out):
        similarity_func = nn.CosineSimilarity(dim=1, eps=1.0e-10)
        mse = nn.MSELoss()
        
        similarity = similarity_func(x, out)
        
        return mse(similarity, torch.ones(similarity.shape, device=self.device))
    
    def training_step(self, batch, batch_idx):
        # training_step defined the train loop.
        # It is independent of forward
        x = batch
        encoder_out, encoder_hidden, out, hidden = self(x)
        
        loss = self.calc_loss(x, out)

        # Logging to TensorBoard by default
        self.log('train_loss', loss, on_step=False, on_epoch=True, prog_bar=True)
        return loss
    
    def validation_step(self, batch, batch_idx):
        x = batch
        encoder_out, encoder_hidden, out, hidden = self(x)
        
        loss = self.calc_loss(x, out)

        self.log('val_loss', loss, on_step=False, on_epoch=True, prog_bar=True)
        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)
        return optimizer

In [None]:
if RUN_RNN:
    dataset_path = "../input/tuh-eeg-epilepsy-v100/edf/epilepsy/01_tcp_ar/003/00000355/s003_2013_01_04/00000355_s003_t000.edf"
    dataset = MNEDataset(dataset_path,
                         drop_channels=['PHOTIC-REF',
                                        'IBI',
                                        'BURSTS',
                                        'SUPPR']
                        )

    train_len = int(len(dataset)*0.95)
    val_len = len(dataset)-train_len

    train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_len, val_len])

    train_dataloader = DataLoader(train_dataset, batch_size=4, shuffle=True)
    val_dataloader = DataLoader(val_dataset, batch_size=4, shuffle=True)

In [None]:
LOAD = True
checkpoint_path = "../input/seizure-risk-assessment-results/sess_and_model.ckpt"

if RUN_RNN:
    model = RNNAutoEncoder.load_from_checkpoint(checkpoint_path=checkpoint_path, num_ch=dataset.num_ch) if LOAD else RNNAutoEncoder(dataset.num_ch)
    trainer = pl.Trainer(gpus=1,
                         #logger=LiveLossLogger(),
                         max_epochs=100,
                         callbacks=[
                             pl.callbacks.early_stopping.EarlyStopping(
                                 monitor='train_loss',
                                 mode='min',
                                 patience=5
                             )
                         ],
                        auto_lr_find=True)

In [None]:
if RUN_RNN:
    trainer.fit(model, train_dataloader=train_dataloader, val_dataloaders=val_dataloader)

In [None]:
if RUN_RNN:
    trainer.save_checkpoint("sess_and_model.ckpt")

In [None]:
def graph_expected_vs_preds(sample, out, filename=None):
    sfreq = dataset.sfreq
    NUM_CHANNELS = dataset.num_ch

    t_start = 0.0
    t_end = sample.shape[1] / sfreq # Full sample
    times = np.arange(t_start, t_end, 1.0/sfreq)

    # Extract data for t_start -> t_end seconds, for first NUM_CHANNELS channels
    expected_eeg = sample[0, int(sfreq * t_start):int(sfreq * t_end), :NUM_CHANNELS]
    predicted_eeg = out.detach().numpy()[0, int(sfreq * t_start):int(sfreq * t_end), :NUM_CHANNELS]


    NUM_COLS = 2
    NUM_ROWS = NUM_CHANNELS# math.ceil(NUM_CHANNELS / NUM_COLS)

    fig, axs = plt.subplots(NUM_ROWS, NUM_COLS, figsize=(20, 50), sharex=True)
    fig.tight_layout(pad=2.0)

    chart_row = 0
    chart_col = 0
    for ch_num, ch_name in enumerate(dataset.ch_names[:NUM_CHANNELS]):
#         axs[chart_row, chart_col].plot(times, expected_eeg[:, ch_num], label="Expected EEG")
        axs[chart_row, chart_col].plot(times, predicted_eeg[:, ch_num], label="Predicted EEG")
        axs[chart_row, chart_col].title.set_text(ch_name + " - Predicted EEG")

        axs[chart_row, chart_col+1].plot(times, expected_eeg[:, ch_num]) # - predicted_eeg[:, ch_num])
        axs[chart_row, chart_col+1].title.set_text(ch_name + " - Expected EEG")

#         if chart_row == NUM_ROWS-1:
#             handles, labels = axs[chart_row, chart_col].get_legend_handles_labels()
#             fig.legend(handles, labels, loc='upper center')

        chart_row += 1

    plt.xlabel('Seconds')
    plt.ylabel('$\mu V$')
    
    if filename:
        plt.savefig(filename, dpi=150)
    else:
        plt.show()
    

In [None]:
if RUN_RNN:
    model.eval() # To turn on training mode, model.train() - funcs from PyTorch nn.Module

    sample = next(iter(val_dataloader))
    sample = sample[0:1]

    encoder_out, encoder_hidden, out, decoder_hidden = model(sample)
    graph_expected_vs_preds(sample, out, filename="basic-rnn-validation-results.png")

In [None]:
if RUN_RNN:
    model.eval() # To turn on training mode, model.train() - funcs from PyTorch nn.Module
    
    line_height = 321
    length = 10
    sample = torch.ones((1, int(length*sfreq), NUM_CHANNELS)) + torch.unsqueeze(torch.linspace(0, line_height, int(length*sfreq))[:, None], 0)
    sample = sample[0:1]

    encoder_out, encoder_hidden, out, decoder_hidden = model(sample)
    graph_expected_vs_preds(sample, out, filename="basic-rnn-assert-line-test.png")

In [None]:
if RUN_RNN:
    model.eval() # To turn on training mode, model.train() - funcs from PyTorch nn.Module

    amplitude = 10
    freq = 80
    length = 10

    sine_wave = torch.sin(torch.linspace(0, 2*np.pi*freq*length, int(length*sfreq)))
    sample = torch.ones((1, int(length*sfreq), NUM_CHANNELS)) + torch.unsqueeze(sine_wave[:, None], 0) # next(iter(val_dataloader))
    sample = sample[0:1]

    encoder_out, encoder_hidden, out, decoder_hidden = model(sample)
    graph_expected_vs_preds(sample, out, filename="basic-rnn-assert-80Hz-sine-test.png")

# VAE with Power Spectrum Density / Spectrograms

In [None]:
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils

class PSDDataset(Dataset):
    
    def __init__(self, epochs_df, drop_channels=[], transform=None):
        self.df = epochs_df
        self.transform=transform

    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, index):
        item = self.df.iloc[index]
        x, target, pid = np.load(item['filename']), item['label'], item['patient_id']
        if self.transform!=None:
            x = self.transform(x)
        return torch.tensor(x).float(), target, pid

In [None]:
class GaussianVAE(pl.LightningModule):
    def __init__(self, input_shape, hidden=500, latent_dim=4):
        super().__init__()
        self.img_dim = np.prod(input_shape)
        self.encoder = nn.Sequential(
            nn.Linear(self.img_dim, hidden),
            nn.Tanh(),
            nn.Linear(hidden, latent_dim)
        )
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim//2, hidden),
            nn.Tanh(),
            nn.Linear(hidden, self.img_dim*2)
        )
        self.lr = 1.0e-3
    
    def unpack_var_params(self, params, d=2):
        mu, logvar = params[:, :d], params[:, d:]
        mu,logvar = mu.view(-1,d),logvar.view(-1, d)
        return mu, logvar
    
    def unpack_decoder_params(self, params):
        n = params.size(1)//2
        mu, logvar = params[:, :n], params[:, n:]
        return mu, logvar
    
    def sample(self, q_mu, q_logsigma):
        z = torch.randn_like(q_mu, device=self.device)*torch.exp(q_logsigma) + q_mu 
        return z

    def gaussian_log_loss(self, x, recon_x):
        mu_x, logvar_x = self.unpack_decoder_params(recon_x)
        part1 = torch.mean(logvar_x)
        sigma = logvar_x.mul(0.5).exp_()
        part2 = torch.mean(((x - mu_x) / sigma) ** 2)
        gll = .5 * (part1 + part2)
        return gll
    
    def KLD(self, mu, logvar):
        kld = -0.5 * torch.mean(1 + logvar - mu.pow(2) - logvar.exp())
        return kld

    def loss(self, x, x_rec, q_mu, q_logvar):
        """ 
            returns: (batch_size, h*w)
        """
        return self.gaussian_log_loss(x, x_rec) + self.KLD(q_mu, q_logvar)
            
    def forward(self, x):
        var_params = self.encoder(x) # Variational params from data
        q_mu, q_logsigma = self.unpack_var_params(var_params)
        z = self.sample(q_mu, q_logsigma) # z = batch_size x 2
        decode_out = self.decoder(z)
        return decode_out, q_mu, q_logsigma
        
    def training_step(self, batch, batch_idx):
        x, t, _ = batch
        x = x.view(-1, self.img_dim)
        x_rec, q_mu, q_logsigma = self(x)
        loss = self.loss(x, x_rec, q_mu, q_logsigma)
        self.log('train_loss', loss, on_step=False, on_epoch=True, prog_bar=True)
        return loss

    def validation_step(self, batch, batch_idx):
        x, t, _ = batch
        x = x.view(-1, self.img_dim)
        x_rec, q_mu, q_logsigma = self(x)
        loss = self.loss(x, x_rec, q_mu, q_logsigma)
        self.log('val_loss', loss, on_step=False, on_epoch=True, prog_bar=True)
        return loss
    
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)
        return optimizer
    
    
class BernoulliVAE(pl.LightningModule):
    def __init__(self, input_shape, hidden=500, latent_dim=4):
        super().__init__()
        self.img_dim = np.prod(input_shape)
        self.encoder = nn.Sequential(
            nn.Linear(self.img_dim, hidden),
            nn.Tanh(),
            nn.Linear(hidden, latent_dim)
        )
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim//2, hidden),
            nn.Tanh(),
            nn.Linear(hidden, self.img_dim)
        )
        self.lr = 1.0e-3
        
    def gaussian_logprob(self, x, mu=None, log_std=None, eps=1e-7):
        if mu==None:
            mu = torch.tensor([0], device=self.device).float()
        if log_std==None:
            log_std = mu
        return -( 0.5*torch.log(2*torch.tensor([np.pi], device=self.device).float()) + log_std + 0.5*((x-mu)/(torch.exp(log_std)+eps))**2)
    
    def bernoulli_log_density(self, x, logit_means): # TODO: change this to gaussian
        """ x: (batch_size, h*w)
            logit_means: (batch_size, h*w)
            returns: (batch_size, h*w)
            p^x*(1-p)^(1-x) --> xlogp + (1-x)log(1-p)
        """
        b = x*2-1 # [0,1] -> [-1,1]
        return -torch.log1p(torch.exp(-b*logit_means))
    
    def log_prior(self, z):
        """z: (batch_size, 2)
        returns: (batch_size, 1)"""
        return 2*self.gaussian_logprob(z) # since 2D gaussian ~ N(0, 2I)

    def joint_log_density(self, x, z, y):
        """ x: (batch_size, h*w)
            z: (batch_size, 2)
            returns: (batch_size, 1)
        """
        l_prior = self.log_prior(z)
        ll = self.bernoulli_log_density(x, y)
        return l_prior.sum(axis=1) + ll.sum(axis=1)

    def log_q(self, z, q_mu, q_logsigma):
        """ z: (batch_size, 2)
            q_mu: (batch_size, 2)
            q_logsigma: (batch_size, 2)
            returns: (batch_size, 1)
        """
        return self.gaussian_logprob(z, q_mu, q_logsigma).sum(axis=1)
    
    def sample(self, q_mu, q_logsigma):
        z = torch.randn_like(q_mu, device=self.device)*torch.exp(q_logsigma) + q_mu 
        return z

    def unpack_var_params(self, params, d=2):
        mu, logvar = params[:, :d], params[:, d:]
        mu,logvar = mu.view(-1,d),logvar.view(-1, d)
        return mu, logvar
    
    def neg_elbo(self, joint_ll, log_q_z):
        """ x: (batch_size, h*w)
            returns: scalar
        """
        elbo_estimate = torch.mean(joint_ll - log_q_z) # scalar, mean variational ELBO over batch
        return -elbo_estimate
    
    def forward(self,x):
        var_params = self.encoder(x) # Variational params from data
        q_mu, q_logsigma = self.unpack_var_params(var_params)
        z = self.sample(q_mu, q_logsigma) # z = batch_size x 2
        y = self.decoder(z)
        return y, z, q_mu, q_logsigma
    
    def training_step(self, batch, batch_idx):
        x, t, _ = batch
        x = x.view(-1, self.img_dim)
        y, z, q_mu, q_logsigma = self(x)
        joint_ll = self.joint_log_density(x, z, y) # joint likelihood of z and x under model -- batch_size x 1
        log_q_z = self.log_q(z, q_mu, q_logsigma) # likelihood of z under variational distribution -- batch_size x 1
        loss = self.neg_elbo(joint_ll, log_q_z)
        self.log('train_loss', loss, on_step=False, on_epoch=True, prog_bar=True)
        return loss
    
    def validation_step(self, batch, batch_idx):
        x, t, _ = batch
        x = x.view(-1, self.img_dim)
        y, z, q_mu, q_logsigma = self(x)
        joint_ll = self.joint_log_density(x, z, y) # joint likelihood of z and x under model -- batch_size x 1
        log_q_z = self.log_q(z, q_mu, q_logsigma) # likelihood of z under variational distribution -- batch_size x 1
        loss = self.neg_elbo(joint_ll, log_q_z)
        self.log('val_loss', loss, on_step=False, on_epoch=True, prog_bar=True)
        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)
        return optimizer

In [None]:
if RUN_VAE:
    MU, SD = torch.load('/kaggle/working/means.pt'), torch.load('/kaggle/working/stds.pt')
    dataset = PSDDataset(epochs,
                         transform=torchvision.transforms.Compose([
                                      torchvision.transforms.ToTensor(),
                                      lambda x: (x-MU)/SD, # > 0, # binarize
                                      lambda x: x.float()
                                  ])
                        )

    train_len = int(len(dataset)*0.90)
    val_len = len(dataset)-train_len
    train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_len, val_len])

    train_dataloader = DataLoader(train_dataset,batch_size=64, shuffle=True)
    val_dataloader = DataLoader(val_dataset,batch_size=64, shuffle=False)

In [None]:
LOAD = False
checkpoint_path = "/kaggle/working/sess_and_model_vae.ckpt"
psd_shape = np.array([24, 129])
VAE = {'gaussian': GaussianVAE, 'ber': BernoulliVAE}
ll = 'gaussian'
if RUN_VAE:
    model = VAE.load_from_checkpoint(checkpoint_path=checkpoint_path, input_shape=psd_shape) if LOAD else VAE[ll](psd_shape)
    trainer = pl.Trainer(gpus=1,
                         logger=LiveLossLogger(),
                         max_epochs=100,
                         callbacks=[
                             pl.callbacks.early_stopping.EarlyStopping(
                                 monitor='train_loss',
                                 mode='min',
                                 patience=5
                             )
                         ],
                        auto_lr_find=True)
    trainer.fit(model, train_dataloader=train_dataloader, val_dataloaders=val_dataloader)
    trainer.save_checkpoint("/kaggle/working/sess_and_model_vae.ckpt")
   # torch.save(model, "/kaggle/working/vae.pt")

# Latent Distribution of Batch

###  Training set

In [None]:
import seaborn as sns

def plot_latent_distribution(model, dataloader, legend="epilepsy"):
    zs=[]
    labels=[]
    pat=[]
    for idx, (x, t, pid) in enumerate(dataloader):
        x_rec, q_mu, q_ls = model(x.view(-1,model.img_dim))
        q_mu = q_mu.detach().numpy()
        zs.append(q_mu)
        labels.append(t)
        pat.append(pid)
    zs = np.vstack(zs);print(zs.shape)
    labels=np.concatenate(labels)
    pat = np.concatenate(pat)
    sns.set(rc={'figure.figsize':(10,8)})
    if legend=='epilepsy':
        colors = labels
    else:
        colors = pat
    sns.scatterplot(x=zs[:,0], y=zs[:,1], hue=colors, palette='deep', legend='full')    
    return zs, labels

if RUN_VAE:
    model.eval()
    zs, labels = plot_latent_distribution(model, train_dataloader)

In [None]:
if RUN_VAE:
    model.eval()
    zs, labels = plot_latent_distribution(model, train_dataloader, legend="patients")

### Validation set

In [None]:
if RUN_VAE:
    model.eval()
    zs, labels = plot_latent_distribution(model, val_dataloader)

In [None]:
if RUN_VAE:
    model.eval()
    zs, labels = plot_latent_distribution(model, val_dataloader, legend="patients")

# Visualize Reconstructed PSDs with Trained VAE

In [None]:
import matplotlib.pyplot as plt

def transform_logit(y):
    return torch.exp(y)/(1+torch.exp(y))

if RUN_VAE:
    model.eval() # To turn on training mode, model.train() - funcs from PyTorch nn.Module
    h,w = psd_shape
    N = 25
    img = np.zeros((N*h, w*2))
    img2 = np.zeros((N*h, w))
    for idx, (sample, _, _) in enumerate(train_dataloader):
        if idx==25:
            break
        x = sample[0]
        x = x.view(-1, model.img_dim)
        x_rec, _,_ = model(x)
        #img[idx*h:(idx+1)*h, :w] = sample[0]
        img[idx*h:(idx+1)*h, w:] = x_rec.view(-1, h, w).detach().numpy()[1] # > 0
        img[idx*h:(idx+1)*h, :w] = sample[0]  #> 0
    fig = plt.figure(figsize=(12,24))
    plt.imshow(img,vmin=-1,vmax=1)

# BernoulliVAE

In [None]:
if RUN_VAE:
    MU, SD = torch.load('/kaggle/working/means.pt'), torch.load('/kaggle/working/stds.pt')
    dataset = PSDDataset(epochs,
                         transform=torchvision.transforms.Compose([
                                      torchvision.transforms.ToTensor(),
                                      lambda x: (x-MU)/SD > 0, # binarize
                                      lambda x: x.float()
                                  ])
                        )

    train_len = int(len(dataset)*0.90)
    val_len = len(dataset)-train_len
    train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_len, val_len])

    train_dataloader = DataLoader(train_dataset,batch_size=64, shuffle=True)
    val_dataloader = DataLoader(val_dataset,batch_size=64, shuffle=False)

    LOAD = False
    checkpoint_path = "/kaggle/working/sess_and_model_vae.ckpt"
    psd_shape = np.array([24, 129])
    VAE = {'gaussian': GaussianVAE, 'ber': BernoulliVAE}
    ll = 'ber'

    model_ber = VAE.load_from_checkpoint(checkpoint_path=checkpoint_path, input_shape=psd_shape) if LOAD else VAE[ll](psd_shape)
    trainer = pl.Trainer(gpus=1,
                         logger=LiveLossLogger(),
                         max_epochs=100,
                         callbacks=[
                             pl.callbacks.early_stopping.EarlyStopping(
                                 monitor='train_loss',
                                 mode='min',
                                 patience=5
                             )
                         ],
                        auto_lr_find=True)
    trainer.fit(model_ber, train_dataloader=train_dataloader, val_dataloaders=val_dataloader)
    trainer.save_checkpoint("/kaggle/working/sess_and_model_vae.ckpt")
   # torch.save(model_ber, "/kaggle/working/vae.pt")

### Training set Latent Distribution

In [None]:
def plot_latent_distribution(model_ber, dataloader, legend='epilepsy'):
    zs=[]
    labels=[]
    pat=[]
    for idx, (x, t, pid) in enumerate(dataloader):
        x = x.view(-1, psd_shape[0]*psd_shape[1])
        logit_means, z, q_mu, q_ls = model_ber(x)
        q_mu = q_mu.detach().numpy()
        zs.append(q_mu)
        labels.append(t)
        pat.append(pid)
    zs = np.vstack(zs);print(zs.shape)
    labels=np.concatenate(labels)
    pat = np.concatenate(pat)
    sns.set(rc={'figure.figsize':(10,8)})
    if legend=='epilepsy':
        colors = labels
    else:
        colors = pat
    sns.scatterplot(x=zs[:,0], y=zs[:,1], hue=colors, palette='deep', legend='full')
    return zs, labels

if RUN_VAE:
    model_ber.eval()
    q_mu, labels = plot_latent_distribution(model_ber, train_dataloader)

In [None]:
if RUN_VAE:
    model_ber.eval()
    zs, labels = plot_latent_distribution(model_ber, train_dataloader, legend="patients")

### Validation set Latent Distribution

In [None]:
if RUN_VAE:
    model_ber.eval()
    q_mu, labels = plot_latent_distribution(model_ber, val_dataloader)

In [None]:
if RUN_VAE:
    model_ber.eval()
    zs, labels = plot_latent_distribution(model_ber, val_dataloader, legend="patients")

In [None]:
def transform_logit(y):
    return torch.exp(y)/(1+torch.exp(y))

if RUN_VAE:
    model_ber.eval() # To turn on training mode, model.train() - funcs from PyTorch nn.Module
    h,w = psd_shape
    N = 25
    img = np.zeros((N*h, w*2))
    img2 = np.zeros((N*h, w))
    for idx, batch in enumerate(train_dataloader):
        if idx==25:
            break
        sample, t, pid = batch
        x = sample[0]
        x = x.view(-1, model_ber.img_dim)
        y,_,_,_ = model_ber(x)
        mu = transform_logit(y)
        #img[idx*h:(idx+1)*h, :w] = sample[0]
        img[idx*h:(idx+1)*h, w:] = mu.view(-1, h, w).detach().numpy()
        img[idx*h:(idx+1)*h, :w] = sample[0]
    fig = plt.figure(figsize=(12,24))
    plt.imshow(img)

# Latent Interpolation Along Lattice

In [None]:
def plot_learned_latent_space(model, n=50, z0=(-6, 6), z1=(-6, 6)):
    h, w = 24, 129
    img = np.zeros((n*h, n*w))
    for i, y in enumerate(np.linspace(*z1, n)):
        for j, x in enumerate(np.linspace(*z0, n)):
            z = torch.Tensor([[x, y]])
            x_ber = transform_logit(model.decoder(z))
            x_ber = x_ber.reshape(h, w).detach().numpy()
            img[(n-1-i)*h:(n-1-i+1)*h, j*w:(j+1)*w] = x_ber
    fig=plt.figure(figsize=(30,30))
    plt.imshow(img, extent=[*z0, *z1])

if RUN_VAE:
    model_ber.eval()
    plot_learned_latent_space(model_ber, n=15)

In [None]:
model_ber