In [3]:
import json
import os
import csv
import matplotlib.pyplot as plt
import numpy as np
import copy
import pandas as pd

In [4]:
keypoint_json_list = []

for json_ in os.listdir(os.getcwd()):
    if json_.endswith('.json'):
        keypoint_json_list.append(json_)

In [5]:
keypoint_json_list

['Laura_0_00_results.json']

In [6]:
#Creates folders to store frame by frame detections per view

for name in keypoint_json_list:
    folder = name.split('.json')[0]
    if not os.path.exists(folder):
        os.makedirs(folder)

In [6]:
def calc_keypoint_overlap(key_a, key_b):
    #If value close to 0, then keypoints belong to the same person
    assert len(key_a) == len(key_b)
    mean = 0
    for k_a, k_b in zip(key_a, key_b):
        mean += abs(k_a - k_b)
    return mean/len(key_a)


In [7]:
def plot_detections(det_list):
    for det in det_list:
        plt.plot(det)
    plt.show()

def write_csv(keys, det_list):
    filename = 'detections.csv'

    df = pd.DataFrame()
    for i in range(len(det_list)):
        df[keys[i]] = det_list[i]
    print(df.to_string())
    df.to_csv(filename, index=False)


def compute_detections(keys):
    '''
     keypoint_list: List that will hold all the keypoint detections per each view  [view][frame][keypoints]
     ensure that:   -keypoint_list[view][frame][keypoints][0] corresponds to exclusively one of the agents
                    -keypoint_list[view][frame][keypoints][1] corresponds to exclusively one of the agents
                    -view indexation consistency^^
    
    '''
    keypoint_list = [] 

    for key in keys:
        detect_keypoint_list = [] #[frame][keypoints]
        frame_count = 0
        
        with open(key) as f:
            k = json.load(f)
        prev = k[0]['image_id'].split('.png')[0]
        last = k[-1]['image_id'].split('.png')[0]
        frame = []
        for id, det in enumerate(k):
            curr = det['image_id'].split('.png')[0]
            if curr==prev:
                frame.append(det)
                if curr==last:
                    if id==len(k)-1: #last element
                        detect_keypoint_list.append(frame.copy())
                        frame.clear()
                        frame_count += 1
            elif curr!=prev:
                detect_keypoint_list.append(frame.copy())
                frame.clear()
                frame_count += 1
                frame.append(det)
            prev=curr
        keypoint_list.append(detect_keypoint_list)

    return keypoint_list

        

In [9]:
def get_erronous_det(det_per_view, missing):
    '''
    det_per_view: list of lists that has the detection count for each frame
    missing: Boolean that chooses if we count the frames where detections are missing (TRUE), or the frames where there are excessive detecions (FALSE)
    '''
    miss_det = []
    for view in det_per_view:
        miss_det_view = []
        for idx, n_det in enumerate(view):
            if missing:
                if len(n_det) < 2:
                    miss_det_view.append(idx)
            else:
                if len(n_det) > 2:
                    miss_det_view.append(idx)
        miss_det.append(miss_det_view)
    return miss_det

In [10]:
def get_missing_detection_ranges(miss_det):
    miss_det_range = []

    for view in range(len(miss_det)):
        view_ranges = []
        interpolation_range = []
        for frame in miss_det[view]:
            if not interpolation_range:
                interpolation_range.append(frame)
                if frame == miss_det[view][-1]:
                    view_ranges.append(interpolation_range.copy())
            elif interpolation_range[-1] + 1 == frame:
                interpolation_range.append(frame)
            elif interpolation_range[-1] + 1 != frame:
                view_ranges.append(interpolation_range.copy())
                interpolation_range.clear()
                interpolation_range.append(frame)
                if frame == miss_det[view][-1]:
                    view_ranges.append(interpolation_range.copy())
        miss_det_range.append(view_ranges.copy())
        

    return miss_det_range

In [163]:
all_detections = compute_detections(keypoint_json_list)
missing_detections = get_erronous_det(all_detections, True)
excess_detections = get_erronous_det(all_detections, False)
missing_detection_ranges = get_missing_detection_ranges(missing_detections)

In [164]:
def identify_keypoint_index(current_det, prev_det):
    '''
    Mehtod that identifies the keypoints similarity in current frame, based on keypoints of previous/next frame
    0: current_det is more similar to the detection prev_det[0]
    1: current_det is more similar to the detection prev_det[1]
    '''
    assert len(current_det) == 1
    assert len(prev_det) == 2

    overlap_one = calc_keypoint_overlap(current_det[0]['keypoints'], prev_det[0]['keypoints'])
    overlap_two = calc_keypoint_overlap(current_det[0]['keypoints'], prev_det[1]['keypoints'])
    #print('Overlap with last index 0, ', overlap_one)
    #print('Overlap with last index 1, ', overlap_two)
    if overlap_one < overlap_two:
        return 0
    else:
        return 1

In [165]:
a = [1, 2]
b = [3, 1]

print(set(a) & set(b))

{1}


In [166]:
def correct_excess_det(ex_det, key_view_list):
    corrected_list = copy.deepcopy(key_view_list)

    assert len(ex_det) == len(corrected_list)
    

    for view, frame_list in enumerate(ex_det):
        for frame in frame_list:
            print('--------')
            print('Frame: ', frame)
            print('View: ', view)

            remove_by_worst_score = False
            match_idx = True
            
            if len(corrected_list[view][frame-1]) == 2:
                #Remove detection with least overlap
                list_cand_overlap = []
                idx_prev_det = []
                for prev_det in corrected_list[view][frame-1]:
                    #For loop gets the overlap between prev. detections (2) and current detections (>2)
                    idx_cand = []
                    overlap_cand = []
                    print('Prev idx: ', prev_det['idx'])
                    idx_prev_det.append(prev_det['idx'])
                    for candidate in corrected_list[view][frame]:
                        overlap = calc_keypoint_overlap(prev_det['keypoints'], candidate['keypoints'])
                        overlap_cand.append(overlap)
                        idx_cand.append(candidate['idx'])
                    print('Candidate idx ', idx_cand)
                    list_cand_overlap.append(overlap_cand.copy())

                assert len(list_cand_overlap) == 2

                curr_best_cand_idx = []
                for cand_result_list in list_cand_overlap:
                    #Gets the index of the current detection that best overlaps the prev. frame detection
                    print(cand_result_list)
                    curr_best_cand_idx.append(np.argmin(cand_result_list))

                print(curr_best_cand_idx)
                print(corrected_list[view][frame][curr_best_cand_idx[0]]['idx'], corrected_list[view][frame][curr_best_cand_idx[1]]['idx'])

                if len(curr_best_cand_idx) == len(set(curr_best_cand_idx)):
                    #If candidate detection is not the same
                    if list_cand_overlap[0][curr_best_cand_idx[0]] < 15 and list_cand_overlap[1][curr_best_cand_idx[1]] < 15:
                        corrected_list[view][frame].clear()
                        corrected_list[view][frame].append(key_view_list[view][frame][curr_best_cand_idx[0]])
                        corrected_list[view][frame].append(key_view_list[view][frame][curr_best_cand_idx[1]])
                        match_idx = False
                if match_idx:
                    candidate_idx = []
                    for prev_det in key_view_list[view][frame-1]:
                        for candidate in key_view_list[view][frame]:
                            if prev_det['idx'] == candidate['idx']:
                                candidate_idx.append(candidate)
                    if len(candidate_idx) == 2:
                        corrected_list[view][frame].clear()
                        corrected_list[view][frame] = copy.deepcopy(candidate_idx)
                    else:
                        remove_by_worst_score = True
            else:
                print('Prev frame does not have 2 detections')
                remove_by_worst_score = True
                
            if remove_by_worst_score: 
                print(f'Removing by score in frame')
                #Remove detection with least score
                first_conf = 0
                first_conf_det = None
                second_conf = 0
                second_conf_det = None
                for d in corrected_list[view][frame]:
                    print(d)
                    if d['score'] > first_conf:
                        if second_conf == 0:
                            second_conf = first_conf
                            second_conf_det = first_conf_det
                        first_conf = d['score']
                        first_conf_det = d
                    elif d['score'] > second_conf and d['score'] < first_conf:
                        second_conf = d['score']
                        second_conf_det = d
                corrected_list[view][frame].clear()
                print(first_conf_det['idx'], second_conf_det['idx'])
                corrected_list[view][frame].append(first_conf_det)
                corrected_list[view][frame].append(second_conf_det)

        assert len(corrected_list[view][frame]) == 2
    return corrected_list

In [167]:
corrected_key_view_list = correct_excess_det(excess_detections, all_detections)

--------
Frame:  264
View:  0
Prev idx:  1
Candidate idx  [1, 3, 2]
Prev idx:  3
Candidate idx  [1, 3, 2]
[0.7586024961410425, 48.72624630404589, 48.00595924945978]
[49.90090217207296, 2.582310668503245, 6.3372292428349075]
[0, 1]
1 3
--------
Frame:  530
View:  0
Prev idx:  1
Candidate idx  [1, 2, 4]
Prev idx:  2
Candidate idx  [1, 2, 4]
[4.896273243121612, 49.93281988684948, 54.855284438301354]
[54.23177910844485, 7.55885659578519, 1.6475703895856173]
[0, 2]
1 4
--------
Frame:  531
View:  0
Prev idx:  1
Candidate idx  [1, 2, 4]
Prev idx:  4
Candidate idx  [1, 2, 4]
[0.0, 50.88264266114969, 55.04739683981125]
[55.04739683981125, 7.5081647004072485, 0.0]
[0, 2]
1 4
--------
Frame:  535
View:  0
Prev idx:  1
Candidate idx  [1, 4, 2]
Prev idx:  4
Candidate idx  [1, 4, 2]
[1.5734831549418278, 56.18803280095259, 52.18847855896904]
[55.27051319640417, 1.0053798178067574, 5.973515294014644]
[0, 1]
1 4
--------
Frame:  536
View:  0
Prev idx:  1
Candidate idx  [1, 4, 2]
Prev idx:  4
Candidate

In [168]:
def view_index_consistency(unstruct_det, view):
    '''
    Method that indexes the detections according to ReID and last frame

    unstruct_det: list of all the detections in a single frame PER VIEW
    '''
    indexed_det_list = []

    for i, detection in enumerate(unstruct_det): #For each frame, get detections
        idx_frame_keypoint_list = [None, None] #[agent_1, agent_2]
        assert len(detection) != 0, f'Cannot proceed! No people were detected in frame {i}, view {view}'


        if len(detection) == 1:
            if not indexed_det_list:
                print('First frame only has one detection!')
            else:
                if  None not in indexed_det_list[-1]:
                    match_id = identify_keypoint_index(detection, indexed_det_list[-1])
                    idx_frame_keypoint_list[match_id] = detection[0]
                else:
                    #Find similarity
                    none_idx = indexed_det_list[-1].index(None)
                    single_idx = 1 if none_idx == 0 else 0
                    if indexed_det_list[-1][single_idx]['idx'] == detection[0]['idx']:
                        idx_frame_keypoint_list[single_idx] = detection[0]
                    elif calc_keypoint_overlap(idx_frame_keypoint_list[single_idx]['keypoints'], detection['keypoints']) < 10:
                        idx_frame_keypoint_list[single_idx] = detection[0]
                    else:
                        idx_frame_keypoint_list[none_idx] = detection[0]


        elif len(detection) == 2:
            idx_a = detection[0]['idx']
            idx_b = detection[1]['idx']
            if not indexed_det_list:
                idx_frame_keypoint_list[0] = detection[0]
                idx_frame_keypoint_list[1] = detection[1]
            else:
                if None not in indexed_det_list[-1]:
                    idx_last_a = indexed_det_list[-1][0]['idx']
                    idx_last_b = indexed_det_list[-1][1]['idx']
                    unidentified_detections = [] #stores the detections that changed their idx score, compared to the last frame
                    for d in detection:
                        if d['idx'] == idx_last_a:
                            idx_frame_keypoint_list[0] = d
                        elif d['idx'] == idx_last_b:
                            idx_frame_keypoint_list[1] = d
                        else:
                            unidentified_detections.append(d)
                else:
                    none_idx = indexed_det_list[-1].index(None)
                    single_idx = 1 if none_idx == 0 else 0
                    match_id = identify_keypoint_index([indexed_det_list[-1][single_idx]], detection)
                    rem_match_id = 1 if match_id == 0 else 0
                    idx_frame_keypoint_list[single_idx] = detection[match_id]
                    idx_frame_keypoint_list[none_idx] = detection[rem_match_id]
                    
                if unidentified_detections:
                    #check keypoint overlap
                    if  None not in indexed_det_list[-1]:
                        for unident in unidentified_detections:
                            match_id = identify_keypoint_index([unident], indexed_det_list[-1])
                            if idx_frame_keypoint_list[match_id] != None:
                                if unident['score'] > idx_frame_keypoint_list[match_id]['score']:
                                    idx_frame_keypoint_list[match_id] = unident
                            else:
                                #unident['idx'] = indexed_det_list[-1][match_id]['idx']
                                idx_frame_keypoint_list[match_id] = unident
                            
        elif len(detection) > 2:
            print('Should not happen')

        indexed_det_list.append(idx_frame_keypoint_list)
    return indexed_det_list


In [169]:
ice_keypoints = []
for id_v, det_view in enumerate(corrected_key_view_list):
    ice_keypoints.append(view_index_consistency(det_view, id_v))


In [170]:

def get_missing_frames(key_list, search_idx):
    #Finds missing detections of a single individual per view
    miss= []
    for n_frame, frame_det in enumerate(key_list):

        assert len(frame_det) == 2
        if frame_det[search_idx] == None:
            miss.append(n_frame)
    return miss

In [171]:
def get_missing_frame_ranges(miss_list):
    range_list = []
    miss_det_range = []
    for idx, f_miss in enumerate(miss_list):
        if not miss_det_range:
            miss_det_range.append(f_miss)
        elif f_miss == (miss_det_range[-1] + 1):
            miss_det_range.append(f_miss)
        else:
            range_list.append(miss_det_range.copy())
            miss_det_range.clear()
            miss_det_range.append(f_miss)
              
    range_list.append(miss_det_range.copy())
        

    return range_list

In [156]:
a = get_missing_frames(ice_keypoints[1], 0)
z = get_missing_frame_ranges(a)
print(a)
print(z)



[424]
[[424]]


In [205]:
def interpolate_keypoints(start_point, end_point, n_interpol):
    assert (len(start_point) == len(end_point)), f'Interpolation lists need to be of the same size'
    interpolation_list = []

    for i in range(1, n_interpol):
        #For each missing point
        keypoints = []
       
        for start_, end_ in zip(start_point, end_point):
            #For each keypoint in list
            start = start_
            end = end_

            rng = abs(start-end)
            stride = rng/n_interpol
            if start > end:
                value = start - (stride * i)
            else:
                value = start + (stride * i)
            keypoints.append(value)
        
        interpolation_list.append(keypoints.copy())

    return interpolation_list  


In [206]:
interpolate_keypoints([10, 20, 30], [20, 30, 40], 3)

[[13.333333333333334, 23.333333333333332, 33.333333333333336],
 [16.666666666666668, 26.666666666666668, 36.666666666666664]]

In [212]:
def interpolate_missing_detections(keypoints):
    n_people = 2
    inter_keypoints = copy.deepcopy(keypoints)

    for person_id in range(n_people):
        for view in range(len(keypoints)):
            #Get missing detections of *person_id* in *view*
                print(f'View: {view} - Person {person_id}')
                missing_frame_list = get_missing_frames(keypoints[view], person_id)

                if missing_frame_list:
                    pers_miss_range = get_missing_frame_ranges(missing_frame_list)
                    pooled_points = []
                    print(pers_miss_range)
                    for range_f in pers_miss_range:
                        start = range_f[0]
                        end = range_f[-1]
                        pooled_points = interpolate_keypoints(keypoints[view][start-1][person_id]['keypoints'].copy(), keypoints[view][end+1][person_id]['keypoints'].copy(), len(range_f)+1)
                        
                        base_inst = copy.deepcopy(keypoints[view][start-1][person_id])
                        base_img_name = int(base_inst['image_id'].split('.png')[0])

                        assert len(range_f) == len(pooled_points)

                        for id, miss_frame in enumerate(range_f):
                            assert inter_keypoints[view][miss_frame][person_id] == None

                            inter_keypoints[view][miss_frame][person_id] = copy.deepcopy(base_inst)
                            frame_numb = '{0:05d}'.format(base_img_name + (id+1)) + '.png'

                            inter_keypoints[view][miss_frame][person_id]['image_id'] = frame_numb
                            inter_keypoints[view][miss_frame][person_id]['keypoints'] = pooled_points[id].copy()
    return inter_keypoints
            

In [213]:
interpool_keypoints = interpolate_missing_detections(ice_keypoints)

View: 0 - Person 0
View: 1 - Person 0
[[424]]
View: 2 - Person 0
View: 3 - Person 0
[[278, 279]]
View: 0 - Person 1
[[565], [572, 573, 574, 575, 576], [655], [659, 660, 661, 662, 663, 664], [677], [688], [691, 692, 693, 694], [871, 872, 873], [899], [901, 902, 903], [929], [936, 937, 938, 939], [942], [952], [959, 960, 961, 962, 963], [966, 967, 968, 969], [979]]
View: 1 - Person 1
View: 2 - Person 1
[[40, 41, 42, 43, 44, 45, 46, 47, 48], [197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253, 254, 255, 256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269, 270, 271, 272, 273, 274, 275, 276, 277, 278, 279, 280, 281, 282, 283, 284, 285, 286, 287, 288, 289, 290, 291, 292, 293, 294, 295, 296, 297, 298, 299, 300, 301, 302, 303, 304, 305, 306, 307, 3

View 0 -> 0:Parent, 1:Child<br>
View 1 -> 0:Child,  1:Parent<br>
View 2 -> 0:Parent, 1:Child<br>
View 3 -> 0:Parent, 1:Child<br>

In [209]:
print(interpool_keypoints[0][658][1])
print(interpool_keypoints[0][659][1])
print(interpool_keypoints[0][660][1])
print(interpool_keypoints[0][661][1])
print(interpool_keypoints[0][662][1])
print(interpool_keypoints[0][663][1])
print(interpool_keypoints[0][664][1])
print(interpool_keypoints[0][665][1])

{'image_id': '00659.png', 'category_id': 1, 'keypoints': [408.39886474609375, 221.78453063964844, 0.8888577222824097, 406.6169128417969, 214.65667724609375, 0.8678443431854248, 403.052978515625, 218.22061157226562, 0.8416237831115723, 379.8874816894531, 228.91236877441406, 0.3592364490032196, 385.2333679199219, 230.6943359375, 0.6958558559417725, 390.5792541503906, 236.04022216796875, 0.6171960830688477, 374.5415954589844, 252.077880859375, 0.8521766662597656, 395.9251403808594, 264.5516052246094, 0.5419926047325134, 385.2333679199219, 278.8072814941406, 0.9980472326278687, 415.5267333984375, 275.2433776855469, 0.4356701076030731, 415.5267333984375, 271.679443359375, 0.9244371652603149, 397.70709228515625, 293.06298828125, 0.7367123365402222, 379.8874816894531, 307.31866455078125, 0.8320669531822205, 431.5643615722656, 285.9351501464844, 0.6628293991088867, 404.8349609375, 314.446533203125, 0.6772988438606262, 463.6396789550781, 293.06298828125, 0.6470603346824646, 436.9102478027344, 3

In [267]:
keypoint_string='''
{
    "version":1.1,
    "people":[
    {"pose_keypoints_2d":[]
    },
    {"pose_keypoints_2d":[]
    }
    ]
}
'''

def save_json(folder, file, temp):
    with open(f'{folder}/{file}_keypoints.json', 'w') as f:
        json.dump(temp, f)
    f.close()

In [270]:
json_string = json.loads(keypoint_string)

save_json('00', 'asd', json_string)

In [276]:
#for view in range(len(interpool_keypoints)):
for frame in interpool_keypoints[3]:
    json_string = json.loads(keypoint_string)
    
    for i_d, det in enumerate(frame):
        #i = 0
        #if i_d == 0:
            #i = 1
        frame_name = det['image_id'].split('.png')[0]
        json_string['people'][i_d]['pose_keypoints_2d'] = copy.deepcopy(det['keypoints'])
    save_json('03', frame_name, json_string)


