In [1]:
import sys, os
import torch
from torch.utils.data import DataLoader
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from PIL import Image
import moviepy.video.io.ImageSequenceClip
from functions import gen_ztta
from models import StateEncoder, OffsetEncoder, TargetEncoder, LSTM, Decoder
from skeleton import Skeleton
from LaFan import LaFan1
from functions import gen_ztta
from config import *

# Data

In [2]:
lafan = LaFan1("./lafan1/lafan1_small/flipped", seq_len=data["seq_length"], offset=10, train=False, debug=False)
lafan_loader = DataLoader(lafan, batch_size=test["batch_size"], shuffle=False, num_workers=0)

Building the data set... ['subject5']
Processing file dance2_subject5.bvh
Nb of sequences : 672



# Models

In [3]:
skeleton = Skeleton(offsets=data["offsets"], parents=data["parents"])
skeleton.to(device)
skeleton.remove_joints(data["joints_to_remove"])

In [4]:
models_folder = "./models/small_dataset_test_00/epoch_1/"

state_encoder = StateEncoder(in_dim=model["state_input_dim"])
state_encoder = state_encoder.to(device)
state_encoder.load_state_dict(torch.load(f"{models_folder}/state_encoder.pkl", map_location=torch.device('cpu') ))

offset_encoder = OffsetEncoder(in_dim=model["offset_input_dim"])
offset_encoder = offset_encoder.to(device)
offset_encoder.load_state_dict(torch.load(f"{models_folder}/offset_encoder.pkl", map_location=torch.device('cpu') ))

target_encoder = TargetEncoder(in_dim=model["target_input_dim"])
target_encoder = target_encoder.to(device)
target_encoder.load_state_dict(torch.load(f"{models_folder}/target_encoder.pkl", map_location=torch.device('cpu') ))

lstm = LSTM(in_dim=model["lstm_dim"], hidden_dim=model["lstm_dim"] * 2)
lstm = lstm.to(device)
lstm.load_state_dict(torch.load(f"{models_folder}/lstm.pkl", map_location=torch.device('cpu') ))

decoder = Decoder(in_dim=model["lstm_dim"]*2, out_dim=model["decoder_output_dim"])
decoder = decoder.to(device)
decoder.load_state_dict(torch.load(f"{models_folder}/decoder.pkl", map_location=torch.device('cpu') ))

<All keys matched successfully>

# Evaluation

In [5]:
ztta = gen_ztta(length=data["seq_length"]).to(device)

In [6]:
state_encoder.eval()
offset_encoder.eval()
target_encoder.eval()
lstm.eval()
decoder.eval()

Decoder(
  (fc0): Linear(in_features=1536, out_features=512, bias=True)
  (fc1): Linear(in_features=512, out_features=256, bias=True)
  (fc2): Linear(in_features=256, out_features=91, bias=True)
  (fc_conct): Linear(in_features=256, out_features=4, bias=True)
  (ac_sig): Sigmoid()
)

In [11]:
sample_to_save = 20

with torch.no_grad():
    sampled_batch = iter(lafan_loader).next()
    
    # State inputs
    local_q = sampled_batch['local_q'].to(device)
    root_v = sampled_batch['root_v'].to(device)
    contact = sampled_batch['contact'].to(device)

    # Offset inputs
    root_p_offset = sampled_batch['root_p_offset'].to(device)
    local_q_offset = sampled_batch['local_q_offset'].to(device)
    local_q_offset = local_q_offset.view(local_q_offset.size(0), -1)

    # Target inputs
    target = sampled_batch['target'].to(device)
    target = target.view(target.size(0), -1)

    # Root position
    root_p = sampled_batch['root_p'].to(device)

    # X
    X = sampled_batch['X'].to(device)
    
    lstm.init_hidden(local_q.size(0))

    root_pred = None
    local_q_pred = None
    contact_pred = None
    root_v_pred = None
    
    for t in range(lafan.cur_seq_length - 1):
        if t  == 0:
            root_p_t = root_p[:,t]
            local_q_t = local_q[:,t]
            local_q_t = local_q_t.view(local_q_t.size(0), -1)
            contact_t = contact[:,t]
            root_v_t = root_v[:,t]
        else:
            root_p_t = root_pred[0]
            local_q_t = local_q_pred[0]
            contact_t = contact_pred[0]
            root_v_t = root_v_pred[0]

        state_input = torch.cat([local_q_t, root_v_t, contact_t], -1)

        root_p_offset_t = root_p_offset - root_p_t
        local_q_offset_t = local_q_offset - local_q_t
        offset_input = torch.cat([root_p_offset_t, local_q_offset_t], -1)

        target_input = target

        h_state = state_encoder(state_input)
        h_offset = offset_encoder(offset_input)
        h_target = target_encoder(target_input)

        h_state += ztta[:, t]
        h_offset += ztta[:, t]
        h_target += ztta[:, t]

        h_in = torch.cat([h_state, h_offset, h_target], -1).unsqueeze(0)
        h_out = lstm(h_in)

        h_pred, contact_pred = decoder(h_out)
        local_q_v_pred = h_pred[:, :, :88]
        local_q_pred = local_q_v_pred + local_q_t

        local_q_pred_ = local_q_pred.view(local_q_pred.size(0), local_q_pred.size(1), -1, 4)
        local_q_pred_ = local_q_pred_ / torch.norm(local_q_pred_, dim = -1, keepdim = True)

        root_v_pred = h_pred[:,:,88:]
        root_pred = root_v_pred + root_p_t

        pos_pred = skeleton.forward_kinematics(local_q_pred_, root_pred)

        local_q_next = local_q[:,t+1]
        local_q_next = local_q_next.view(local_q_next.size(0), -1)

        # Saving images
        plot_pose(np.concatenate([X[sample_to_save,0].view(22, 3).detach().cpu().numpy(),\
                                pos_pred[0, sample_to_save].view(22, 3).detach().cpu().numpy(),\
                                X[sample_to_save,-1].view(22, 3).detach().cpu().numpy()], 0),\
                                t, './results/temp/pred')
        plot_pose(np.concatenate([X[sample_to_save,0].view(22, 3).detach().cpu().numpy(),\
                                X[sample_to_save,t+1].view(22, 3).detach().cpu().numpy(),\
                                X[sample_to_save,-1].view(22, 3).detach().cpu().numpy()], 0),\
                                t, './results/temp/gt')

In [14]:
save_video("./results/temp/", "./results/small_dataset_test_00.mp4")

Moviepy - Building video ./results/small_dataset_test_00.mp4.
Moviepy - Writing video ./results/small_dataset_test_00.mp4



                                                                                                                                                                                      

Moviepy - Done !
Moviepy - video ready ./results/small_dataset_test_00.mp4


# Plotting functions

In [8]:
def plot_pose(pose, cur_frame, prefix):
    fig = plt.figure(figsize=(8, 8))
    ax = fig.add_subplot(111, projection='3d')

    parents = [-1, 0, 1, 2, 3, 0, 5, 6, 7, 0, 9, 10, 11, 12, 11, 14, 15, 16, 11, 18, 19, 20]
    ax.cla()
    num_joint = pose.shape[0] // 3
    for i, p in enumerate(parents):
        if i > 0:
            ax.plot([pose[i, 0], pose[p, 0]],\
                    [pose[i, 2], pose[p, 2]],\
                    [pose[i, 1], pose[p, 1]], c='r')
            ax.plot([pose[i+num_joint, 0], pose[p+num_joint, 0]],\
                    [pose[i+num_joint, 2], pose[p+num_joint, 2]],\
                    [pose[i+num_joint, 1], pose[p+num_joint, 1]], c='b')
            ax.plot([pose[i+num_joint*2, 0], pose[p+num_joint*2, 0]],\
                    [pose[i+num_joint*2, 2], pose[p+num_joint*2, 2]],\
                    [pose[i+num_joint*2, 1], pose[p+num_joint*2, 1]], c='g')
    # ax.scatter(pose[:num_joint, 0], pose[:num_joint, 2], pose[:num_joint, 1],c='b')
    # ax.scatter(pose[num_joint:num_joint*2, 0], pose[num_joint:num_joint*2, 2], pose[num_joint:num_joint*2, 1],c='b')
    # ax.scatter(pose[num_joint*2:num_joint*3, 0], pose[num_joint*2:num_joint*3, 2], pose[num_joint*2:num_joint*3, 1],c='g')
    xmin = np.min(pose[:, 0])
    ymin = np.min(pose[:, 2])
    zmin = np.min(pose[:, 1])
    xmax = np.max(pose[:, 0])
    ymax = np.max(pose[:, 2])
    zmax = np.max(pose[:, 1])
    scale = np.max([xmax - xmin, ymax - ymin, zmax - zmin])
    xmid = (xmax + xmin) // 2
    ymid = (ymax + ymin) // 2
    zmid = (zmax + zmin) // 2
    ax.set_xlim(xmid - scale // 2, xmid + scale // 2)
    ax.set_ylim(ymid - scale // 2, ymid + scale // 2)
    ax.set_zlim(zmid - scale // 2, zmid + scale // 2)

    plt.draw()
    plt.savefig(f"{prefix}_{cur_frame:02}.png", dpi=200, bbox_inches='tight')
    plt.close()

In [13]:
def save_video(frames_loc, filepath):
    fps = 30
    
    frames = [os.path.join(frames_loc, img) for img in os.listdir(frames_loc) if img.endswith(".png") and img.startswith("pred")]
    clip = moviepy.video.io.ImageSequenceClip.ImageSequenceClip(frames, fps=fps)
    clip.write_videofile(filepath)