In [1]:
import numpy as np
import matplotlib.pyplot as plt
from abc import ABC

%matplotlib qt

class FigureBase(ABC):
    """Base class for `display_figure` option for .core.video.VideoPoseProcessor()"""
    def __init__(self,  ax=None, fig=None, figsize=None,
                 add_subplot=None, axes_3d=False, **kwargs):
        self.artists = {}
        self.setup_figure(ax=ax, fig=fig, figsize=figsize,
                          add_subplot=add_subplot, axes_3d=axes_3d)

    def setup_figure(self, ax=None, fig=None, figsize=None, add_subplot=None, axes_3d=False):
        if ax is None:
            if fig is None:
                self.fig = plt.figure(figsize=figsize)
                self.fig.set_constrained_layout(True)
            else:
                self.fig = fig
            add_subplot = (1, 1, 1) if add_subplot is None else add_subplot
            add_subplot_kwargs = dict(projection='3d') if axes_3d else {}
            self.ax = self.fig.add_subplot(*add_subplot, **add_subplot_kwargs)
        else:
            self.ax = ax
            self.fig = ax.get_figure()

    def clear_artists(self):
        """Remove all artists from plot"""
        for a in self.artists.values():
            try:
                a.remove()
            except:
                pass

    def add(self, artist_key, artist_object):
        """Add artist to plot"""
        self.artists[artist_key] = artist_object


class LandmarkFigure(FigureBase):
    def __init__(self, pose_connections=[], **kwargs):
        """pose_connections: list of connected pairs of landmark index"""
        kwargs = kwargs | dict(axes_3d=True)
        super().__init__(**kwargs)
        self.conns = [list(conn) for conn in pose_connections]
        self.set_axis(**kwargs)
        self.add_plot_func('plot')

    def set_axis(self, axis_limits=[[-1., 1.]] * 3, tick_interval=0.5, view=(10, -90), **kwargs):
        """Setting 3D axes"""
        # ticks = []
        # for lim in axis_limits:
        #     ticks_range = np.asarray(lim) / tick_interval
        #     ticks.append(np.arange(np.ceil(ticks_range[0]), np.floor(ticks_range[1]) + 1) * tick_interval)
        ax = self.ax
        ax.set_xlabel('y')
        ax.set_zlabel('x')  # swap y and z, y downward
        ax.set_ylabel('z')
        ax.set_xlim(axis_limits[1])
        ax.set_zlim(axis_limits[0])
        ax.set_ylim(axis_limits[2])
        # ax.set_xticks(ticks[1])
        # ax.set_zticks(ticks[0])
        # ax.set_yticks(ticks[2])
        view = tuple(view) + (None,) * (3 - len(view))
        ax.view_init(elev=view[0], azim=view[1], roll=view[2])

    def add_plot_func(self, func_name):
        """Decorate axes plotting function to apply transform of axes"""
        axes_func = getattr(self.ax, func_name)
        def func(x, y, z, *args, **kwargs):
            return axes_func(y, z, x, *args, **kwargs)
        setattr(self, func_name, func)

    def plot_landmarks(self, pos, vis_idx=None):
        """Plot landmarks and connections. Need to be called before adding other artists.
        pos: 3d coordinates of landmarks
        vis_idx: boolean array of visible landmarks
        """
        self.clear_artists()
        pos = np.asarray(pos) if vis_idx is None else np.asarray(pos)[vis_idx]
        self.add('landmarks', self.plot(*pos.T, 'b.')[0])
        if vis_idx is None:
            for i, conn in enumerate(self.conns):
                self.add(i, self.plot(*pos[conn].T, 'k', linewidth=.5)[0])
        else:
            for i, conn in enumerate(self.conns):
                if all(vis_idx[conn]):
                    self.add(i, self.plot(*pos[conn].T, 'k', linewidth=.5)[0])


#### Test SVD performance

In [2]:
# from scipy.linalg import svd

# X = lambda *_: np.random.rand(4, 4)

In [3]:
# %%timeit
# _, _, v = svd(X(), full_matrices=False, overwrite_a=True, check_finite=False, lapack_driver='gesvd')

In [4]:
# %%timeit
# _, _, v = svd(X(), full_matrices=False, overwrite_a=True, check_finite=False, lapack_driver='gesdd')

#### Display reconstruction results

In [5]:
kpts_A = np.loadtxt('kpts_3d_A.dat')  # svd on A
kpts_B = np.loadtxt('kpts_3d_B.dat')  # svd on A.T @ A
print(np.max(kpts_A - kpts_B))

kpts = kpts_A
kpts = kpts.reshape((kpts.shape[0], -1, 3))
axis_limits = np.column_stack([np.min(kpts, axis=(0, 1)), np.max(kpts, axis=(0, 1))])

3.268496584496461e-13


In [6]:
from bodypose3d import pose_keypoints

print(pose_keypoints)
kpt_ids = sorted(pose_keypoints)

connections = [[11, 12], [11, 23], [12, 24], [23, 24],
               [11, 13], [13, 15], [12, 14], [14, 16],
               [23, 25], [25, 27], [24, 26], [26, 28]]
connections = [[kpt_ids.index(c1), kpt_ids.index(c2)] for c1, c2 in connections]

view = (15, -105)

[16, 14, 12, 11, 13, 15, 24, 23, 25, 26, 27, 28]


In [7]:
lm_fig = LandmarkFigure(connections, view=view, axis_limits=axis_limits)

plt.show()
plt.ion()

<contextlib.ExitStack at 0x1c4722be690>

In [8]:
fig_closed = False
def close_fig(event):
    global fig_closed
    fig_closed = True

def fig_key(event):
    global fig_closed
    fig_closed = event.key == 'q'

lm_fig.fig.canvas.mpl_connect('key_press_event', fig_key)
lm_fig.fig.canvas.mpl_connect('close_event', close_fig)

for pt in kpts:
    lm_fig.plot_landmarks(pt)
    # plt.draw()
    plt.pause(0.1)
    if fig_closed:
        break