In [1]:
import pickle
import random
import numpy as np
import pandas as pd
import os
import os.path as osp
import plotly.express as px
import plotly.graph_objects as go

In [42]:
# X_path = './NTU/X.pkl'
# X_path = "C:\\Users\\tcarr23\\Downloads\\X_resnet_file.pkl"
# x_path = "C:\\Users\\tcarr23\\Downloads\\X_unet_file.pkl"
x_path = "C:\\Users\\tcarr23\\Downloads\\X_SGN_FileNameKey.pkl"
with open(X_path, 'rb') as f:
    X = pickle.load(f)

In [29]:
def anonymizer_to_sgn(t, max_frames=300):
    xyz, frames, joints, actors = t.shape
    
    # Pre-allocate memory for the output array
    X = np.zeros((max_frames, xyz * joints * actors), dtype=np.float32)
    
    # Reshape the input array for easier manipulation
    t_reshaped = t.reshape((frames, -1))
    
    # Copy over the reshaped data to the pre-allocated output
    X[:frames, :t_reshaped.shape[1]] = t_reshaped
    
    return X

In [43]:
bad_files = []
for key in X:
    if type(X[key]) == list:
        if len(X[key]) == 0:
            bad_files.append(key)
        elif len(X[key]) == 1:
            X[key] = np.array(X[key][0])
            if X[key].shape == (3, 300, 25, 2):
                # Anonymization for Skeleton Action Recognition
                X[key] = anonymizer_to_sgn(X[key])[:, :75]
        else:
            print(len(X[key]))
        continue
    if X[key][0].shape == (25, 7):
        X[key] = X[key][:, :, :3]
        X[key] = X[key].reshape(X[key].shape[0], -1)
    else:
        print(X[key].shape)

In [44]:
d = X[random.choice(list(X.keys()))]

In [45]:
def render_frame(d):
    reshaped_data = d.reshape(-1, 3)
    x = reshaped_data[:, 0]
    y = reshaped_data[:, 1]
    z = reshaped_data[:, 2]

    df = pd.DataFrame({'x': x, 'y': y, 'z': z})

    fig = px.scatter_3d(df, x='x', y='y', z='z', color=np.linspace(1, 25, len(x)),
                        color_continuous_scale='Rainbow', title='Interactive 3D Scatter Plot')

    fig.update_traces(marker=dict(size=2))

    cons = [[0, 1], [1, 20], [20, 2], [2, 3], [20, 8], [8, 9], [9, 10], [10, 11], [11, 23], [11, 24], [20, 4], [4, 5], [5, 6], [6, 7], [7, 21], [7, 22], [0, 16], [16, 17], [17, 18], [18, 19], [0, 12], [12, 13], [13, 14], [14, 15]]

    for con in cons:
        lx = [x[con[0]], x[con[1]]]
        ly = [y[con[0]], y[con[1]]]
        lz = [z[con[0]], z[con[1]]]
        fig.add_trace(go.Scatter3d(x=lx, y=ly, z=lz, mode='lines', line=dict(color='black', width=2)))

    fig.show()

render_frame(d[random.choice(range(d.shape[0]))])

In [46]:
def render_video(d):
    cons = [[0, 1], [1, 20], [20, 2], [2, 3], [20, 8], [8, 9], [9, 10], [10, 11], [11, 23], [11, 24], [20, 4], [4, 5], [5, 6], [6, 7], [7, 21], [7, 22], [0, 16], [16, 17], [17, 18], [18, 19], [0, 12], [12, 13], [13, 14], [14, 15]]

    frame_data = d[0].reshape(-1, 3)
    x = frame_data[:, 0]
    y = frame_data[:, 1]
    z = frame_data[:, 2]

    scatter = go.Scatter3d(x=x, y=y, z=z, mode='markers',
                        marker=dict(size=2, color=np.linspace(1, 25, 25), colorscale='Rainbow'))

    traces = [scatter]

    for con in cons:
        lx = [x[con[0]], x[con[1]]]
        ly = [y[con[0]], y[con[1]]]
        lz = [z[con[0]], z[con[1]]]
        line_trace = go.Scatter3d(x=lx, y=ly, z=lz, mode='lines', line=dict(color='black', width=2))
        traces.append(line_trace)

    layout = go.Layout(updatemenus=[dict(type='buttons', showactive=False,
                                        buttons=[dict(label='Play',
                                                    method='animate',
                                                    args=[None, dict(frame=dict(duration=100, redraw=True), fromcurrent=True)])])],
                    sliders=[dict(steps=[])],
                    title="Animated 3D Scatter Plot with Connections"
            )

    fig = go.Figure(data=traces, layout=layout)

    frame_list = []

    for i in range(d.shape[0]):
        frame_data = d[i].reshape(-1, 3)
        x, y, z = frame_data[:, 0], frame_data[:, 1], frame_data[:, 2]
    
        fig.data[0].x = x
        fig.data[0].y = y
        fig.data[0].z = z
        
        center_x, center_y, center_z = x[0], y[0], z[0]
        axis_bound = 2

        fig.update_layout(scene=dict(
            xaxis=dict(range=[center_x - axis_bound, center_x + axis_bound]),
            yaxis=dict(range=[center_y - axis_bound, center_y + axis_bound]),
            zaxis=dict(range=[center_z - axis_bound, center_z + axis_bound])
        ))

        frame_traces = []

        frame_scatter = go.Scatter3d(x=x, y=y, z=z, mode='markers',
                                    marker=dict(size=2, color=np.linspace(1, 25, 25), colorscale='Rainbow'))
        frame_traces.append(frame_scatter)

        for con in cons:
            lx = [x[con[0]], x[con[1]]]
            ly = [y[con[0]], y[con[1]]]
            lz = [z[con[0]], z[con[1]]]
            line_trace = go.Scatter3d(x=lx, y=ly, z=lz, mode='lines', line=dict(color='black', width=2))
            frame_traces.append(line_trace)

        frame = go.Frame(data=frame_traces, name=f'Frame {i}')
        frame_list.append(frame)

    fig.frames = frame_list

    fig.show()

render_video(d)