<a href="https://colab.research.google.com/github/DiliSimon/generative-snowboard/blob/master/inference_pose.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [9]:
import os
import pickle
from tensorflow import keras
import matplotlib.pyplot as plt
import numpy as np

In [2]:
model_weight_path = "kp_rnn_normalized.weights.h5"

In [5]:
with open("keypoint_sequences_normalized.pkl", "rb") as infile:
  keypoint_sequences = pickle.load(infile)
with open("target_predictions_normalized.pkl", "rb") as infile:
  target_predictions = pickle.load(infile)
with open("keypoint_sequences_val_normalized.pkl", "rb") as infile:
  keypoint_sequences_val = pickle.load(infile)
with open("target_predictions_val_normalized.pkl", "rb") as infile:
  target_predictions_val = pickle.load(infile)

In [3]:
def create_model():
  model = keras.Sequential()
  model.add(keras.Input(shape=(None, 50), dtype="float32")) # unknown number of time steps to look into the past, 50 features (25 keypoint ordered pairs)
  model.add(keras.layers.LSTM(64, return_sequences=True))
  model.add(keras.layers.Dropout(0.3))
  model.add(keras.layers.BatchNormalization())
  model.add(keras.layers.LSTM(64))
  model.add(keras.layers.Dropout(0.3))
  model.add(keras.layers.Dense(64, activation="sigmoid"))
  model.add(keras.layers.Dense(50, activation="sigmoid")) # output 50 features (25 keypoint ordered pairs)

  model.compile(
    optimizer=keras.optimizers.SGD(learning_rate=3e-3, momentum=0.2, nesterov=False),
    loss=keras.losses.MeanSquaredError(),
    metrics=[],
  )

  model.load_weights(model_weight_path)

  return model

In [29]:
def invert_scale(keypoints):
  for i in range(len(keypoints)):
    if i % 2 == 0:
      keypoints[i] = keypoints[i] * 1280
    else:
      keypoints[i] = keypoints[i] * 720
  return keypoints

In [11]:
model = create_model()

  saveable.load_own_variables(weights_store.get(inner_path))


In [42]:
num_predict_frames = 24
current_window = np.copy(keypoint_sequences_val[:1])

for i in range(num_predict_frames):
  print(current_window.shape)
  next_frame = model.predict(current_window)
  print(next_frame.shape)
  current_window = np.append(current_window, next_frame[np.newaxis, :, :], axis = 1)

(1, 24, 50)
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 29ms/step
(1, 50)
(1, 25, 50)
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 408ms/step
(1, 50)
(1, 26, 50)
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 22ms/step
(1, 50)
(1, 27, 50)
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 21ms/step
(1, 50)
(1, 28, 50)
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 21ms/step
(1, 50)
(1, 29, 50)
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 21ms/step
(1, 50)
(1, 30, 50)
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 20ms/step
(1, 50)
(1, 31, 50)
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 20ms/step
(1, 50)
(1, 32, 50)
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 20ms/step
(1, 50)
(1, 33, 50)
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 21ms/step
(1, 50)
(1, 34, 50)
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 22ms/

In [45]:
current_window = np.squeeze(current_window)
for i in range(current_window.shape[0]):
  current_window[i] = invert_scale(current_window[i])

[[919.359      456.945      907.529      ... 658.802      929.066
  654.882     ]
 [940.83       449.127      917.452      ... 656.786      946.778
  651.        ]
 [942.864      447.222      929.056      ... 649.03       962.393
  643.098     ]
 ...
 [476.50543213 201.44275904 701.02020264 ... 411.75577641 660.78224182
  425.97521782]
 [476.342659   201.84335232 698.66531372 ... 412.89410591 660.27709961
  427.34344482]
 [476.31816864 202.20834732 696.39930725 ... 414.01938915 659.82009888
  428.64429474]]


In [46]:
for i in range(current_window.shape[0]):
  print(current_window[i])

[ 919.359  456.945  907.529  472.643  878.203  476.559  854.614  505.972
  850.724  525.493  935.066  466.793  952.618  517.713  917.322  547.081
  927.21   547.05   909.513  551.011  913.407  601.943  919.337  650.931
  944.844  541.204  978.079  586.212 1011.34   635.239  913.489  449.184
  923.271  449.082  899.747  449.149  930.98   445.248 1005.51   650.913
 1015.29   647.066 1015.37   637.253  909.485  662.676  901.77   658.802
  929.066  654.882]
[ 940.83   449.127  917.452  466.84   889.971  474.58   866.391  498.161
  854.703  521.583  946.804  462.895  964.402  507.947  936.987  541.206
  938.89   541.27   919.339  547.109  929.165  598.024  938.874  648.921
  956.552  537.265  991.814  576.459 1027.03   625.45   933.086  445.202
  940.83   443.285  915.385  445.181    0.       0.    1021.16   637.192
 1030.98   635.222 1031.04   629.364  929.093  660.685  921.303  656.786
  946.778  651.   ]
[ 942.864  447.222  929.056  462.879  899.69   468.716  872.308  498.084
  860.563  