In [None]:
# Load the annotation file
# Consider a particular annotation
# Load the corresponding gif
# Track the bounding boxes
# Repurpose the IKEA-ASM feature extraction code to extract the features
# Will need to implement the code for I3D network or repurpose the code 

In [3]:
# Load the annotation file
anno_path = '/workspace/work/O2ONet/data/annotations_minus_unavailable_yt_vids.pkl'

import pickle as pkl

f = open(anno_path, 'rb')
anno = pkl.load(f)
f.close()

In [23]:

def tracker(frames, bbox_tb):

    import cv2
    import sys

    main_bbox_wh = (bbox_tb[0], bbox_tb[1], bbox_tb[2]-bbox_tb[0], bbox_tb[3]-bbox_tb[1])

    (major_ver, minor_ver, subminor_ver) = cv2.__version__.split('.')


    # Set up tracker.
    # Instead of MIL, you can also use

    tracker_types = ['BOOSTING', 'MIL','KCF', 'TLD', 'MEDIANFLOW', 'GOTURN', 'MOSSE', 'CSRT']
    tracker_type = tracker_types[-2]

    if int(minor_ver) < 3:
        tracker = cv2.Tracker_create(tracker_type)
    else:
        if tracker_type == 'BOOSTING':
            tracker = cv2.TrackerBoosting_create()
            tracker_rev = cv2.TrackerBoosting_create()
        if tracker_type == 'MIL':
            tracker = cv2.TrackerMIL_create()
            tracker_rev = cv2.TrackerMIL_create()
        if tracker_type == 'KCF':
            tracker = cv2.TrackerKCF_create()
            tracker_rev = cv2.TrackerKCF_create()
        if tracker_type == 'TLD':
            tracker = cv2.TrackerTLD_create()
            tracker_rev = cv2.TrackerTLD_create()
        if tracker_type == 'MEDIANFLOW':
            tracker = cv2.TrackerMedianFlow_create()
            tracker_rev = cv2.TrackerMedianFlow_create()
        if tracker_type == 'GOTURN':
            tracker = cv2.TrackerGOTURN_create()
            tracker_rev = cv2.TrackerGOTURN_create()
        if tracker_type == 'MOSSE':
            tracker = cv2.legacy_TrackerMOSSE.create()
            tracker_rev = cv2.legacy_TrackerMOSSE.create()
        if tracker_type == "CSRT":
            tracker = cv2.TrackerCSRT_create()
            tracker_rev = cv2.TrackerCSRT_create()

    num_frames = len(frames)

    central_index = int((num_frames - 1)/2)
    window_size = int(num_frames/2)

    central_frame = frames[central_index]

    # Initialize tracker with first frame and bounding box

    ok = tracker.init(central_frame, main_bbox_wh)
    bboxes_forward = []

    for i in range(window_size):

        # Read a new frame
        frame = frames[central_index + 1 + i]        

        # Update tracker
        ok, bbox_wh = tracker.update(frame)

        # add to the bbox list
        if ok:
            bbox_tb = [ bbox_wh[0], bbox_wh[1], bbox_wh[0] + bbox_wh[2], bbox_wh[1] + bbox_wh[3] ]
            bboxes_forward.append(bbox_tb)
            # # Tracking success
            # p1 = (int(bbox[0]), int(bbox[1]))
            # p2 = (int(bbox[0] + bbox[2]), int(bbox[1] + bbox[3]))
            # cv2.rectangle(frame, p1, p2, (255,0,0), 2, 1)
        else :
            print("Tracking Failure")
            return 0
            # Tracking failure
            # cv2.putText(frame, "Tracking failure detected", (100,80), cv2.FONT_HERSHEY_SIMPLEX, 0.75,(0,0,255),2)

    # Initialize tracker with first frame and bounding box
    ok = tracker_rev.init(central_frame, main_bbox_wh)
    bboxes_backward = []
    for i in range(window_size):
        
        # Read a new frame
        frame = frames[central_index - 1 - i]        

        # Update tracker
        ok, bbox_wh = tracker_rev.update(frame)

        # Add to the bbox list
        if ok:
            bbox_tb = [ bbox_wh[0], bbox_wh[1], bbox_wh[0] + bbox_wh[2], bbox_wh[1] + bbox_wh[3] ]
            bboxes_backward.append(bbox_tb)
            # # Tracking success
            # p1 = (int(bbox[0]), int(bbox[1]))
            # p2 = (int(bbox[0] + bbox[2]), int(bbox[1] + bbox[3]))
            # cv2.rectangle(frame, p1, p2, (255,0,0), 2, 1)
        else:
            print("Tracking Failure")
            return 0
            # Tracking failure
            # cv2.putText(frame, "Tracking failure detected", (100,80), cv2.FONT_HERSHEY_SIMPLEX, 0.75,(0,0,255),2)

    bboxes_backward_reversed = bboxes_backward[-1::-1]
    all_bbox = bboxes_backward_reversed + [bbox_tb] + bboxes_backward_reversed
    
    return all_bbox

In [24]:
def visualise_tracking(frames, bboxes):

        import cv2
        vis_frames = []
        
        for i, frame in enumerate(frames):

            bbox = bboxes[i]
            p1 = ( int(bbox[0]), int(bbox[1]) )
            p2 = ( int(bbox[2]), int(bbox[3]) )
            temp_frame = cv2.rectangle(frame, p1, p2, (255,0,0), 2, 1)
            rgb_frame = cv2.cvtColor(temp_frame, cv2.COLOR_BGR2RGB)

            vis_frames.append(rgb_frame)
        
        import imageio
        fps = 4
        imageio.mimsave( './visualisation.gif', vis_frames, fps=4)


def track_bbox(anno, gif_folder):

    bbox = anno['bboxes']['4']['bbox']
    
    yt_id = anno['metadata']['yt_id']
    frame_index = anno['metadata']['frame no.']
    window_size = 5
    
    filename = yt_id + '_' + str(frame_index) + '_' + str(window_size) + '.gif'
    import os
    file_location = os.path.join(gif_folder, filename)
    import cv2
    vid = cv2.VideoCapture(file_location)
    frames = []

    frame_count = int(vid.get(cv2.CAP_PROP_FRAME_COUNT))

    for i in range(frame_count):
        success, frame = vid.read()
        frames.append(frame)

    bboxes = tracker(frames, bbox)
    visualise_tracking(frames, bboxes)
    return

gif_path = '/workspace/data/data_folder/o2o/gifs_11'
track_bbox(anno[156], gif_path)