In [47]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from pytorch3d.transforms import quaternion_invert, quaternion_apply, axis_angle_to_quaternion

from torch.utils.data import DataLoader, Dataset

from tqdm.notebook import tqdm
import matplotlib.pyplot as plt

import pandas as pd

force_cpu = True
if not force_cpu:
    if torch.cuda.is_available():
        device = torch.device("cuda")
    elif torch.backends.mps.is_available() and torch.backends.mps.is_built():
        device = torch.device("mps")
    else:
        device = torch.device("cpu")
else:
    device = torch.device("cpu")

print("Using device", device)


class JointDataset(Dataset):
    def __init__(self, x, joint_names):
        super().__init__()
        self.joint_names = joint_names
        self.x = self._transform(self._scale(x))
        self.furthest_distance = None
        self.centroid = None

    def __len__(self):
        return self.x.size(0)

    def __getitem__(self, idx):
        return self.x[idx]

    def _scale(self, x):
        # Center the points and reduce to unit-sphere
        x = x.reshape(-1, 3)
        self.centroid = torch.mean(x, axis=0)
        x -= self.centroid
        self.furthest_distance = torch.max(torch.sqrt(torch.sum(x ** 2, axis=-1)))
        x /= self.furthest_distance
        # Put back into the original shape
        x = x.reshape(-1, 32, 3)
        return x

    def _unscale(self, x):
        if self.furthest_distance is None:
            raise ValueError("Dataset has not been scaled yet")
        x *= self.furthest_distance
        x += self.centroid
        return x
    
    def _transform(self, x):
        """
            Compute the transformation matrix that transforms the reference points into the target points.
            references = torch.from_numpy(np.random.randint(0, 10, size=((n_obs, 3)))).float()
            targets = torch.from_numpy(np.random.randint(0, 10, size=((n_obs, n_joints, 3)))).float()
        """
        
        pelvis_idx = list(self.joint_names).index("PELVIS")
        
        # Get the references
        references = x[:, pelvis_idx]
        
        # Split the tensor into two parts, before and after the pelvis_idx
        first_part = x[:, :pelvis_idx]
        second_part = x[:, pelvis_idx+1:]
        
        # Stack the two parts together
        targets = torch.cat((first_part, second_part), dim=1)
        
        print("targets shape", targets.shape)
        print("references shape", references.shape)
                
        n_joints = targets.size(1)
        
        # Normalize A and B
        normalized_A = references / torch.norm(references, dim=-1, keepdim=True)
        normalized_B = targets / torch.norm(targets, dim=-1, keepdim=True)
        
        # Calculate the axis of rotation
        axis = torch.cross(normalized_A.unsqueeze(1), normalized_B)
        axis = axis / torch.norm(axis, dim=-1, keepdim=True)
        
        # The angle of rotation is the arccosine of the dot product of the normalized vectors
        angle_of_rotation = torch.acos((normalized_A.unsqueeze(1) * normalized_B).sum(-1))
        
        # The axis-angle representation is the axis of rotation scaled by the angle of rotation
        axis_angle = axis * angle_of_rotation.unsqueeze(-1)
        
        # Convert the axis-angle representation to a quaternion
        rotation = axis_angle_to_quaternion(axis_angle)
        
        # Apply the rotation to point A
        A_rotated = quaternion_apply(rotation, references.unsqueeze(1).expand(-1, n_joints, -1))
        
        # Compute the translation vector
        translation = targets - A_rotated
        
        # Apply the translation to point A
        A_transformed = A_rotated + translation

        transformations = torch.cat((rotation, translation), dim=-1)
        flatten_transformations = transformations.reshape(transformations.size(0), -1)
        
        print("transformations shape", flatten_transformations.shape)
        print("references shape", references.shape)
        data = torch.cat((references, flatten_transformations), dim=-1)
        print("data shape", data.shape)
        
        return data
    
    def _untransform(self, x):
        
        transformations = torch.split(x, 4, dim=-1) 

        print(transformations.shape)
        rotation, translation = torch.split(transformation, 4, dim=-1)
        # Apply the rotation to point A
        A_rotated = quaternion_apply(rotation, references.unsqueeze(1).expand(-1, n_joints, -1))

        # Compute the translation vector
        translation = targets - A_rotated

        # Apply the translation to point A
        A_transformed = A_rotated + translation
        
        
    def untransform_and_unscale(self):
        x = self._untransform(self.x)
        x = self._unscale(x)
        return x

def make_joint_dataset(device: torch.device):
    with open("data/unlabelled/camera/joints/front_sit_stand.csv") as f:
        data = pd.read_csv(f, header=0)
        data = data.rename(columns={"x-axis": "x", "y-axis": "y", "z-axis": "z", "joint_names": "joint_name"})
        # print(data.head())

    # Get the unique joint names and frame IDs
    joint_names = np.sort(data["joint_name"].unique())
    frame_ids = data["frame_id"].unique()

    # Create a multi-index using 'frame_id' and 'joint_names'
    data.set_index(['frame_id', 'joint_name'], inplace=True)

    # Sort the index to ensure the data is in the correct order
    data.sort_index(inplace=True)

    # Convert the DataFrame to a NumPy array and reshape it
    x = data[['x', 'y', 'z']].values.reshape((frame_ids.size, joint_names.size, 3))

    # Convert the NumPy array to a PyTorch tensor
    x = torch.from_numpy(x).float().to(device)

    dataset = JointDataset(joint_names=joint_names, x=x)
    return dataset

Using device cpu


In [48]:
trainings_data = make_joint_dataset(device)

targets shape torch.Size([2520, 31, 3])
references shape torch.Size([2520, 3])
transformations shape torch.Size([2520, 217])
references shape torch.Size([2520, 3])
data shape torch.Size([2520, 220])
