In [None]:
import os
from os import listdir
import shutil
import json
import time
from tqdm import tqdm
import pickle

import torch
import torchvision
from torchvision.io import read_video
from torchvision.ops import box_convert

from torchvision.models.detection import maskrcnn_resnet50_fpn, MaskRCNN_ResNet50_FPN_Weights
from torchvision.utils import draw_bounding_boxes
import torchvision.transforms.functional as F

import cv2 as cv
import numpy as np
import pandas as pd
from scipy.spatial import distance

import matplotlib.pyplot as plt

In [None]:
scene_graph_path = f'../DVUChallenge/dev_dataset/scenes_knowledge_graphs/'

shots_data_path = f'../DVUChallenge/dev_dataset/movie.shots/'

dec_rate = 50

filmnames = [file for file in listdir(f'../DVUChallenge/dev_dataset/movie_knowledge_graph/') if '.' not in file]

### Object-based tracking

In [None]:
!git clone https://github.com/KaihuaTang/Scene-Graph-Benchmark.pytorch
!git clone https://github.com/paul-pias/Face-Recognition

In [None]:
mask_rcnn_labels = np.array(['BG', 'person', 'bicycle', 'car', 'motorcycle', 'airplane',
               'bus', 'train', 'truck', 'boat', 'traffic light',
               'fire hydrant', 'stop sign', 'parking meter', 'bench', 'bird',
               'cat', 'dog', 'horse', 'sheep', 'cow', 'elephant', 'bear',
               'zebra', 'giraffe', 'backpack', 'umbrella', 'handbag', 'tie',
               'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball',
               'kite', 'baseball bat', 'baseball glove', 'skateboard',
               'surfboard', 'tennis racket', 'bottle', 'wine glass', 'cup',
               'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple',
               'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza',
               'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed',
               'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote',
               'keyboard', 'cell phone', 'microwave', 'oven', 'toaster',
               'sink', 'refrigerator', 'book', 'clock', 'vase', 'scissors',
               'teddy bear', 'hair drier', 'toothbrush'])

In [None]:
model = maskrcnn_resnet50_fpn(weights=MaskRCNN_ResNet50_FPN_Weights.DEFAULT).cuda()
_ = model.eval()

In [None]:
def bbox_2_mask(bbox, dims=(1,720,1280)):
    base = np.zeros(shape=(1,720,1280))
    base[0, bbox[1]:bbox[3],bbox[0]:bbox[2]] = 1
    return base

# ================

def getLastCentersForAllEntities(entityList):
    centers = []
    for entity in entityList:
        centers.append(entity['centers'][-1])
    return centers

def getClosestEntity(center, entityCenters):
    dists = []
    for entity_center in entityCenters:
        dists.append(distance.euclidean(center.cpu(), entity_center.cpu()))
    
    print('Distances - ', dists)
    if len(dists) == 0:
        return None
    result = np.argmin(dists)
    if dists[result] > 100:  
        return None
    else:
        return result

def getCenter(box):
    return box_convert(box, in_fmt = 'xyxy', out_fmt = 'cxcywh')[:2]

def overlapFaceIndex(mask, faceBoxes):

    mask_b = (mask > 0.01).int().cpu()
    int_rates = []
    
    for i, faceBox in enumerate(faceBoxes):
        base = np.zeros(shape=mask_b.shape)
        base[0, faceBox[1]:faceBox[3],faceBox[0]:faceBox[2]] = 1

        A,B = mask_b, torch.Tensor(base)

        int_rates.append([i, torch.where((A == B) & (B == 1), 1, 0).int().sum()/B.sum()])
        
    int_rates.sort(key=lambda x: x[1])
    int_rates = int_rates[::-1]
    
    if len(int_rates) == 0:
        return None
    
    if int_rates[0][1] > 0.5:
        return int_rates[0][0]
    else:
        return None

In [None]:
def update_entity_list(frameID, entityList, boxes, masks):
    boundingBoxes, maskList = boxes[frameID], masks[frameID]
    print('Person bounding boxes - ', boundingBoxes)
    
    nameList, faceBoxes = face_rec_res[frameID]
    print('Name list - ', nameList, 'Face boxes - ', faceBoxes)                                   
    
    entityCenters = getLastCentersForAllEntities(entityList)
    
    for box, mask in zip(boundingBoxes, maskList):
        print('   The current bbox is - ', box)
        center = getCenter(box.detach())
        
        if overlapFaceIndex(mask, faceBoxes) != None:
            correctName = nameList[overlapFaceIndex(mask, faceBoxes)]
        else:
            correctName = None
            
        print('   Face Index - ', overlapFaceIndex(mask, faceBoxes), ' with corresp. name ', correctName)
        
        print('   Bbox center is: ',center, '; Last entity centers are: ', entityCenters)
        correctEntity = getClosestEntity(center, entityCenters)
        print('   Closest entity index is ', correctEntity)
        
        if correctEntity is None:
            newEntity = {'centers':[], 'boxes':[], 'frameID':[], 'names':[]} # createNewEntity()
            newEntity['boxes'].append(box.detach())
            newEntity['centers'].append(center) # newEntity.addLastCenter(center, frameID)
            newEntity['frameID'].append(frameID) # ...
            
            if correctName is not None:
                newEntity['names'].append(correctName) # newEntity.addName(correctName)
            
            print('   Creating new entity - ', newEntity)
            
            entityList.append(newEntity)
        else:
            
            print('   Adding data to existing entity.')
            
            entityList[correctEntity]['boxes'].append(box.detach())
            entityList[correctEntity]['centers'].append(center) # correctEntity.addLastCenter(center, frameID)
            entityList[correctEntity]['frameID'].append(frameID)
            
            if correctName is not None:
                entityList[correctEntity]['names'].append(correctName) # correctEntity.addName(correctName)
    
    return entityList

### Pose-based tracking 