In [2]:
# Enable interactive plot
#@formatter:off
%matplotlib inline
%load_ext autoreload
%autoreload 2
#@formatter:on

import torch
from torch import nn
from torch import Tensor
import networkx as nx
from itertools import combinations

import pandas as pd
from datasets.RSO_LModule import RSO_LModule
import os

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [3]:
DEBUG = False
SAMPLE_RATE = 30.0

plot_length = int(SAMPLE_RATE * 60) # ie one minute

# share
base_path = "/share/data/yhartmann/data/ma-jonah/"
# local storage
local_path = "/home/yale1/ma-jonah-data/"
if os.path.exists(local_path):
    base_path = local_path
print(base_path)


/home/yale1/ma-jonah-data/


# Load Data

In [7]:
data_module = RSO_LModule(data_dir=base_path, batch_size=1, n_jobs=1, debug=DEBUG, in_mem=False)
all_datasets = data_module._load_datasets(sessions=[1])
all_optitrack_data = pd.concat([d.optitrack_data for d in all_datasets])

  0%|          | 0/1 [00:00<?, ?it/s]

100%|██████████| 1/1 [00:06<00:00,  6.66s/it]


In [8]:
all_optitrack_data

Unnamed: 0_level_0,Ab,Ab,Ab,Chest,Chest,Chest,Head,Head,Head,Hip,...,RShoulder,RThigh,RThigh,RThigh,RToe,RToe,RToe,RUArm,RUArm,RUArm
Unnamed: 0_level_1,Position,Position,Position,Position,Position,Position,Position,Position,Position,Position,...,Position,Position,Position,Position,Position,Position,Position,Position,Position,Position
Frame,X,Y,Z,X,Y,Z,X,Y,Z,X,...,Z,X,Y,Z,X,Y,Z,X,Y,Z
0,34.911232,102.498116,16.107584,35.102566,117.664268,18.061092,34.254826,147.644104,15.279250,34.697842,...,21.951651,43.865959,94.707199,16.072893,50.566929,4.496662,14.127948,54.588383,133.739410,20.517035
1,34.911255,102.495071,16.100935,35.107567,117.658165,18.054359,34.251488,147.640305,15.268892,34.697655,...,21.939573,43.865757,94.703903,16.067360,50.564671,4.498981,14.126810,54.595005,133.734116,20.504158
2,34.909985,102.494904,16.090366,35.108509,117.658875,18.037148,34.243488,147.634247,15.236509,34.696449,...,21.919186,43.864536,94.703789,16.058727,50.565918,4.498118,14.127122,54.599728,133.731659,20.484558
3,34.906654,102.494957,16.074156,35.108437,117.658401,18.013845,34.242786,147.629715,15.194829,34.692921,...,21.888792,43.860931,94.703514,16.041416,50.564911,4.497473,14.125580,54.606785,133.724533,20.458763
4,34.904053,102.495766,16.053795,35.104115,117.659447,17.983305,34.227776,147.625946,15.145537,34.689236,...,21.849388,43.857185,94.702995,16.025818,50.565292,4.495665,14.126169,54.610489,133.721817,20.424810
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
26694,2.572512,99.832474,18.930128,2.874780,114.380394,21.297070,2.113917,144.455612,19.022125,2.678054,...,23.956131,11.365096,92.335617,16.214571,18.782452,4.465117,31.220150,22.468884,128.610260,18.672184
26695,1.143206,100.090553,16.349379,1.227071,114.649902,18.697294,0.087116,144.747742,16.583263,1.240797,...,21.425528,9.943723,92.565964,13.738190,18.770275,4.695205,30.950472,20.674820,129.070221,16.253668
26696,-0.411177,100.690475,13.786022,-0.474266,115.220810,16.118238,-1.935193,145.310287,14.152716,-0.304639,...,18.942808,8.396120,93.172768,11.177876,18.728424,5.116483,30.399231,18.870667,129.872849,13.953614
26697,-2.095114,101.285912,11.412931,-2.163943,115.822411,13.650352,-3.830245,145.912155,11.820728,-1.959797,...,16.583031,6.730126,93.811684,8.741262,18.694496,5.645886,29.356703,17.096020,130.702484,11.786323


# Reference implementation

In [9]:
parents = {
    'Ab': 'Hip',
    'Chest': 'Ab',
    'Head': 'Neck',
    'Hip': 'Hip',
    'LFArm': 'LUArm',
    'LFoot': 'LShin',
    'LHand': 'LFArm',
    'LShin': 'LThigh',
    'LShoulder': 'Chest',
    'LThigh': 'Hip',
    'LToe': 'LFoot',
    'LUArm': 'LShoulder',
    'Neck': 'Chest',
    'RFArm': 'RUArm',
    'RFoot': 'RShin',
    'RHand': 'RFArm',
    'RShin': 'RThigh',
    'RShoulder': 'Chest',
    'RThigh': 'Hip',
    'RToe': 'RFoot',
    'RUArm': 'RShoulder'
}


def get_joint_index(joint_name: str):
    return list(parents.keys()).index(joint_name)


parents_index = {get_joint_index(x): get_joint_index(parents[x]) for x in list(parents.keys())}

graph = nx.Graph()
graph_index = nx.Graph()
graph.add_edges_from([(x, parents[x]) for x in parents.keys()])
graph_index.add_edges_from([(x, parents_index[x]) for x in parents_index.keys()])
all_shortest_paths = dict(nx.all_pairs_shortest_path(graph))
all_shortest_paths_index = dict(nx.all_pairs_shortest_path(graph_index))


def sgn(path, m: int):
    return 1 if parents_index[path[m]] == path[m + 1] else -1


def get_ith_bone(input: Tensor, i: int) -> Tensor:
    batch_size = input.shape[0]
    i *= 3
    bone_index = torch.tensor([[i, i + 1, i + 2]], requires_grad=False).expand(batch_size, -1)
    return input.gather(1, bone_index)


def get_long_range_relative_position(input: Tensor, joint_u: int, joint_v: int) -> Tensor:
    path = all_shortest_paths_index[joint_u][joint_v]
    result = torch.zeros([input.shape[0], 3], requires_grad=False)
    for m in range(len(path) - 1):
        bone_m = get_ith_bone(input, path[m])
        result += sgn(path, m) * bone_m
    return result


def compose_output(outputs: Tensor) -> Tensor:
    results = torch.zeros(outputs.shape)
    for output, result in zip(outputs, results):
        joints = output.split(3)
        for i, joint in enumerate(joints):
            for bone_path_index in all_shortest_paths_index[3][i]:
                result[i*3:(i+1)*3] += joints[bone_path_index]
    return results
  

class CompositionalLoss_ref(nn.Module):
    def __init__(self, reduction='mean'):
        super().__init__()
        self.P = list(combinations(parents_index.keys(), r=2))
        self.avg_reduction = reduction == 'mean'

    def forward(self, input: Tensor, target: Tensor) -> Tensor:
        loss = torch.zeros(1)

        for u, v in self.P:
            delta_j = get_long_range_relative_position(input, int(u), int(v))
            delta_j_gt = get_ith_bone(target, u) - get_ith_bone(target, v)

            delta_batch = delta_j - delta_j_gt
            
            delta_norms = torch.zeros(delta_batch.shape[0])
            for i, delta in enumerate(delta_batch):
                delta_norms[i] = torch.linalg.norm(delta, ord=1)
            
            if self.avg_reduction:
                loss += delta_norms.mean()
            else:
                loss += delta_norms.sum()
        return loss

# New Implementation

In [10]:
# list(combinations(parents_index.keys(), r=2))

In [11]:
all_shortest_paths_index

{0: {0: [0],
  3: [0, 3],
  1: [0, 1],
  9: [0, 3, 9],
  18: [0, 3, 18],
  8: [0, 1, 8],
  12: [0, 1, 12],
  17: [0, 1, 17],
  7: [0, 3, 9, 7],
  16: [0, 3, 18, 16],
  11: [0, 1, 8, 11],
  2: [0, 1, 12, 2],
  20: [0, 1, 17, 20],
  5: [0, 3, 9, 7, 5],
  14: [0, 3, 18, 16, 14],
  4: [0, 1, 8, 11, 4],
  13: [0, 1, 17, 20, 13],
  10: [0, 3, 9, 7, 5, 10],
  19: [0, 3, 18, 16, 14, 19],
  6: [0, 1, 8, 11, 4, 6],
  15: [0, 1, 17, 20, 13, 15]},
 3: {3: [3],
  0: [3, 0],
  9: [3, 9],
  18: [3, 18],
  1: [3, 0, 1],
  7: [3, 9, 7],
  16: [3, 18, 16],
  8: [3, 0, 1, 8],
  12: [3, 0, 1, 12],
  17: [3, 0, 1, 17],
  5: [3, 9, 7, 5],
  14: [3, 18, 16, 14],
  11: [3, 0, 1, 8, 11],
  2: [3, 0, 1, 12, 2],
  20: [3, 0, 1, 17, 20],
  10: [3, 9, 7, 5, 10],
  19: [3, 18, 16, 14, 19],
  4: [3, 0, 1, 8, 11, 4],
  13: [3, 0, 1, 17, 20, 13],
  6: [3, 0, 1, 8, 11, 4, 6],
  15: [3, 0, 1, 17, 20, 13, 15]},
 1: {1: [1],
  0: [1, 0],
  8: [1, 8],
  12: [1, 12],
  17: [1, 17],
  3: [1, 0, 3],
  11: [1, 8, 11],
  2: [1,

In [24]:
class CompositionalLoss_new(nn.Module):
    def __init__(self, reduction='mean'):
        super().__init__()
        parents = {
            'Ab': 'Hip',
            'Chest': 'Ab',
            'Head': 'Neck',
            'Hip': 'Hip',
            'LFArm': 'LUArm',
            'LFoot': 'LShin',
            'LHand': 'LFArm',
            'LShin': 'LThigh',
            'LShoulder': 'Chest',
            'LThigh': 'Hip',
            'LToe': 'LFoot',
            'LUArm': 'LShoulder',
            'Neck': 'Chest',
            'RFArm': 'RUArm',
            'RFoot': 'RShin',
            'RHand': 'RFArm',
            'RShin': 'RThigh',
            'RShoulder': 'Chest',
            'RThigh': 'Hip',
            'RToe': 'RFoot',
            'RUArm': 'RShoulder'
        }
        
        parents_index = [list(parents.keys()).index(parents[x]) for x in parents.keys()]
        self.parents_index_dict = parents_index

        graph_index = nx.Graph()
        graph_index.add_edges_from([(x, parents_index[x]) for x in range(len(parents_index))])
        self.all_shortest_paths_index = dict(nx.all_pairs_shortest_path(graph_index))

        self.P = list(combinations(range(len(parents)), r=2))
        self.avg_reduction = reduction == 'mean'

    def sgn(self, path, m: int):
        return 1 if self.parents_index_dict[path[m]] == path[m + 1] else -1

    @staticmethod
    def get_ith_bone_idx(i: int):
        return [start_index := i * 3, start_index + 1 , start_index + 2]

    def get_long_range_relative_position(self, input: Tensor, joint_u: int, joint_v: int) -> Tensor:
        path = self.all_shortest_paths_index[joint_u][joint_v]
        idx = [self.get_ith_bone_idx(i) for i in path[:-1]]
        sign = torch.tensor([[self.sgn(path, m)] * 3 for m in range(len(path) - 1)], requires_grad=False)
        return (input[:, idx] * sign).sum(dim=1)
        
    def forward(self, input: Tensor, target: Tensor) -> Tensor:
        delta_batch = []

        # Todo: consider pre-computing the index and sign for each pair of joints as the pairs are already determined on init
        for u, v in self.P:
            delta_j = self.get_long_range_relative_position(input, int(u), int(v))
            # delta_j = input[:, self.get_ith_bone_idx(u)] - input[:, self.get_ith_bone_idx(v)]
            delta_j_gt = target[:, self.get_ith_bone_idx(u)] - target[:, self.get_ith_bone_idx(v)]
            delta_batch.append(delta_j - delta_j_gt)

        # stacked shape: 210, 2000, 3
        stacked = torch.stack(delta_batch, dim=0)
        # norm shape: 210, 2000
        delta_norms = torch.linalg.norm(stacked, ord=1, dim=2)

        if self.avg_reduction:
            # mean shape: 210; sum shape: 1
            return delta_norms.mean(dim=1).sum()
        return delta_norms.sum()


# Calculate Compositional Loss

In [13]:
input_tensor = torch.from_numpy(all_optitrack_data.to_numpy())
input_tensor.shape

torch.Size([26187, 63])

In [25]:
%%time

new_loss = CompositionalLoss_new()
new_loss(input_tensor, input_tensor)

CPU times: user 22 s, sys: 162 ms, total: 22.2 s
Wall time: 926 ms


tensor(50275.8516)

In [15]:
break

SyntaxError: 'break' outside loop (668683560.py, line 1)

In [None]:
%%time

new_loss = CompositionalLoss_new()
new_loss(input_tensor.to(dtype=torch.float16), input_tensor.to(dtype=torch.float16))

CPU times: user 24.1 s, sys: 329 ms, total: 24.4 s
Wall time: 1.07 s


tensor(inf, dtype=torch.float16)

In [None]:
%%time

new_loss = CompositionalLoss_new()
new_loss(input_tensor.double(), input_tensor.double())

CPU times: user 22.1 s, sys: 550 ms, total: 22.7 s
Wall time: 985 ms


tensor(105003.9251, dtype=torch.float64)

In [None]:
break

In [None]:
%%time

reference_loss = CompositionalLoss_ref()
reference_loss(input_tensor, input_tensor)

CPU times: user 16min 29s, sys: 31.1 s, total: 17min
Wall time: 56.4 s


tensor([105003.9062])