In [None]:
from esper.table_tennis.utils import *
from esper.table_tennis.pose_utils import * 

import cv2
import random
import pickle
import pycocotools.mask as mask_util
from scipy import ndimage
import numpy as np
from scipy.signal import savgol_filter

In [None]:
sc = Client()
# video_id = 65
# video = Video.objects.filter(id=video_id)[0]
# video_ids = [video_id]
# video.item_name()

In [None]:
video.item_name()

# load data

In [None]:
match_intervals_all = pickle.load(open('/app/data/pkl/match_scene_intervals_dict.pkl', 'rb'))
match_intervals_A = match_intervals_all['HW_foreground']
match_intervals_B = match_intervals_all['JZ_foreground']
match_intervals_A.sort()
match_intervals_B.sort()

In [None]:
match_ism_A = list_to_IntervalSetMapping(match_intervals_A)
match_ism_B = list_to_IntervalSetMapping(match_intervals_B)

In [None]:
count_duration(match_ism_A) + count_duration(match_ism_B)

In [None]:
# collect all openpose for foreground players
fid2openpose_A = {}
fid2openpose_B = {}
for interval in match_intervals_A:
    fid2pose_fg, fid2pose_bg = group_pose_from_interval(interval)
    fid2openpose_A = {**fid2openpose_A, **fid2pose_fg}
for interval in match_intervals_B:
    fid2pose_fg, fid2pose_bg = group_pose_from_interval(interval)
    fid2openpose_B = {**fid2openpose_B, **fid2pose_fg}

In [None]:
match_scene_cls = pickle.load(open('/app/data/pkl/match_scene_cls.pkl', 'rb'))
densepose_result = pickle.load(open('/app/data/pkl/densepose_result.pkl', 'rb'), encoding='latin1')
maskrcnn_result = pickle.load(open('/app/data/pkl/maskrcnn_result.pkl', 'rb'))

In [None]:
fid2densepose = {}
for player in densepose_result:
    fid2densepose[player['fid']] = player
fid2maskrcnn = {}
for fid, bbox in enumerate(maskrcnn_result):
    fid2maskrcnn[fid] = bbox

# Find clean sport field background

In [None]:
for fid in range(len(match_scene_cls)):
    if match_scene_cls[fid]:
        poses = Pose.objects.filter(frame__video_id=65, frame__number=fid)
        if len(poses) == 0:
            print(fid)

In [None]:
# clean background in 65: 39050
background = load_frame(video, 39050, [])
# imshow(background)
cv2.imwrite('/app/tmp/background.jpg', background)

# Hand annotate ball hits

In [None]:
%matplotlib notebook

In [None]:
match_intervals_B

In [None]:
fid -= 5

In [None]:
# fid = 37287
frame = load_frame(video, fid, [])
print(fid)
fid += 1
imshow(frame)

In [None]:
hit_annotation = [
    [(35583, 806, 398, 0), (35599, 1067, 587, 1),  (35621, 937, 418, 0), (35638, 863, 537, 1), (35644, 1306, 369, 0)],
    [(36506, 802, 537, 1), (36526, 745, 469, 0), (36543, 1198, 612, 1), (36551, 946, 366, 0), 
    (36563, 854, 635, 1), (36577, 1077, 274, 0), (36590, 779, 790, 1), (36611, 464, 285, 0)],
    [(37839, 1214, 385, 0), (37858, 1108, 567, 1), (37866, 1148, 415, 0), (37878, 920, 603, 1)],
    [(37304, 740, 612, 1), (37322, 1058, 447, 0), (37332, 857, 600, 1), (37339, 1110, 408, 0), 
    (37351, 877, 665, 1), (37363, 1157, 357, 0), (37372, 726, 619, 1)]
]
def convert_hit_annotation(annot):
    annotation = []
    for traj in annot:
        annotation += [[{'fid':fid, 'hit':(x, y), 'fg':fg} for fid, x, y, fg in traj]]
    return annotation
hit_annotation = convert_hit_annotation(hit_annotation)

# draw ball trajectories

In [None]:
videowriter = cv2.VideoWriter('/app/result/visualize_trajectory.avi', cv2.VideoWriter_fourcc('M','J','P','G'), 8, (video.width, video.height))
for hit_traj in hit_annotation:
    print(hit_traj)
    ball_traj = interpolate_trajectory_from_hit(hit_traj)
    for ball in ball_traj:
#         frame = background.copy()
        frame = load_frame(video, ball['fid'], [])
        cv2.circle(frame, ball['pt'], 8, (0, 0, 255), -1)
        videowriter.write(frame)
videowriter.release()

# Generate motion from simple left/right control

## Label clip for left/right control demo
According to the x position, segment the clip into three types of the motion: moving left, still, moving right

In [None]:
motion_dict = {'left': [], 'right': [], 'still': []}
for interval in match_intervals_A:
    motion_dict_i = group_motion(interval, fid2openpose_A)
    motion_dict['left'] += motion_dict_i['left']
    motion_dict['right'] += motion_dict_i['right']
    motion_dict['still'] += motion_dict_i['still']
print(len(motion_dict['left']), len(motion_dict['still']), len(motion_dict['right']))

In [None]:
# visualize labeled clips
plt.figure(figsize=(10, 10))

for motion in motion_dict['left']:
    plt.scatter(motion['start_x'], motion['end_x'], c='r', s=motion['duration']*50)
for motion in motion_dict['right']:
    plt.scatter(motion['start_x'], motion['end_x'], c='b', s=motion['duration']*50)
for motion in motion_dict['still']:
    plt.scatter(motion['start_x'], motion['end_x'], c='g', s=motion['duration']*50)
plt.xlim([0, 1])
plt.ylim([0, 1])
plt.xlabel('Foreground player start position X', fontsize=22)
plt.ylabel('Foreground player end position X', fontsize=22)
plt.show()

## Generate motion match left/right control

In [None]:
motion_1 = find_motion(motion_dict, 0.35, 0.65, 1, 'right', (1,1,0))
motion_2 = find_motion(motion_dict, 0.65, 0.35, 1, 'left', (1,1,0))
motion_3 = find_motion(motion_dict, 0.35, 0.35, 3, 'still', (1,1,1))

print(motion_1, motion_2, motion_3)
def motion2interval(motion):
    return (video_id, motion['start_fid'], motion['end_fid'], motion['duration'])
searched_intervals = [motion2interval(motion_1), motion2interval(motion_2), motion2interval(motion_3)]
print(searched_intervals)

In [None]:
from esper.supercut import stitch_video_temporal
stitch_video_temporal(searched_intervals, out_path='/app/result/naive_control.mp4')

# Generate motion without hit label
Use hand annotated ball trajectory, search for any interval from the motion database

In [None]:
H, W = video.height, video.width
motion_dict_A = []
for _, sfid, efid, _ in match_intervals_A:
    motion_traj = {}
    for fid in range(sfid, efid):
        if fid in fid2openpose_A:
            pose = fid2openpose_A[fid]
            kp = pose._format_keypoints()
            motion_traj[fid] = {'Neck': (int(kp[Pose.Neck][0]*W), int(kp[Pose.Neck][1]*H)), 
                                'RWrist': (int(kp[Pose.RWrist][0]*W), int(kp[Pose.RWrist][1]*H)),
                               'LAnkle': (int(kp[Pose.LAnkle][0]*W), int(kp[Pose.LAnkle][1]*H)),
                               'RAnkle': (int(kp[Pose.RAnkle][0]*W), int(kp[Pose.RAnkle][1]*H))}
    motion_dict_A.append(motion_traj)

In [None]:
hit_traj = hit_annotation[3]

In [None]:
generate_motion_without_hitlabel(sc, video, fid2densepose, motion_dict_A, hit_traj, 
                                 out_path='/app/result/naive_control.avi')

# Generate motion with hit label

In [None]:
# prepare motion database
hit_annotation = pickle.load(open('/app/data/pkl/hit_annotation.pkl', 'rb'))
hit_dict = hit_annotation['Tabletennis_2012_Olympics_men_single_final_gold']['JZ']
motion_dict = {} 
for k, v in hit_annotation.items():
    motion_dict[k] = v['HW']

## Plot the distribution of the labeled hit

In [None]:
intervals_A = []
intervals_B = []
for point in hit_dict_A:
    for idx in range(1, len(point)):
        interval = (point[idx]['fid'], (point[idx]['fid'] - point[idx-1]['fid']) / video.fps)
        if point[idx-1]['fg']:
            intervals_A.append(interval)
        else:
            intervals_B.append(interval)
for point in hit_dict_B:
    for idx in range(1, len(point)):
        interval = (point[idx]['fid'], (point[idx]['fid'] - point[idx-1]['fid']) / video.fps)
        if point[idx-1]['fg']:
            intervals_B.append(interval)
        else:
            time_intervals_A.append(interval)
intervals_A.sort()
intervals_B.sort()

In [None]:
# plot 
plt.bar(np.arange(len(intervals_A)), [t[1] for t in intervals_A], label='HW')
plt.bar(np.arange(len(intervals_B)), [t[1] for t in intervals_B], label='JZ')
plt.ylabel('Time between hits(s)')
plt.title('Length of time intervals between hits')
plt.legend()
plt.show()

In [None]:
H, W = video.height, video.width
hit_location_A = []
hit_location_B = []
for point in hit_dict_A:
    for hit in point:
        if hit['pos'] is None:
            continue
        if hit['fg']:
            hit_location_A.append(1. * hit['pos'][0] / W)
        else:
            hit_location_B.append(1 - 1. * hit['pos'][0] / W)
for point in hit_dict_B:
    for hit in point:
        if hit['pos'] is None:
            continue
        if hit['fg']:
            hit_location_B.append(1. * hit['pos'][0] / W)
        else:
            hit_location_A.append(1 - 1. * hit['pos'][0] / W)
hit_location_A.sort()
hit_location_B.sort()

In [None]:
# plot 
plt.bar(np.arange(len(hit_location_A)), hit_location_A, width=0.3, label='HW')
plt.bar(np.arange(len(hit_location_B))+0.5, hit_location_B, width=0.3, label='JZ')
plt.ylabel('X postion')
plt.title('Distribution of X position of two players')
plt.legend()
plt.plot([0, len(hit_location_A)], [0.5, 0.5], 'k')
plt.show()

## Generate motion with triangle query offline

In [None]:
from esper.table_tennis.motion_control_offline import *
# select ball trajectory
hit_traj = hit_dict[0]
hit_traj

In [None]:
# generate motion for a single point
# generate_motion_local(sc, video, motion_dict, hit_traj, 
#                                out_path='/app/result/motion_generation/local_label_JZ_0.avi')
# generate_motion_global(sc, video, motion_dict, hit_traj, 
#                                out_path='/app/result/motion_generation/greedy_label_JZ_0.avi')
query2result = generate_motion_dijkstra(sc, motion_dict, hit_traj, 
                         out_path='/app/result/motion_generation/dijkstra_full_stick_0.avi',
                         interpolation=False, draw_stick=True)

In [None]:
render_motion(sc, video, query2result, 
              out_path='/app/result/motion_generation/global_dijkstra_interpolation_label_JZ_0.avi',
              interpolation=True)

In [None]:
# generate motion for a list of points
interval_path = '/app/result/interval.avi'
clips_list = []
for idx, hit_traj in enumerate(hit_traj_dict):
    if len(hit_traj) >= 6:
        tmp_path = '/app/result/motion_generation/dijkstra_full_stick_{}.avi'.format(idx)

        generate_motion_dijkstra(sc, motion_dict, hit_traj,
                                 out_path=tmp_path, interpolation=False, draw_stick=True)

#         render_motion(sc, video, query2result, 
#               out_path=dijkstra_path,
#               interpolation=)

        clips_list.append(tmp_path)
        clips_list.append(interval_path)
        print(idx)

In [None]:
concat_videos_simple(clips_list, '/app/result/motion_generation/dijkstra_full_stick_all.mp4')

In [None]:
hit_annotation_new = {}
for k, v in hit_annotation.items():
    newk = k.split('/')[-1].split('.')[0]
    print(newk)
    hit_annotation_new[newk] = v

In [None]:
pickle.dump(hit_annotation_new, open('/app/data/pkl/hit_annotation.pkl', 'wb'))

## Generate motion online

In [None]:
from esper.table_tennis.motion_control_online import *

In [None]:
# collect ball trajectory as a long point
# hit_traj = []
# for point in hit_dict:
#     for idx, hit in enumerate(point):
#         if idx > 0 and idx+1 < len(point)-1 and not hit['fg'] :
#             next_hit = point[idx+1]
#             if not next_hit['pos'] is None:
#                 hit_traj += [{'pos': next_hit['pos'], 'nframes': next_hit['fid'] - hit['fid']}]
# random.shuffle(hit_traj)
# len(hit_traj)
hit_candidates = pickle.load(open('/app/data/pkl/hit_traj.pkl', 'rb'))

In [None]:
pickle.dump(hit_traj, open('/app/data/pkl/hit_traj.pkl', 'wb'))

In [None]:
generate_motion_online(sc, motion_dict, hit_candidates, 
                       out_path='/app/result/motion_generation/online_test_30.avi',
                       interpolation=False, draw_stick=False, num_hits=5)