# Weak labels for timeseries model

To use Shiori's timeseries model for shot detection, we need to make the following changes to the label matrix:
* Each window of 16 frames is a single datapoint
* The windows are **not** overlapping, so the first window is frames `[1, 2, ..., 16]`, second is `[17, 18, ..., 32]`, etc
* Ground truth: our annotated shot boundaries are the first frames of new shots, so a window contains a shot boundary f only if frames f and f - 1 are in the window
* Frames are 0-indexed!

In [1]:
import numpy as np
from scipy.sparse import csr_matrix
import scipy.sparse as sparse
import pickle
import rekall
from rekall.video_interval_collection import VideoIntervalCollection
from rekall.interval_list import IntervalList
from rekall.temporal_predicates import *

# Load manually annotated data

In [2]:
with open('../../data/manually_annotated_shots.pkl', 'rb') as f:
    shots = VideoIntervalCollection(pickle.load(f))

In [3]:
with open('../../data/shot_detection_folds.pkl', 'rb') as f:
    shot_detection_folds = pickle.load(f)

In [4]:
clips = shots.dilate(1).coalesce().dilate(-1)

100%|██████████| 28/28 [00:00<00:00, 9384.73it/s]
100%|██████████| 28/28 [00:00<00:00, 34119.85it/s]


In [5]:
shot_boundaries = shots.map(
    lambda intrvl: (intrvl.start, intrvl.start, intrvl.payload)
).set_union(
    shots.map(lambda intrvl: (intrvl.end + 1, intrvl.end + 1, intrvl.payload))
).coalesce()

In [6]:
boundary_frames = {
    video_id: [
        intrvl.start
        for intrvl in shot_boundaries.get_intervallist(video_id).get_intervals()
    ]
    for video_id in shot_boundaries.get_allintervals()
}

In [7]:
video_ids = sorted(list(clips.get_allintervals().keys()))

In [8]:
frames_per_video = {
    video_id: sorted([
        f
        for interval in clips.get_intervallist(video_id).get_intervals()
        for f in range(interval.start, interval.end + 2)
    ])
    for video_id in video_ids
}

In [9]:
ground_truth = {
    video_id: [
        1 if f in boundary_frames[video_id] else 2
        for f in frames_per_video[video_id]
    ] 
    for video_id in video_ids
}

In [10]:
val_set = shot_detection_folds[2] + shot_detection_folds[3]

In [11]:
test_set = shot_detection_folds[0] + shot_detection_folds[1] + shot_detection_folds[4]

## Load label matrix with all the frames in it

In [12]:
with open('../../data/shot_detection_weak_labels/all_labels_high_pre.pkl', 'rb') as f:
    weak_labels_all_movies = pickle.load(f)

## Load videos and number of frames per video

In [13]:
with open('../../data/frame_counts.pkl', 'rb') as f:
    frame_counts = pickle.load(f)

In [14]:
video_ids_all = sorted(list(frame_counts.keys()))

In [15]:
video_ids_train = sorted(list(set(video_ids_all).difference(set(video_ids))))

## Construct windows for each video

In [16]:
# First, construct windows of 16 frames for each video
windows = VideoIntervalCollection({
    video_id: [
        (f, f + 16, video_id)
        for f in range(0, frame_counts[video_id] - 16, 16)
    ]
    for video_id in video_ids_all
})

# Get ground truth labels for windows

In [17]:
# Next, intersect the windows with ground truth and get ground truth labels for the windows
windows_intersecting_ground_truth = windows.filter_against(
    clips,
    predicate=overlaps()
).map(lambda intrvl: (intrvl.start, intrvl.end, 2))
windows_with_shot_boundaries = windows_intersecting_ground_truth.filter_against(
    shot_boundaries,
    predicate = lambda window, shot_boundary:
        shot_boundary.start - 1 >= window.start and shot_boundary.start <= window.end
).map(
    lambda intrvl: (intrvl.start, intrvl.end, 1)
)
windows_with_labels = windows_with_shot_boundaries.set_union(
    windows_intersecting_ground_truth
).coalesce(
    predicate = equal(),
    payload_merge_op = lambda p1, p2: min(p1, p2)
)

# Get weak labels for all windows

In [18]:
# Label windows with the weak labels in our labeling functions
def label_window(per_frame_weak_labels):
    if 1 in per_frame_weak_labels:
        return 1
    if len([l for l in per_frame_weak_labels if l == 2]) >= len(per_frame_weak_labels) / 2:
        return 2
    return 0

windows_with_weak_labels = windows.map(
    lambda window: (
        window.start,
        window.end,
        [
            label_window([
                lf[window.payload][f]
                for f in range(window.start, window.end)
            ])
            for lf in weak_labels_all_movies
        ]
    )
)

# Y_val, Y_test

In [19]:
Y_val = np.array([
    intrvl.payload
    for video_id in val_set
    for intrvl in windows_with_labels.get_intervallist(video_id).get_intervals()
])

In [20]:
Y_test = np.array([
    intrvl.payload
    for video_id in test_set
    for intrvl in windows_with_labels.get_intervallist(video_id).get_intervals()
])

In [21]:
Y_val[:10]

array([1, 2, 2, 2, 2, 2, 2, 1, 2, 2])

In [22]:
Y_test[:10]

array([1, 2, 1, 2, 2, 2, 2, 2, 1, 2])

In [23]:
Y_val.shape

(1886,)

In [24]:
Y_test.shape

(1778,)

# L_val, L_test

In [25]:
L_val = csr_matrix([
    intrvl.payload
    for video_id in val_set
    for intrvl in windows_with_weak_labels.filter_against(
        clips, predicate=overlaps(), working_window=1
    ).get_intervallist(video_id).get_intervals()
])

In [26]:
L_test = csr_matrix([
    intrvl.payload
    for video_id in test_set
    for intrvl in windows_with_weak_labels.filter_against(
        clips, predicate=overlaps(), working_window=1
    ).get_intervallist(video_id).get_intervals()
])

In [27]:
L_val[:10].todense()

matrix([[1, 1, 1, 0, 1],
        [2, 2, 2, 2, 1],
        [2, 2, 2, 0, 1],
        [2, 2, 2, 0, 0],
        [2, 2, 2, 0, 0],
        [2, 2, 2, 0, 0],
        [2, 2, 2, 0, 0],
        [1, 1, 1, 2, 2],
        [0, 2, 0, 2, 2],
        [0, 0, 0, 2, 2]])

In [28]:
L_test[:10].todense()

matrix([[1, 1, 1, 0, 0],
        [2, 2, 2, 2, 2],
        [2, 1, 1, 0, 0],
        [2, 2, 2, 2, 2],
        [2, 2, 2, 2, 2],
        [2, 2, 2, 2, 1],
        [2, 2, 2, 2, 1],
        [2, 2, 2, 2, 1],
        [1, 1, 1, 0, 1],
        [2, 2, 2, 2, 2]])

In [29]:
L_val.shape

(1886, 5)

In [30]:
L_test.shape

(1778, 5)

# L_train 100 movies

In [31]:
# or load train split
with open('../../data/shot_detection_weak_labels/train_split_100.pkl', 'rb') as f:
    train_split_100 = pickle.load(f)

In [32]:
L_train_100 = csr_matrix([
    intrvl.payload
    for video_id in train_split_100
    for intrvl in windows_with_weak_labels.get_intervallist(
        video_id
    ).get_intervals()
])

# L_train all movies

In [33]:
L_train_all = csr_matrix([
    intrvl.payload
    for video_id in video_ids_train
    for intrvl in windows_with_weak_labels.get_intervallist(
        video_id
    ).get_intervals()
])

In [34]:
L_train_all.shape

(5879519, 5)

# Save them all to disk

In [35]:
with open('../../data/shot_detection_weak_labels/Y_val_windows_high_pre_downsampled.npy', 'wb') as f:
    np.save(f, Y_val)

In [36]:
with open('../../data/shot_detection_weak_labels/Y_test_windows_high_pre_downsampled.npy', 'wb') as f:
    np.save(f, Y_test)

In [37]:
with open('../../data/shot_detection_weak_labels/L_val_windows_high_pre_downsampled.npz', 'wb') as f:
    sparse.save_npz(f, L_val)

In [38]:
with open('../../data/shot_detection_weak_labels/L_test_windows_high_pre_downsampled.npz', 'wb') as f:
    sparse.save_npz(f, L_test)

In [39]:
with open('../../data/shot_detection_weak_labels/L_train_100_windows_high_pre_downsampled.npz', 'wb') as f:
    sparse.save_npz(f, L_train_100)

In [40]:
with open('../../data/shot_detection_weak_labels/L_train_all_windows_high_pre_downsampled.npz', 'wb') as f:
    sparse.save_npz(f, L_train_all)