In [None]:
import torch
from torch import nn
from torch.autograd import Variable

import losswise

from prettytable import PrettyTable
from tqdm import tqdm
import numpy as np
import os
import sys
import pickle
import random

from datasets import BurstDataset, ShuffledBatchSequentialSampler, FakeBurstDataset
from prep_dataset import BurstDatasetStandardizer
from models import Encoder, Decoder
from eval_functions import plot_autoencoding, autoencode, encode

import matplotlib.pyplot as plt

sys.path.append('../')
import utils
from readers.patient_info import PatientInfo

### Load datasets

In [None]:
SAVE_DIR = 'saved_encs/config18/'

In [None]:
params_dict = pickle.load(open(os.path.join(SAVE_DIR, 'params_dict.pkl')))
(train_dataset, dev_dataset, test_dataset) = pickle.load(open(os.path.join(SAVE_DIR, 'datasets.pkl')))
standardizer = pickle.load(open(os.path.join(SAVE_DIR, 'standardizer.pkl')))

### Load models

In [None]:
torch.manual_seed(1)
np.random.seed(1)
random.seed(1)

In [None]:
HIDDEN_SIZE = params_dict['hidden_size']
INPUT_SIZE = 1 # This CANNOT be changed! 
BIDIRECTIONAL = params_dict['bidirectional']
NUM_LAYERS = params_dict['num_layers']
EXTRA_INPUT_DIM = params_dict['extra_input_dim']

encoder = Encoder(INPUT_SIZE, HIDDEN_SIZE, bidirectional=BIDIRECTIONAL, num_layers=NUM_LAYERS)
decoder = Decoder(HIDDEN_SIZE, INPUT_SIZE, extra_input_dim=EXTRA_INPUT_DIM, encoder_bidirectional=BIDIRECTIONAL, 
                  num_layers=NUM_LAYERS)
if torch.cuda.is_available():
    encoder = encoder.cuda()
    decoder = decoder.cuda()

In [None]:
NUM_EPOCHS = params_dict['num_epochs']
use_epoch = 278 # by defualt, use the last epoch
encoder.load_state_dict(torch.load(os.path.join(SAVE_DIR, 'epoch{}_enc.pkl'.format(use_epoch))))
decoder.load_state_dict(torch.load(os.path.join(SAVE_DIR, 'epoch{}_dec.pkl'.format(use_epoch))))
encoder.eval()
decoder.eval()

## Plot the autoencoding

In [None]:
TRAIN_REVERSED = params_dict['train reversed']

In [None]:
len(train_dataset), len(test_dataset)

In [None]:
downsample_factor = params_dict['downsample_factor']
robust = params_dict['robust_scale']

In [None]:
%matplotlib inline
dataset = test_dataset
for i in range(len(dataset)-15, len(dataset)):
    sample = dataset[i]
    try:
        # for updated datasets
        undownsampled = dataset.get_undownsampled_item(i, standardizer, robust)
    except AttributeError:
        print('old dataset')
        # for old datasets
        undownsampled = None
    mse = plot_autoencoding(sample, encoder, decoder, toss_encoder_output=False, 
                            reverse=TRAIN_REVERSED, undownsampled=undownsampled)

## Plot encodings in 2d for all patients

In [None]:
import plotly
import plotly.figure_factory
from plotly.graph_objs import *
from sklearn.decomposition import IncrementalPCA
from sklearn.preprocessing import StandardScaler
from tqdm import tqdm
plotly.offline.init_notebook_mode(connected=True)

In [None]:
def encode_and_cat(bursts, masks):
    out, hidden, cell = encode(bursts, masks, encoder)
    catted_hidden = torch.cat([hidden[i, :, :] for i in range(hidden.size(0))], dim=1)
    catted_cell = torch.cat([cell[i, :, :] for i in range(hidden.size(0))], dim=1)
    encodings = torch.cat([catted_hidden, catted_cell], dim=1)
    return encodings.data.cpu().numpy()

In [None]:
def plot_encodings(indices_list, labels, pca, max_points_per_label=5):
    # indices_list is list of lists of indices, where each list is a group of indices with the same label
    # labels is list of labels 
    traces = []
    for inds, label in zip(indices_list, labels):
        bursts = np.take(test_dataset.all_bursts, inds, axis=0)
        masks = np.take(test_dataset.all_burst_masks, inds, axis=0)
        burst_info = [test_dataset.all_burst_info[i] for i in inds]
        X = encode_and_cat(bursts, masks)
        print X.shape
        #X = StandardScaler().fit_transform(X)
        X_after_pca = pca.transform(X)
        if max_points_per_label is not None and X_after_pca.shape[0] > max_points_per_label:
            # if there's too many points, use kmeans to get summary points
            kmeans = KMeans(n_clusters=max_points_per_label, random_state=0).fit(X_after_pca)
            summary_points = kmeans.cluster_centers_
        else:
            summary_points = X_after_pca
        trace = Scatter(
            x=summary_points[:,0],
            y=summary_points[:,1],
            mode='markers',
            name=label,
            marker=Marker(
                size=12,
                line=Line(
                    color='rgba(217, 217, 217, 0.14)',
                    width=0.5),
                opacity=0.8))
        traces.append(trace)
    data = Data(traces)
    layout = Layout(xaxis=XAxis(title='PC1', showline=False),
                    yaxis=YAxis(title='PC2', showline=False))
    fig = Figure(data=data, layout=layout)
    plotly.offline.iplot(fig)

In [None]:
newDataset = BurstDataset('/home/alice-eeg/NFS/script_output/describe_bs/')

In [None]:
patientInfo = PatientInfo('../../../patient_outcome_info/')

In [None]:
len(test_dataset.all_bursts)

In [None]:
pca = IncrementalPCA(n_components=2)
all_encodings = []
batch_size = 10
data_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False, drop_last=True)
for batch in tqdm(data_loader):
    masks = batch['mask']
    bursts = batch['burst']
    out, hidden, cell = encode(bursts, masks, encoder)
    batch_encodings = encode_and_cat(masks, bursts)
    pca.partial_fit(batch_encodings)

In [None]:
pts = set([])
for burst_info in test_dataset.all_burst_info:
    edf, episode_start_ind, episode_end_ind, burst_num = newDataset.parse_burst_info(burst_info)
    pt = utils.get_pt_from_edf_name(edf)
    pts.add(pt)
indices_list = []
labels = []
for pt in pts:
    inds = [i for (i, val) in enumerate(test_dataset.all_burst_info) if '{}_'.format(pt) in val]
    inds = np.array(inds)
    indices_list.append(inds)
    labels.append(pt)
plot_encodings(indices_list, labels, pca)