In [None]:
import os
import sys
import glob  # to pick random image from test
module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)

In [None]:
import numpy as np
import pandas as pd
from PIL import Image
import matplotlib.pyplot as plt
plt.rcParams['figure.figsize'] = [12, 8]

from data.data_loading import TRANSFORM, create_angle_features
from pose.pose_utils import TESTPATH, CLASS_MAPPINGS_NAMES, LANDMARKS_ANGLES_DICT, LANDMARK_DICT, calc_limb_lengths
from pose.plot import plot_image, plot_3d_keypoints, plot_distribution_with_image
from pose.pose_mediapipe import pose_landmarks_to_list, estimate_poses
from classifier.classify import classify_image
from pose.nearest_neighbor_correction import get_angle_confidence_intervals, create_pose_df, warrior_pose_front_back, \
get_annotated_img, compare_two_figures, select_correct_closest_image, which_leg_front, plot_3D, normalize_on_right_hip
from pose.decision_tree import decision_tree_correct
from gan.results_cLimbGAN import generate_coords_given_limb_lengths
from classifier.classify_pose_quality import classify_correct

## Read in Image

Note that these images are all unseen by the system in training (ie test images). They can be images that are correct or incorrect.

You can put in your own photos in the `data/test/` folder and it will get randomly picked by the code snippet below. If you wish, you can also manually set the path to your image.

In [None]:
# Selects a random image from the test folder
path = str(TESTPATH)
types = ('*.jpg','*.jpeg')
tests_images = []
for files in types:
    tests_images.extend(glob.glob(os.path.join(TESTPATH,'**',files),recursive=True))
CPATH = np.random.choice(tests_images)
test_image = TRANSFORM(Image.open(CPATH))
plot_image(test_image, dataloader=True)

# Prints human labeled characteristics of the pose
secret_truth = CPATH.split('/')[-2].split('_')
print(f"Ground Truth: \nThis is a {secret_truth[1]} pose")
if secret_truth[0] == '0':
    print('This pose is an INCORRECT pose')
else:
    print('This pose is a Correct pose')

### Apply Pose Estimation

In [None]:
test_result, annotated_test_image = estimate_poses(test_image, CPATH, skip_image_annotation=False)
plot_image(annotated_test_image, dataloader=False)

In [None]:
val, _, nump = pose_landmarks_to_list(test_result, 'pose_world_landmarks')

df_test = pd.DataFrame.from_records([val]).rename(LANDMARK_DICT, axis=1)
create_angle_features(df_test)

np_test = normalize_on_right_hip(np.array([nump]))

x = np_test[0].T[0]
y = np_test[0].T[1]
z = np_test[0].T[2]

plot_3d_keypoints(x, y, z, -70, 270)

### Pose Classification

In [None]:
label = classify_image(np_test).item()
print(f"Image Classified as : {CLASS_MAPPINGS_NAMES[label]}")
correct = classify_correct(np_test).item()
if correct < .5:
    print(f"Image is a Bad Pose")
else:
    print("Image is a Good Pose")
    
df_w1, df_w2, df_dd = create_pose_df()
df_test_handedness = df_test.copy()

if label == 0:
    df = df_dd
    LANDMARKS = LANDMARKS_ANGLES_DICT.keys()
elif label == 1:
    df, LANDMARKS = warrior_pose_front_back(df_w1)
    df_test = which_leg_front(df_test)
elif label == 2:
    df, LANDMARKS = warrior_pose_front_back(df_w2)
    df_test = which_leg_front(df_test)

# Pose Correction

### The Learned Angle Distribution Correction

In [None]:
distribution_angles = get_angle_confidence_intervals(df, LANDMARKS, percent = .15)
plot_distribution_with_image(df, df_test, distribution_angles, LANDMARKS)

### The Nearest Neighbour Correction

The image on the left is the input image, the image on the right is the closest correct nearest neighbor.

Red is the original image, green is the output correction.

In [None]:
ground_truth, ground_truth_indx = select_correct_closest_image(np_test, df)
ground_truth_img = get_annotated_img(ground_truth_indx)

compare_two_figures(np.squeeze(np_test), np.squeeze(ground_truth), annotated_test_image, ground_truth_img, plot=True)

### The Generative GAN Correction

In [None]:
generate_coords_given_limb_lengths(calc_limb_lengths(np.squeeze(np_test, axis=0)), label, version=780, plot=True)