# 3D Pose Viz

In [None]:
import pyglet
from pyglet.gl import *
from pyglet.window import key
from pyglet import shapes, text

import numpy as np
import cv2

import os
import json
import math
import yaml
from pathlib import Path
from collections import Counter

%reload_ext autoreload
%autoreload 2

## Set the env parameters..

In [None]:
RENDER_FPS = 50
RENDER_INTERVAL = 1 / RENDER_FPS
WINDOW_WIDTH = 1920
WINDOW_HEIGHT = 1080

FILE_PATH = "./samples/running_pose_smooth.json"
FILE_PATH = "./samples/throwing_pose_smooth.json"

file_root, _ = os.path.splitext(FILE_PATH)
VIDEO_WRITER = cv2.VideoWriter(file_root + ".mp4", cv2.VideoWriter_fourcc(*'mp4v'), RENDER_FPS, (WINDOW_WIDTH, WINDOW_HEIGHT))

# Play the animation automatically
AUTO_PLAY = False
# Show the keystroke information
SHOW_INFO = True
# Record the sample to a video
RECORD_MODE = False



In [None]:
ASPSET_KEYPOINT_NAMES = np.array([
    'right_ankle', 'right_knee', 'right_hip', 'right_wrist', 'right_elbow', 'right_shoulder',
    'left_ankle', 'left_knee', 'left_hip', 'left_wrist', 'left_elbow', 'left_shoulder',
    'head_top', 'head', 'neck', 'spine', 'pelvis', 'right_toe_base', 'right_heel', 'left_toe_base', 'left_heel'
])

ASPSET_JOINT_PAIRS = np.array([
    ['right_ankle', 'right_knee'],
    ['right_knee', 'right_hip'],
    ['right_hip', 'pelvis'],
    ['left_hip', 'pelvis'],
    ['left_knee', 'left_hip'],
    ['left_ankle', 'left_knee'],
    ['pelvis', 'spine'],
    ['spine', 'neck'],
    ['neck', 'head'],
    ['head', 'head_top'],
    ['right_wrist', 'right_elbow'],
    ['right_elbow', 'right_shoulder'],
    ['right_shoulder', 'neck'],
    ['left_shoulder', 'neck'],
    ['left_wrist', 'left_elbow'],
    ['left_elbow', 'left_shoulder'],
    ['right_ankle', 'right_heel'],
    ['right_heel', 'right_toe_base'],
    ['left_ankle', 'left_heel'],
    ['left_heel', 'left_toe_base'],
])


pose_data = None
joint_wise_axis_means = np.array([0, 0, 0])
global_axis_means = np.array([0, 0, 0])
global_axis_max = np.array([0, 0, 0])
global_axis_min = np.array([0, 0, 0])

frame = 0

with open(FILE_PATH, 'r') as f:
    json_data = f.read()

pose_data = json.loads(json_data)

# Initialize lists to accumulate data
all_keypoints = []

# Iterate through each pose in data
for pose in pose_data:
    keypoints = pose['data']['data']
    if len(keypoints) > 0:
        # Append keypoints of current pose to all_keypoints
        all_keypoints.extend(keypoints)

# Convert all_keypoints to a NumPy array for easier manipulation
data_np = np.array(all_keypoints)

# Calculate mean along axis 0 to get global mean of each (x, y, z) tuple
joint_wise_axis_means = data_np.mean(axis=0)
global_axis_means = np.mean(data_np, axis=(0, 1))
global_axis_max = np.max(data_np, axis=(0, 1))
global_axis_min = np.min(data_np, axis=(0, 1))

print("Joint_wise_axis_means (x, y, z):")
print(joint_wise_axis_means)
print("Global_axis_means (x, y, z):")
print(global_axis_means)
print("Global_axis_min (x, y, z):")
print(global_axis_min)
print("Global_axis_max (x, y, z):")
print(global_axis_max)


# Create OpenGL instance

In [None]:
global file_root

glEnable(GL_DEPTH_TEST)
glEnable(GL_LINE_SMOOTH)
camera_position = [5, -15, 5]
zoom_factor = 32.
rotation_angle_horizontal = 1.3
rotation_angle_vertical = 0.9

viewpoints = [{"zoom_factor": 44.1, "rotation_angle_horizontal": 0.23, "rotation_angle_vertical": -0.27},
              {"zoom_factor": 34.2, "rotation_angle_horizontal": 2.54, "rotation_angle_vertical": -0.4},
              {"zoom_factor": 34.2, "rotation_angle_horizontal": -0.46, "rotation_angle_vertical": -3.0},
              {"zoom_factor": 72.7, "rotation_angle_horizontal": -3.12, "rotation_angle_vertical": -0.17}
             ]

viewpoint = viewpoints[0]
zoom_factor, rotation_angle_horizontal, rotation_angle_vertical = viewpoint.values()


SURFACE_LENGTH = 100
SURFACE_WIDTH = 100

MAX_VERTICAL_ANGLE = math.pi / 2 - 0.1  # Just below straight up
MIN_VERTICAL_ANGLE = -MAX_VERTICAL_ANGLE  # Just below straight down

window = pyglet.window.Window(width=WINDOW_WIDTH, height=WINDOW_HEIGHT, resizable=False)


sample_label = pyglet.text.Label(
    os.path.basename(FILE_PATH),
    font_name='Arial',
    font_size=20,
    x= window.width // 2,
    y= window.height - 30,
    anchor_x='center',
    anchor_y='center',
    color=(255, 255, 255, 255)
)

instructions_label = pyglet.text.Label(
    "Play/Pause: SPACE\nFrame Forward: RIGHT ARROW\nFrame Backward: LEFT ARROW\n"
    "Step Forward: SHIFT+RIGHT ARROW\nStep Backward: SHIFT+LEFT ARROW\n"
    "Record Video: R\nViewpoints: [0-3]\nClose: ESC",
    font_size=12,
    x=10,
    y=10,
    multiline=True,
    width=400,
    anchor_x="left",
    anchor_y="bottom",
)

def time_label_with_value(t):
    return text.Label(
        f"Rally Time: {t:.2f}",
        font_name='Arial',
        font_size=20,
        x=window.width - 10,
        y=10,
        multiline=True,
        width=300,
        anchor_x="right",
        anchor_y="bottom",
    )
    
def draw_surface():
    # Define the size of each square
    square_length = SURFACE_LENGTH / 8
    square_width = SURFACE_WIDTH / 8

    glColor3f(0.6, 0.6, 0.6)  # Light grey color for the grid
    glEnable(GL_BLEND)
    glBlendFunc(GL_SRC_ALPHA, GL_ONE_MINUS_SRC_ALPHA)
    
    glBegin(GL_LINES)
    
    # Grid spacing based on 1/8th of the court width
    grid_spacing = SURFACE_WIDTH / 4
    
    max_distance = SURFACE_LENGTH * 2
    fade_distance = SURFACE_LENGTH 
    
    # Draw vertical lines
    for x in range(-int(max_distance / grid_spacing), int(max_distance / grid_spacing) + 1):
        distance = abs(x * grid_spacing)
        alpha = max(0.0, 1.0 - (distance - fade_distance) / fade_distance)
        glColor4f(0.25, 0.25, 0.25, alpha)
        glVertex3f(x * grid_spacing, -max_distance / 1, 0)
        glVertex3f(x * grid_spacing, max_distance / 1, 0)
    
    # Draw horizontal lines
    for y in range(-int(max_distance / grid_spacing), int(max_distance / grid_spacing) + 1):
        distance = abs(y * grid_spacing)
        alpha = max(0.0, 1.0 - (distance - fade_distance) / fade_distance)
        glColor4f(0.25, 0.25, 0.25, alpha)
        glVertex3f(-max_distance / 1, y * grid_spacing, 0)
        glVertex3f(max_distance / 1, y * grid_spacing, 0)
    
    glEnd()
    glDisable(GL_BLEND)


class PoseRenderer:
    def __init__(self, pose):
        self.pose = pose
        self.origin = [SURFACE_WIDTH / 2, SURFACE_LENGTH / 2] 

    def draw_joint_center(self, x, y, z, radius=1.5, slices=16, stacks=16):

        x = x + global_axis_means[0]
        y = (y - global_axis_means[1]) + global_axis_min[1]
        z = z - global_axis_means[2]
        
        glPushMatrix()
        glTranslatef(x/20, -y/20, z/20)  # Flip y-axis here by negating y
        glColor3f(1.0, 1.0, 0.0)  # Yellow color for the spheres
        quadric = gluNewQuadric()
        gluSphere(quadric, radius, slices, stacks)
        gluDeleteQuadric(quadric)
        glPopMatrix()

    def draw_limb_length(self, from_tuple, to_tuple, line_width=2.0):
        from_x, from_y, from_z = from_tuple
        to_x, to_y, to_z = to_tuple

        from_x = from_x + global_axis_means[0]
        from_y = (from_y - global_axis_means[1]) + global_axis_min[1]
        from_z = from_z - global_axis_means[2]
        
        to_x = to_x + global_axis_means[0]
        to_y = (to_y - global_axis_means[1]) + global_axis_min[1]
        to_z = to_z - global_axis_means[2]

        glColor3f(0.5, 0.5, 0.35)  # Light grey color for the line
        glLineWidth(line_width)
        glBegin(GL_LINES)
        glVertex3f(from_x / 20, -from_y / 20, from_z / 20)
        glVertex3f(to_x / 20, -to_y / 20, to_z / 20)
        glEnd()
        glLineWidth(1.0)  # Reset line width to default

    
    def draw(self):

        if len(self.pose) > 0:
            keypoint_index = {name: idx for idx, name in enumerate(ASPSET_KEYPOINT_NAMES)}
            
            for pair in ASPSET_JOINT_PAIRS:
                from_keypoint, to_keypoint = pair
                from_index = keypoint_index[from_keypoint]
                to_index = keypoint_index[to_keypoint]               
                from_tuple = self.pose[0][from_index]
                to_tuple = self.pose[0][to_index]
                # Draw segment
                self.draw_limb_length(from_tuple, to_tuple)

            enable_lighting()
            for pos_group in self.pose:
                for pos in pos_group:
                    self.draw_joint_center(pos[0], pos[1], pos[2])
            disable_lighting()

def enable_lighting():
    glEnable(GL_LIGHTING)
    
    # Enable the first light source (GL_LIGHT0)
    glEnable(GL_LIGHT0)
    
    # Define properties for the first light source (GL_LIGHT0)
    light_diffuse0 = (1.0, 1.0, 1.0, 1.0)
    light_ambient0 = (0.1, 0.1, 0.1, 1.0)
    light_position0 = (0.0, 0.0, 1000.0, 1.0)
    
    glLightfv(GL_LIGHT0, GL_DIFFUSE, (GLfloat * 4)(*light_diffuse0))
    glLightfv(GL_LIGHT0, GL_AMBIENT, (GLfloat * 4)(*light_ambient0))
    glLightfv(GL_LIGHT0, GL_POSITION, (GLfloat * 4)(*light_position0))
    
    # Enable the second light source (GL_LIGHT1)
    glEnable(GL_LIGHT1)
    
    # Define properties for the second light source (GL_LIGHT1)
    light_diffuse1 = (0.5, 0.5, 0.5, 1.0)  # Example: dimmer light
    light_ambient1 = (0.05, 0.05, 0.05, 1.0)  # Example: very subtle ambient
    light_position1 = (-1000.0, -1000.0, 1000.0, 1.0)  # Example: different position
    
    glLightfv(GL_LIGHT1, GL_DIFFUSE, (GLfloat * 4)(*light_diffuse1))
    glLightfv(GL_LIGHT1, GL_AMBIENT, (GLfloat * 4)(*light_ambient1))
    glLightfv(GL_LIGHT1, GL_POSITION, (GLfloat * 4)(*light_position1))
    
    glEnable(GL_COLOR_MATERIAL)
    glColorMaterial(GL_FRONT, GL_AMBIENT_AND_DIFFUSE)
    

def disable_lighting():
    glDisable(GL_LIGHTING)
    glDisable(GL_LIGHT0)
    glDisable(GL_LIGHT1)
    glDisable(GL_COLOR_MATERIAL)
    

In [None]:
# Pre-baked view points..
keystroke_to_index = {
    pyglet.window.key._0: 0,
    pyglet.window.key._1: 1,
    pyglet.window.key._2: 2,
    pyglet.window.key._3: 3
}

def on_mouse_scroll(x, y, scroll_x, scroll_y):
    global zoom_factor
    zoom_factor += scroll_y * 1.1
    zoom_factor = max(0.1, zoom_factor)  # Limit zoom factor to avoid negative values
    
def on_mouse_drag(x, y, dx, dy, buttons, modifiers):
    global rotation_angle_horizontal, rotation_angle_vertical
    rotation_speed = 0.01
    rotation_angle_horizontal += dx * rotation_speed
    rotation_angle_vertical += dy * rotation_speed
        
def on_key_press(symbol, modifiers):
    global frame
    global zoom_factor, rotation_angle_horizontal, rotation_angle_vertical
    global AUTO_PLAY
    global SHOW_INFO
    global RECORD_MODE
    
    if symbol == pyglet.window.key.LEFT:
        if modifiers & pyglet.window.key.MOD_SHIFT:
            frame = max(0, frame - RENDER_FPS)
        else:
            frame = max(0, frame - 1)

    elif symbol == pyglet.window.key.RIGHT:
        if modifiers & pyglet.window.key.MOD_SHIFT:
            frame = min(len(pose_data), frame + RENDER_FPS)
        else:
            frame = min(len(pose_data), frame + 1)
            
    elif symbol == pyglet.window.key.SPACE:
        AUTO_PLAY = not AUTO_PLAY
        if AUTO_PLAY:
            if frame >= len(pose_data):
                frame = 0
            pyglet.clock.schedule_interval(update, RENDER_INTERVAL)
        else:
            pyglet.clock.unschedule(update)

    elif symbol == pyglet.window.key.I:
        SHOW_INFO = not SHOW_INFO
    
    elif symbol == pyglet.window.key.R:
        RECORD_MODE = not RECORD_MODE
        if RECORD_MODE:
            # Zero the clock, enable a file write and hide the instructions.
            SHOW_INFO = False
            AUTO_PLAY = False
            frame = 0
            pyglet.clock.schedule_once(update, RENDER_INTERVAL)
    
    elif symbol == pyglet.window.key.ESCAPE or symbol == pyglet.window.key.Q:
        window.close()

    elif symbol in keystroke_to_index:
        print(symbol)
        index = keystroke_to_index[symbol]
        viewpoint = viewpoints[index]
        zoom_factor, rotation_angle_horizontal, rotation_angle_vertical = viewpoint.values()

        
def update_camera():
    global zoom_factor
    glLoadIdentity()
    
    # Update the position of the camera
    x = camera_position[0] * zoom_factor * math.cos(rotation_angle_horizontal)
    y = camera_position[1] * zoom_factor * math.sin(rotation_angle_horizontal)
    z = max(camera_position[2] * zoom_factor * -rotation_angle_vertical, 30)    

    gluLookAt(x, y, z,  # Camera position        print("Updating viewpoint:", zoom_factor, rotation_angle_horizontal, rotation_angle_vertical)
              0, 0, 0,   # Look at the origin
              0, 0, 1)   # Up vector 

    # print(f"zoom_factor: {zoom_factor}, rotation_angle_horizontal: {rotation_angle_horizontal}, rotation_angle_vertical: {rotation_angle_vertical}")

@window.event
def on_draw():
    global zoom_factorpos_y
    global AUTO_PLAY
    global RECORD_MODE
    global pose_data

    glClearColor(0.2, 0.2, 0.2, 1)  # Dark grey background color
    glClear(GL_COLOR_BUFFER_BIT | GL_DEPTH_BUFFER_BIT)
    
    glMatrixMode(GL_PROJECTION)
    glLoadIdentity()
    gluPerspective(45, window.width / window.height, 10, 10000)
    glMatrixMode(GL_MODELVIEW)
    
    update_camera()

    # Draw the surface and labels
    draw_surface()

    # Rotate the modelview matrix to swap the x and z axes
    glRotatef(90, 1, 0, 0)  # Rotate 90 degrees around the x-axis
    
    #enable_lighting()

    # Draw the pose       
    pose = pose_data[frame]
    keypoints = pose['data']['data']
    
    if len(keypoints) >= 0:
        pose_model = PoseRenderer(keypoints)
        pose_model.draw()

    #disable_lighting()
    
    # Switch to 2D mode for on-screen info
    glMatrixMode(GL_PROJECTION)
    glLoadIdentity()
    gluOrtho2D(0, window.width, 0, window.height)
    glMatrixMode(GL_MODELVIEW)
    glLoadIdentity()

    sample_label.draw()
    
    time_label = time_label_with_value(frame/RENDER_FPS)
    time_label.draw()

    if SHOW_INFO:
        instructions_label.draw()
                
    if RECORD_MODE:       
        # Convert the window buffer to a numpy array
        buffer = pyglet.image.get_buffer_manager().get_color_buffer()
        image_data = buffer.get_image_data()
        image_as_np = np.frombuffer(image_data.get_data(), dtype=np.uint8).reshape(WINDOW_HEIGHT, WINDOW_WIDTH, 4)
    
        # Convert RGBA to BGR (OpenCV uses BGR)
        frame_bgr = cv2.flip(cv2.cvtColor(image_as_np, cv2.COLOR_RGBA2BGR), 0)
    
        # Write frame to video
        VIDEO_WRITER.write(frame_bgr)

        pyglet.clock.schedule_once(update, RENDER_INTERVAL)

        if frame >= len(pose_data):
            # Wrap it up..
            VIDEO_WRITER.release()
            RECORD_MODE = False
        
def update(dt):
    global frame

    print(len(pose_data))
    if frame <= len(pose_data)-2:
        frame += 1
    else:
        frame = 0
        # pyglet.clock.unschedule(update)
        


In [None]:
# Event handlers..
window.on_mouse_scroll = on_mouse_scroll
window.on_mouse_drag = on_mouse_drag
window.on_draw = on_draw
window.on_key_press = on_key_press

AUTO_PLAY = False

if AUTO_PLAY:
    pyglet.clock.schedule_interval(update, RENDER_INTERVAL)
    
pyglet.app.run()

VIDEO_WRITER.release()
