In [15]:
# Pytorch utilities
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import torch.optim as optim
from torch.utils.data import TensorDataset, DataLoader

from scipy.spatial.transform import Rotation as R
#from sklearn.linear_model import LinearRegression
#from sklearn.preprocessing import PolynomialFeatures
#from sklearn.pipeline import make_pipeline

# Plotting utilities
%matplotlib widget
import matplotlib
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from IPython.display import HTML
import matplotlib.animation as animation
from torch.utils.tensorboard import SummaryWriter
from timeit import default_timer as timer
import pyprind

# Directory and file utilities
from os import listdir
from os.path import isfile, isdir, join
import json

In [16]:
with open('../../results/small_body_S.json', 'r') as j:
    jd = json.load(j)
    tr, val, test = jd['train'], jd['validation'], jd['test']
    tr_inputs, tr_predictions, tr_groundtruth, tr_lengths = tuple(torch.tensor(tr[n]) for n in ['inputs', 'predictions',
                                                                                              'labels', 'lengths'])
    val_inputs, val_predictions, val_groundtruth, val_length = tuple(torch.tensor(val[n]) for n in ['inputs', 'predictions',
                                                                                              'labels', 'lengths'])
    test_inputs, test_predictions, test_groundtruth, test_lengths = tuple(torch.tensor(test[n]) for n in ['inputs', 'predictions',
                                                                                              'labels', 'lengths'])

In [10]:
tr_inputs.shape, val_predictions.shape, test_lengths.shape

(torch.Size([288, 250, 52]), torch.Size([32, 250, 26]), torch.Size([38]))

In [11]:
frames = [i for i in range(1,9,2)]
video_n = 34

c_inputs = tr_inputs.clone()
c_output = tr_predictions.clone()
c_labels = tr_groundtruth.clone()


In [None]:
for vid in range(c_labels.shape[0]): 
    c_inputs[vid,:,::2].mul_(tr_inp_scale[vid, 0])
    c_inputs[vid,:,1::2].mul_(tr_inp_scale[vid, 1])
    c_output[vid].mul_(tr_out_scale[vid])
    c_labels[vid].mul_(tr_out_scale[vid])
    
    c_inputs[vid,:,::2].mul_(tr_mx[vid, 1])
    c_inputs[vid,:,1::2].mul_(tr_my[vid, 1])
    c_output[vid].mul_(tr_mz[vid])
    c_labels[vid].mul_(tr_mz[vid])

In [12]:
from matplotlib.animation import FuncAnimation

fig = plt.figure()
fig.set_tight_layout(True)
ax = fig.add_subplot(111, projection='3d')
# Query the figure's on-screen size and DPI. Note that when saving the figure to
# a file, we need to provide a DPI for that separately.
print('fig size: {0} DPI, size in inches {1}'.format(
    fig.get_dpi(), fig.get_size_inches()))

frames = [i for i in range(1,60)]
inp = c_inputs
preds = c_output
bodiesXY = torch.chunk(inp[video_n, frames, :], len(frames), dim=0)
pred_bodiesZ = torch.chunk(preds[video_n, frames, :], len(frames), dim=0)
def update(j):
    
    label = 'timestep {0}'.format(j)
    # Update the line and the axes (with a new xlabel). Return a tuple of
    # "artists" that have to be redrawn for this frame.
    
    x = bodiesXY[j].squeeze()[::2].tolist()
    y = bodiesXY[j].squeeze()[1::2].tolist()
    pred_z = pred_bodiesZ[j].squeeze().tolist()

    r = R.from_euler('y', -60, degrees=True)

    xyz1 = np.asarray([c for c in zip(x, y, pred_z)])
    xyz1 = r.apply(xyz1)
    x1 = xyz1[:,0]
    y1 = xyz1[:,1]
    pred_z = xyz1[:,2]

    r_arm = [[c[i] for i in [1, 0, 9, 10, 11]] for c in [x1, y1, pred_z]]
    l_arm = [[c[i] for i in [0, 3, 4, 5]] for c in [x1, y1, pred_z]]
    r_leg = [[c[i] for i in [0, 2, 12, 13, 14, 22, 23, 24]] for c in [x1, y1, pred_z]]
    l_leg = [[c[i] for i in [2, 6, 7, 8, 19, 20, 21]] for c in [x1, y1, pred_z]]
    head = [[c[i] for i in [18, 17, 1, 15, 16]] for c in [x1, y1, pred_z]]

    ax.set_xdata(r_arm[0], r_arm[1], r_arm[2])
    ax.plot(l_arm[0], l_arm[1], l_arm[2])
    ax.plot(r_leg[0], r_leg[1], r_leg[2])
    ax.plot(l_leg[0], l_leg[1], l_leg[2])
    ax.plot(head[0], head[1], head[2])

    lims = ax.get_xlim(), ax.get_ylim(), ax.get_zlim()
    spans = lims[0][1]-lims[0][0], lims[1][1]-lims[1][0], lims[2][1]-lims[2][0]
    span = max(spans)
    margins = [(span-s)/2 for  s in spans]
    ax.set_xlim(lims[0][0]-margins[0], lims[0][1]+margins[0])
    ax.set_ylim(lims[1][0]-margins[1], lims[1][1]+margins[1])
    ax.set_zlim(lims[2][0]-margins[2], lims[2][1]+margins[2])

    ax.view_init(elev=-65., azim=-90.)
    
    ax.set_xlabel(label)
    
    return line, ax

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

fig size: 100.0 DPI, size in inches [6.4 4.8]


In [13]:
# FuncAnimation will call the 'update' function for each frame; here
# animating over 10 frames, with an interval of 200ms between frames.
anim = FuncAnimation(fig, update, frames=np.arange(0, 60), interval=100)
    # plt.show() will just loop the animation forever.
plt.show()

In [14]:
anim.save('body.gif', dpi=80, writer='imagemagick')


TypeError: 'MovieWriterRegistry' object is not an iterator