# Majority Vote Model

In [25]:
import numpy as np
from scipy.sparse import csr_matrix
import scipy.sparse as sparse
import pickle
import rekall
from rekall.video_interval_collection import VideoIntervalCollection
from rekall.interval_list import IntervalList
from rekall.temporal_predicates import *
from metal.label_model.baselines import MajorityLabelVoter

# Load manually annotated data

In [2]:
with open('../../data/manually_annotated_shots.pkl', 'rb') as f:
    shots = VideoIntervalCollection(pickle.load(f))

In [3]:
with open('../../data/shot_detection_folds.pkl', 'rb') as f:
    shot_detection_folds = pickle.load(f)

In [4]:
clips = shots.dilate(1).coalesce().dilate(-1)

100%|██████████| 28/28 [00:00<00:00, 8930.16it/s]
100%|██████████| 28/28 [00:00<00:00, 32451.09it/s]


In [5]:
shot_boundaries = shots.map(
    lambda intrvl: (intrvl.start, intrvl.start, intrvl.payload)
).set_union(
    shots.map(lambda intrvl: (intrvl.end + 1, intrvl.end + 1, intrvl.payload))
).coalesce()

In [6]:
boundary_frames = {
    video_id: [
        intrvl.start
        for intrvl in shot_boundaries.get_intervallist(video_id).get_intervals()
    ]
    for video_id in shot_boundaries.get_allintervals()
}

In [7]:
video_ids = sorted(list(clips.get_allintervals().keys()))

In [8]:
frames_per_video = {
    video_id: sorted([
        f
        for interval in clips.get_intervallist(video_id).get_intervals()
        for f in range(interval.start, interval.end + 2)
    ])
    for video_id in video_ids
}

In [9]:
ground_truth = {
    video_id: [
        1 if f in boundary_frames[video_id] else 2
        for f in frames_per_video[video_id]
    ] 
    for video_id in video_ids
}

## Load label matrix with all the frames in it

In [10]:
with open('../../data/shot_detection_weak_labels/all_labels.pkl', 'rb') as f:
    weak_labels_all_movies = pickle.load(f)

## Load videos and number of frames per video

In [12]:
with open('../../data/frame_counts.pkl', 'rb') as f:
    frame_counts = pickle.load(f)

In [13]:
video_ids_all = sorted(list(frame_counts.keys()))

In [14]:
video_ids_train = sorted(list(set(video_ids_all).difference(set(video_ids))))

## Construct windows for each video

In [15]:
# First, construct windows of 16 frames for each video
windows = VideoIntervalCollection({
    video_id: [
        (f, f + 16, video_id)
        for f in range(0, frame_counts[video_id] - 16, 8)
    ]
    for video_id in video_ids_all
})

# Get ground truth labels for windows

In [16]:
# Next, intersect the windows with ground truth and get ground truth labels for the windows
windows_intersecting_ground_truth = windows.filter_against(
    clips,
    predicate=overlaps()
).map(lambda intrvl: (intrvl.start, intrvl.end, 2))
windows_with_shot_boundaries = windows_intersecting_ground_truth.filter_against(
    shot_boundaries,
    predicate = lambda window, shot_boundary:
        shot_boundary.start >= window.start and shot_boundary.start < window.end
).map(
    lambda intrvl: (intrvl.start, intrvl.end, 1)
)
windows_with_labels = windows_with_shot_boundaries.set_union(
    windows_intersecting_ground_truth
).coalesce(
    predicate = equal(),
    payload_merge_op = lambda p1, p2: min(p1, p2)
)

# Get weak labels for all windows

In [17]:
# Label windows with the weak labels in our labeling functions
def label_window(per_frame_weak_labels):
    if 1 in per_frame_weak_labels:
        return 1
    if len([l for l in per_frame_weak_labels if l == 2]) >= len(per_frame_weak_labels) / 2:
        return 2
    return 0

windows_with_weak_labels = windows.map(
    lambda window: (
        window.start,
        window.end,
        [
            label_window([
                lf[window.payload][f-1]
                for f in range(window.start, window.end)
            ])
            for lf in weak_labels_all_movies
        ]
    )
)

# Prepare L matrices

In [18]:
L_everything_frame = csr_matrix([
    [
        label
        for video_id in sorted(list(video_ids_all))
        for label in lf[video_id]
    ]
    for lf in weak_labels_all_movies
]).transpose()

In [23]:
L_everything_windows = csr_matrix([
    intrvl.payload
    for video_id in sorted(list(video_ids_all))
    for intrvl in windows_with_weak_labels.get_intervallist(video_id).get_intervals()
])

# Run majority vote

In [26]:
mv = MajorityLabelVoter(seed=123)

In [27]:
frame_predictions_everything = mv.predict_proba(L_everything_frame)

In [34]:
window_predictions_everything = mv.predict_proba(L_everything_windows)

# Save per-frame labels to disk

In [35]:
video_frame_nums = [
    (video_id, f+1)
    for video_id in sorted(list(video_ids_all))
    for f in range(frame_counts[video_id])
]

In [36]:
predictions_to_save = [
    (frame_info, prediction.tolist())
    for frame_info, prediction in zip(video_frame_nums, frame_predictions_everything)
]

In [38]:
preds_np = np.array(predictions_to_save)

In [39]:
# save predictions to disk
with open('../../data/shot_detection_weak_labels/majority_vote_labels_all_frame.npy', 'wb') as f:
    np.save(f, preds_np)

# Save per-window labels to disk

In [40]:
window_nums = [
    (video_id, intrvl.start, intrvl.end)
    for video_id in sorted(list(video_ids_all))
    for intrvl in windows_with_weak_labels.get_intervallist(video_id).get_intervals()
]

In [41]:
predictions_to_save_windows = [
    (window_info, prediction)
    for window_info, prediction in zip(window_nums, window_predictions_everything)
]

In [42]:
preds_np_windows = np.array(predictions_to_save_windows)

In [43]:
# save predictions to disk
with open('../../data/shot_detection_weak_labels/majority_vote_labels_all_windows.npy', 'wb') as f:
    np.save(f, preds_np_windows)

In [44]:
preds_np_windows[:10]

array([[(1, 0, 16), array([0., 1.])],
       [(1, 8, 24), array([0., 1.])],
       [(1, 16, 32), array([0., 1.])],
       [(1, 24, 40), array([0., 1.])],
       [(1, 32, 48), array([0., 1.])],
       [(1, 40, 56), array([1., 0.])],
       [(1, 48, 64), array([1., 0.])],
       [(1, 56, 72), array([1., 0.])],
       [(1, 64, 80), array([0., 1.])],
       [(1, 72, 88), array([0., 1.])]], dtype=object)