In [1]:
from IPython.core.display import display, HTML
display(HTML("<style>.container { width:100% !important; }</style>"))

import h5py
import matplotlib.pyplot as plt
import hdc
import numpy as np
from sklearn import svm
from itertools import combinations

%matplotlib notebook

In [2]:
dataFile = '/Users/andy/Research/py_hdc_cont/emg_mat/armPosition/sub1exp0_accHV.mat'
# file is saved in hdf5 format
file = h5py.File(dataFile,'r')
experimentData = file['experimentData']
keys = list(experimentData.keys())

print('Available keys include:')
for k in keys:
    print('  ' + k)
    
numTrials, numPositions, numGestures = experimentData[keys[0]].shape
D = file[experimentData[keys[4]][0,0,0]].shape[1]

contextVec = np.random.choice([-1.0, 1.0], size=(numPositions,D))

Available keys include:
  accHV128
  accHV16
  accHV32
  accHV4
  accHV64
  accHV8
  accMeanFeat
  accRaw
  accStdFeat
  emgFeat
  emgHV
  emgRaw
  expGestLabel
  expPosLabel


In [3]:
numTrainPositions = 2
trainCombinations = list(combinations(np.arange(numPositions),numTrainPositions))
numCombinations = len(trainCombinations)

In [4]:
numIters = 1

meanHDAcc = np.zeros((numCombinations,numPositions))
meanSVMAcc = np.zeros((numCombinations,numPositions))

clustHits = np.zeros((numCombinations,numPositions,numPositions))
clustCorrectHits = np.zeros((numCombinations,numPositions,numPositions))
clustIncorrectHits = np.zeros((numCombinations,numPositions,numPositions))

for apComb in range(numCombinations):
    for apTest in range(numPositions):
        hdAcc = []
        svmAcc = []
        for n in range(numIters):
            # set up new associative memory
            AM = []
            AM = hdc.Memory(D)
            # train/test split with single trial for training, remaining trials for testing
            trainTrials = np.random.randint(numTrials,size=numGestures)
            # train AM and build training array for SVM
            Xtrain = np.empty((0,320))
            ytrain = np.empty(0)
            for apTrain in trainCombinations[apComb]:
                for g in range(numGestures):
                    for t in range(numTrials):
                        if t == trainTrials[g]:
                            expLabel = file[experimentData['expGestLabel'][t,apTrain,g]][0,:]
                            ng = file[experimentData['emgHV'][t,apTrain,g]][expLabel>0,:]
                            accHV = file[experimentData['accHV128'][t,apTrain,g]][expLabel>0,:]
                            ng = ng*accHV
#                             ng = ng*contextVec[apTrain]
#                             AM.train(ng,vClass=g,vClust=0)
                            AM.train(ng,vClass=g,vClust=apTrain)
#                             AM.train_sub_cluster(ng,vClass=g)
#                             AM.prune(min=5)

                            # gather features for SVM (or other model)
                            feat = file[experimentData['emgFeat'][t,apTrain,g]][:,expLabel>0].T
                            numEx, numCh = feat.shape
                            ngramLen = 5
                            x = np.zeros((numEx-ngramLen+1,numCh*ngramLen))
                            for i in range(ngramLen):
                                x[:,np.arange(numCh)+i*numCh] = feat[np.arange(numEx-ngramLen+1)+i,:]*6400
                            Xtrain = np.concatenate((Xtrain,x))
                            ytrain = np.concatenate((ytrain,g*np.ones(numEx-ngramLen+1)))
            
            # train SVM (or other model)
            clf = svm.SVC(decision_function_shape='ovo',kernel='linear',C=1)
            clf.fit(Xtrain,ytrain)
            
            # test AM
            for g in range(numGestures):
                for t in range(numTrials):
                    if t != trainTrials[g]:
                        expLabel = file[experimentData['expGestLabel'][t,apTest,g]][0,:]
                        ng = file[experimentData['emgHV'][t,apTest,g]][expLabel>0,:]
                        accHV = file[experimentData['accHV128'][t,apTest,g]][expLabel>0,:]
                        ng = ng*accHV
#                         ng = ng*contextVec[apTest]
                        label,clust,sim = AM.match(np.asarray(ng),bipolar=True)
                        clust = np.asarray(clust)
                        label = np.asarray(label)
                        hdAcc.append(np.sum(label == g)/len(label))
                        
                        for p in range(numPositions):
                            clustHits[apComb,apTest,p] += sum(clust == p)
                            clustCorrectHits[apComb,apTest,p] += sum(clust[label==g] == p)
                            clustIncorrectHits[apComb,apTest,p] += sum(clust[label!=g] == p)
                        
                        feat = file[experimentData['emgFeat'][t,apTest,g]][:,expLabel>0].T
                        numEx, numCh = feat.shape
                        ngramLen = 5
                        x = np.zeros((numEx-ngramLen+1,numCh*ngramLen))
                        for i in range(ngramLen):
                            x[:,np.arange(numCh)+i*numCh] = feat[np.arange(numEx-ngramLen+1)+i,:]*6400
                        
                        yhat = clf.predict(x)
                        svmAcc.append(np.sum(yhat == g)/len(yhat))
        
        meanHDAcc[apComb,apTest] = np.mean(hdAcc)
        meanSVMAcc[apComb,apTest] = np.mean(svmAcc)
        

In [5]:
def plot_arm_position(acc,vmin,vmax):
    h,w = acc.shape
    data = acc*100

    # Limits for the extent
    x_start = 0
    x_end = w
    y_start = 0
    y_end = h

    extent = [x_start, x_end, y_start, y_end]

    # The normal figure
    fig = plt.figure(figsize=(8, 12))
    ax = fig.add_subplot(111)
    im = ax.imshow(data, extent=extent, interpolation='None', cmap='viridis', vmin=vmin, vmax=vmax)

    # Add the text
    jump_x = (x_end - x_start) / (2.0 * w)
    jump_y = (y_end - y_start) / (2.0 * h)
    x_positions = np.linspace(start=x_start, stop=x_end, num=w, endpoint=False)
    y_positions = np.flip(np.linspace(start=y_start, stop=y_end, num=h, endpoint=False))

    for y_index, y in enumerate(y_positions):
        for x_index, x in enumerate(x_positions):
            label = data[y_index, x_index]
            text_x = x + jump_x
            text_y = y + jump_y
            if label < (vmin+vmax)/2:
                ax.text(text_x, text_y, '%.4f' % (label), color='white', ha='center', va='center')
            else:
                ax.text(text_x, text_y, '%.4f' % (label), color='black', ha='center', va='center')

    fig.colorbar(im)


In [6]:
plot_arm_position(meanHDAcc,0,100)

<IPython.core.display.Javascript object>

In [7]:
plot_arm_position(meanSVMAcc,0,100)

<IPython.core.display.Javascript object>

In [8]:
def plot_clust_hits(clust,vmin,vmax):
    h,w = clust[0].shape
    data = clust*100

    # Limits for the extent
    x_start = 0
    x_end = w
    y_start = 0
    y_end = h

    extent = [x_start, x_end, y_start, y_end]
    
    # The normal figure
    fig = plt.figure(figsize=(5, 50))
    
    for comb in range(data.shape[0]):
        data[comb] = data[comb]/np.sum(data[comb],axis=1)[:,None]*100
        ax = fig.add_subplot(data.shape[0],1,comb+1)
        im = ax.imshow(data[comb], extent=extent, interpolation='None', cmap='viridis', vmin=vmin, vmax=vmax)

        # Add the text
        jump_x = (x_end - x_start) / (2.0 * w)
        jump_y = (y_end - y_start) / (2.0 * h)
        x_positions = np.linspace(start=x_start, stop=x_end, num=w, endpoint=False)
        y_positions = np.flip(np.linspace(start=y_start, stop=y_end, num=h, endpoint=False))

        for y_index, y in enumerate(y_positions):
            for x_index, x in enumerate(x_positions):
                label = data[comb, y_index, x_index]
                text_x = x + jump_x
                text_y = y + jump_y
                if label < (vmin+vmax)/2:
                    ax.text(text_x, text_y, '%.4f' % (label), color='white', ha='center', va='center')
                else:
                    ax.text(text_x, text_y, '%.4f' % (label), color='black', ha='center', va='center')




In [9]:
plot_clust_hits(clustHits,0,100)

<IPython.core.display.Javascript object>

In [10]:
plot_clust_hits(clustCorrectHits,0,100)

<IPython.core.display.Javascript object>

In [11]:
plot_clust_hits(clustIncorrectHits,0,100)

<IPython.core.display.Javascript object>