In [27]:
import requests
import pickle
import copy
import json
from collections import defaultdict


track_names = list(map(str,[1]))
min_iou = 0.3
fit_method = "xy"
override_segment = True
interval_to_change = [20,50]

res = requests.get('http://localhost:8080/api/v1/tasks/16/annotations', auth=('admin', 'mjuzik'))
cvat_annotations = res.json()
shapes = cvat_annotations['shapes']

with open('car_video/annotations.pkl','rb') as f:
    tracked_annotations = pickle.load(f)


def get_iou(bb1, bb2):
    """
    Calculate the Intersection over Union (IoU) of two bounding boxes.

    Parameters
    ----------
    bb1 : dict
        Keys: {'x1', 'x2', 'y1', 'y2'}
        The (x1, y1) position is at the top left corner,
        the (x2, y2) position is at the bottom right corner
    bb2 : dict
        Keys: {'x1', 'x2', 'y1', 'y2'}
        The (x, y) position is at the top left corner,
        the (x2, y2) position is at the bottom right corner

    Returns
    -------
    float
        in [0, 1]
    """
    assert bb1['x1'] < bb1['x2']
    assert bb1['y1'] < bb1['y2']
    assert bb2['x1'] < bb2['x2']
    assert bb2['y1'] < bb2['y2']

    # determine the coordinates of the intersection rectangle
    x_left = max(bb1['x1'], bb2['x1'])
    y_top = max(bb1['y1'], bb2['y1'])
    x_right = min(bb1['x2'], bb2['x2'])
    y_bottom = min(bb1['y2'], bb2['y2'])

    if x_right < x_left or y_bottom < y_top:
        return 0.0

    # The intersection of two axis-aligned bounding boxes is always an
    # axis-aligned bounding box
    intersection_area = (x_right - x_left) * (y_bottom - y_top)

    # compute the area of both AABBs
    bb1_area = (bb1['x2'] - bb1['x1']) * (bb1['y2'] - bb1['y1'])
    bb2_area = (bb2['x2'] - bb2['x1']) * (bb2['y2'] - bb2['y1'])

    # compute the intersection over union by taking the intersection
    # area and dividing it by the sum of prediction + ground-truth
    # areas - the interesection area
    iou = intersection_area / float(bb1_area + bb2_area - intersection_area)
    assert iou >= 0.0
    assert iou <= 1.0
    return iou

def get_xs_and_ys(shape):
    points = shape['points']
    l = len(points)//2
    xs = []
    ys = []
    for i in range(l):
        xs.append(points[i*2])
        ys.append(points[i*2+1])  
    return (xs,ys)

def check_if_shape_has_track(shape, tracked_annotations):   
    xs,ys = get_xs_and_ys(shape)
    x1,x2,y1,y2 = min(xs),max(xs),min(ys),max(ys)
    rect4shape = {'x1':x1,'x2':x2,'y1':y1,'y2':y2}
    id_iou_tuples = []
    for k,rect in tracked_annotations.items():
        tracked_rect = {'x1':rect[0],'x2':rect[0]+rect[2],'y1':rect[1],'y2':rect[1]+rect[3]}
        iou = get_iou(tracked_rect, rect4shape)
        if iou > min_iou:
            id_iou_tuples.append((k,iou))
    if len(id_iou_tuples) == 0:
        return None
    # get the key of the track with the highest iou
    else: return sorted(id_iou_tuples,key=lambda x:x[1], reverse=True)[0][0] 

def get_center(rect,axis):
    center = rect[axis] + rect[axis+2]/2
    return center  


def get_scaled_components(components,center,scale):
    scaled_components = []
    for c in components:
        new_c = center + (c - center)*scale
        scaled_components.append(new_c)
    return scaled_components

def get_moved_points(tracked_rect_start,tracked_rect_current,xs,ys,fit_method):
    tracked_center_start_x = get_center(tracked_rect_start,0)
    tracked_center_start_y = get_center(tracked_rect_start,1)
    tracked_center_current_x = get_center(tracked_rect_current,0)
    tracked_center_current_y = get_center(tracked_rect_current,1)
    delta_x,delta_y = tracked_center_current_x - tracked_center_start_x,tracked_center_current_y - tracked_center_start_y
    if fit_method == "xy":
        xs,ys = xs,ys
    elif fit_method == 'xyscale':
        x_center = (min(xs) + max(xs))/2
        y_center = (min(ys) + max(ys))/2
        scale_x = tracked_rect_current[2]/tracked_rect_start[2]
        scale_y = tracked_rect_current[3]/tracked_rect_start[3]
        xs = get_scaled_components(xs,x_center,scale_x)
        ys = get_scaled_components(ys,y_center,scale_y) 
    else:
        raise Exception(f'Fit method not implemented: {fit_method}')
    moved_points = []
    for j in range(len(xs)):
        moved_points.append(xs[j]+delta_x)
        moved_points.append(ys[j]+delta_y)
    return moved_points

def get_tracked_labels(frames,key,index,fit_method):
    tracked_labels = []
    ix2frame =  {frame['frame']:frame for frame in frames}
    ixs = ix2frame.keys()
    if override_segment:
        ixs = list(set(ixs)-set(range(interval_to_change[0]+1,interval_to_change[1]+1)))
    xs,ys = [],[]
    idx = index
    first_frame_ix = min(ixs)
    for i,tr in enumerate(tracked_annotations):
        if i<first_frame_ix: continue
            
        if override_segment and (i < interval_to_change[0] or i > interval_to_change[1]):
            continue
        if i in ixs:
            frame = ix2frame[i]
            xs,ys = get_xs_and_ys(frame)
            tracked_rect_start = tr[key]
        else:
            tracked_rect_current = tr[key]
            moved_points = get_moved_points(tracked_rect_start,tracked_rect_current,xs,ys,fit_method)
            new_frame = copy.deepcopy(frame)
            new_frame['frame']=i
            idx += 1
            new_frame['id']=idx
            new_frame['points']=moved_points
            tracked_labels.append(new_frame)
    return tracked_labels

matches = defaultdict(list)
indexes = []
original_frames = []
for shape in shapes:
    indexes.append(shape['id'])
    frame_ix = shape['frame']
    annotation4frame = tracked_annotations[frame_ix]
    matched_track = check_if_shape_has_track(shape, annotation4frame)
    if override_segment and (shape['frame'] not in list(range(interval_to_change[0]+1,interval_to_change[1]+1))):
        original_frames.append(shape)
    if matched_track is None:
        continue
    else:
        if matched_track not in track_names:
            print(matched_track)
            continue
        matches[matched_track].append(shape)
    
all_tracked_labels = []
max_index = max(indexes) + 100000 # otherwise create conflicts

for k, frames in matches.items():
    sorted_frames = sorted(frames, key=lambda x: x['frame'])
    tracked_labels = get_tracked_labels(frames,k,max_index,fit_method)
    all_tracked_labels += tracked_labels
    max_index += len(tracked_labels)
    
all_frames = original_frames + all_tracked_labels


cvat_annotations['shapes'] = all_frames
value = json.dumps(cvat_annotations)
headers = {'Content-Type':'application/json'}
res = requests.put('http://localhost:8080/api/v1/tasks/16/annotations', data=value, headers=headers, auth=('admin', 'mjuzik'))
res.text


'{"shapes":[{"type":"polyline","occluded":false,"z_order":1,"points":[526.79296875,472.6796875,531.1995071360943,478.8484536882206,538.7000000000007,480.6000000000022,546.6000000000022,481.5,558.9614294190578,478.8484536882206,560.2834257182476,472.23847219227537,554.9954405214921,467.8318178616464,541.334812096542,467.39115242858315,531.1995071360943,466.069156129397,526.7928528054654,470.9164758930892],"id":28,"frame":0,"label_id":43,"group":0,"attributes":[]},{"type":"polyline","occluded":false,"z_order":1,"points":[523.29296875,470.1796875,527.6995071360943,476.3484536882206,535.2000000000007,478.1000000000022,543.1000000000022,479.0,555.4614294190578,476.3484536882206,556.7834257182476,469.73847219227537,551.4954405214921,465.3318178616464,537.834812096542,464.89115242858315,527.6995071360943,463.569156129397,523.2928528054654,468.4164758930892],"id":100030,"frame":1,"label_id":43,"group":0,"attributes":[]},{"type":"polyline","occluded":false,"z_order":1,"points":[520.29296875,470

In [25]:
%debug


> [0;32m<ipython-input-24-b85ea44a7929>[0m(169)[0;36m<module>[0;34m()[0m
[0;32m    167 [0;31m    [0mannotation4frame[0m [0;34m=[0m [0mtracked_annotations[0m[0;34m[[0m[0mframe_ix[0m[0;34m][0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    168 [0;31m    [0mmatched_track[0m [0;34m=[0m [0mcheck_if_shape_has_track[0m[0;34m([0m[0mshape[0m[0;34m,[0m [0mannotation4frame[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m--> 169 [0;31m    [0;32mif[0m [0moverride_segment[0m [0;32mand[0m [0;34m([0m[0mshape[0m[0;34m[[0m[0;34m'frame'[0m[0;34m][0m [0;32mnot[0m [0;32min[0m [0mlist[0m[0;34m([0m[0mrange[0m[0;34m[[0m[0minterval_to_change[0m[0;34m[[0m[0;36m0[0m[0;34m][0m[0;34m+[0m[0;36m1[0m[0;34m,[0m[0minterval_to_change[0m[0;34m[[0m[0;36m1[0m[0;34m][0m[0;34m+[0m[0;36m1[0m[0;34m][0m[0;34m)[0m[0;34m)[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    170 [0;31m        [0moriginal_frames[0m[0;34m.[0m[0mappe