In [1]:
import sys
sys.path.append('/lfs/1/danfu/metal')
import metal

In [2]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [3]:
import pickle
import rekall
from rekall.video_interval_collection import VideoIntervalCollection
from rekall.interval_list import IntervalList
from rekall.temporal_predicates import *
import numpy as np
from scipy.sparse import csr_matrix
import os
from tqdm import tqdm
import random

from metal.analysis import lf_summary
from metal.label_model.baselines import MajorityLabelVoter
from metal.label_model import LabelModel

# Load Shot Data

In [4]:
with open('../../data/shot_detection_folds.pkl', 'rb') as f:
    shot_detection_folds = pickle.load(f)

In [5]:
with open('../../data/manually_annotated_shots.pkl', 'rb') as f:
    shots = VideoIntervalCollection(pickle.load(f))

In [6]:
clips = shots.dilate(1).coalesce().dilate(-1)

100%|██████████| 28/28 [00:00<00:00, 13469.49it/s]
100%|██████████| 28/28 [00:00<00:00, 48629.61it/s]


In [7]:
shot_boundaries = shots.map(
    lambda intrvl: (intrvl.start, intrvl.start, intrvl.payload)
).set_union(
    shots.map(lambda intrvl: (intrvl.end + 1, intrvl.end + 1, intrvl.payload))
).coalesce()

In [8]:
boundary_frames = {
    video_id: [
        intrvl.start
        for intrvl in shot_boundaries.get_intervallist(video_id).get_intervals()
    ]
    for video_id in shot_boundaries.get_allintervals()
}

In [9]:
video_ids = sorted(list(clips.get_allintervals().keys()))

In [10]:
frames_per_video = {
    video_id: sorted([
        f
        for interval in clips.get_intervallist(video_id).get_intervals()
        for f in range(interval.start, interval.end + 2)
    ])
    for video_id in video_ids
}

In [11]:
ground_truth = {
    video_id: [
        1 if f in boundary_frames[video_id] else 2
        for f in frames_per_video[video_id]
    ] 
    for video_id in video_ids
}

# Load Weak Labels

In [12]:
with open('../../data/frame_counts.pkl', 'rb') as f:
    frame_counts = pickle.load(f)

In [26]:
labeling_function_folder = '../../data/shot_detection_weak_labels/cinematic_shots'

In [31]:
windows = VideoIntervalCollection({
    video_id: [
        (f, f + 16, 0)
        for f in range(0, frame_counts[video_id] - 16, 8)
    ]
    for video_id in list(frame_counts.keys())
})

In [63]:
preds_frames = []
preds_windows = []
for video_id in tqdm(sorted(list(frame_counts.keys()))):
    with open(os.path.join(labeling_function_folder, '{}.pkl'.format(video_id)), 'rb') as f:
        pos_boundaries, neg_boundaries = pickle.load(f)
    pos_intervallist = IntervalList([
        (b, b, 1) for b in pos_boundaries
    ])
    all_frames = IntervalList([
        (f, f, 0) for f in range(1, frame_counts[video_id] + 1)
    ])
    frames_with_labels = all_frames.set_union(pos_intervallist).coalesce(
        predicate = equal(),
        payload_merge_op = lambda p1, p2: max(p1, p2)
    )
    
    for f in frames_with_labels.get_intervals():
        preds_frames.append(((video_id, f.start), (0.0, 1.0) if f.payload == 0 else (1.0, 0.0)))
    
    window_intervallist = windows.get_intervallist(video_id)
    windows_with_boundaries = window_intervallist.join(
        pos_intervallist,
        predicate = lambda window, frame: frame.start >= window.start and frame.end < window.end,
        merge_op = lambda window, frame: [(window.start, window.end, 1)],
        working_window = 1
    )
    
    windows_with_labels = windows_with_boundaries.set_union(
        window_intervallist
    ).coalesce(
        predicate = equal(),
        payload_merge_op = lambda p1, p2: max(p1, p2)
    )
    for window in windows_with_labels.get_intervals():
        preds_windows.append(((video_id, window.start, window.end),
                              (0.0, 1.0) if window.payload == 0 or window.start == 0 else (1.0, 0.0)))


  0%|          | 0/589 [00:00<?, ?it/s][A
  0%|          | 1/589 [00:00<09:07,  1.07it/s][A
  0%|          | 2/589 [00:01<08:12,  1.19it/s][A
  1%|          | 3/589 [00:02<08:12,  1.19it/s][A
  1%|          | 4/589 [00:03<08:04,  1.21it/s][A
  1%|          | 5/589 [00:03<07:26,  1.31it/s][A
  1%|          | 6/589 [00:04<06:45,  1.44it/s][A
  1%|          | 7/589 [00:04<06:09,  1.57it/s][A
  1%|▏         | 8/589 [00:05<05:53,  1.64it/s][A
  2%|▏         | 9/589 [00:07<09:11,  1.05it/s][A
  2%|▏         | 10/589 [00:07<07:48,  1.23it/s][A
  2%|▏         | 11/589 [00:08<07:44,  1.24it/s][A
  2%|▏         | 12/589 [00:09<07:58,  1.21it/s][A
  2%|▏         | 13/589 [00:10<08:00,  1.20it/s][A
  2%|▏         | 14/589 [00:10<06:08,  1.56it/s][A
  3%|▎         | 15/589 [00:11<06:11,  1.54it/s][A
  3%|▎         | 16/589 [00:11<06:10,  1.55it/s][A
  3%|▎         | 17/589 [00:12<06:21,  1.50it/s][A
  3%|▎         | 18/589 [00:14<09:53,  1.04s/it][A
  3%|▎         | 19/589 [00:1

 26%|██▋       | 156/589 [02:11<04:57,  1.46it/s][A
 27%|██▋       | 157/589 [02:11<04:28,  1.61it/s][A
 27%|██▋       | 158/589 [02:12<04:27,  1.61it/s][A
 27%|██▋       | 159/589 [02:12<04:41,  1.53it/s][A
 27%|██▋       | 160/589 [02:15<08:08,  1.14s/it][A
 27%|██▋       | 161/589 [02:16<07:51,  1.10s/it][A
 28%|██▊       | 162/589 [02:17<07:56,  1.12s/it][A
 28%|██▊       | 163/589 [02:17<06:56,  1.02it/s][A
 28%|██▊       | 164/589 [02:18<06:06,  1.16it/s][A
 28%|██▊       | 165/589 [02:19<05:25,  1.30it/s][A
 28%|██▊       | 166/589 [02:19<05:29,  1.28it/s][A
 28%|██▊       | 167/589 [02:20<05:14,  1.34it/s][A
 29%|██▊       | 168/589 [02:23<09:12,  1.31s/it][A
 29%|██▊       | 169/589 [02:24<08:19,  1.19s/it][A
 29%|██▉       | 170/589 [02:24<07:21,  1.05s/it][A
 29%|██▉       | 171/589 [02:25<06:31,  1.07it/s][A
 29%|██▉       | 172/589 [02:26<05:49,  1.19it/s][A
 29%|██▉       | 173/589 [02:26<05:23,  1.29it/s][A
 30%|██▉       | 174/589 [02:27<05:30,  1.26it

 53%|█████▎    | 310/589 [04:52<06:56,  1.49s/it][A
 53%|█████▎    | 311/589 [04:52<05:29,  1.19s/it][A
 53%|█████▎    | 312/589 [04:53<04:26,  1.04it/s][A
 53%|█████▎    | 313/589 [04:53<03:43,  1.24it/s][A
 53%|█████▎    | 314/589 [04:54<03:29,  1.31it/s][A
 53%|█████▎    | 315/589 [04:54<03:19,  1.37it/s][A
 54%|█████▎    | 316/589 [04:55<03:47,  1.20it/s][A
 54%|█████▍    | 317/589 [04:56<03:43,  1.22it/s][A
 54%|█████▍    | 318/589 [04:57<03:45,  1.20it/s][A
 54%|█████▍    | 319/589 [04:58<03:29,  1.29it/s][A
 54%|█████▍    | 320/589 [05:01<06:47,  1.52s/it][A
 54%|█████▍    | 321/589 [05:02<05:40,  1.27s/it][A
 55%|█████▍    | 322/589 [05:02<04:47,  1.08s/it][A
 55%|█████▍    | 323/589 [05:03<04:01,  1.10it/s][A
 55%|█████▌    | 324/589 [05:04<03:45,  1.18it/s][A
 55%|█████▌    | 325/589 [05:04<03:38,  1.21it/s][A
 55%|█████▌    | 326/589 [05:05<03:20,  1.31it/s][A
 56%|█████▌    | 327/589 [05:06<03:35,  1.22it/s][A
 56%|█████▌    | 328/589 [05:07<03:22,  1.29it

 79%|███████▉  | 464/589 [07:37<01:39,  1.26it/s][A
 79%|███████▉  | 465/589 [07:38<01:38,  1.25it/s][A
 79%|███████▉  | 466/589 [07:38<01:21,  1.51it/s][A
 79%|███████▉  | 467/589 [07:39<01:24,  1.44it/s][A
 79%|███████▉  | 468/589 [07:39<01:25,  1.41it/s][A
 80%|███████▉  | 469/589 [07:44<03:51,  1.93s/it][A
 80%|███████▉  | 470/589 [07:45<03:15,  1.64s/it][A
 80%|███████▉  | 471/589 [07:46<02:36,  1.32s/it][A
 80%|████████  | 472/589 [07:47<02:21,  1.21s/it][A
 80%|████████  | 473/589 [07:48<02:11,  1.13s/it][A
 80%|████████  | 474/589 [07:48<01:54,  1.00it/s][A
 81%|████████  | 475/589 [07:49<01:53,  1.00it/s][A
 81%|████████  | 476/589 [07:50<01:44,  1.08it/s][A
 81%|████████  | 477/589 [07:51<01:35,  1.17it/s][A
 81%|████████  | 478/589 [07:55<03:26,  1.86s/it][A
 81%|████████▏ | 479/589 [07:56<02:50,  1.55s/it][A
 81%|████████▏ | 480/589 [07:57<02:22,  1.30s/it][A
 82%|████████▏ | 481/589 [07:57<02:04,  1.16s/it][A
 82%|████████▏ | 482/589 [07:58<01:56,  1.09s/

In [64]:
preds_np_frames = np.array(preds_frames)

In [65]:
with open('../../data/shot_detection_weak_labels/noisy_labels_heuristics_frame.npy', 'wb') as f:
    np.save(f, preds_np_frames)

In [66]:
preds_np_windows = np.array(preds_windows)

In [67]:
with open('../../data/shot_detection_weak_labels/noisy_labels_heuristics_windows.npy', 'wb') as f:
    np.save(f, preds_np_windows)