In [None]:
import rekall
from rekall.video_interval_collection import VideoIntervalCollection
from rekall.merge_ops import payload_plus
from rekall.parsers import in_array, bbox_payload_parser
from rekall.bbox_predicates import *
from rekall.spatial_predicates import *
from query.models import Face
from tqdm import tqdm
import pickle

In [None]:
faces_qs = Face.objects.filter(
    frame__video__ignore_film=False
).annotate(
    min_frame=F('frame__number'),
    max_frame=F('frame__number'),
    video_id=F('frame__video_id')
).all()

In [None]:
faces = VideoIntervalCollection.from_django_qs(
    faces_qs,
    with_payload=lambda face : 1,
    progress=True
)

In [None]:
face_counts = faces.coalesce(payload_merge_op = payload_plus)

In [None]:
def get_weak_labels_from_face_counts(intrvllist, stride):
    if intrvllist.size() == 0:
        return []
    pos_boundaries = []
    neg_boundaries = []
    
    intervals = intrvllist.get_intervals()
    cur_frame = intervals[0].start
    cur_face_count = intervals[0].payload
    for interval in intervals[1:]:
        if interval.start - cur_frame == stride:
            if interval.payload == cur_face_count:
                neg_boundaries += list(range(cur_frame, interval.start))
            elif abs(interval.payload - cur_face_count) >= 2:
                pos_boundaries += list(range(cur_frame, interval.start))
        cur_frame = interval.start
        cur_face_count = interval.payload
    
    return pos_boundaries, neg_boundaries

In [None]:
face_count_labels_pos_neg = {}

In [None]:
for video_id in tqdm(faces.get_allintervals()):
    stride = Video.objects.get(id=video_id).get_stride()
    face_count_labels_pos_neg[video_id] = get_weak_labels_from_face_counts(
        faces.get_intervallist(video_id), stride
    )

In [None]:
# Save these weak labels to disk
for video_id in tqdm(face_count_labels_pos_neg):
    with open('/app/data/shot_detection_weak_labels/face_counts/{}.pkl'.format(video_id), 'wb') as f:
        pickle.dump(face_count_labels_pos_neg[video_id], f)

In [None]:
faces_with_bboxes = VideoIntervalCollection.from_django_qs(
    faces_qs,
    with_payload=in_array(
        bbox_payload_parser(VideoIntervalCollection.django_accessor)),
    progress=True
).coalesce(payload_merge_op = payload_plus)

In [None]:
def get_weak_labels_from_face_positions(intrvllist, stride):
    if intrvllist.size() == 0:
        return []
    pos_boundaries = []
    neg_boundaries = []
    
    intervals = intrvllist.get_intervals()
    cur_frame = intervals[0].start
    cur_faces = intervals[0].payload
    for interval in intervals[1:]:
        if interval.start - cur_frame == stride:
            graph = {
                'nodes': [
                    {
                        'name': 'face{}'.format(idx),
                        'predicates': [ position(face['x1'], face['y1'], face['x2'], face['y2'], epsilon=.05) ]
                    }
                    for idx, face in enumerate(cur_faces)
                ],
                'edges': []
            }
            new_payload_matches = scene_graph(graph, exact=True)([
                { 'x1': face['x1'], 'y1': face['y1'], 'x2': face['x2'], 'y2': face['y2'] }
                for face in interval.payload
            ])
            
            if new_payload_matches:
                neg_boundaries += list(range(cur_frame, interval.start))
            else:
                pos_boundaries += list(range(cur_frame, interval.start))
                
        cur_frame = interval.start
        cur_faces = interval.payload
    
    return pos_boundaries, neg_boundaries

In [None]:
face_position_labels_pos_neg = {}

In [None]:
faces_with_bboxes.get_intervallist(1).get_intervals()[0]

In [None]:
for video_id in tqdm(faces_with_bboxes.get_allintervals()):
    stride = Video.objects.get(id=video_id).get_stride()
    face_position_labels_pos_neg[video_id] = get_weak_labels_from_face_positions(
        faces_with_bboxes.get_intervallist(video_id), stride
    )

In [None]:
# Save these weak labels to disk
for video_id in tqdm(face_position_labels_pos_neg):
    with open('/app/data/shot_detection_weak_labels/face_positions/{}.pkl'.format(video_id), 'wb') as f:
        pickle.dump(face_position_labels_pos_neg[video_id], f)