In [132]:
import pandas as pd
import numpy as np
from sklearn.ensemble import ExtraTreesClassifier
import warnings
warnings.filterwarnings('ignore')
from sklearn import preprocessing
from sklearn.metrics import accuracy_score
from sklearn.neural_network import MLPClassifier

In [134]:
def chunk_filter_data(X,Y,chunk_size,train=False):
    inter_val = chunk_size
    if train == True:
        inter_val = 1
    X_out = []
    Y_out = []
    for i in range(0,len(X),inter_val):
        x = X[i:i+chunk_size]
        y = Y[i:i+chunk_size]
        if (np.all(y == y[0])) == False:
            continue
        if x.shape[0] != chunk_size:
            continue
        X_out.append(x.flatten())
        Y_out.append(y[0])
    return np.array(X_out), np.array(Y_out)

In [135]:
df = pd.read_pickle("./calibdata.pkl")

In [136]:
continous = np.mean(np.array(df["Gaze Target XY"].values.tolist()),axis=1)
fixation = [0]
for i in range(1,len(continous)):
    fixation.append(continous[i] - continous[i-1])
for i in range(0,len(fixation)):
    if fixation[i] == 0:
        fixation[i] = True
    else:
        fixation[i] = False
df['fixation'] = fixation

In [137]:
feat = "Embeddings Hist"
control = "Subject"
df_group = df.groupby([control])["fixation",feat]
g_id = list(df_group.groups.keys())

In [138]:
chunk_size = 9

Y_OUT, Y_GT = None, None
acc = []

for i in range(len(g_id)):
    
    print("Participant",g_id[i])
    
    X_train, Y_train, X_test, Y_test = None, None, None, None
    
    for j in range(len(g_id)):
        df_extract = df_group.get_group(g_id[j])
        x = np.array(df_extract[feat].tolist())
        y = np.array(df_extract["fixation"].tolist())
        x = x.reshape(x.shape[0],x.shape[1]*x.shape[2])
        y = y.reshape(-1,1)
        if g_id[i] == g_id[j]:
            X_test = x
            Y_test = y
        else:
            if X_train is None:
                X_train = x
                Y_train = y
            else:
                X_train = np.vstack((X_train,x))
                Y_train = np.vstack((Y_train,y))
        
    
    X_train = np.nan_to_num(X_train)
    X_test = np.nan_to_num(X_test)
    
    X_train, Y_train = chunk_filter_data(X_train,Y_train,chunk_size,train=True)
    X_test, Y_test = chunk_filter_data(X_test,Y_test,chunk_size)

    print("\t Train shape:",X_train.shape,Y_train.shape)
    print("\t Test shape:",X_test.shape,Y_test.shape)
    
    min_max_scaler_x = preprocessing.MinMaxScaler()
    X_train = min_max_scaler_x.fit_transform(X_train)
    X_test = min_max_scaler_x.transform(X_test)
    
#     clf = ExtraTreesClassifier(n_estimators=1000)
    clf = MLPClassifier()
    
    clf.fit(X_train,Y_train)
    
    Y_pred = clf.predict(X_test).reshape(-1,1)
    
    if Y_OUT is None:
        Y_OUT = Y_pred
        Y_GT = Y_test
    else:
        Y_OUT = np.vstack((Y_OUT,Y_pred))
        Y_GT = np.vstack((Y_GT,Y_test))
    
    print("\t Accuracy",accuracy_score(Y_test,Y_pred))
    acc.append(accuracy_score(Y_test,Y_pred))


Participant 0
	 Train shape: (6840, 234) (6840, 1)
	 Test shape: (85, 234) (85, 1)
	 Accuracy 0.7764705882352941
Participant 1
	 Train shape: (6840, 234) (6840, 1)
	 Test shape: (85, 234) (85, 1)
	 Accuracy 0.7647058823529411
Participant 2
	 Train shape: (6840, 234) (6840, 1)
	 Test shape: (85, 234) (85, 1)
	 Accuracy 0.8
Participant 3
	 Train shape: (6840, 234) (6840, 1)
	 Test shape: (85, 234) (85, 1)
	 Accuracy 0.6470588235294118
Participant 4
	 Train shape: (6840, 234) (6840, 1)
	 Test shape: (85, 234) (85, 1)
	 Accuracy 0.7176470588235294
Participant 5
	 Train shape: (6840, 234) (6840, 1)
	 Test shape: (85, 234) (85, 1)
	 Accuracy 0.6941176470588235
Participant 6
	 Train shape: (6840, 234) (6840, 1)
	 Test shape: (85, 234) (85, 1)
	 Accuracy 0.6352941176470588
Participant 7
	 Train shape: (6840, 234) (6840, 1)
	 Test shape: (85, 234) (85, 1)
	 Accuracy 0.7529411764705882
Participant 8
	 Train shape: (6840, 234) (6840, 1)
	 Test shape: (85, 234) (85, 1)
	 Accuracy 0.729411764705882

In [139]:
print("Accuracy",np.mean(acc),"SD",np.std(acc))

Accuracy 0.7247058823529412 SD 0.05079301916452448
