In [234]:
from glob import glob
import os
import numpy as np
import matplotlib.pyplot as plt
from scipy.signal import butter, filtfilt, iirnotch
import numpy as np
import matplotlib.pyplot as plt
from sklearn.preprocessing import MinMaxScaler

#read trajectory data
def resample(data, target_length):
    ori_indices = np.arange(len(data))
    new_indices = np.linspace(0, len(data) - 1, target_length)
    return np.interp(new_indices, ori_indices, data)


def bw_bandpass(data, lowcut, highcut, fs, order = 2):
    nyquist = fs * 0.5
    low = lowcut / nyquist
    high = highcut / nyquist
    b, a = butter(order, [low, high], btype = 'band', output = 'ba', analog = False)
    return filtfilt(b, a, data, axis = 1)

def bw_notch5060(data, fs, order = 2):
    nyquist = fs * 0.5
    bp50 = np.array([49.0, 51.0]) / nyquist
    bp60 = np.array([59.0, 61.0]) / nyquist

    b, a = butter(order, bp50, btype = 'bandstop', output = 'ba', analog = False)
    data = filtfilt(b, a, data, axis = 1)

    b, a = butter(order, bp60, btype = 'bandstop', output = 'ba', analog = False)
    data = filtfilt(b, a, data, axis = 1)
    return data

#create a pair of timestamp with 15 dp eeg window (prior to target angular)
def extract_eeg_chunks(ts, eegs):
    offset_index = 250 # offset time 2s, fs = 125, 15 point grouping
    len_group = 16 #calculate from angular data
    chunks = np.array([])
    for t in np.round(ts * fs) + offset_index:
        index = int(t[0])
        c = eegs[:, index - len_group:index]
        c = np.expand_dims(c, 0)
        chunks = np.concatenate([chunks, c], axis= 0) if chunks.size else c

    return chunks

def norm(x):
    scaler =  MinMaxScaler((0,1))
    ori_dim = x.shape
    norm_x = scaler.fit_transform(x.reshape(x.shape[0]*x.shape[1], x.shape[2]))
    return norm_x.reshape(ori_dim)


chs_map = {0: 'FP1', 1: 'FP2', 2: 'C3', 3: 'C4', 4: 'T5', 5: 
                        'T6', 6: 'O1', 7: 'O2', 8: 'P4', 9: 'P3', 10: 'T4', 11: 'T3', 12: 'F4', 13: 'F3', 14: 'F8', 15: 'F7'}


RESAMPLE_SIZE = 20
fs = 125
angs_path = os.path.join(os.getcwd(), 'outputs','*ANGLE*') #repalce to the directory
angs_files = glob(angs_path)

eeg_path = os.path.join(os.getcwd(), 'outputs','*EEG*') #repalce to the directory
eeg_files = glob(eeg_path)

angs = np.array([]) #angular data of each trial in (trial, sample, 1)
trial_eegs = np.array([]) #eeg signal of each trial (trial, sample, dp)
chunks_eeg = np.array([]) #chunks of eeg assosicated with timestamp, (trial, sample, chs, dp)
eegs = np.array([])

for fa,fe in zip(angs_files, eeg_files):
    tmp_a = np.load(fa)
    tmp_e = np.load(fe)[1:17]

    tmp_e = tmp_e - np.mean(tmp_e) #remove dc offset
    tmp_e = bw_notch5060(tmp_e, 125) #apply notch
    tmp_e  = bw_bandpass(tmp_e, 1.0, 40.1, 125, 2) #filter
    tmp_e = tmp_e[:, 0:600]
    # print(tmp[:,0].shape)
    ang = np.expand_dims(resample(tmp_a[:,0], RESAMPLE_SIZE), [0,-1])
    ts = np.expand_dims(resample(tmp_a[:,1], RESAMPLE_SIZE),[0, -1])
    
    ts = (ts - ts.min()) * 10**-9 #$time stamp is in nano second
    eeg = np.expand_dims(tmp_e, 0)
    if 1.9 <= ts.max() - ts.min() <= 2.1: #filter only feasible trajectories
        angs = np.concatenate([angs, ang], axis = 0) if angs.size else ang
        eegs = np.concatenate([eegs, eeg], axis = 0) if eegs.size else eeg

for eeg, ts in zip(eegs, timestamp):
    chunks = np.expand_dims(extract_eeg_chunks(ts, eeg),[0, -1])
    chunks_eeg = np.concatenate([chunks_eeg, chunks]) if chunks_eeg.size else chunks

angs = norm(angs)
print(f"Angular size : {angs.shape}, | EEG Chunk size {chunks_eeg.shape}")

Angular size : (18, 20, 1), | EEG Chunk size (18, 20, 16, 16, 1)


In [271]:
obs_max = 10
train_N = 14
train_p = np.random.permutation(15)

def get_train_sample(eegs, angs, obs_len):
    n = np.random.randint(0,obs_max)+1
    d = train_p[np.random.randint(0, train_N)]

    #create timestamp vector
    times = np.linspace(0, 1, obs_len)
    perm = np.random.permutation(obs_len)
    obs_eeg = np.zeros((n, 16, 16, 2))
    target_t = np.zeros((1,1))
    target_eeg = np.zeros((1,16,16,2))
    for i in range(n):
        obs_eeg[i,:,:,0] = np.ones((16,16)) * times[perm[i]]
        obs_eeg[i,:,:,1] = eegs[d, i,:,:,0] 

    target_t[0, 0] = times[n]
    target_eeg[0, :,:, 1] = eegs[d, n, :,:, 0]
    
    return [obs_eeg, target_t], [target_eeg], d, perm[n]
# obs, out, _,_ = get_train_sample(chunks_eeg, angs, 20)

# Test EEG AE 

In [785]:
#VARIABLE
#data preparatio
test_eeg_ec_data = chunks_eeg[:, :, :, :] #test single channel , C3

X = timestamp[:15,:,:]#time
Y = test_eeg_ec_data[:15,:,:] #eeg

v_X = timestamp[15:]
v_Y = test_eeg_ec_data[15:]

obs_max = 5 #max observation range
d_N = X.shape[0] #nbr of observation
d_x, d_y = X.shape[-1], Y.shape[-1] #depth?
time_len = X.shape[1]

v_X.shape, v_Y.shape, X.shape, Y.shape

((3, 20, 1), (3, 20, 16, 16), (15, 20, 1), (15, 20, 16, 16))

In [None]:
X[0][2], Y[0][2]

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.distributions as D
import torch.optim as optim

def get_train_sample_eeg():
    n = np.random.randint(0,obs_max)+1
    d = np.random.randint(0, d_N)

    observations = np.zeros((1, n, 2, 15))
    target_X = np.zeros((1,d_y))
    target_Y = np.zeros((1,d_y))

    # print(observations.shape, test_eeg_ec_data.shape)
    perm = np.random.permutation(time_len) #random timepoint

    # Y[d, perm[:n]].shape, Y[d, perm[:n]][0], times[0]
    # X[d,perm[:n]].shape, Y[d, perm[:n]].shape
    times = np.repeat(X[d, perm[:n]], d_y, axis = 1)
    signal = Y[d, perm[:n]]

    # X[d,perm[:n]].shape, Y[d, perm[:n]].shape, times.shape
    for i in range(n):
        observations[0, i, 0, :] = times[i]
        observations[0, i, 1, :] = signal[i]

    target_X[0] = np.repeat(X[d,perm[n]], Y.shape[-1])
    target_Y[0] = Y[d,perm[n]]

    return torch.from_numpy(observations), torch.from_numpy(target_X), torch.from_numpy(target_Y)

def predict_model(observations, target_X, plot = True):
    predicted_Y = np.zeros((time_len,d_y))
    predicted_std = np.zeros((time_len,d_y))
    with torch.no_grad():
        prediction = model(torch.from_numpy(observations),torch.from_numpy(target_X)).numpy()
    predicted_Y = prediction[:,:d_y]
    predicted_std = np.log(1+np.exp(prediction[:,d_y:]))
    if plot: # We highly recommend that you customize your own plot function, but you can use this function as default
        for i in range(d_y): #for every feature in Y vector we are plotting training data and its prediction
            fig = plt.figure(figsize=(5,5))
            for j in range(d_N):
                plt.plot(X[j,:,0],Y[j,:,i]) # assuming X[j,:,0] is time
            plt.plot(X[j,:,0],predicted_Y[:,i],color='black')
            plt.errorbar(X[j,:,0],predicted_Y[:,i],yerr=predicted_std[:,i],color = 'black',alpha=0.4)
            plt.scatter(observations[:,0],observations[:,d_x+i],marker="X",color='black')
            plt.show()  
    return predicted_Y, predicted_std

def log_prob_loss(output, target):
    mean, sigma = output.chunk(2, dim = -1)
    sigma = F.softplus(sigma)
    dist = D.Independent(D.Normal(loc=mean, scale=sigma), 1)
    return -torch.mean(dist.log_prob(target))



class CNMP(nn.Module):
    def __init__(self):
        super(CNMP, self).__init__()
        
        # Encoder takes observations which are (X,Y) tuples and produces latent representations for each of them
        self.encoder = nn.Sequential(
        nn.Linear(d_y,128),nn.ReLU(),
        nn.Linear(128,128),nn.ReLU(),
        nn.Linear(128,128)
        )
        
        #Decoder takes the (r_mean, target_t) tuple and produces mean and std values for each dimension of the output
        self.decoder = nn.Sequential(
        nn.Linear(128+d_x,128),nn.ReLU(),
        nn.Linear(128,128),nn.ReLU(),
        nn.Linear(128,2*d_y)
        )
        
    def forward(self,observations,target_t):
        global dbg
        print(observations.shape, target_t.shape)
        r = self.encoder(observations) # Generating observations
        r_mean = torch.mean(r,dim=0) # Taking mean and generating the general representation
        dbg = r_mean
        print(r_mean.shape)
        r_mean = r_mean.repeat(target_t.shape[0],1) # Duplicating general representation for every target_t

        concat = torch.cat((r_mean,target_t),dim=-1) # Concatenating each target_t with general representation
        output = self.decoder(concat) # Producing mean and std values for each target_t
        return output
dbg = None
model = CNMP().double()
model.eval()
with torch.no_grad():
    obs, targett, targetout = get_train_sample_eeg()
    print(targetout.shape, targett.shape)
    out = model(obs, targett)

In [None]:
dbg.shape

In [None]:
import matplotlib.pyplot as plt
import math
import pylab as pl
from IPython import display
from IPython.core.display import HTML
from IPython.core.display import display as html_width

model = CNMP().double()
learning_rate = 1e-4
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

smooth_losses = [0]
losses = []
loss_checkpoint = 1000
plot_checkpoint = 1000
validation_checkpoint = 100
validation_error = 9999999

for step in range(1000000):  # loop over the dataset multiple times
    observations, target_t, target_output = get_train_sample_eeg()
    
    optimizer.zero_grad()

    output = model(observations, target_t)
    loss = log_prob_loss(output, target_output)
    loss.backward()
    optimizer.step()
    
    if step % loss_checkpoint == 0:
        losses.append(loss.data)
        smooth_losses[-1] += loss.data/(plot_checkpoint/loss_checkpoint)
    
    if step % validation_checkpoint == 0:
        current_error = 0
        for i in range(v_X.shape[0]):
            predicted_Y,predicted_std = predict_model(np.array([np.concatenate((v_X[i,0],v_Y[i,0]))]), v_X[i], plot= False)
            current_error += np.mean((predicted_Y - v_Y[i,:])**2) / v_X.shape[0]
        if current_error < validation_error:
            validation_error = current_error
            torch.save(model.state_dict(), 'cnmp_best_validation.h5')
            print(' New validation best. Error is ', current_error)
        
    if step % plot_checkpoint == 0:
        #clearing output cell
        display.clear_output(wait=True)
        display.display(pl.gcf())
        
        print(step)
        #plotting training examples and smoothed losses
        
        plt.figure(figsize=(15,5))
        plt.subplot(121)
        plt.title('Train Loss')
        plt.plot(range(len(losses)),losses)
        plt.subplot(122)
        plt.title('Train Loss (Smoothed)')
        plt.plot(range(len(smooth_losses)),smooth_losses)
        plt.show()
        
        #plotting validation cases
        for i in range(v_X.shape[0]):
            predict_model(np.array([np.concatenate((v_X[i,0],v_Y[i,0]))]), v_X[i])
        
        if step!=0:
            smooth_losses.append(0)
print('Finished Training')