# This is a full pipeline with keypoints prediction using GRU-Real-NVP from pytorch-ts in reconstruction mode for VoxCeleb dataset

Grady King

I install pytorch-ts with pytorchSetup.sh

# Import functions

In [1]:
pip install pytorchts gluonts==0.9.3

Note: you may need to restart the kernel to use updated packages.


In [4]:
pip install scikit-image==0.18.3 imageio[ffmpeg] imageio[pyav] ffmpeg

Collecting scikit-image==0.18.3
  Downloading scikit-image-0.18.3.tar.gz (29.2 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m29.2/29.2 MB[0m [31m54.7 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0m
[?25h  Installing build dependencies ... [?25ldone
[?25h  Getting requirements to build wheel ... [?25ldone
[?25h  Preparing metadata (pyproject.toml) ... [?25ldone
Collecting ffmpeg
  Downloading ffmpeg-1.4.tar.gz (5.1 kB)
  Preparing metadata (setup.py) ... [?25ldone
Collecting imageio-ffmpeg (from imageio[ffmpeg])
  Obtaining dependency information for imageio-ffmpeg from https://files.pythonhosted.org/packages/a0/2d/43c8522a2038e9d0e7dbdf3a61195ecc31ca576fb1527a528c877e87d973/imageio_ffmpeg-0.6.0-py3-none-manylinux2014_x86_64.whl.metadata
  Downloading imageio_ffmpeg-0.6.0-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting av (from imageio[ffmpeg])
  Obtaining dependency information for av from https://files.pythonhosted.org/packages/46/b0/6380e

In [3]:
pip install natsort

Collecting natsort
  Obtaining dependency information for natsort from https://files.pythonhosted.org/packages/ef/82/7a9d0550484a62c6da82858ee9419f3dd1ccc9aa1c26a1e43da3ecd20b0d/natsort-8.4.0-py3-none-any.whl.metadata
  Downloading natsort-8.4.0-py3-none-any.whl.metadata (21 kB)
Downloading natsort-8.4.0-py3-none-any.whl (38 kB)
Installing collected packages: natsort
Successfully installed natsort-8.4.0
Note: you may need to restart the kernel to use updated packages.


In [1]:
import os, sys  
from tqdm import tqdm
import torch
from torch.utils.data import DataLoader, Dataset
import numpy as np
import imageio
from tqdm import trange
import tensorflow.compat.v1 as tf
import pickle, gc, yaml
from torch import nn
from torch.autograd import Variable
import random
import matplotlib.pyplot as plt
import pandas as pd
#from gluonts.dataset.pandas import PandasDataset
os.environ["CUDA_VISIBLE_DEVICES"]='0'

2026-01-18 16:15:01.686443: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.cuda.get_device_name()

'NVIDIA GeForce RTX 3090'

# Import keypoints of 44 VoxCeleb test videos

In [3]:
with open("kp_test_44_vox.pkl", "rb") as f:
    kp_time_series = pickle.load(f)
len(kp_time_series)

44

# Convert list of keypoints to dictionary

In [4]:
for video_idx in range(len(kp_time_series)):
    kp_time_series[video_idx] = kp_time_series[video_idx]['kp']

kp_dict_init = []
for video_idx in range(len(kp_time_series)): # 
    init_mean = []
    init_jacobian = []
    for frame_idx in range(len(kp_time_series[video_idx])):
        kp_mean = kp_time_series[video_idx][frame_idx]['value'].reshape(1,10,2)
        kp_mean = torch.tensor(kp_mean)
        kp_jacobian = kp_time_series[video_idx][frame_idx]['jacobian'].reshape(1,10,2,2)
        kp_jacobian = torch.tensor(kp_jacobian)

        init_mean.append(kp_mean)
        init_jacobian.append(kp_jacobian)

    init_mean = torch.cat(init_mean)
    init_jacobian = torch.cat(init_jacobian)

    init_mean = torch.reshape(init_mean,(1,init_mean.shape[0],init_mean.shape[1],init_mean.shape[2]))
    init_jacobian = torch.reshape(init_jacobian,(1,init_jacobian.shape[0],10,2,2))

    if torch.cuda.is_available():
        # add tensor to cuda
        init_mean = init_mean.to('cuda:0')
        init_jacobian = init_jacobian.to('cuda:0')

    kp_dict_both = {"value":init_mean,"jacobian":init_jacobian}
    kp_dict_init.append(kp_dict_both)

# Apply min-max std to keypoints and convert to batches


In [5]:
kp_list_test = []
for video_idx in range(len(kp_dict_init)):
    kp_one_video = torch.cat((kp_dict_init[video_idx]['value'], kp_dict_init[video_idx]['jacobian'].reshape(1,-1,10,4)),dim=-1).reshape(-1,60)
    kp_one_video_array = np.array(kp_one_video.cpu())
    kp_list_test.append(kp_one_video_array)

reduced_keypoint_list_test = [kp[::2] for kp in kp_list_test]
print(len(reduced_keypoint_list_test))
print(reduced_keypoint_list_test[0].shape)
    
#####  min-max std to 60 dimensions of selected one video ######
kp_list_test_std = []
min_list = []
range_list = []
for video_idx in range(len(reduced_keypoint_list_test)):
    data = reduced_keypoint_list_test[video_idx]
    data_length = len(reduced_keypoint_list_test[video_idx])
    step_interval = 8 # choose between 12 frames or 24 frames 
    min_required_steps = 24 #2*step_interval
    selected_data = []
    for i in range(0, data_length - min_required_steps+1, min_required_steps):
        selected_data.extend(data[i:i + step_interval])
    min_values = np.min(selected_data,axis=0) # 60 mins of one selected video in the loop
    max_values = np.max(selected_data,axis=0) # 60 maxs of one selected video in the loop 
    range_values = max_values - min_values 
    kp_one_video_std = (reduced_keypoint_list_test[video_idx] - min_values) / range_values
    kp_list_test_std.append(kp_one_video_std)
    min_list.append(min_values)
    range_list.append(range_values)

test_trajs = kp_list_test_std
print(len(test_trajs))
print(test_trajs[0].shape)

44
(59, 60)
44
(59, 60)


In [6]:
import numpy as np

# Initialize an empty list to store all series
all_series_test = []
# Initialize an empty list to track which video each series comes from
video_indices_test = []

# Loop through each video in kp_list_train_std
for video_idx, video in enumerate(kp_list_test_std):
    num_frames = video.shape[0]
    num_full_series = num_frames // 24 # Number of full 24-frame series in this video

    # Collect each series of 24 frames
    for series_idx in range(num_full_series):
        start_frame = series_idx * 24
        end_frame = start_frame + 24
        series = video[start_frame:end_frame]  # Extract 24-frame series
        all_series_test.append(series)  # Append the 24-frame series to the list
        video_indices_test.append(video_idx)  # Append the video index for this series

# Convert `all_series` to a single numpy array with shape (total_series_count, 24, 60)
all_series_test = np.array(all_series_test)
video_indices_test = np.array(video_indices_test)  # Convert video indices to a numpy array

print(f"Shape of all_series: {all_series_test.shape}")  # Expected: (35325, 24, 60)
print(f"Shape of video_indices: {video_indices_test.shape}")  # Expected: (35325,)

Shape of all_series: (255, 24, 60)
Shape of video_indices: (255,)


In [7]:
# Adjust start_time and time_delta
from torch.utils.data import Dataset

# Define start date and frequency
start_time = pd.Timestamp("1999-05-01 00:00:00", freq = "1s")  # Ensure freq is set
freq = "1s"

# Create the CustomDataset ensuring proper time intervals
class CustomDataset(Dataset):
    def __init__(self, features, start_date, freq):
        self.features = features
        self.start_date = start_date
        self.freq = freq
        self.timestamps = pd.date_range(start=self.start_date, periods=len(features), freq=self.freq)

    def __len__(self):
        return len(self.features)

    def __getitem__(self, idx):
        transposed_features = self.features[idx].T
        return {
            'start': self.timestamps[idx],  # Using pd.date_range for intervals with freq
            'target': transposed_features
        }

# Create the train, validation, and test datasets with proper time intervals
#train_ds = CustomDataset(features=all_series, start_date=start_time, freq="1s")
test_ds = CustomDataset(features=all_series_test, start_date=start_time, freq="1s")

  start_time = pd.Timestamp("1999-05-01 00:00:00", freq = "1s")  # Ensure freq is set


## For loading trained model and inference

In [9]:
from gluonts.dataset.multivariate_grouper import MultivariateGrouper
from gluonts.dataset.repository.datasets import dataset_recipes, get_dataset
from tempflow_estimator_SNF import TempFlowEstimator
from tempflow_network_SNF import TempFlowPredictionNetwork
from gluonts.evaluation.backtest import make_evaluation_predictions
from gluonts.evaluation import MultivariateEvaluator	
from pts.model.utils import get_module_forward_input_names
from gluonts.torch.model.predictor import PyTorchPredictor
# 1. Load state dict from old model
predictor = torch.load("Checkpoints/GRU-NF_3883videos_vox_8-16.pth")
#predictor = torch.load("Checkpoints/GRU-NF_syntheticdata_10-14.pth")
state_dict = predictor.prediction_net.state_dict()

# 2. Create new model with SNF
snf_model = TempFlowPredictionNetwork(
    num_parallel_samples=100,  # or any desired number
    target_dim=60,
    prediction_length=16,
    cell_type='GRU',
    num_layers = 3,
    num_cells = 512,
    flow_type="SRealNVP",
    hidden_size=512,
    n_hidden=3,
    conditioning_length = 1,
    dropout_rate = 0.2,
    input_size=60,
    context_length=8,
    scaling=False,
    history_length=8,
    lags_seq=[],    
    dequantize=False,
    n_blocks=5 # all other args
).to(device)

# 3. Load weights from RealNVP-based model
snf_model.load_state_dict(predictor.prediction_net.state_dict(), strict=False)

# 4. Wrap SNF model in PyTorchPredictor using original transform
input_names = get_module_forward_input_names(snf_model)

snf_predictor = PyTorchPredictor(
    input_transform=predictor.input_transform,
    input_names=input_names,
    prediction_net=snf_model,
    batch_size=predictor.batch_size,
    freq=predictor.freq,
    prediction_length=predictor.prediction_length,
    device=predictor.device,
)

# 4. Use the SNF model for inference
forecast_it, ts_it = make_evaluation_predictions(
    dataset=test_ds,
    predictor=snf_predictor,
    num_samples=100,
)

In [10]:
# Process the predictions
forecasts = list(forecast_it)
targets = list(ts_it)
evaluator = MultivariateEvaluator()
agg_metric, _ = evaluator(targets, forecasts)

  return _shift_timestamp_helper(ts, ts.freq, offset)
  date_before_forecast = forecast.index[0] - forecast.index[0].freq
Running evaluation: 255it [00:00, 314.99it/s]
  date_before_forecast = forecast.index[0] - forecast.index[0].freq
Running evaluation: 255it [00:00, 314.43it/s]
  date_before_forecast = forecast.index[0] - forecast.index[0].freq
Running evaluation: 255it [00:00, 315.07it/s]
  date_before_forecast = forecast.index[0] - forecast.index[0].freq
Running evaluation: 255it [00:00, 315.36it/s]
  date_before_forecast = forecast.index[0] - forecast.index[0].freq
Running evaluation: 255it [00:00, 315.79it/s]
  date_before_forecast = forecast.index[0] - forecast.index[0].freq
Running evaluation: 255it [00:00, 320.56it/s]
  date_before_forecast = forecast.index[0] - forecast.index[0].freq
Running evaluation: 255it [00:00, 319.59it/s]
  date_before_forecast = forecast.index[0] - forecast.index[0].freq
Running evaluation: 255it [00:00, 316.99it/s]
  date_before_forecast = forecast.

In [11]:
print("CRPS: {}".format(agg_metric['mean_wQuantileLoss']))
print("ND: {}".format(agg_metric['ND']))
print("NRMSE: {}".format(agg_metric['NRMSE']))
print("MSE: {}".format(agg_metric['MSE']))

CRPS: 0.3748036171651725
ND: 0.47961225223755605
NRMSE: 0.6273260731986486
MSE: 0.1059282181328893


In [10]:
# Step 1: Collect all `.samples` into a list
forecast_samples_list = [forecast.samples for forecast in forecasts]

# Step 2: Convert the list to a single NumPy array
forecast_samples_array = np.array(forecast_samples_list)  # Shape: (n_forecasts, batch_size, n_features, n_timesteps)
print("Shape of forecast_samples_array:", forecast_samples_array.shape)

# Step 3: Expand `test_data_reshape` to add a new axis
expanded_all_series_test = np.expand_dims(all_series_test, axis=1)  # Shape: (529, 1, 12, 60)

# Step 4: Tile `expanded_all_series_test` to match the shape of `forecast_samples_array`
tiled_all_series_test = np.tile(expanded_all_series_test, (1, forecast_samples_array.shape[1], 1, 1))  # Shape: (529, 50, 12, 60)

# Step 5: Concatenate along the time axis
test_gt_pred = np.concatenate((tiled_all_series_test[:,:,:8], forecast_samples_array), axis=2)  # Shape: (529, 50, 24, 60)

print("Final Shape of test_gt_pred:", test_gt_pred.shape)

Shape of forecast_samples_array: (255, 100, 16, 60)
Final Shape of test_gt_pred: (255, 100, 24, 60)


In [11]:
pip install dcor

Collecting dcor
  Obtaining dependency information for dcor from https://files.pythonhosted.org/packages/a4/70/d82c194d53d684b6e75a228170a36f414cc86f5824693f6b0e443032461d/dcor-0.7-py3-none-any.whl.metadata
  Downloading dcor-0.7-py3-none-any.whl.metadata (8.1 kB)
Collecting array-api-compat (from dcor)
  Obtaining dependency information for array-api-compat from https://files.pythonhosted.org/packages/df/5d/493b1b5528ab5072feae30821ff3a07b7a0474213d548efb1fdf135f85c1/array_api_compat-1.13.0-py3-none-any.whl.metadata
  Downloading array_api_compat-1.13.0-py3-none-any.whl.metadata (2.5 kB)
Downloading dcor-0.7-py3-none-any.whl (44 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m44.6/44.6 kB[0m [31m3.5 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading array_api_compat-1.13.0-py3-none-any.whl (58 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m58.6/58.6 kB[0m [31m10.4 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: array-api-compat,

In [None]:
import dcor

# Flatten the (24, 60) sequences to 1440-D vectors
gt_flat = all_series_test.reshape(255, -1)         # shape: (200, 1440)
pred_flat = test_gt_pred.reshape(-1, 24*60)    # shape: (200*100 = 20000, 1440)

# Compute energy distance between the full ground truth and predicted distributions
ed = dcor.energy_distance(gt_flat, pred_flat)

print("Energy Distance between ground truth and predicted keypoint distributions:", ed)

## unstandardization

In [18]:
# save num_batches for each video:
frames = 24
num_batch_video = []
num_full_batches_all = 0
for t,x in enumerate(kp_list_test_std):
    if x.shape[0] >= frames:
        num_full_batches = x.shape[0] // frames
        num_full_batches_all += num_full_batches
        num_batch_video.append(num_full_batches)
print(f'number of batches of each video:', len(num_batch_video))

number of batches of each video: 44


In [19]:
test_video_unstd_list = []

# Outer loop: Iterate over all videos
for video_idx in range(len(num_batch_video)):
    video_segments_list = []  # List to store 50 sets for the current video
    
    # Extract the segments for the current video
    start_idx = sum(num_batch_video[:video_idx])
    end_idx = sum(num_batch_video[:video_idx + 1])
    test_video = test_gt_pred[start_idx:end_idx]  # Shape: (num_segments, 50, 24, 60)
    
    # Inner loop: Process each of the 50 samples
    for sample_idx in range(test_video.shape[1]):  # 50 samples
        test_sample = test_video[:, sample_idx, :, :]  # Shape: (num_segments, 24, 60)
        
        # Unstandardize using the corresponding range and min
        test_sample_unstd = test_sample * range_list[video_idx] + min_list[video_idx]
        
        # Append the unstandardized sample to the current video's set list
        video_segments_list.append(test_sample_unstd)  # Shape: (num_segments, 24, 60)
    
    # Append the 50 sets for the current video to the main list
    test_video_unstd_list.append(video_segments_list)

# Final list structure:
# test_video_unstd_list[video_idx][sample_idx]: Shape (num_segments, 24, 60)
print(f"Total videos: {len(test_video_unstd_list)}")
print(f"First video has {len(test_video_unstd_list[0])} sets, each with shape {test_video_unstd_list[0][0].shape}")

Total videos: 44
First video has 100 sets, each with shape (2, 24, 60)


In [20]:
import pickle

# Save the test_gt_pred array as a .pkl file
with open("GRU-SNF_vox8-16_test_video_unstd_list_100_mcmc.pkl", "wb") as f:
    pickle.dump(test_video_unstd_list, f)

print("test_video_unstd_list has been saved as 'GRU-SNF_vox8-16_test_video_unstd_list_100_mcmc.pkl'.")

test_video_unstd_list has been saved as 'GRU-SNF_vox10-14_test_video_unstd_list_100_reduced_framerate_mcmc.pkl'.
