In [1]:
import os
import json
import pandas as pd
from PIL import Image
from tqdm import *
import numpy as np
import pickle
from shutil import copyfile
from mmdet.apis import init_detector, inference_detector, show_result
import glob

import torch.utils.data as data

from PIL import Image
import os
import os.path

import cv2
import sys
from copy import deepcopy
import torchvision
import torch
from mmdet.ops.nms import nms_cpu, nms
from skimage.feature import match_template
from skimage.color import rgb2gray
import numpy as np
import time

from psina_track_now_staff import *

In [2]:
IMAGES_DATA_PATH = "/home/mml6/IceChallenge/test/"
PREDICTIONS_PATH = "Final_model/full_skolkovo_final_stage0_842.pkl"
PRED_THRESHOLD = 0.55
IOU_NMS_THRESHOLD = 0.05
MIN_BBOX_SQUARE = 100
TRACKING_MIN_IOU = 0.001
TRACKING_IOU_INTERPOLATE = 0.9
TRACKING_MIN_CORRELATION = 0.5
TRACKING_MAX_DISTANCE_PIXELS = 60
BIG_CROP_PADDING = 5

In [3]:
with open(PREDICTIONS_PATH, 'rb') as fin:  
    detector_predictions = filter_all_predictions(
        pickle.load(fin),
        PRED_THRESHOLD,
        IOU_NMS_THRESHOLD,
        MIN_BBOX_SQUARE)

In [4]:
file_names = []

for cur_img in glob.glob(IMAGES_DATA_PATH + "**", recursive=True):
    if not ".jpg" in cur_img:
        continue
    cur_img = '/'.join(cur_img.split('/')[-2:])
    file_names.append(cur_img)

file_names = np.array(file_names)
video_seq = np.argsort(file_names)[::-1]
selected_filenames = file_names[video_seq]

In [5]:
with open('all_classes.txt') as f:
    all_pos_classes = f.read().split()

convert_class = lambda x: '.'.join(str(x).split('.')[:2])

all_pos_classes = [convert_class(sign) for sign in all_pos_classes]

valid_classes = sorted(
    ['2.1',
     '2.4',
     '3.1',
     '3.24',
     '3.27',
     '4.1',
     '4.2',
     '5.19',
     '5.20',
     '8.22'])

In [11]:
def match_bboxes_iou(start_frame_bboxes, finish_frame_bboxes):
    # Returns list of pairs (start_bbox, finish_bbox, start_bbox_index, finish_bbox_index)
    # where each bbox is a pair (bbox, class_id)
    
    start_to_finish_iou = []
    for i, start_bbox in enumerate(start_frame_bboxes):
        for j, finish_bbox in enumerate(finish_frame_bboxes):
            score = iou(start_bbox[0], finish_bbox[0])
            if start_bbox[1] != finish_bbox[1]:
                score = 0
            
            start_to_finish_iou.append((score, i, j))
    
    start_to_finish_iou.sort(key=lambda x: x[0], reverse=True)
    
    matched_pairs = []
    used_start_bboxes = set()
    used_finish_bboxes = set()
    
    for score, i, j in start_to_finish_iou:
        if score < TRACKING_MIN_IOU:
            break
    
        if i in used_start_bboxes or j in used_finish_bboxes:
            continue
        
        start_bbox = start_frame_bboxes[i][:-1]
        finish_bbox = finish_frame_bboxes[j][:-1]
        matched_pairs.append((start_bbox, finish_bbox, i, j))
        
        used_start_bboxes.add(i)
        used_finish_bboxes.add(j)

    return matched_pairs

def get_distance_between_bboxes(start_bbox, finish_bbox):
    distance = (start_bbox[0] - finish_bbox[0]) ** 2 + (start_bbox[1] - finish_bbox[1]) ** 2
    return distance ** (1 / 2.)

def match_bboxes_template(start_frame, finish_frame):
    # Returns list of pairs (start_bbox, finish_bbox, start_bbox_index, finish_bbox_index)
    # where each bbox is a pair (bbox, class_id)
    
    start_frame_bboxes = start_frame[0]
    finish_frame_bboxes = finish_frame[0]
    
    start_frame_img = start_frame[1]
    finish_frame_img = finish_frame[1]
    
    start_to_finish_iou = []
    for i, start_bbox in enumerate(start_frame_bboxes):
        for j, finish_bbox in enumerate(finish_frame_bboxes):
            if start_bbox[1] != finish_bbox[1]:
                continue

            distance = get_distance_between_bboxes(start_bbox[0], finish_bbox[0])
            if distance > TRACKING_MAX_DISTANCE_PIXELS:
                continue

            start_crop = start_frame_img[
                int(start_bbox[0][1]):int(start_bbox[0][1] + start_bbox[0][3]),
                int(start_bbox[0][0]):int(start_bbox[0][0] + start_bbox[0][2])]

            finish_crop = finish_frame_img[
                int(finish_bbox[0][1]):int(finish_bbox[0][1] + finish_bbox[0][3]),
                int(finish_bbox[0][0]):int(finish_bbox[0][0] + finish_bbox[0][2])]
            
            h = min(start_crop.shape[0], finish_crop.shape[0])
            w = min(start_crop.shape[1], finish_crop.shape[1])
            start_crop = cv2.resize(start_crop, (w, h))
            finish_crop = cv2.resize(finish_crop, (w, h))

            score = cv2.matchTemplate(start_crop, finish_crop, cv2.TM_CCOEFF_NORMED)[0][0]
            
            start_to_finish_iou.append((score, i, j))
    
    start_to_finish_iou.sort(key=lambda x: x[0], reverse=True)
    
    matched_pairs = []
    used_start_bboxes = set()
    used_finish_bboxes = set()
    
    for score, i, j in start_to_finish_iou:
        if score < TRACKING_MIN_CORRELATION:
            break
    
        if i in used_start_bboxes or j in used_finish_bboxes:
            continue
        
        start_bbox = start_frame_bboxes[i][:-1]
        finish_bbox = finish_frame_bboxes[j][:-1]
        matched_pairs.append((start_bbox, finish_bbox, i, j))
        
        used_start_bboxes.add(i)
        used_finish_bboxes.add(j)

    return matched_pairs


def match_bboxes(start_frame, finish_frame):
    # Returns list of pairs (start_bbox, finish_bbox)
    # where each bbox is a pair (bbox, class_id)
    
    matched_pairs = []
    
    matched_pairs_iou = match_bboxes_iou(start_frame[0], finish_frame[0])
    start_bbox_indices_to_remove = set()
    finish_bbox_indices_to_remove = set()
    
    for start_bbox, finish_bbox, start_bbox_index, finish_bbox_index in matched_pairs_iou:
        start_bbox_indices_to_remove.add(start_bbox_index)
        finish_bbox_indices_to_remove.add(finish_bbox_index)
        
        matched_pairs.append((start_bbox, finish_bbox))
    
    new_start_frame_bboxes = []
    new_finish_frame_bboxes = []
    
    for i, bbox in enumerate(start_frame[0]):
        if i not in start_bbox_indices_to_remove:
            new_start_frame_bboxes.append(bbox)

    for j, bbox in enumerate(finish_frame[0]):
        if j not in finish_bbox_indices_to_remove:
            new_finish_frame_bboxes.append(bbox)
    
    start_frame = [new_start_frame_bboxes, start_frame[1]]
    finish_frame = [new_finish_frame_bboxes, finish_frame[1]]

    matched_pairs_template = match_bboxes_template(start_frame, finish_frame)
    
    for start_bbox, finish_bbox, start_bbox_index, finish_bbox_index in matched_pairs_template:
        matched_pairs.append((start_bbox, finish_bbox))

    return matched_pairs

In [12]:
def run_tracking(sequence_tracking):
    result = []

    start_frame = sequence_tracking[0]
    finish_frame = sequence_tracking[-1]
    
    start_frame_img = start_frame[1]
    finish_frame_img = finish_frame[1]

    interpolate_frames = sequence_tracking[1:-1]

    matched_bboxes = match_bboxes(start_frame, finish_frame)

    for frame_num, (_, image) in enumerate(interpolate_frames):
        frame_bboxes = []
        for start_bbox, finish_bbox in matched_bboxes:
            class_id = start_bbox[1]
            start_bbox = start_bbox[0]
            finish_bbox = finish_bbox[0]
            
            start_bbox_size = [start_bbox[2], start_bbox[3]]
            finish_bbox_size = [finish_bbox[2], finish_bbox[3]]
            
            bbox_diff = [
                finish_bbox[0] - start_bbox[0],
                finish_bbox[1] - start_bbox[1],
                finish_bbox[2] - start_bbox[2],
                finish_bbox[3] - start_bbox[3]]
            
            new_coarse_bbox = start_bbox[:]
            for i in range(4):
                new_coarse_bbox[i] += bbox_diff[i] / (len(interpolate_frames) + 1) * (frame_num + 1)
                
            
            if iou(start_bbox, finish_bbox) > TRACKING_IOU_INTERPOLATE:
                frame_bboxes.append((new_coarse_bbox, class_id, 1))
                continue

            start_bbox = hw_to_min_max(start_bbox)
            finish_bbox = hw_to_min_max(finish_bbox)

            bbox_find_area = [
                min(start_bbox[0], finish_bbox[0]) - BIG_CROP_PADDING,
                min(start_bbox[1], finish_bbox[1]) - BIG_CROP_PADDING,
                max(start_bbox[2], finish_bbox[2]) + BIG_CROP_PADDING,
                max(start_bbox[3], finish_bbox[3]) + BIG_CROP_PADDING
            ]
            
            bbox_find_area = [
                max(bbox_find_area[0], 0),
                max(bbox_find_area[1], 0),
                min(bbox_find_area[2], image.shape[1]),
                min(bbox_find_area[3], image.shape[0]),
            ]
            
            crop_proposals = []

            start_crop = start_frame_img[
                int(start_bbox[1]):int(start_bbox[3]),
                int(start_bbox[0]):int(start_bbox[2])]
#             crop_proposals.append((start_crop, start_bbox_size[0], start_bbox_size[1]))

#             finish_crop = finish_frame_img[
#                 int(finish_bbox[1]):int(finish_bbox[3]),
#                 int(finish_bbox[0]):int(finish_bbox[2])]
#             crop_proposals.append((finish_crop, finish_bbox_size[0], finish_bbox_size[1]))
            
            estimated_crop = cv2.resize(start_crop, (int(new_coarse_bbox[2]), int(new_coarse_bbox[3])))
            crop_proposals.append((estimated_crop, estimated_crop.shape[1], estimated_crop.shape[0]))
            
            find_area_crop = image[
                int(bbox_find_area[1]):int(bbox_find_area[3]),
                int(bbox_find_area[0]):int(bbox_find_area[2])]
            

            crop_proposals_scores = []
            for crop_proposal in crop_proposals:
                score = cv2.matchTemplate(find_area_crop, crop_proposal[0], cv2.TM_CCOEFF_NORMED)
                score_argmax = np.argmax(score)
                
                crop_proposals_scores.append((score_argmax, score, crop_proposal[1], crop_proposal[2]))
            
            
            best_proposal = max(crop_proposals_scores, key=lambda x: x[0])
            
            match_result = np.unravel_index(best_proposal[0], best_proposal[1].shape)
            new_y_min = float(match_result[0] + int(bbox_find_area[1])) # TODO test it!
            new_x_min = float(match_result[1] + int(bbox_find_area[0]))
            new_y_max = new_y_min + best_proposal[3]
            new_x_max = new_x_min + best_proposal[2]
            
            new_bbox = [new_x_min, new_y_min, new_x_max, new_y_max]
            new_bbox = min_max_to_hw(new_bbox)
            
            frame_bboxes.append((new_bbox, class_id, 1))
        
        result.append(frame_bboxes)

    result.append(finish_frame[0])
    return result

In [13]:
DETECTOR_FREQUENCY = 3

dataset = ImageFilelist(IMAGES_DATA_PATH, selected_filenames)
loader = torch.utils.data.DataLoader(dataset,
                                     shuffle=False,
                                     num_workers=12)
loader = iter(loader)

final_boxes = []
start_time = time.time()

current_sequence_tracking = []
current_sequence_index = 0

for ind in range(len(selected_filenames))[:]:
    if ind == 0 or selected_filenames[ind].split('/')[0] != selected_filenames[ind-1].split('/')[0]:
        # NEW SEQUENCE starts!!!
        if len(current_sequence_tracking) > 0:
            for boxes, _ in current_sequence_tracking[1:]:
                final_boxes.append(boxes)
        
        current_sequence_tracking = []
        current_sequence_index = 0

    cur_img_gray = next(loader)[0][0].data.numpy()
    #cur_img_gray = cv2.cvtColor(cur_img, cv2.COLOR_BGR2GRAY)

    #continue
    frame_predictions = detector_predictions[ind]
    
    frame_final_boxes = []    
    if (not current_sequence_tracking) or ((len(current_sequence_tracking) + 1) % DETECTOR_FREQUENCY == 0):
        frame_final_boxes = []
        for prediction in frame_predictions:
            bbox = prediction[:4]
            class_id = int(prediction[4])
            frame_final_boxes.append((bbox, class_id, 0))
    
    current_sequence_tracking.append((frame_final_boxes, cur_img_gray))
    
    if len(current_sequence_tracking) == DETECTOR_FREQUENCY:
        if current_sequence_index == 0:
            final_boxes.append(current_sequence_tracking[0][0])
            current_sequence_index += 1
        for boxes in run_tracking(current_sequence_tracking):
            final_boxes.append(boxes)
        
        current_sequence_tracking = [current_sequence_tracking[-1]]

    current_time = int(time.time() - start_time)
    time_per_iter = (time.time() - start_time) / (ind + 1)
    eta_time = int((len(selected_filenames) - ind) * time_per_iter)
    print("\rIndex: {}. Time: {} s. ETA: {} s".format(ind, current_time, eta_time), end="")

if len(current_sequence_tracking) > 0:
    for boxes, _ in current_sequence_tracking[1:]:
        final_boxes.append(boxes)

Index: 14999. Time: 63 s. ETA: 0 s

In [14]:
with open('tracking_check.tsv', 'w') as f:
    f.write('\t'.join(['frame', 'xtl', 'ytl', 'xbr', 'ybr', 'class']) + '\n')
    
    for ind in range(len(selected_filenames)):
        img_name = selected_filenames[ind]
        img_name = img_name.replace('.jpg', '')
        for bbox, class_id, _ in final_boxes[ind]:
            class_id = int(class_id)
            class_name = all_pos_classes[class_id]
            if class_name not in valid_classes:
                continue
            bbox = hw_to_min_max(bbox)
            bbox = list(map(str, bbox))
            f.write('\t'.join([img_name, *bbox, class_name]) + '\n')

In [15]:
!./score/target/release/icevision-score /media/mml6/HDD/Ice/annotations/final ~/IceChallenge/tracking_check.tsv

Total score:	787.791
Total penalty:	124.000
Score 2.1:	41.580
Score 2.4:	104.835
Score 3.1:	53.158
Score 3.24:	-5.280
Score 3.27:	56.303
Score 4.1:	53.973
Score 4.2:	55.698
Score 5.19:	255.439
Score 5.20:	140.381
Score 8.22:	31.703
Penalty 2.1:	12.000
Penalty 2.4:	26.000
Penalty 3.1:	6.000
Penalty 3.24:	36.000
Penalty 3.27:	6.000
Penalty 4.1:	10.000
Penalty 4.2:	4.000
Penalty 5.19:	24.000
Penalty 5.20:	0.000
Penalty 8.22:	0.000


In [None]:
# REFERENCE SCORE

# Total score:	819.665
# Total penalty:	142.000
# Score 2.1:	43.085
# Score 2.4:	102.378
# Score 3.1:	57.288
# Score 3.24:	-2.041
# Score 3.27:	54.694
# Score 4.1:	58.172
# Score 4.2:	54.263
# Score 5.19:	283.166
# Score 5.20:	135.837
# Score 8.22:	32.821
# Penalty 2.1:	14.000
# Penalty 2.4:	28.000
# Penalty 3.1:	8.000
# Penalty 3.24:	36.000
# Penalty 3.27:	8.000
# Penalty 4.1:	10.000
# Penalty 4.2:	4.000
# Penalty 5.19:	24.000
# Penalty 5.20:	10.000
# Penalty 8.22:	0.000

In [25]:
ind = 0
for cur_name, frame_final_predictions in zip(selected_filenames[:1000], final_boxes):
    image = cv2.imread(os.path.join(IMAGES_DATA_PATH, cur_name))
    for bbox, class_id, source in frame_final_predictions:
        class_name = all_pos_classes[int(class_id)]
        if class_name not in valid_classes:
            continue
        bbox = hw_to_min_max(bbox)
        
        if source == 0: # detector
            cv2.rectangle(image, (int(bbox[0]), int(bbox[1])),
                          (int(bbox[2]), int(bbox[3])),
                          (0, 255, 0), 2)
        elif source == 1: # tracking
            cv2.rectangle(image, (int(bbox[0]), int(bbox[1])),
                          (int(bbox[2]), int(bbox[3])), (0, 0, 255), 2)
    
        cv2.putText(image,
                        class_name,# + ' ' + str(cur_prob)[:4],
                        (int(bbox[0]), int(bbox[1])),
                        cv2.FONT_HERSHEY_SIMPLEX,
                        1.5,
                        (0, 255, 0),
                        thickness=3,
                        lineType=cv2.LINE_AA) 
    
    cv2.imwrite('VisualTracking/' + cur_name.split('/')[1], image)
    ind += 1