In [None]:
# imports

import torch
import matplotlib.pyplot as plt
from sportspose.dataset import SportsPoseDataset
from sportspose.plot_skeleton_3d import plot_skeleton_3d

In [None]:
### Update path to path where data is downloaded ###
datapath = "data/SportsPose"

In [None]:
# loading datasets
dataset_pervideo = SportsPoseDataset(
    data_dir=datapath,
    sample_level="video",
)
dataset_perframe = SportsPoseDataset(
    data_dir=datapath,
    sample_level="frame",
)

print(f"{len(dataset_pervideo)} videos were found for a total of {len(dataset_perframe)} frames!")

In [None]:
# loading example from dataset
sample = dataset_perframe[300]

# load frame
frameRight = sample["video"]["image"]["right"][0]
# load 2D reprojection to "right" view
jointsRight = sample["joints_2d"]["right"][0]
# load 3D joints
joints3D = sample["joints_3d"]["data_points"][0]

In [None]:
# plotting both 3d view and 2d reprojection
f = plt.figure(figsize=(10,5))
ax2d = f.add_subplot(1,2,1)
ax3d = f.add_subplot(1,2,2,projection="3d")
ax2d.imshow(frameRight)
ax2d.scatter(jointsRight[:, 0], jointsRight[:, 1])
ax2d.axis('off')
plot_skeleton_3d(joints3D, ax=ax3d)
plt.tight_layout()
plt.show()

In [None]:
# Example of picking sequences based off requirements
dataset_outdoors = SportsPoseDataset(
    data_dir=datapath,
    sample_level="video",
    whitelist={
        "metadata": 
            {"tag":"outdoors",
             "person_id": "S22",
             "activity": "jump"},
    }
)
dataloader = torch.utils.data.DataLoader(dataset_outdoors)  

In [None]:
# Show all frames from this whitelist
fig, axs = plt.subplots(1, len(dataloader), figsize=(5*len(dataloader), 5))
for i, sample in enumerate(dataloader):
    frame = sample["video"]["image"]["right"][0,0]
    axs[i].imshow(frame)
    axs[i].axis('off')

plt.show()  