# Majority Vote Model

In [1]:
import numpy as np
from scipy.sparse import csr_matrix
import scipy.sparse as sparse
import pickle
import rekall
from rekall.video_interval_collection import VideoIntervalCollection
from rekall.interval_list import IntervalList
from rekall.temporal_predicates import *
from metal.label_model.baselines import MajorityLabelVoter
from metal.metrics import metric_score

# Load manually annotated data

In [2]:
with open('../../data/manually_annotated_shots.pkl', 'rb') as f:
    shots = VideoIntervalCollection(pickle.load(f))

In [3]:
with open('../../data/shot_detection_folds.pkl', 'rb') as f:
    shot_detection_folds = pickle.load(f)

In [4]:
clips = shots.dilate(1).coalesce().dilate(-1)

100%|██████████| 28/28 [00:00<00:00, 8632.16it/s]
100%|██████████| 28/28 [00:00<00:00, 30800.03it/s]


In [5]:
shot_boundaries = shots.map(
    lambda intrvl: (intrvl.start, intrvl.start, intrvl.payload)
).set_union(
    shots.map(lambda intrvl: (intrvl.end + 1, intrvl.end + 1, intrvl.payload))
).coalesce()

In [6]:
boundary_frames = {
    video_id: [
        intrvl.start
        for intrvl in shot_boundaries.get_intervallist(video_id).get_intervals()
    ]
    for video_id in shot_boundaries.get_allintervals()
}

In [7]:
video_ids = sorted(list(clips.get_allintervals().keys()))

In [8]:
frames_per_video = {
    video_id: sorted([
        f
        for interval in clips.get_intervallist(video_id).get_intervals()
        for f in range(interval.start, interval.end + 2)
    ])
    for video_id in video_ids
}

In [9]:
ground_truth = {
    video_id: [
        1 if f in boundary_frames[video_id] else 2
        for f in frames_per_video[video_id]
    ] 
    for video_id in video_ids
}

## Load label matrix with all the frames in it

In [10]:
with open('../../data/shot_detection_weak_labels/all_labels.pkl', 'rb') as f:
    weak_labels_all_movies = pickle.load(f)

## Load videos and number of frames per video

In [11]:
with open('../../data/frame_counts.pkl', 'rb') as f:
    frame_counts = pickle.load(f)

In [12]:
video_ids_all = sorted(list(frame_counts.keys()))

In [13]:
video_ids_train = sorted(list(set(video_ids_all).difference(set(video_ids))))

# Load Y_val, Y_test

In [15]:
with open('../../data/shot_detection_weak_labels/Y_val_windows_downsampled_same_val_test.npy', 'rb') as f:
    Y_val = np.load(f)

In [16]:
with open('../../data/shot_detection_weak_labels/Y_test_windows_downsampled_same_val_test.npy', 'rb') as f:
    Y_test = np.load(f)

In [17]:
print(Y_val[:10])
print(Y_test[:10])
print(Y_val.shape)
print(Y_test.shape)

[1 2 2 2 2 1 2 2 2 2]
[2 2 2 2 2 2 1 2 2 1]
(1817,)
(1846,)


# L_val, L_test

In [18]:
with open('../../data/shot_detection_weak_labels/L_val_windows_downsampled_same_val_test.npz', 'rb') as f:
    L_val = sparse.load_npz(f)

In [19]:
with open('../../data/shot_detection_weak_labels/L_test_windows_downsampled_same_val_test.npz', 'rb') as f:
    L_test = sparse.load_npz(f)

In [20]:
L_val[:10].todense()

matrix([[1, 1, 1, 2, 2],
        [2, 2, 1, 0, 0],
        [2, 2, 2, 0, 0],
        [2, 2, 2, 0, 0],
        [2, 2, 2, 0, 0],
        [1, 1, 1, 0, 0],
        [2, 2, 2, 0, 0],
        [2, 2, 1, 0, 0],
        [2, 2, 2, 0, 0],
        [2, 2, 2, 0, 0]])

In [21]:
L_test[:10].todense()

matrix([[2, 2, 2, 2, 1],
        [1, 2, 1, 0, 0],
        [2, 2, 1, 0, 0],
        [2, 2, 1, 0, 0],
        [1, 1, 1, 0, 0],
        [2, 2, 1, 0, 0],
        [1, 1, 1, 0, 1],
        [2, 2, 2, 2, 1],
        [2, 2, 2, 0, 1],
        [2, 1, 1, 0, 0]])

In [22]:
L_val.shape

(1817, 5)

In [23]:
L_test.shape

(1846, 5)

# Performance of Majority Vote on Validation Set

In [24]:
mv = MajorityLabelVoter(seed=123)

In [25]:
window_predictions_val = mv.predict(L_val)

In [26]:
for metric in ['accuracy', 'f1', 'recall', 'precision']:
    score = metric_score(Y_val, window_predictions_val, metric)
    print(f"{metric.capitalize()}: {score:.3f}")

Accuracy: 0.948
F1: 0.872
Recall: 0.933
Precision: 0.819


In [27]:
window_predictions_test = mv.predict(L_test)

In [28]:
for metric in ['accuracy', 'f1', 'recall', 'precision']:
    score = metric_score(Y_test, window_predictions_test, metric)
    print(f"{metric.capitalize()}: {score:.3f}")

Accuracy: 0.952
F1: 0.862
Recall: 0.924
Precision: 0.809
