# Get data into dictionary

In [1]:
from sklearn import metrics

In [187]:
import scipy.io
import os
import numpy as np
path = 'groundTruthmat/'
dataset = []
files = os.listdir(path)
data_dict = {}
# devide four folds
# s1: P03 – P15
# s2: P16 – P28
# s3: P29 – P41
# s4: P42 – P54
for file in files:
    read = scipy.io.loadmat(path + file)
    data_dict[file] = read['labseqid']

# Helper functions

In [188]:
def one_hot():
# create frame_label for one hot encoding
    frame_label = {}
    for i in range(48):
        frame_label[i] = 48*[0]
        frame_label[i][i] = 1
    frame_label[-1] = 48*[0]
    return frame_label

frame_label = one_hot()

In [189]:
def get_input_frames(data_dict, trainProportion):
    new_dict = {}
    train_y = []
    for filename, frames in data_dict.items():
        frameLen = len(frames)
        
        inputLen = round(frameLen*trainProportion)  
        inputFrames = frames[:inputLen]
        new_dict[filename] = inputFrames
    return new_dict

In [208]:
def get_input_y(data_dict, input_dict):
    train_y = []
    for filename, frames in data_dict.items():
        frame_len = len(input_dict[filename])
#         print(input_dict[filename])
        y = [input_dict[filename][frame_len-1][0]]
        for frame in frames[frame_len:]:
#             print(frame[0], y[-1])s
            if frame[0] != y[-1]:
                y.append(frame[0])
        y.pop(0)

        train_y.append(y)
    return train_y

In [191]:
# convert seccessive frmaes of same action to single action
""" data_dict structure eample
    {'P25_stereo01_P25_sandwich.mat': array([[ 0],
        [ 0],
        [ 0],
        ...,
        [39],
        [39],
        [39]], dtype=int32),"""

def frames_to_action(input_dict):
    for filename, frames in input_dict.items():
        action_list = []
        for frame in frames:
            if len(action_list) == 0:
                action_list.append(frame[0])
            else:
                if frame[0] != action_list[-1]:
                    action_list.append(frame[0])
        input_dict[filename] = action_list   
    return input_dict

In [192]:
# add padding, each video has same length of action input
def add_padding(data_dict):
    maxLen = 0
    count = 0
    for frames in data_dict.values():
        if len(frames) > maxLen:
            maxLen = len(frames)
    for filename, frames in data_dict.items():
#         print(frames)
        data_dict[filename] = (maxLen - len(frames)) * [-1]  + frames
    return data_dict, maxLen
            

In [206]:
def add_padding_to_y(trainY):
    maxLen = 0
    count = 0
    new_trainY = []
    for frames in trainY:
        if len(frames) > maxLen:
            maxLen = len(frames)
    for frames in trainY:
        temp = frames + (maxLen - len(frames)) * [-1]
        new_trainY.append(temp)
    return new_trainY, maxLen

In [194]:
def feature_encoding(input_dict):
    for filename, frames in input_dict.items():
        # get corresponding one-hot encode
        new = []
        for each in frames:
            new.append(frame_label[each])
        input_dict[filename] = new
    return input_dict

In [213]:
def label_encoding(trainY):
    new_trainY = []
    for frames in trainY:
        new = []
        for each in frames:
            new.append(frame_label[each])
        new_trainY.append(new)
    return new_trainY

In [196]:
train_y

[]

# LSTM

In [351]:
from keras.models import Sequential
from keras.layers import Dense
from keras.layers import LSTM, RepeatVector
from keras.layers import Dropout, Masking, TimeDistributed, Activation
def lstm_model(frame_len, max_timesteps):
#     clear_session() 
    
#     model = Sequential()


#     model.add(LSTM(100, input_shape=(frame_len, 48), return_sequences=True))
#     model.add(Dense(48, activation='tanh'))
#     model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'], sample_weight_mode="temporal")
#     model.summary()
    model = Sequential()
    model.add(Masking(mask_value=48*[0], input_shape=(frame_len,48)))
    model.add(LSTM(100))
    model.add(RepeatVector(23))
    model.add(LSTM(100, activation='sigmoid', return_sequences=True))
    model.add(TimeDistributed(Dense(48)))
#     model.add(Dense(48, activation='tanh'))
    model.compile(loss='mae', optimizer='adam', metrics=['accuracy'])
    
    model.summary()
    return model

# encoder_input_layer = Input(shape=(sequence_len, frame_len,))
# decoder_input_layer = Input(shape=(sequence_len, frame_len,))

# encoder layer
# model.add(LSTM(100, activation='relu', input_shape=(3, 1)))

# # repeat vector
# model.add(RepeatVector(3))

# # decoder layer
# model.add(LSTM(100, activation='relu', return_sequences=True))

# model.add(TimeDistributed(Dense(1)))
# model.compile(optimizer='adam', loss='mse')

# print(model.summary())

# 4-fold cross validation

In [227]:
import tensorflow as tf
from keras import backend as K
def train_model(trainX, trainY, model):

#     trainX = K.cast_to_floatx(trainX)
#     print(trainY)
#     trainY = K.cast_to_floatx(trainY)
    model.fit(np.array(trainX), np.array(trainY), epochs = 10, batch_size = 256)
    

In [339]:
def evaluation(testX, testY, model):
    predictions = model.predict(testX)
    print(predictions)
    results = []
    count = 0
    for i in range(len(predictions)):
        each_video = []
        for j in range(len(predictions[i]))
            result = np.array([0] * 48)
            index = predictions[j].argmax(axis=-1)
            result[index] = 1
#             if pre == 1:
#                 count += 1
            each_video.append(result)
        results.append(each_video)
    print(results)
    return (count/len(testY))

In [324]:
import copy
def cross_validation(input_dict, encoded_y, model):
    file_count = 0
    s1_x, s2_x, s3_x, s4_x = [], [], [], []
    s1_y, s2_y, s3_y, s4_y = [], [], [], []
    # s1: P03 – P15
    # s2: P16 – P28
    # s3: P29 – P41
    # s4: P42 – P54
    count = 0
    for filename, frames in input_dict.items():
        if int(filename[1:3]) <= 15:
            s1_x.append(input_dict[filename])
            s1_y.append(encoded_y[file_count])
        elif 16 <= int(filename[1:3]) <= 28:
            s2_x.append(input_dict[filename])
            s2_y.append(encoded_y[file_count])
        elif 29 <= int(filename[1:3]) <= 41:
            s3_x.append(input_dict[filename])
            s3_y.append(encoded_y[file_count])
        elif 42 <= int(filename[1:3]) <= 54:
            s4_x.append(input_dict[filename])
            s4_y.append(encoded_y[file_count])
        file_count += 1
        
    splits_x = [s1_x, s2_x, s3_x, s4_x]
    splits_y = [s1_y, s2_y, s3_y, s4_y]
    final_acc = 0
    for i in range(4):
        trainX = None
        trainY = None
        for x in range(4):
            if splits_x[x] != splits_x[i]:
                if trainX == None:
                    trainX = copy.deepcopy(splits_x[x])
                else:
                    trainX += copy.deepcopy(splits_x[x])

        for y in range(4):
            if splits_y[y] != splits_y[i]:
                if trainY == None:
                    trainY = copy.deepcopy(splits_y[y])
                else:
                    trainY += copy.deepcopy(splits_y[y])
        testX = splits_x[i]
        testY = splits_y[i]
        
        print(np.array(trainX).shape, np.array(trainY).shape, np.array(testX).shape,np.array(testY).shape)
        train_model(trainX, trainY, model)
        final_acc += evaluation(testX, testY, model)
        
    print(final_acc/4) 
    return (final_acc/4) 

In [352]:
from keras.backend import clear_session
input_proportion = [0.1]
# input_proportion = [0.1, 0.2, 0.3, 0.4, 0.5]
results = []
for proportion in input_proportion:
    input_dict = get_input_frames(data_dict, proportion)
    train_y = get_input_y(data_dict, input_dict)
    input_dict = frames_to_action(input_dict)
    input_dict, frame_len = add_padding(input_dict)
    input_dict = feature_encoding(input_dict)
    
    train_y, max_timesteps = add_padding_to_y(train_y)
    encoded_y = label_encoding(train_y)
    encoded_y = np.array(encoded_y)
#     encoded_y = encoded_y.reshape(23, 48)
    model = lstm_model(frame_len, max_timesteps)
    
    result = cross_validation(input_dict, encoded_y, model)
    results.append(result)
    clear_session() 

Model: "sequential_22"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
masking_1 (Masking)          (None, 6, 48)             0         
_________________________________________________________________
lstm_24 (LSTM)               (None, 100)               59600     
_________________________________________________________________
repeat_vector_8 (RepeatVecto (None, 23, 100)           0         
_________________________________________________________________
lstm_25 (LSTM)               (None, 23, 100)           80400     
_________________________________________________________________
time_distributed_9 (TimeDist (None, 23, 48)            4848      
Total params: 144,848
Trainable params: 144,848
Non-trainable params: 0
_________________________________________________________________
(1460, 6, 48) (1460, 23, 48) (252, 6, 48) (252, 23, 48)
Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Ep

ValueError: Found input variables with inconsistent numbers of samples: [48, 23]

In [None]:
for frames in input_dict.values():
    print(frames)

In [None]:
import tf_models, datase                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                           
ts, utils, metrics
from utils import imshow_
model, param_str = tf_models.temporal_convs_linear(n_nodes[0], conv, n_classes, n_feat, 
                                                    max_len, causal=causal, return_param_str=True)

In [56]:
input_dict["P25_cam01_P25_coffee.mat"]

[[1,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0],
 [1,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0],
 [1,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0],
 [1,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0],
 [1,
  0,
  0,
  0,
  0,
  0,
  0,
 

In [61]:
import keras
from keras.utils import np_utils
a = [[7,2], [4,5,6]]

y = np_utils.to_categorical(a)
y



ValueError: setting an array element with a sequence.

In [134]:
import keras
a = [[0,2], [4,5,6]]

seq=keras.preprocessing.sequence.pad_sequences(a, maxlen=None, padding='pre', value=0.0)
seq = keras.utils.to_categorical(seq)
seq



array([[[1., 0., 0., 0., 0., 0., 0.],
        [1., 0., 0., 0., 0., 0., 0.],
        [0., 0., 1., 0., 0., 0., 0.]],

       [[0., 0., 0., 0., 1., 0., 0.],
        [0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 0., 1.]]], dtype=float32)

In [69]:
for i in range(1):
    print(i)

0


In [155]:
trainy

NameError: name 'trainY' is not defined