In [1]:
import nu_smrutils as u
import pandas as pd
from mpl_toolkits.mplot3d import Axes3D
import matplotlib.pyplot as plt
import matplotlib.animation as animation
import numpy as np
%matplotlib notebook

pd.set_option('display.max_columns', None)
pd.set_option('display.max_rows', None)

In [2]:
filename='datasets/aBNCI2014004R.pickle'
d = u.loaddat(filename)
df=d[0].to_data_frame()
subjects=len(d)
print('Subjects: '+str(subjects))
print('Channels: '+str(df.columns.to_list()[3:]))
print('Total epochs: '+str(len(df['epoch'].unique())))
print('Points per epoch: ' + str(df.loc[df['epoch']==1,:].shape[0]))

Subjects: 9
Channels: ['C3', 'Cz', 'C4']
Total epochs: 720
Points per epoch: 321


In [3]:
def produce_coordinates(d, subject, l_epoch, r_epoch, sampling, t, channels):
    data=d[subject].to_data_frame()
    data=data[['time','condition','epoch']+channels]
    left_hand=data.loc[data['condition']=="left_hand", :]
    right_hand=data.loc[data['condition']=='right_hand',:]
    l_epochs=left_hand['epoch'].unique()
    r_epochs=right_hand['epoch'].unique()
    times=data['time'].unique()    
    right, left = [], []
    for c in channels:
        right.append(right_hand.loc[(right_hand['epoch']==r_epochs[r_epoch])&(right_hand['time']>=times[t[0]])&(right_hand['time']<=times[t[-1]]),:][c].to_list()[::sampling])
        left.append(left_hand.loc[(left_hand['epoch']==l_epochs[l_epoch])&(left_hand['time']>=times[t[0]])&(left_hand['time']<=times[t[-1]]),:][c].to_list()[::sampling])
    return right, left

In [4]:
params={
    'subject':2,
    'l_epoch':5, #0-359
    'r_epoch':5, #0-359
    'sampling':1, 
    't':[0,10], #0-320
    'channels':['C3','Cz','C4'] #up to 3, order is x y z
}

fig=plt.figure()
ax=fig.add_subplot(1,1,1, projection='3d')
ax.view_init(elev=10, azim=125)

ax.set_xlim3d(-5,5)
ax.set_ylim3d(-5,5)
ax.set_zlim3d(-5,5)

ax.set_xlabel(params['channels'][0])
ax.set_ylabel(params['channels'][1])
ax.set_zlabel(params['channels'][2])

right, left = produce_coordinates(d, **params)
## Right hand -- red
## Left hand == blue
ax.scatter(right[0], right[1], right[2], marker='o', color='red')
ax.scatter(left[0], left[1], left[2], marker='o', color='blue')
plt.show()

<IPython.core.display.Javascript object>

In [5]:
params2={
    'l_epoch':5, 
    'r_epoch':5, 
    'sampling':1, 
    't':[50],
    'channels':['C3','Cz','C4']
}

fig=plt.figure()
axes=[]
for i in range(subjects):
    ax=fig.add_subplot(3,3,i+1,projection='3d')
    ax.set_xlabel(params2['channels'][0])
    ax.set_ylabel(params2['channels'][1])
    ax.set_zlabel(params2['channels'][2])
    ax.view_init(elev=10, azim=125)
    right, left = produce_coordinates(d, i, **params2)
    ax.scatter(right[0], right[1], right[2], marker='o', color='red')
    ax.scatter(left[0], left[1], left[2], marker='o', color='blue')
    axes.append(ax)

plt.show()    

<IPython.core.display.Javascript object>