In [1]:
import pandas as pd
import numpy as np
import pickle
from scipy import stats
import tensorflow as tf
from sklearn import metrics
from sklearn.model_selection import train_test_split
RANDOM_SEED = 42

In [25]:
import warnings
warnings.filterwarnings('ignore')

In [2]:
columns = ['user','activity','timestamp', 'x-axis', 'y-axis', 'z-axis']
df = pd.read_csv('data/WISDM_ar_v1.1_raw.txt', header = None, names = columns)
df = df.dropna()

In [3]:
df.head()

Unnamed: 0,user,activity,timestamp,x-axis,y-axis,z-axis
0,33,Jogging,49105962326000,-0.694638,12.680544,0.503953
1,33,Jogging,49106062271000,5.012288,11.264028,0.953424
2,33,Jogging,49106112167000,4.903325,10.882658,-0.081722
3,33,Jogging,49106222305000,-0.612916,18.496431,3.023717
4,33,Jogging,49106332290000,-1.18497,12.108489,7.205164


In [4]:
df['timestamp'][0]

49105962326000

In [5]:
df['timestamp'][200]

49126972305000

In [6]:
import datetime

In [7]:
df.info()

<class 'pandas.core.frame.DataFrame'>
Int64Index: 1098203 entries, 0 to 1098203
Data columns (total 6 columns):
user         1098203 non-null int64
activity     1098203 non-null object
timestamp    1098203 non-null int64
x-axis       1098203 non-null float64
y-axis       1098203 non-null float64
z-axis       1098203 non-null float64
dtypes: float64(3), int64(2), object(1)
memory usage: 58.7+ MB


In [8]:
N_TIME_STEPS = 200
N_FEATURES = 3
step = 20
segments = []
labels = []
for i in range(0, len(df) - N_TIME_STEPS, step):
    xs = df['x-axis'].values[i: i + N_TIME_STEPS]
    ys = df['y-axis'].values[i: i + N_TIME_STEPS]
    zs = df['z-axis'].values[i: i + N_TIME_STEPS]
    label = stats.mode(df['activity'][i: i + N_TIME_STEPS])[0][0]
    segments.append([xs, ys, zs])
    labels.append(label)

In [9]:
np.array(segments).shape

(54901, 3, 200)

In [10]:
reshaped_segments = np.asarray(segments, dtype= np.float32).reshape(-1, N_TIME_STEPS, N_FEATURES)
labels = np.asarray(pd.get_dummies(labels), dtype = np.float32)

In [11]:
reshaped_segments.shape

(54901, 200, 3)

In [12]:
labels[0]

array([0., 1., 0., 0., 0., 0.], dtype=float32)

In [13]:
X_train, X_test, y_train, y_test = train_test_split(
        reshaped_segments, labels, test_size=0.2, random_state=RANDOM_SEED)

In [14]:
len(X_train)

43920

In [15]:
len(X_test)

10981

In [16]:
N_CLASSES = 6
N_HIDDEN_UNITS = 64

In [17]:
def create_LSTM_model(inputs):
    W = {
        'hidden': tf.Variable(tf.random_normal([N_FEATURES, N_HIDDEN_UNITS])),
        'output': tf.Variable(tf.random_normal([N_HIDDEN_UNITS, N_CLASSES]))
    }
    biases = {
        'hidden': tf.Variable(tf.random_normal([N_HIDDEN_UNITS], mean=1.0)),
        'output': tf.Variable(tf.random_normal([N_CLASSES]))
    }
    
    X = tf.transpose(inputs, [1, 0, 2])
    X = tf.reshape(X, [-1, N_FEATURES])
    hidden = tf.nn.relu(tf.matmul(X, W['hidden']) + biases['hidden'])
    hidden = tf.split(hidden, N_TIME_STEPS, 0)

    lstm_layers = [tf.contrib.rnn.BasicLSTMCell(N_HIDDEN_UNITS, forget_bias=1.0) for _ in range(2)]
    lstm_layers = tf.contrib.rnn.MultiRNNCell(lstm_layers)

    outputs, _ = tf.contrib.rnn.static_rnn(lstm_layers, hidden, dtype=tf.float32)

    lstm_last_output = outputs[-1]

    return tf.matmul(lstm_last_output, W['output']) + biases['output']

In [26]:
tf.reset_default_graph()

X = tf.placeholder(tf.float32, [None, N_TIME_STEPS, N_FEATURES], name="input")
Y = tf.placeholder(tf.float32, [None, N_CLASSES])

In [27]:
pred_Y = create_LSTM_model(X)

pred_softmax = tf.nn.softmax(pred_Y, name="y_")

In [28]:
L2_LOSS = 0.0015

l2 = L2_LOSS * \
    sum(tf.nn.l2_loss(tf_var) for tf_var in tf.trainable_variables())

loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits = pred_Y, labels = Y)) + l2

In [29]:
LEARNING_RATE = 0.0025

optimizer = tf.train.AdamOptimizer(learning_rate=LEARNING_RATE).minimize(loss)

correct_pred = tf.equal(tf.argmax(pred_softmax, 1), tf.argmax(Y, 1))
accuracy = tf.reduce_mean(tf.cast(correct_pred, dtype=tf.float32))

In [30]:
N_EPOCHS = 50
BATCH_SIZE = 1024

In [None]:
saver = tf.train.Saver()

history = dict(train_loss=[], 
                     train_acc=[], 
                     test_loss=[], 
                     test_acc=[])

sess=tf.InteractiveSession()
sess.run(tf.global_variables_initializer())

train_count = len(X_train)

for i in range(1, N_EPOCHS + 1):
    for start, end in zip(range(0, train_count, BATCH_SIZE),
                          range(BATCH_SIZE, train_count + 1,BATCH_SIZE)):
        sess.run(optimizer, feed_dict={X: X_train[start:end],
                                       Y: y_train[start:end]})

    _, acc_train, loss_train = sess.run([pred_softmax, accuracy, loss], feed_dict={
                                            X: X_train, Y: y_train})

    _, acc_test, loss_test = sess.run([pred_softmax, accuracy, loss], feed_dict={
                                            X: X_test, Y: y_test})

    history['train_loss'].append(loss_train)
    history['train_acc'].append(acc_train)
    history['test_loss'].append(loss_test)
    history['test_acc'].append(acc_test)

    if i != 1 and i % 10 != 0:
        continue

    print(f'epoch: {i} test accuracy: {acc_test} loss: {loss_test}')
    
predictions, acc_final, loss_final = sess.run([pred_softmax, accuracy, loss], feed_dict={X: X_test, Y: y_test})

print()
print(f'final results: accuracy: {acc_final} loss: {loss_final}')

In [None]:
pickle.dump(predictions, open("predictions.p", "wb"))
pickle.dump(history, open("history.p", "wb"))
tf.train.write_graph(sess.graph_def, '.', './checkpoint/har.pbtxt')  
saver.save(sess, save_path = "./checkpoint/har.ckpt")
sess.close()

In [None]:
history = pickle.load(open("history.p", "rb"))
predictions = pickle.load(open("predictions.p", "rb"))

In [None]:
LABELS = ['Downstairs', 'Jogging', 'Sitting', 'Standing', 'Upstairs', 'Walking']

In [None]:
from tensorflow.python.tools import freeze_graph

MODEL_NAME = 'har'

input_graph_path = 'checkpoint/' + MODEL_NAME+'.pbtxt'
checkpoint_path = './checkpoint/' +MODEL_NAME+'.ckpt'
restore_op_name = "save/restore_all"
filename_tensor_name = "save/Const:0"
output_frozen_graph_name = 'frozen_'+MODEL_NAME+'.pb'

freeze_graph.freeze_graph(input_graph_path, input_saver="",
                          input_binary=False, input_checkpoint=checkpoint_path, 
                          output_node_names="y_", restore_op_name="save/restore_all",
                          filename_tensor_name="save/Const:0", 
                          output_graph=output_frozen_graph_name, clear_devices=True, initializer_nodes="")

In [18]:
frozen_graph = "./frozen_har.pb"
with tf.gfile.GFile(frozen_graph, "rb") as f:
    restored_graph_def = tf.GraphDef()
    restored_graph_def.ParseFromString(f.read())

In [19]:
with tf.Graph().as_default() as graph:
    tf.import_graph_def(
    restored_graph_def,
    input_map=None,
    return_elements=None,
    name=""
    )

In [56]:
y_ = graph.get_tensor_by_name("y_:0")
input = graph.get_tensor_by_name("input:0")
sess = tf.Session(graph=graph)
feed_input = {input:out}
result = sess.run(y_, feed_dict=feed_input)
print(result)

[[0.74011135 0.00457557 0.00290634 0.0041967  0.2425842  0.00562574]]


In [62]:
result = list(result[0])

In [63]:
print(result)

[0.74011135, 0.004575569, 0.0029063362, 0.0041967044, 0.2425842, 0.0056257388]


In [65]:
result.index(max(result))

0

In [41]:
inputs_batch = [[ 0.043095629662275314, 2.5282769203186035, 9.437943458557129],[ 0.05267243832349777, 2.6001031398773193, 9.490615844726562],[ 0.05267243832349777, 2.561795711517334, 9.409213066101074],[ 0.08140286058187485, 2.6240451335906982, 9.390059471130371],[ 0.05267243832349777, 2.6431987285614014, 9.332598686218262],[ 0.038307227194309235, 2.557007312774658, 9.418789863586426],[ 0.05746084079146385, 2.6001031398773193, 9.437943458557129],[ 0.028730420395731926, 2.5330653190612793, 9.452308654785156],[ -0.07182604819536209, 2.5905263423919678, 9.385271072387695],[ 0.08140286058187485, 2.5761611461639404, 9.366117477416992],[ 0.047884032130241394, 2.5761611461639404, 9.476250648498535],[ 0.07182604819536209, 2.5761611461639404, 9.433155059814453],[ 0.023942016065120697, 2.5905263423919678, 9.351752281188965],[ 0.043095629662275314, 2.5713725090026855, 9.361329078674316],[ 0.014365210197865963, 2.609679937362671, 9.442731857299805],[ 0.05267243832349777, 2.585737943649292, 9.399636268615723],[ 0.11013327538967133, 2.5713725090026855, 9.428366661071777],[ 0.033518824726343155, 2.604891538619995, 9.433155059814453],[ 0.08619125932455063, 2.5953147411346436, 9.404424667358398],[ 0, 2.604891538619995, 9.385271072387695],[ 0.08140286058187485, 2.561795711517334, 9.308655738830566],[ 0.07182604819536209, 2.585737943649292, 9.375694274902344],[ 0.08619125932455063, 2.5474305152893066, 9.457097053527832],[ 0.05267243832349777, 2.5953147411346436, 9.375694274902344],[ 0.033518824726343155, 2.5761611461639404, 9.41400146484375],[ 0.07182604819536209, 2.585737943649292, 9.318233489990234],[ 0.028730420395731926, 2.585737943649292, 9.452308654785156],[ 0.090979665517807, 2.628833532333374, 9.409213066101074],[ 0.038307227194309235, 2.5522189140319824, 9.433155059814453],[ 0.05746084079146385, 2.557007312774658, 9.370905876159668],[ 0.1149216815829277, 2.5953147411346436, 9.437943458557129],[ 0.033518824726343155, 2.6001031398773193, 9.452308654785156],[ 0.08619125932455063, 2.5713725090026855, 9.346963882446289],[ 0.05267243832349777, 2.5953147411346436, 9.390059471130371],[ 0.05267243832349777, 2.6001031398773193, 9.433155059814453],[ 0.019153613597154617, 2.6240451335906982, 9.404424667358398],[ 0.05267243832349777, 2.6001031398773193, 9.514557838439941],[ 0.06224924325942993, 2.5905263423919678, 9.375694274902344],[ 0.05746084079146385, 2.557007312774658, 9.399636268615723],[ 0.038307227194309235, 2.63362193107605, 9.38048267364502],[ 0.023942016065120697, 2.5953147411346436, 9.495404243469238],[ 0.033518824726343155, 2.604891538619995, 9.394847869873047],[ 0.07661445438861847, 2.6192567348480225, 9.289502143859863],[ 0.06703764945268631, 2.6384103298187256, 9.418789863586426],[ 0.019153613597154617, 2.561795711517334, 9.41400146484375],[ 0.019153613597154617, 2.585737943649292, 9.41400146484375],[ 0.043095629662275314, 2.6144683361053467, 9.437943458557129],[ 0.043095629662275314, 2.5953147411346436, 9.370905876159668],[ 0.028730420395731926, 2.585737943649292, 9.318233489990234],[ 0.05267243832349777, 2.6001031398773193, 9.390059471130371],[ 0.019153613597154617, 2.5713725090026855, 9.332598686218262],[ 0.09576806426048279, 2.6384103298187256, 9.337387084960938],[ 0.07182604819536209, 2.609679937362671, 9.38048267364502],[ 0.014365210197865963, 2.432508945465088, 9.485827445983887],[ -0.004788403399288654, 2.6192567348480225, 9.442731857299805],[ 0.05746084079146385, 2.5761611461639404, 9.370905876159668],[ 0.033518824726343155, 2.5522189140319824, 9.38048267364502],[ 0.05746084079146385, 2.5953147411346436, 9.385271072387695],[ 0.08140286058187485, 2.6240451335906982, 9.275136947631836],[ 0.06703764945268631, 2.5713725090026855, 9.44752025604248],[ 0.05746084079146385, 2.609679937362671, 9.418789863586426],[ 0.10534487664699554, 2.5953147411346436, 9.457097053527832],[ 0.09576806426048279, 2.5713725090026855, 9.41400146484375],[ 0.08140286058187485, 2.561795711517334, 9.294290542602539],[ 0.1149216815829277, 2.604891538619995, 9.734824180603027],[ 0.014365210197865963, 2.5091233253479004, 9.481039047241211],[ 0.06703764945268631, 2.537853717803955, 9.394847869873047],[ 0.05267243832349777, 2.647987127304077, 9.337387084960938],[ 0.08140286058187485, 2.6192567348480225, 9.433155059814453],[ 0.090979665517807, 2.604891538619995, 9.390059471130371],[ 0.05746084079146385, 2.5665841102600098, 9.399636268615723],[ 0.1149216815829277, 2.5043349266052246, 9.308655738830566],[ 0.023942016065120697, 2.5474305152893066, 9.385271072387695],[ 0.06703764945268631, 2.5522189140319824, 9.361329078674316],[ 0.023942016065120697, 2.628833532333374, 9.289502143859863],[ 0.13407529890537262, 2.7437551021575928, 9.299078941345215],[ 0.038307227194309235, 2.6144683361053467, 9.428366661071777],[ 0.09576806426048279, 2.5713725090026855, 9.375694274902344],[ 0.07182604819536209, 2.518700122833252, 9.255983352661133],[ 0.08140286058187485, 2.5091233253479004, 9.677363395690918],[ 0.043095629662275314, 2.499546527862549, 9.409213066101074],[ 0.11013327538967133, 2.561795711517334, 9.366117477416992],[ 0.033518824726343155, 2.5474305152893066, 9.38048267364502],[ 0.033518824726343155, 2.604891538619995, 9.327810287475586],[ 0.06224924325942993, 2.604891538619995, 9.41400146484375],[ 0.06224924325942993, 2.5474305152893066, 9.32302188873291],[ 0.10534487664699554, 2.557007312774658, 9.38048267364502],[ 0.12928688526153564, 2.5761611461639404, 9.375694274902344],[ 0.05746084079146385, 2.5665841102600098, 9.409213066101074],[ 0.028730420395731926, 2.557007312774658, 9.375694274902344],[ 0.06224924325942993, 2.5761611461639404, 9.390059471130371],[ 0.07661445438861847, 2.604891538619995, 9.308655738830566],[ 0.08619125932455063, 2.5665841102600098, 9.299078941345215],[ 0.038307227194309235, 2.542642116546631, 9.433155059814453],[ 0.08140286058187485, 2.6144683361053467, 9.452308654785156],[ 0.033518824726343155, 2.580949544906616, 9.409213066101074],[ 0.08140286058187485, 2.6001031398773193, 9.332598686218262],[ 0.038307227194309235, 2.6431987285614014, 9.370905876159668],[ 0.08140286058187485, 2.63362193107605, 9.366117477416992],[ 0.038307227194309235, 2.5330653190612793, 9.346963882446289],[ 0.06703764945268631, 2.5665841102600098, 9.32302188873291],[ 0.038307227194309235, 2.5665841102600098, 9.437943458557129],[ 0.10534487664699554, 2.6144683361053467, 9.346963882446289],[ 0.05746084079146385, 2.6144683361053467, 9.423578262329102],[ 0.023942016065120697, 2.557007312774658, 9.409213066101074],[ 0.090979665517807, 2.5665841102600098, 9.404424667358398],[ 0.014365210197865963, 2.585737943649292, 9.533711433410645],[ 0.05267243832349777, 2.6431987285614014, 9.394847869873047],[ 0.028730420395731926, 2.585737943649292, 9.409213066101074],[ 0.014365210197865963, 2.5761611461639404, 9.437943458557129],[ 0.1149216815829277, 2.5761611461639404, 9.47146224975586],[ 0.033518824726343155, 2.652775526046753, 8.978256225585938],[ 0.043095629662275314, 2.6384103298187256, 9.461885452270508],[ 0.06703764945268631, 2.5953147411346436, 9.332598686218262],[ 0.24420857429504395, 2.6575639247894287, 9.773131370544434],[ -0.038307227194309235, 2.4803929328918457, 9.677363395690918],[ -0.023942016065120697, 2.671929121017456, 9.409213066101074],[ 0.05746084079146385, 2.6001031398773193, 9.289502143859863],[ 0.12449848651885986, 2.5953147411346436, 9.346963882446289],[ 0.16759411990642548, 2.604891538619995, 9.466673851013184],[ 0.06224924325942993, 2.5330653190612793, 9.361329078674316],[ 0.090979665517807, 2.6815059185028076, 9.57201862335205],[ 0.07182604819536209, 2.6384103298187256, 9.44752025604248],[ 0.181959331035614, 2.580949544906616, 9.476250648498535],[ 0.038307227194309235, 2.580949544906616, 9.370905876159668],[ 0.05746084079146385, 2.557007312774658, 9.327810287475586],[ 0.06703764945268631, 2.557007312774658, 9.428366661071777],[ 0.07182604819536209, 2.6623523235321045, 9.648633003234863],[ 0.05746084079146385, 2.580949544906616, 9.399636268615723],[ 0.05267243832349777, 2.537853717803955, 9.44752025604248],[ 0.10055647045373917, 2.5330653190612793, 9.370905876159668],[ 0.009576806798577309, 2.403778553009033, 9.519346237182617],[ -0.014365210197865963, 2.6431987285614014, 9.232041358947754],[ 0.033518824726343155, 2.6001031398773193, 9.32302188873291],[ 0.06703764945268631, 2.609679937362671, 9.35654067993164],[ -0.06224924325942993, 2.604891538619995, 9.265560150146484],[ 0.033518824726343155, 2.557007312774658, 9.586383819580078],[ 0.10534487664699554, 2.557007312774658, 9.41400146484375],[ 0.08140286058187485, 2.5905263423919678, 9.47146224975586],[ -0.10534487664699554, 2.6192567348480225, 9.275136947631836],[ 0.05267243832349777, 2.609679937362671, 9.485827445983887],[ 0.047884032130241394, 2.7964274883270264, 9.409213066101074],[ 0.12449848651885986, 2.671929121017456, 9.38048267364502],[ 0.10055647045373917, 2.604891538619995, 9.481039047241211],[ 0.023942016065120697, 2.6144683361053467, 9.361329078674316],[ 0.090979665517807, 2.5522189140319824, 9.222464561462402],[ 0.07661445438861847, 2.5474305152893066, 9.428366661071777],[ 0.10055647045373917, 2.6384103298187256, 9.394847869873047],[ 0.05746084079146385, 2.585737943649292, 9.385271072387695],[ 0.14365209639072418, 2.6001031398773193, 9.32302188873291],[ 0.19632454216480255, 2.719813108444214, 9.47146224975586],[ 0.033518824726343155, 2.561795711517334, 9.423578262329102],[ 0.1388636976480484, 2.580949544906616, 9.351752281188965],[ -0.090979665517807, 2.5091233253479004, 8.939949035644531],[ 0.07661445438861847, 2.609679937362671, 9.543288230895996],[ 0.1149216815829277, 2.5953147411346436, 9.452308654785156],[ 0.05267243832349777, 2.6431987285614014, 9.437943458557129],[ 0.028730420395731926, 2.609679937362671, 9.318233489990234],[ 0.10055647045373917, 2.585737943649292, 9.35654067993164],[ 0.033518824726343155, 2.6240451335906982, 9.423578262329102],[ 0.047884032130241394, 2.5330653190612793, 9.332598686218262],[ 0.05267243832349777, 2.557007312774658, 9.44752025604248],[ 0, 2.5665841102600098, 9.255983352661133],[ 0.009576806798577309, 2.561795711517334, 9.265560150146484],[ 0.11971008777618408, 2.5713725090026855, 9.466673851013184],[ 0.08140286058187485, 2.580949544906616, 9.308655738830566],[ 0.033518824726343155, 2.5713725090026855, 9.394847869873047],[ 0.08619125932455063, 2.5330653190612793, 9.38048267364502],[ 0.09576806426048279, 2.5953147411346436, 9.27034854888916],[ -0.014365210197865963, 2.6001031398773193, 9.289502143859863],[ 0.06224924325942993, 2.604891538619995, 9.346963882446289],[ 0.019153613597154617, 2.5522189140319824, 9.260771751403809],[ 0.028730420395731926, 2.6575639247894287, 9.615114212036133],[ 0.10055647045373917, 2.5953147411346436, 9.490615844726562],[ 0.07182604819536209, 2.557007312774658, 9.428366661071777],[ -0.05267243832349777, 2.5282769203186035, 8.853757858276367],[ 0.10055647045373917, 2.6575639247894287, 9.485827445983887],[ 0.05746084079146385, 2.604891538619995, 9.313444137573242],[ 0, 2.5474305152893066, 9.428366661071777],[ 0.047884032130241394, 2.5522189140319824, 9.524134635925293],[ 0.043095629662275314, 2.5953147411346436, 9.279925346374512],[ 0.07661445438861847, 2.557007312774658, 9.38048267364502],[ 0.06703764945268631, 2.561795711517334, 9.423578262329102],[ 0.09576806426048279, 2.6575639247894287, 9.433155059814453],[ -0.1149216815829277, 2.494758129119873, 9.423578262329102],[ 0, 2.542642116546631, 9.44752025604248],[ 0.08619125932455063, 2.6575639247894287, 9.21288776397705],[ 0, 2.5330653190612793, 9.21288776397705],[ 0.033518824726343155, 2.6001031398773193, 9.452308654785156],[ 0.11013327538967133, 2.561795711517334, 9.385271072387695],[ 0.009576806798577309, 2.557007312774658, 9.466673851013184],[ 0.11013327538967133, 2.557007312774658, 9.552865028381348],[ 0.10055647045373917, 2.585737943649292, 9.418789863586426],[ 0.05267243832349777, 2.585737943649292, 9.495404243469238],[ 0.07182604819536209, 2.5761611461639404, 9.346963882446289],[ 0.1149216815829277, 2.6144683361053467, 9.409213066101074],[ 0.043095629662275314, 2.537853717803955, 9.399636268615723],[ 0.033518824726343155, 2.6144683361053467, 9.428366661071777],[ 0.06703764945268631, 2.5953147411346436, 9.390059471130371],[ 0.06224924325942993, 2.5953147411346436, 9.366117477416992]]

In [42]:
import numpy as np

(200, 3)

In [51]:
out = np.reshape(inputs_batch,  (1,200,3))

In [52]:
out.shape

(1, 200, 3)