In [None]:
import sys
sys.path.append('/lfs/1/danfu/metal')
import metal

In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
import pickle
import rekall
from rekall.video_interval_collection import VideoIntervalCollection
from rekall.interval_list import IntervalList
from rekall.temporal_predicates import *
import numpy as np
from scipy.sparse import csr_matrix
import os
from tqdm import tqdm
import random

from metal.analysis import lf_summary
from metal.label_model.baselines import MajorityLabelVoter
from metal.label_model import LabelModel

# Load Shot Data

In [None]:
with open('../../data/shot_detection_folds.pkl', 'rb') as f:
    shot_detection_folds = pickle.load(f)

In [None]:
with open('../../data/manually_annotated_shots.pkl', 'rb') as f:
    shots = VideoIntervalCollection(pickle.load(f))

In [None]:
clips = shots.dilate(1).coalesce().dilate(-1)

In [None]:
shot_boundaries = shots.map(
    lambda intrvl: (intrvl.start, intrvl.start, intrvl.payload)
).set_union(
    shots.map(lambda intrvl: (intrvl.end + 1, intrvl.end + 1, intrvl.payload))
).coalesce()

In [None]:
boundary_frames = {
    video_id: [
        intrvl.start
        for intrvl in shot_boundaries.get_intervallist(video_id).get_intervals()
    ]
    for video_id in shot_boundaries.get_allintervals()
}

In [None]:
video_ids = sorted(list(clips.get_allintervals().keys()))

In [None]:
frames_per_video = {
    video_id: sorted([
        f
        for interval in clips.get_intervallist(video_id).get_intervals()
        for f in range(interval.start, interval.end + 2)
    ])
    for video_id in video_ids
}

In [None]:
ground_truth = {
    video_id: [
        1 if f in boundary_frames[video_id] else 2
        for f in frames_per_video[video_id]
    ] 
    for video_id in video_ids
}

# Load Weak Labels

In [None]:
with open('../../data/frame_counts.pkl', 'rb') as f:
    frame_counts = pickle.load(f)

In [None]:
labeling_function_folders = [
    '../../data/shot_detection_weak_labels/rgb_hists',
    '../../data/shot_detection_weak_labels/hsv_hists',
#     '../../data/shot_detection_weak_labels/flow_hists_magnitude', # this is just really really bad
    '../../data/shot_detection_weak_labels/flow_hists_diffs',
    '../../data/shot_detection_weak_labels/face_counts',
    '../../data/shot_detection_weak_labels/face_positions'
]

In [None]:
weak_labels_all = []
weak_labels_gt_only = []

In [None]:
for folder in labeling_function_folders:
    labels_for_function_all = {}
    labels_for_function_gt_only = {}
    for video_id in tqdm(video_ids):
        all_frames = IntervalList([
            (f+1, f+1, 0)
            for f in range(frame_counts[video_id])
        ])
        with open(os.path.join(folder, '{}.pkl'.format(video_id)), 'rb') as f:
            positive_boundaries, negative_boundaries = pickle.load(f)
            positive_frames = IntervalList([
                (f, f, 1)
                for f in positive_boundaries if f <= frame_counts[video_id]
            ])
            negative_frames = IntervalList([
                (f, f, 2)
                for f in negative_boundaries if f <= frame_counts[video_id]
            ])
            frames_w_labels = all_frames.set_union(
                positive_frames
            ).set_union(
                negative_frames
            ).coalesce(payload_merge_op = lambda p1, p2: max(p1, p2))
            
            labels_for_function_all[video_id] = [
                intrvl.payload
                for intrvl in frames_w_labels.get_intervals()
            ]
            
            labels_for_function_gt_only[video_id] = [
                frames_w_labels.get_intervals()[f-1].payload
                for f in frames_per_video[video_id]
            ]
            
    weak_labels_all.append(labels_for_function_all)
    weak_labels_gt_only.append(labels_for_function_gt_only)

In [None]:
Y = np.array([
    label
    for video_id in video_ids
    for label in ground_truth[video_id]
])

In [None]:
Y.shape

In [None]:
L = csr_matrix([
    [
        label
        for video_id in video_ids
        for label in lf[video_id]
    ]
    for lf in weak_labels_gt_only
]).transpose()

In [None]:
L.shape

In [None]:
lf_summary(L, Y=Y, lf_names = ['RGB hist', 'HSV hist', 'flow hist', 'face counts', 'face positions'])

# Train Label Model

## Part 0: Majority Vote

In [None]:
for i in range(5):
    test_fold = shot_detection_folds[i]
    train_videos = [
        video_id
        for video_id in video_ids if video_id not in test_fold
    ]
    test_videos = [
        video_id
        for video_id in video_ids if video_id in test_fold
    ]
    
    Y_test = np.array([
        label
        for video_id in test_videos
        for label in ground_truth[video_id]
    ])
    L_test = csr_matrix([
        [
            label
            for video_id in test_videos
            for label in lf[video_id]
        ]
        for lf in weak_labels_gt_only
    ]).transpose()
    
    mv = MajorityLabelVoter(seed=123)
    scores = mv.score((L_test, Y_test), metric=['accuracy','precision', 'recall', 'f1'])

```
Accuracy: 0.988
Precision: 0.475
Recall: 0.970
F1: 0.637
        y=1    y=2   
 l=1    131    145   
 l=2     4    12392  
 
Accuracy: 0.991
Precision: 0.760
Recall: 0.941
F1: 0.841
        y=1    y=2   
 l=1    95     30    
 l=2     6    3996   
 
Accuracy: 0.995
Precision: 0.639
Recall: 0.994
F1: 0.778
        y=1    y=2   
 l=1    172    97    
 l=2     1    21015  
 
Accuracy: 0.986
Precision: 0.452
Recall: 0.980
F1: 0.619
        y=1    y=2   
 l=1    99     120   
 l=2     2    8586   
 
Accuracy: 0.992
Precision: 0.621
Recall: 0.937
F1: 0.747
        y=1    y=2   
 l=1    133    81    
 l=2     9    11039  
 
Average F1: .724
```

## Part 1: Train only on frames that we have gold labels for

In [None]:
for i in range(5):
    test_fold = shot_detection_folds[i]
    train_videos = [
        video_id
        for video_id in video_ids if video_id not in test_fold
    ]
    test_videos = [
        video_id
        for video_id in video_ids if video_id in test_fold
    ]
    
    L_train = csr_matrix([
        [
            label
            for video_id in train_videos
            for label in lf[video_id]
        ]
        for lf in weak_labels_gt_only
    ]).transpose()
    
    Y_test = np.array([
        label
        for video_id in test_videos
        for label in ground_truth[video_id]
    ])
    L_test = csr_matrix([
        [
            label
            for video_id in test_videos
            for label in lf[video_id]
        ]
        for lf in weak_labels_gt_only
    ]).transpose()
    
    label_model = LabelModel(k=2, seed=123)
    label_model.train_model(L_train, class_balance=(0.01, 0.99), n_epochs=500, log_train_every=50)
    label_model.score((L_test, Y_test), metric=['accuracy','precision', 'recall', 'f1'])

Results per fold:

```
Accuracy: 0.997
Precision: 0.887
Recall: 0.874
F1: 0.881
        y=1    y=2   
 l=1    118    15    
 l=2    17    12522  

Accuracy: 0.990
Precision: 0.984
Recall: 0.594
F1: 0.741
        y=1    y=2   
 l=1    60      1    
 l=2    41    4025   

Accuracy: 0.999
Precision: 0.953
Recall: 0.942
F1: 0.948
        y=1    y=2   
 l=1    163     8    
 l=2    10    21104  

Accuracy: 0.996
Precision: 0.802
Recall: 0.842
F1: 0.821
        y=1    y=2   
 l=1    85     21    
 l=2    16    8685   

Accuracy: 0.996
Precision: 0.946
Recall: 0.746
F1: 0.835
        y=1    y=2   
 l=1    106     6    
 l=2    36    11114  
 
Average F1: .845
```

## Part 2: Train on entire movies

In [None]:
prediction_probabilities = []

In [None]:
for i in range(5):
    test_fold = shot_detection_folds[i]
    train_videos = [
        video_id
        for video_id in video_ids if video_id not in test_fold
    ]
    test_videos = [
        video_id
        for video_id in video_ids if video_id in test_fold
    ]
    
    L_train = csr_matrix([
        [
            label
            for video_id in train_videos
            for label in lf[video_id]
        ]
        for lf in weak_labels_all
    ]).transpose()
    
    Y_test = np.array([
        label
        for video_id in test_videos
        for label in ground_truth[video_id]
    ])
    L_test = csr_matrix([
        [
            label
            for video_id in test_videos
            for label in lf[video_id]
        ]
        for lf in weak_labels_gt_only
    ]).transpose()
    
    label_model = LabelModel(k=2, seed=123)
    label_model.train_model(L_train, class_balance=(0.01, 0.99), n_epochs=500, log_train_every=50)
    label_model.score((L_test, Y_test), metric=['accuracy','precision', 'recall', 'f1'])
    
    Y_readable = [
        (video_id, f, 1 if f in boundary_frames[video_id] else 2)
        for video_id in test_videos
        for f in frames_per_video[video_id]
    ]
    
    predictions = label_model.predict(L_test)
    prediction_probs = label_model.predict_proba(L_test)
    prediction_probabilities.append([p[0] for p in prediction_probs])
    
    wrong_predictions = np.where(predictions != Y_test)[0]
    
    wrong_interval_preds = [
        (Y_readable[int(wp)], prediction_probs[int(wp)].tolist())
        for wp in wrong_predictions
    ]

    with open('../../data/failure_cases/metal_frame_only/{}_fold.pkl'.format(i + 1), 'wb') as f:
        pickle.dump(wrong_interval_preds, f)

```
Accuracy: 0.997
Precision: 0.887
Recall: 0.874
F1: 0.881
        y=1    y=2   
 l=1    118    15    
 l=2    17    12522  

Accuracy: 0.990
Precision: 0.984
Recall: 0.594
F1: 0.741
        y=1    y=2   
 l=1    60      1    
 l=2    41    4025   

Accuracy: 0.999
Precision: 0.953
Recall: 0.942
F1: 0.948
        y=1    y=2   
 l=1    163     8    
 l=2    10    21104  

Accuracy: 0.996
Precision: 0.802
Recall: 0.842
F1: 0.821
        y=1    y=2   
 l=1    85     21    
 l=2    16    8685   

Accuracy: 0.996
Precision: 0.946
Recall: 0.746
F1: 0.835
        y=1    y=2   
 l=1    106     6    
 l=2    36    11114  
 
Average F1: .845
```

In [None]:
for i, problist in enumerate(prediction_probabilities):
    plt.hist(
        problist,
        log=True)
    plt.title('Probability histogram for fold {}'.format(i + 1))
    plt.xlabel('Probability')
    plt.ylabel('Count')
    plt.show()

## Part 3: Classify windows of 16 frames

### Labeling Functions for windows of 16 frames

In [None]:
# First, construct windows of 16 frames for each video
windows = VideoIntervalCollection({
    video_id: [
        (f, f + 16, video_id)
        for f in range(0, frame_counts[video_id] - 16, 8)
    ]
    for video_id in video_ids
})

In [None]:
# Next, intersect the windows with ground truth and get ground truth labels for the windows
windows_intersecting_ground_truth = windows.filter_against(
    clips,
    predicate=overlaps()
).map(lambda intrvl: (intrvl.start, intrvl.end, 2))
windows_with_shot_boundaries = windows_intersecting_ground_truth.filter_against(
    shot_boundaries,
    predicate = lambda window, shot_boundary:
        shot_boundary.start >= window.start and shot_boundary.start < window.end
).map(
    lambda intrvl: (intrvl.start, intrvl.end, 1)
)
windows_with_labels = windows_with_shot_boundaries.set_union(
    windows_intersecting_ground_truth
).coalesce(
    predicate = equal(),
    payload_merge_op = lambda p1, p2: min(p1, p2)
)

In [None]:
# Label windows with the weak labels in our labeling functions
def label_window(per_frame_weak_labels):
    if 1 in per_frame_weak_labels:
        return 1
    if len([l for l in per_frame_weak_labels if l == 2]) >= len(per_frame_weak_labels) / 2:
        return 2
    return 0

windows_with_weak_labels = windows.map(
    lambda window: (
        window.start,
        window.end,
        [
            label_window([
                lf[window.payload][f-1]
                for f in range(window.start, window.end)
            ])
            for lf in weak_labels_all
        ]
    )
)

In [None]:
Y_windows = np.array([
    intrvl.payload
    for video_id in video_ids
    for intrvl in windows_with_labels.get_intervallist(video_id).get_intervals()
])

In [None]:
Y_windows.shape

In [None]:
len([y for y in Y_windows if y == 1])

In [None]:
L_windows = csr_matrix([
    intrvl.payload
    for video_id in video_ids
    for intrvl in windows_with_weak_labels.filter_against(
        clips, predicate=overlaps(), working_window=1
    ).get_intervallist(video_id).get_intervals()
])

In [None]:
L_windows.shape

In [None]:
lf_summary(L_windows, Y=Y_windows, lf_names = ['RGB hist', 'HSV hist', 'flow hist', 'face counts', 'face positions'])

In [None]:
csr_matrix([
    intrvl.payload
    for video_id in video_ids
    for intrvl in windows_with_weak_labels.get_intervallist(video_id).get_intervals()
]).shape

### Part 0: Majority Vote

In [None]:
windows_with_weak_labels_gt_only = windows_with_weak_labels.filter_against(
    clips, predicate=overlaps(), working_window=1
)
for i in range(5):
    test_fold = shot_detection_folds[i]
    train_videos = [
        video_id
        for video_id in video_ids if video_id not in test_fold
    ]
    test_videos = [
        video_id
        for video_id in video_ids if video_id in test_fold
    ]
    
    Y_test = np.array([
        intrvl.payload
        for video_id in test_videos
        for intrvl in windows_with_labels.get_intervallist(video_id).get_intervals()
    ])
    L_test = csr_matrix([
        intrvl.payload
        for video_id in test_videos
        for intrvl in windows_with_weak_labels_gt_only.get_intervallist(video_id).get_intervals()
    ])
    
    mv = MajorityLabelVoter(seed=123)
    scores = mv.score((L_test, Y_test), metric=['accuracy','precision', 'recall', 'f1'])

```
Accuracy: 0.947
Precision: 0.785
Recall: 0.940
F1: 0.856
        y=1    y=2   
 l=1    252    69    
 l=2    16    1269   
Accuracy: 0.912
Precision: 0.918
Recall: 0.845
F1: 0.880
        y=1    y=2   
 l=1    169    15    
 l=2    31     306   
Accuracy: 0.976
Precision: 0.861
Recall: 0.968
F1: 0.911
        y=1    y=2   
 l=1    334    54    
 l=2    11    2268   
Accuracy: 0.914
Precision: 0.707
Recall: 0.906
F1: 0.794
        y=1    y=2   
 l=1    183    76    
 l=2    19     827   
Accuracy: 0.957
Precision: 0.874
Recall: 0.905
F1: 0.889
        y=1    y=2   
 l=1    249    36    
 l=2    26    1116   

Average F1: .866
```

### Part 1: Train LabelModel on frames that we have GT for

In [None]:
windows_with_weak_labels_gt_only = windows_with_weak_labels.filter_against(
    clips, predicate=overlaps(), working_window=1
)
for i in range(5):
    test_fold = shot_detection_folds[i]
    train_videos = [
        video_id
        for video_id in video_ids if video_id not in test_fold
    ]
    test_videos = [
        video_id
        for video_id in video_ids if video_id in test_fold
    ]
    
    L_train = csr_matrix([
        intrvl.payload
        for video_id in train_videos
        for intrvl in windows_with_weak_labels_gt_only.get_intervallist(video_id).get_intervals()
    ])
    
    Y_test = np.array([
        intrvl.payload
        for video_id in test_videos
        for intrvl in windows_with_labels.get_intervallist(video_id).get_intervals()
    ])
    L_test = csr_matrix([
        intrvl.payload
        for video_id in test_videos
        for intrvl in windows_with_weak_labels_gt_only.get_intervallist(video_id).get_intervals()
    ])
    
    label_model = LabelModel(k=2, seed=123)
    label_model.train_model(L_train, class_balance=(0.15, 0.85), n_epochs=500, log_train_every=50)
    label_model.score((L_test, Y_test), metric=['accuracy','precision', 'recall', 'f1'])

In [None]:
windows_with_weak_labels_gt_only = windows_with_weak_labels.filter_against(
    clips, predicate=overlaps(), working_window=1
)
for i in range(5):
    test_fold = shot_detection_folds[i]
    train_videos = [
        video_id
        for video_id in video_ids if video_id not in test_fold
    ]
    test_videos = [
        video_id
        for video_id in video_ids if video_id in test_fold
    ]
    
    Y_train = np.array([
        intrvl.payload
        for video_id in train_videos
        for intrvl in windows_with_labels.get_intervallist(video_id).get_intervals()
    ])
    L_train = csr_matrix([
        intrvl.payload
        for video_id in train_videos
        for intrvl in windows_with_weak_labels_gt_only.get_intervallist(video_id).get_intervals()
    ])
    
    Y_test = np.array([
        intrvl.payload
        for video_id in test_videos
        for intrvl in windows_with_labels.get_intervallist(video_id).get_intervals()
    ])
    L_test = csr_matrix([
        intrvl.payload
        for video_id in test_videos
        for intrvl in windows_with_weak_labels_gt_only.get_intervallist(video_id).get_intervals()
    ])
    
    label_model = LabelModel(k=2, seed=123)
    label_model.train_model(L_train, Y_dev=Y_train, n_epochs=500, log_train_every=50)
    label_model.score((L_test, Y_test), metric=['accuracy','precision', 'recall', 'f1'])

```
Accuracy: 0.927
Precision: 0.709
Recall: 0.882
F1: 0.786
        y=1    y=2   
 l=1    217    89    
 l=2    29    1271 
 
Accuracy: 0.893
Precision: 0.911
Recall: 0.789
F1: 0.845
        y=1    y=2   
 l=1    153    15    
 l=2    41     312  

Accuracy: 0.969
Precision: 0.868
Recall: 0.903
F1: 0.885
        y=1    y=2   
 l=1    316    48    
 l=2    34    2269

Accuracy: 0.886
Precision: 0.657
Recall: 0.763
F1: 0.706
        y=1    y=2   
 l=1    151    79    
 l=2    47     828

Accuracy: 0.941
Precision: 0.805
Recall: 0.881
F1: 0.842
        y=1    y=2   
 l=1    223    54    
 l=2    30    1120 

Average F1: .813
```

### Part 2: Train LabelModel on entire videos

In [None]:
windows_with_weak_labels_gt_only = windows_with_weak_labels.filter_against(
    clips, predicate=overlaps(), working_window=1
)
prediction_probabilities_windows = []
for i in range(5):
    test_fold = shot_detection_folds[i]
    train_videos = [
        video_id
        for video_id in video_ids if video_id not in test_fold
    ]
    test_videos = [
        video_id
        for video_id in video_ids if video_id in test_fold
    ]
    
    L_train = csr_matrix([
        intrvl.payload
        for video_id in train_videos
        for intrvl in windows_with_weak_labels.get_intervallist(video_id).get_intervals()
    ])
    
    Y_test = np.array([
        intrvl.payload
        for video_id in test_videos
        for intrvl in windows_with_labels.get_intervallist(video_id).get_intervals()
    ])
    L_test = csr_matrix([
        intrvl.payload
        for video_id in test_videos
        for intrvl in windows_with_weak_labels_gt_only.get_intervallist(video_id).get_intervals()
    ])
    
    label_model = LabelModel(k=2, seed=123)
    label_model.train_model(L_train, class_balance=(0.15, 0.85), n_epochs=500, log_train_every=50)
    label_model.score((L_test, Y_test), metric=['accuracy','precision', 'recall', 'f1'])
    
    predictions = label_model.predict(L_test)
    prediction_probs = label_model.predict_proba(L_test)
    prediction_probabilities_windows.append([p[0] for p in prediction_probs])

```
Accuracy: 0.934
Precision: 0.765
Recall: 0.873
F1: 0.815
        y=1    y=2   
 l=1    234    72    
 l=2    34    1266   

Accuracy: 0.910
Precision: 0.932
Recall: 0.825
F1: 0.875
        y=1    y=2   
 l=1    165    12    
 l=2    35     309   

Accuracy: 0.979
Precision: 0.877
Recall: 0.971
F1: 0.922
        y=1    y=2   
 l=1    335    47    
 l=2    10    2275   

Accuracy: 0.898
Precision: 0.682
Recall: 0.827
F1: 0.747
        y=1    y=2   
 l=1    167    78    
 l=2    35     825   

Accuracy: 0.955
Precision: 0.881
Recall: 0.887
F1: 0.884
        y=1    y=2   
 l=1    244    33    
 l=2    31    1119   

Average F1: .849
```

In [None]:
for i, problist in enumerate(prediction_probabilities_windows):
    plt.hist(
        problist,
        log=True)
    plt.title('Probability histogram for fold {}'.format(i + 1))
    plt.xlabel('Probability')
    plt.ylabel('Count')
    plt.show()

# Part 4: Training on the entire dataset

In [None]:
# First, load noisy labels from the entire dataset
video_ids_all = sorted(list(frame_counts.keys()))

In [None]:
weak_labels_all_movies = []
for folder in labeling_function_folders:
    labels_for_function_all = {}
    for video_id in tqdm(video_ids):
        all_frames = IntervalList([
            (f+1, f+1, 0)
            for f in range(frame_counts[video_id])
        ])
        with open(os.path.join(folder, '{}.pkl'.format(video_id)), 'rb') as f:
            positive_boundaries, negative_boundaries = pickle.load(f)
            positive_frames = IntervalList([
                (f, f, 1)
                for f in positive_boundaries if f <= frame_counts[video_id]
            ])
            negative_frames = IntervalList([
                (f, f, 2)
                for f in negative_boundaries if f <= frame_counts[video_id]
            ])
            frames_w_labels = all_frames.set_union(
                positive_frames
            ).set_union(
                negative_frames
            ).coalesce(payload_merge_op = lambda p1, p2: max(p1, p2))
            
            labels_for_function_all[video_id] = [
                intrvl.payload
                for intrvl in frames_w_labels.get_intervals()
            ]
            
    weak_labels_all_movies.append(labels_for_function_all)

In [None]:
# Save weak labels
with open('../../data/shot_detection_weak_labels/all_labels.pkl', 'wb') as f:
    pickle.dump(weak_labels_all_movies, f)

In [None]:
# Or load weak labels
with open('../../data/shot_detection_weak_labels/all_labels.pkl', 'rb') as f:
    weak_labels_all_movies = pickle.load(f)

In [None]:
weak_labels_gt_only = [
    {
        video_id: [
            lf[video_id][f-1]
            for f in frames_per_video[video_id]
        ]
        for video_id in sorted(list(clips.get_allintervals().keys()))
    }
    for lf in weak_labels_all_movies
]

## 100 Movies

In [None]:
random.seed(0)

In [None]:
# randomly choose 100 movies to train on; do not choose any movies that we have GT for
vid_candidates = sorted(list(set(video_ids_all).difference(set(clips.get_allintervals().keys()))))

In [None]:
random.shuffle(vid_candidates)

In [None]:
train_split = sorted(vid_candidates[:100])

In [None]:
# Save train split
with open('../../data/shot_detection_weak_labels/train_split_100.pkl', 'wb') as f:
    pickle.dump(train_split, f)

In [None]:
# or load train split
with open('../../data/shot_detection_weak_labels/train_split_100.pkl', 'rb') as f:
    train_split = pickle.load(f)

### Frame-Based model

In [None]:
test_videos = sorted(list(clips.get_allintervals().keys()))

In [None]:
L_train_100_frames = csr_matrix([
    [
        label
        for video_id in train_split
        for label in lf[video_id]
    ]
    for lf in weak_labels_all_movies
]).transpose()

Y_test = np.array([
    label
    for video_id in test_videos
    for label in ground_truth[video_id]
])
L_test = csr_matrix([
    [
        label
        for video_id in test_videos
        for label in lf[video_id]
    ]
    for lf in weak_labels_gt_only
]).transpose()

In [None]:
MajorityLabelVoter(seed=123).score((L_test, Y_test), metric=['accuracy','precision', 'recall', 'f1'])

In [None]:
label_model_100_frames = LabelModel(k=2, seed=123)
label_model_100_frames.train_model(L_train_100_frames, Y_dev = Y_test, n_epochs=5000, log_train_every=50)
label_model_100_frames.score((L_test, Y_test), metric=['accuracy','precision', 'recall', 'f1'])

### Window-based model

In [None]:
# First, construct windows of 16 frames for each video
windows_train = VideoIntervalCollection({
    video_id: [
        (f, f + 16, video_id)
        for f in range(0, frame_counts[video_id] - 16, 8)
    ]
    for video_id in train_split
})

In [None]:
windows_test = VideoIntervalCollection({
    video_id: [
        (f, f + 16, video_id)
        for f in range(0, frame_counts[video_id] - 16, 8)
    ]
    for video_id in test_videos
})

In [None]:
# Label windows with the weak labels in our labeling functions
def label_window(per_frame_weak_labels):
    if 1 in per_frame_weak_labels:
        return 1
    if len([l for l in per_frame_weak_labels if l == 2]) >= len(per_frame_weak_labels) / 2:
        return 2
    return 0

windows_with_weak_labels_train = windows_train.map(
    lambda window: (
        window.start,
        window.end,
        [
            label_window([
                lf[window.payload][f-1]
                for f in range(window.start, window.end)
            ])
            for lf in weak_labels_all_movies
        ]
    )
)

In [None]:
windows_with_weak_labels_test = windows_test.map(
    lambda window: (
        window.start,
        window.end,
        [
            label_window([
                lf[window.payload][f-1]
                for f in range(window.start, window.end)
            ])
            for lf in weak_labels_all_movies
        ]
    )
)

In [None]:
# L_train_100_windows = csr_matrix([
#     intrvl.payload
#     for video_id in train_split
#     for intrvl in windows_with_weak_labels_train.get_intervallist(video_id).get_intervals()
# ])

Y_test_windows = np.array([
    intrvl.payload
    for video_id in test_videos
    for intrvl in windows_with_labels.get_intervallist(video_id).get_intervals()
])
L_test_windows = csr_matrix([
    intrvl.payload
    for video_id in test_videos
    for intrvl in windows_with_weak_labels_gt_only.get_intervallist(video_id).get_intervals()
])

In [None]:
MajorityLabelVoter(seed=123).score((L_test_windows, Y_test_windows), metric=['accuracy','precision', 'recall', 'f1'])

In [None]:
label_model_100_windows = LabelModel(k=2, seed=123)
label_model_100_windows.train_model(L_train_100_windows, class_balance=(0.15, 0.85), n_epochs=10000, log_train_every=50)
label_model_100_windows.score((L_test_windows, Y_test_windows), metric=['accuracy','precision', 'recall', 'f1'])

## All Movies

### Frame based

In [None]:
train_movies_all = sorted(list(set(video_ids_all).difference(set(clips.get_allintervals().keys()))))

In [None]:
L_train_everything = csr_matrix([
    [
        label
        for video_id in train_movies_all
        for label in lf[video_id]
    ]
    for lf in weak_labels_all_movies
]).transpose()

Y_test = np.array([
    label
    for video_id in test_videos
    for label in ground_truth[video_id]
])
L_test = csr_matrix([
    [
        label
        for video_id in test_videos
        for label in lf[video_id]
    ]
    for lf in weak_labels_gt_only
]).transpose()

In [None]:
label_model_everything = LabelModel(k=2, seed=123)
label_model_everything.train_model(L_train_everything, class_balance=(0.01, 0.99), n_epochs=5000, log_train_every=50)
label_model_everything.score((L_test, Y_test), metric=['accuracy','precision', 'recall', 'f1'])

In [None]:
L_everything_frame = csr_matrix([
    [
        label
        for video_id in sorted(list(video_ids_all))
        for label in lf[video_id]
    ]
    for lf in weak_labels_all_movies
]).transpose()

In [None]:
len(weak_labels_all_movies[1][1])

In [None]:
frame_counts[1]

In [None]:
L_everything_frame.shape

In [None]:
frame_predictions_everything = label_model_everything.predict_proba(L_everything_frame)

In [None]:
video_frame_nums = [
    (video_id, f+1)
    for video_id in sorted(list(video_ids_all))
    for f in range(frame_counts[video_id])
]

In [None]:
frame_predictions_everything.shape

In [None]:
video_frame_nums[-10:]

In [None]:
len(video_frame_nums)

In [None]:
predictions_to_save = [
    (frame_info, prediction.tolist())
    for frame_info, prediction in zip(video_frame_nums, frame_predictions_everything)
]

In [None]:
predictions_to_save[:10]

In [None]:
preds_np = np.array(predictions_to_save)

In [None]:
preds_np.shape

In [None]:
# save predictions to disk
with open('../../data/shot_detection_weak_labels/noisy_labels_all_frame.npy', 'wb') as f:
    np.save(f, preds_np)

### Window based

In [None]:
# First, construct windows of 16 frames for each video
windows_train_all = VideoIntervalCollection({
    video_id: [
        (f, f + 16, video_id)
        for f in range(0, frame_counts[video_id] - 16, 8)
    ]
    for video_id in train_movies_all
})

In [None]:
# Label windows with the weak labels in our labeling functions
def label_window(per_frame_weak_labels):
    if 1 in per_frame_weak_labels:
        return 1
    if len([l for l in per_frame_weak_labels if l == 2]) >= len(per_frame_weak_labels) / 2:
        return 2
    return 0

windows_with_weak_labels_train_all = windows_train_all.map(
    lambda window: (
        window.start,
        window.end,
        [
            label_window([
                lf[window.payload][f-1]
                for f in range(window.start, window.end)
            ])
            for lf in weak_labels_all_movies
        ]
    )
)

In [None]:
windows_with_weak_labels_test_all = windows.map(
    lambda window: (
        window.start,
        window.end,
        [
            label_window([
                lf[window.payload][f-1]
                for f in range(window.start, window.end)
            ])
            for lf in weak_labels_all_movies
        ]
    )
)

In [None]:
windows_with_weak_labels_gt_only = windows_with_weak_labels_test_all.filter_against(
    clips, predicate=overlaps(), working_window=1
)

In [None]:
windows_with_weak_labels_all = windows_with_weak_labels_train_all.set_union(
    windows_with_weak_labels_test_all
)

In [None]:
Y_test_windows = np.array([
    intrvl.payload
    for video_id in test_videos
    for intrvl in windows_with_labels.get_intervallist(video_id).get_intervals()
])
L_test_windows = csr_matrix([
    intrvl.payload
    for video_id in test_videos
    for intrvl in windows_with_weak_labels_gt_only.get_intervallist(video_id).get_intervals()
])

In [None]:
L_train_windows_all = csr_matrix([
    intrvl.payload
    for video_id in train_split
    for intrvl in windows_with_weak_labels_train_all.get_intervallist(video_id).get_intervals()
])

In [None]:
label_model_everything_windows = LabelModel(k=2, seed=123)
label_model_everything_windows.train_model(L_train_windows_all, class_balance=(0.15, 0.85), n_epochs=20000, log_train_every=50)

In [None]:
label_model_everything_windows.score((L_test_windows, Y_test_windows), metric=['accuracy','precision', 'recall', 'f1'])

In [None]:
L_everything_windows = csr_matrix([
    intrvl.payload
    for video_id in sorted(list(video_ids_all))
    for intrvl in windows_with_weak_labels_all.get_intervallist(video_id).get_intervals()
])

In [None]:
window_predictions_everything = label_model_everything_windows.predict_proba(L_everything_windows)

In [None]:
window_nums = [
    (video_id, intrvl.start, intrvl.end)
    for video_id in sorted(list(video_ids_all))
    for intrvl in windows_with_weak_labels_all.get_intervallist(video_id).get_intervals()
]

In [None]:
predictions_to_save_windows = [
    (window_info, prediction)
    for window_info, prediction in zip(window_nums, window_predictions_everything)
]

In [None]:
preds_np_windows = np.array(predictions_to_save_windows)

In [None]:
# save predictions to disk
with open('../../data/shot_detection_weak_labels/noisy_labels_all_windows.npy', 'wb') as f:
    np.save(f, preds_np_windows)