# Get data into dictionary

In [1]:
from sklearn import metrics

In [2]:
import scipy.io
import os
import numpy as np
path = 'groundTruthmat/'
dataset = []
files = os.listdir(path)
data_dict = {}
# devide four folds
# s1: P03 – P15
# s2: P16 – P28
# s3: P29 – P41
# s4: P42 – P54
for file in files:
    read = scipy.io.loadmat(path + file)
    data_dict[file] = read['labseqid']

In [3]:
def one_hot():
# create frame_label for one hot encoding
    frame_label = {}
    for i in range(48):
        frame_label[i] = 48*[0]
        frame_label[i][i] = 1
    return frame_label

frame_label = one_hot()

In [4]:
def get_input_x_y(data_dict, trainPortion):
    new_dict = {}
    train_y = []
    for filename, frames in data_dict.items():
        frameLen = len(frames)
        
        inputLen = round(frameLen*trainPortion)  
        inputFrames = frames[:inputLen]
        y = frames[inputLen-1]
        for frame in frames[inputLen:]:
#             print(frame, y)
            if frame != y[0]:
                y[0] = frame[0]
                break
        train_y.append(y)
        new_dict[filename] = inputFrames
    return new_dict, train_y

In [30]:
def add_padding(data_dict):
    maxLen = 0
    count = 0
    for frames in data_dict.values():
        
        if len(frames) > maxLen:
            maxLen = len(frames)
    for filename, frames in data_dict.items():
        data_dict[filename] = (maxLen - len(frames)) * [0] + [i[0] for i in frames]
    return data_dict, maxLen
            

In [31]:
def feature_encoding(input_dict):
    for filename, frames in input_dict.items():
        # get corresponding one-hot encode
        new = []
        for each in frames:
            new.append(frame_label[each])
        input_dict[filename] = new
    return input_dict

In [32]:
def label_encoding(train_y):
    encoded_y = []
    for label in train_y:
        new_label = frame_label[label[0]]
        encoded_y.append(new_label)
    return encoded_y

# LSTM

In [1]:
from keras.models import Sequential
from keras.layers import Dense
from keras.layers import LSTM
from keras.layers import Dropout
def lstm_model(frame_len):
    clear_session() 
    model = Sequential()
    model.add(LSTM(100, input_shape=(frame_len,48), return_sequences=False))
    model.add(Dense(48, activation='tanh'))
    model.compile(loss='mae', optimizer='adam', metrics=['accuracy'])
    model.summary()
    return model

# 4-fold cross validation

In [41]:
def train_model(trainX, trainY, model):
    model.fit(np.array(trainX), np.array(trainY), epochs = 10, batch_size = 256)
    

In [35]:
def evaluation(testX, testY, model):
    predictions = model.predict(testX)
    results = []
    count = 0
    for i in range(len(predictions)):
        result = np.array([0] * 48)
        index = predictions[i].argmax(axis=-1)
        result[index] = 1
        pre = metrics.accuracy_score(result, testY[i])
        if pre == 1:
            count += 1
    print(count/len(testY))
    return (count/len(testY))

In [36]:
import copy
def cross_validation(input_dict, encoded_y, model):
    file_count = 0
    s1_x, s2_x, s3_x, s4_x = [], [], [], []
    s1_y, s2_y, s3_y, s4_y = [], [], [], []
    # s1: P03 – P15
    # s2: P16 – P28
    # s3: P29 – P41
    # s4: P42 – P54
    count = 0
    for filename, frames in input_dict.items():
        if int(filename[1:3]) <= 15:
            s1_x.append(input_dict[filename])
            s1_y.append(encoded_y[file_count])
        elif 16 <= int(filename[1:3]) <= 28:
            s2_x.append(input_dict[filename])
            s2_y.append(encoded_y[file_count])
        elif 29 <= int(filename[1:3]) <= 41:
            s3_x.append(input_dict[filename])
            s3_y.append(encoded_y[file_count])
        elif 42 <= int(filename[1:3]) <= 54:
            s4_x.append(input_dict[filename])
            s4_y.append(encoded_y[file_count])
        file_count += 1
        
    splits_x = [s1_x, s2_x, s3_x, s4_x]
    splits_y = [s1_y, s2_y, s3_y, s4_y]
    final_acc = 0
    for i in range(4):
        trainX = None
        trainY = None
        for x in range(4):
            if splits_x[x] != splits_x[i]:
                if trainX == None:
                    trainX = copy.deepcopy(splits_x[x])
                else:
                    trainX += copy.deepcopy(splits_x[x])

        for y in range(4):
            if splits_y[y] != splits_y[i]:
                if trainY == None:
                    trainY = copy.deepcopy(splits_y[y])
                else:
                    trainY += copy.deepcopy(splits_y[y])
        testX = splits_x[i]
        testY = splits_y[i]
        print(np.array(trainX).shape, np.array(trainY).shape, np.array(testX).shape,np.array(testY).shape)
        train_model(trainX, trainY, model)
        final_acc += evaluation(testX, testY, model)
        
    print(final_acc/4) 
    return (final_acc/4) 

In [2]:
from keras.backend import clear_session
input_proportion = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]
# input_proportion = [0.1, 0.2]
results = []
for proportion in input_proportion:
    input_dict, train_y = get_input_x_y(data_dict, proportion)
    input_dict, frame_len = add_padding(input_dict)
    input_dict = feature_encoding(input_dict)
    encoded_y = label_encoding(train_y)
    model = lstm_model(frame_len)
    result = cross_validation(input_dict, encoded_y, model)
    results.append(result)
    clear_session() 

NameError: name 'get_input_x_y' is not defined

In [43]:
results

[0.9346375492427392,
 0.9420017031118413,
 0.6593585180819996,
 0.7241083223368788,
 0.7683955406846872,
 0.7602085603136483,
 0.7313101374006132,
 0.7721406072919286,
 0.8289806590800173]

In [None]:
import tf_models, datase                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                           
ts, utils, metrics
from utils import imshow_
model, param_str = tf_models.temporal_convs_linear(n_nodes[0], conv, n_classes, n_feat, 
                                                    max_len, causal=causal, return_param_str=True)