In [22]:
import sys
sys.path.append('/lfs/1/danfu/metal')
import metal

In [3]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [4]:
import pickle
import rekall
from rekall.video_interval_collection import VideoIntervalCollection
from rekall.interval_list import IntervalList
from rekall.temporal_predicates import *
import numpy as np
from scipy.sparse import csr_matrix
import os
from tqdm import tqdm
import random

from metal.analysis import lf_summary
from metal.label_model.baselines import MajorityLabelVoter
from metal.label_model import LabelModel

# Load Shot Data

In [5]:
with open('../../data/shot_detection_folds.pkl', 'rb') as f:
    shot_detection_folds = pickle.load(f)

In [6]:
with open('../../data/manually_annotated_shots.pkl', 'rb') as f:
    shots = VideoIntervalCollection(pickle.load(f))

In [7]:
clips = shots.dilate(1).coalesce().dilate(-1)

100%|██████████| 28/28 [00:00<00:00, 11953.23it/s]
100%|██████████| 28/28 [00:00<00:00, 46109.35it/s]


In [8]:
shot_boundaries = shots.map(
    lambda intrvl: (intrvl.start, intrvl.start, intrvl.payload)
).set_union(
    shots.map(lambda intrvl: (intrvl.end + 1, intrvl.end + 1, intrvl.payload))
).coalesce()

In [9]:
boundary_frames = {
    video_id: [
        intrvl.start
        for intrvl in shot_boundaries.get_intervallist(video_id).get_intervals()
    ]
    for video_id in shot_boundaries.get_allintervals()
}

In [10]:
video_ids = sorted(list(clips.get_allintervals().keys()))

In [11]:
frames_per_video = {
    video_id: sorted([
        f
        for interval in clips.get_intervallist(video_id).get_intervals()
        for f in range(interval.start, interval.end + 2)
    ])
    for video_id in video_ids
}

In [12]:
ground_truth = {
    video_id: [
        1 if f in boundary_frames[video_id] else 2
        for f in frames_per_video[video_id]
    ] 
    for video_id in video_ids
}

# Load Weak Labels

In [13]:
with open('../../data/frame_counts.pkl', 'rb') as f:
    frame_counts = pickle.load(f)

In [15]:
with open('../../data/shot_detection_weak_labels/all_labels.pkl', 'rb') as f:
    all_labels = pickle.load(f)

In [23]:
labeling_function_folders = [
    '../../data/shot_detection_weak_labels/rgb_hists',
    '../../data/shot_detection_weak_labels/hsv_hists',
#     '../../data/shot_detection_weak_labels/flow_hists_magnitude', # this is just really really bad
    '../../data/shot_detection_weak_labels/flow_hists_diffs',
    '../../data/shot_detection_weak_labels/face_counts',
    '../../data/shot_detection_weak_labels/face_positions'
]

In [24]:
weak_labels_all = []
weak_labels_gt_only = []

In [25]:
for folder in labeling_function_folders:
    labels_for_function_all = {}
    labels_for_function_gt_only = {}
    for video_id in tqdm(video_ids):
        all_frames = IntervalList([
            (f+1, f+1, 0)
            for f in range(frame_counts[video_id])
        ])
        with open(os.path.join(folder, '{}.pkl'.format(video_id)), 'rb') as f:
            positive_boundaries, negative_boundaries = pickle.load(f)
            positive_frames = IntervalList([
                (f, f, 1)
                for f in positive_boundaries if f <= frame_counts[video_id]
            ])
            negative_frames = IntervalList([
                (f, f, 2)
                for f in negative_boundaries if f <= frame_counts[video_id]
            ])
            frames_w_labels = all_frames.set_union(
                positive_frames
            ).set_union(
                negative_frames
            ).coalesce(payload_merge_op = lambda p1, p2: max(p1, p2))
            
            labels_for_function_all[video_id] = [
                intrvl.payload
                for intrvl in frames_w_labels.get_intervals()
            ]
            
            labels_for_function_gt_only[video_id] = [
                frames_w_labels.get_intervals()[f-1].payload
                for f in frames_per_video[video_id]
            ]
            
    weak_labels_all.append(labels_for_function_all)
    weak_labels_gt_only.append(labels_for_function_gt_only)

100%|██████████| 28/28 [03:15<00:00,  6.52s/it]
100%|██████████| 28/28 [02:59<00:00,  6.36s/it]
  4%|▎         | 1/28 [00:06<03:06,  6.92s/it]

KeyboardInterrupt: 

In [None]:
Y = np.array([
    label
    for video_id in video_ids
    for label in ground_truth[video_id]
])

In [None]:
Y.shape

In [None]:
L = csr_matrix([
    [
        label
        for video_id in video_ids
        for label in lf[video_id]
    ]
    for lf in weak_labels_gt_only
]).transpose()

In [None]:
L.shape

In [None]:
lf_summary(L, Y=Y, lf_names = ['RGB hist', 'HSV hist', 'flow hist', 'face counts', 'face positions'])

# Train Label Model

## Part 0: Majority Vote

In [None]:
for i in range(5):
    test_fold = shot_detection_folds[i]
    train_videos = [
        video_id
        for video_id in video_ids if video_id not in test_fold
    ]
    test_videos = [
        video_id
        for video_id in video_ids if video_id in test_fold
    ]
    
    Y_test = np.array([
        label
        for video_id in test_videos
        for label in ground_truth[video_id]
    ])
    L_test = csr_matrix([
        [
            label
            for video_id in test_videos
            for label in lf[video_id]
        ]
        for lf in weak_labels_gt_only
    ]).transpose()
    
    mv = MajorityLabelVoter(seed=123)
    scores = mv.score((L_test, Y_test), metric=['accuracy','precision', 'recall', 'f1'])

```
Accuracy: 0.988
Precision: 0.475
Recall: 0.970
F1: 0.637
        y=1    y=2   
 l=1    131    145   
 l=2     4    12392  
 
Accuracy: 0.991
Precision: 0.760
Recall: 0.941
F1: 0.841
        y=1    y=2   
 l=1    95     30    
 l=2     6    3996   
 
Accuracy: 0.995
Precision: 0.639
Recall: 0.994
F1: 0.778
        y=1    y=2   
 l=1    172    97    
 l=2     1    21015  
 
Accuracy: 0.986
Precision: 0.452
Recall: 0.980
F1: 0.619
        y=1    y=2   
 l=1    99     120   
 l=2     2    8586   
 
Accuracy: 0.992
Precision: 0.621
Recall: 0.937
F1: 0.747
        y=1    y=2   
 l=1    133    81    
 l=2     9    11039  
 
Average F1: .724
```

## Part 1: Train only on frames that we have gold labels for

In [None]:
for i in range(5):
    test_fold = shot_detection_folds[i]
    train_videos = [
        video_id
        for video_id in video_ids if video_id not in test_fold
    ]
    test_videos = [
        video_id
        for video_id in video_ids if video_id in test_fold
    ]
    
    L_train = csr_matrix([
        [
            label
            for video_id in train_videos
            for label in lf[video_id]
        ]
        for lf in weak_labels_gt_only
    ]).transpose()
    
    Y_test = np.array([
        label
        for video_id in test_videos
        for label in ground_truth[video_id]
    ])
    L_test = csr_matrix([
        [
            label
            for video_id in test_videos
            for label in lf[video_id]
        ]
        for lf in weak_labels_gt_only
    ]).transpose()
    
    label_model = LabelModel(k=2, seed=123)
    label_model.train_model(L_train, class_balance=(0.01, 0.99), n_epochs=500, log_train_every=50)
    label_model.score((L_test, Y_test), metric=['accuracy','precision', 'recall', 'f1'])

Results per fold:

```
Accuracy: 0.997
Precision: 0.887
Recall: 0.874
F1: 0.881
        y=1    y=2   
 l=1    118    15    
 l=2    17    12522  

Accuracy: 0.990
Precision: 0.984
Recall: 0.594
F1: 0.741
        y=1    y=2   
 l=1    60      1    
 l=2    41    4025   

Accuracy: 0.999
Precision: 0.953
Recall: 0.942
F1: 0.948
        y=1    y=2   
 l=1    163     8    
 l=2    10    21104  

Accuracy: 0.996
Precision: 0.802
Recall: 0.842
F1: 0.821
        y=1    y=2   
 l=1    85     21    
 l=2    16    8685   

Accuracy: 0.996
Precision: 0.946
Recall: 0.746
F1: 0.835
        y=1    y=2   
 l=1    106     6    
 l=2    36    11114  
 
Average F1: .845
```

## Part 2: Train on entire movies

In [None]:
prediction_probabilities = []

In [None]:
for i in range(5):
    test_fold = shot_detection_folds[i]
    train_videos = [
        video_id
        for video_id in video_ids if video_id not in test_fold
    ]
    test_videos = [
        video_id
        for video_id in video_ids if video_id in test_fold
    ]
    
    L_train = csr_matrix([
        [
            label
            for video_id in train_videos
            for label in lf[video_id]
        ]
        for lf in weak_labels_all
    ]).transpose()
    
    Y_test = np.array([
        label
        for video_id in test_videos
        for label in ground_truth[video_id]
    ])
    L_test = csr_matrix([
        [
            label
            for video_id in test_videos
            for label in lf[video_id]
        ]
        for lf in weak_labels_gt_only
    ]).transpose()
    
    label_model = LabelModel(k=2, seed=123)
    label_model.train_model(L_train, class_balance=(0.01, 0.99), n_epochs=500, log_train_every=50)
    label_model.score((L_test, Y_test), metric=['accuracy','precision', 'recall', 'f1'])
    
    Y_readable = [
        (video_id, f, 1 if f in boundary_frames[video_id] else 2)
        for video_id in test_videos
        for f in frames_per_video[video_id]
    ]
    
    predictions = label_model.predict(L_test)
    prediction_probs = label_model.predict_proba(L_test)
    prediction_probabilities.append([p[0] for p in prediction_probs])
    
    wrong_predictions = np.where(predictions != Y_test)[0]
    
    wrong_interval_preds = [
        (Y_readable[int(wp)], prediction_probs[int(wp)].tolist())
        for wp in wrong_predictions
    ]

    with open('../../data/failure_cases/metal_frame_only/{}_fold.pkl'.format(i + 1), 'wb') as f:
        pickle.dump(wrong_interval_preds, f)

```
Accuracy: 0.997
Precision: 0.887
Recall: 0.874
F1: 0.881
        y=1    y=2   
 l=1    118    15    
 l=2    17    12522  

Accuracy: 0.990
Precision: 0.984
Recall: 0.594
F1: 0.741
        y=1    y=2   
 l=1    60      1    
 l=2    41    4025   

Accuracy: 0.999
Precision: 0.953
Recall: 0.942
F1: 0.948
        y=1    y=2   
 l=1    163     8    
 l=2    10    21104  

Accuracy: 0.996
Precision: 0.802
Recall: 0.842
F1: 0.821
        y=1    y=2   
 l=1    85     21    
 l=2    16    8685   

Accuracy: 0.996
Precision: 0.946
Recall: 0.746
F1: 0.835
        y=1    y=2   
 l=1    106     6    
 l=2    36    11114  
 
Average F1: .845
```

In [None]:
for i, problist in enumerate(prediction_probabilities):
    plt.hist(
        problist,
        log=True)
    plt.title('Probability histogram for fold {}'.format(i + 1))
    plt.xlabel('Probability')
    plt.ylabel('Count')
    plt.show()

## Part 3: Classify windows of 16 frames

### Labeling Functions for windows of 16 frames

In [38]:
# First, construct windows of 16 frames for each video
windows = VideoIntervalCollection({
    video_id: [
        (f, f + 16, video_id)
        for f in range(0, frame_counts[video_id] - 16, 8)
    ]
    for video_id in video_ids
})

In [43]:
# Next, intersect the windows with ground truth and get ground truth labels for the windows
windows_intersecting_ground_truth = windows.filter_against(
    clips,
    predicate=overlaps()
).map(lambda intrvl: (intrvl.start, intrvl.end, 2))
windows_with_shot_boundaries = windows_intersecting_ground_truth.filter_against(
    shot_boundaries,
    predicate = lambda window, shot_boundary:
        shot_boundary.start >= window.start and shot_boundary.start < window.end
).map(
    lambda intrvl: (intrvl.start, intrvl.end, 1)
)
windows_with_labels = windows_with_shot_boundaries.set_union(
    windows_intersecting_ground_truth
).coalesce(
    predicate = equal(),
    payload_merge_op = lambda p1, p2: min(p1, p2)
)

In [None]:
# Label windows with the weak labels in our labeling functions
def label_window(per_frame_weak_labels):
    if 1 in per_frame_weak_labels:
        return 1
    if len([l for l in per_frame_weak_labels if l == 2]) >= len(per_frame_weak_labels) / 2:
        return 2
    return 0

windows_with_weak_labels = windows.map(
    lambda window: (
        window.start,
        window.end,
        [
            label_window([
                lf[window.payload][f-1]
                for f in range(window.start, window.end)
            ])
            for lf in weak_labels_all
        ]
    )
)

In [None]:
Y_windows = np.array([
    intrvl.payload
    for video_id in video_ids
    for intrvl in windows_with_labels.get_intervallist(video_id).get_intervals()
])

In [None]:
Y_windows.shape

In [None]:
len([y for y in Y_windows if y == 1])

In [None]:
L_windows = csr_matrix([
    intrvl.payload
    for video_id in video_ids
    for intrvl in windows_with_weak_labels.filter_against(
        clips, predicate=overlaps(), working_window=1
    ).get_intervallist(video_id).get_intervals()
])

In [None]:
L_windows.shape

In [None]:
lf_summary(L_windows, Y=Y_windows, lf_names = ['RGB hist', 'HSV hist', 'flow hist', 'face counts', 'face positions'])

In [None]:
csr_matrix([
    intrvl.payload
    for video_id in video_ids
    for intrvl in windows_with_weak_labels.get_intervallist(video_id).get_intervals()
]).shape

### Part 0: Majority Vote

In [None]:
windows_with_weak_labels_gt_only = windows_with_weak_labels.filter_against(
    clips, predicate=overlaps(), working_window=1
)
for i in range(5):
    test_fold = shot_detection_folds[i]
    train_videos = [
        video_id
        for video_id in video_ids if video_id not in test_fold
    ]
    test_videos = [
        video_id
        for video_id in video_ids if video_id in test_fold
    ]
    
    Y_test = np.array([
        intrvl.payload
        for video_id in test_videos
        for intrvl in windows_with_labels.get_intervallist(video_id).get_intervals()
    ])
    L_test = csr_matrix([
        intrvl.payload
        for video_id in test_videos
        for intrvl in windows_with_weak_labels_gt_only.get_intervallist(video_id).get_intervals()
    ])
    
    mv = MajorityLabelVoter(seed=123)
    scores = mv.score((L_test, Y_test), metric=['accuracy','precision', 'recall', 'f1'])

```
Accuracy: 0.947
Precision: 0.785
Recall: 0.940
F1: 0.856
        y=1    y=2   
 l=1    252    69    
 l=2    16    1269   
Accuracy: 0.912
Precision: 0.918
Recall: 0.845
F1: 0.880
        y=1    y=2   
 l=1    169    15    
 l=2    31     306   
Accuracy: 0.976
Precision: 0.861
Recall: 0.968
F1: 0.911
        y=1    y=2   
 l=1    334    54    
 l=2    11    2268   
Accuracy: 0.914
Precision: 0.707
Recall: 0.906
F1: 0.794
        y=1    y=2   
 l=1    183    76    
 l=2    19     827   
Accuracy: 0.957
Precision: 0.874
Recall: 0.905
F1: 0.889
        y=1    y=2   
 l=1    249    36    
 l=2    26    1116   

Average F1: .866
```

### Part 1: Train LabelModel on frames that we have GT for

In [None]:
windows_with_weak_labels_gt_only = windows_with_weak_labels.filter_against(
    clips, predicate=overlaps(), working_window=1
)
for i in range(5):
    test_fold = shot_detection_folds[i]
    train_videos = [
        video_id
        for video_id in video_ids if video_id not in test_fold
    ]
    test_videos = [
        video_id
        for video_id in video_ids if video_id in test_fold
    ]
    
    L_train = csr_matrix([
        intrvl.payload
        for video_id in train_videos
        for intrvl in windows_with_weak_labels_gt_only.get_intervallist(video_id).get_intervals()
    ])
    
    Y_test = np.array([
        intrvl.payload
        for video_id in test_videos
        for intrvl in windows_with_labels.get_intervallist(video_id).get_intervals()
    ])
    L_test = csr_matrix([
        intrvl.payload
        for video_id in test_videos
        for intrvl in windows_with_weak_labels_gt_only.get_intervallist(video_id).get_intervals()
    ])
    
    label_model = LabelModel(k=2, seed=123)
    label_model.train_model(L_train, class_balance=(0.15, 0.85), n_epochs=500, log_train_every=50)
    label_model.score((L_test, Y_test), metric=['accuracy','precision', 'recall', 'f1'])

In [None]:
windows_with_weak_labels_gt_only = windows_with_weak_labels.filter_against(
    clips, predicate=overlaps(), working_window=1
)
for i in range(5):
    test_fold = shot_detection_folds[i]
    train_videos = [
        video_id
        for video_id in video_ids if video_id not in test_fold
    ]
    test_videos = [
        video_id
        for video_id in video_ids if video_id in test_fold
    ]
    
    Y_train = np.array([
        intrvl.payload
        for video_id in train_videos
        for intrvl in windows_with_labels.get_intervallist(video_id).get_intervals()
    ])
    L_train = csr_matrix([
        intrvl.payload
        for video_id in train_videos
        for intrvl in windows_with_weak_labels_gt_only.get_intervallist(video_id).get_intervals()
    ])
    
    Y_test = np.array([
        intrvl.payload
        for video_id in test_videos
        for intrvl in windows_with_labels.get_intervallist(video_id).get_intervals()
    ])
    L_test = csr_matrix([
        intrvl.payload
        for video_id in test_videos
        for intrvl in windows_with_weak_labels_gt_only.get_intervallist(video_id).get_intervals()
    ])
    
    label_model = LabelModel(k=2, seed=123)
    label_model.train_model(L_train, Y_dev=Y_train, n_epochs=500, log_train_every=50)
    label_model.score((L_test, Y_test), metric=['accuracy','precision', 'recall', 'f1'])

```
Accuracy: 0.927
Precision: 0.709
Recall: 0.882
F1: 0.786
        y=1    y=2   
 l=1    217    89    
 l=2    29    1271 
 
Accuracy: 0.893
Precision: 0.911
Recall: 0.789
F1: 0.845
        y=1    y=2   
 l=1    153    15    
 l=2    41     312  

Accuracy: 0.969
Precision: 0.868
Recall: 0.903
F1: 0.885
        y=1    y=2   
 l=1    316    48    
 l=2    34    2269

Accuracy: 0.886
Precision: 0.657
Recall: 0.763
F1: 0.706
        y=1    y=2   
 l=1    151    79    
 l=2    47     828

Accuracy: 0.941
Precision: 0.805
Recall: 0.881
F1: 0.842
        y=1    y=2   
 l=1    223    54    
 l=2    30    1120 

Average F1: .813
```

### Part 2: Train LabelModel on entire videos

In [None]:
windows_with_weak_labels_gt_only = windows_with_weak_labels.filter_against(
    clips, predicate=overlaps(), working_window=1
)
prediction_probabilities_windows = []
for i in range(5):
    test_fold = shot_detection_folds[i]
    train_videos = [
        video_id
        for video_id in video_ids if video_id not in test_fold
    ]
    test_videos = [
        video_id
        for video_id in video_ids if video_id in test_fold
    ]
    
    L_train = csr_matrix([
        intrvl.payload
        for video_id in train_videos
        for intrvl in windows_with_weak_labels.get_intervallist(video_id).get_intervals()
    ])
    
    Y_test = np.array([
        intrvl.payload
        for video_id in test_videos
        for intrvl in windows_with_labels.get_intervallist(video_id).get_intervals()
    ])
    L_test = csr_matrix([
        intrvl.payload
        for video_id in test_videos
        for intrvl in windows_with_weak_labels_gt_only.get_intervallist(video_id).get_intervals()
    ])
    
    label_model = LabelModel(k=2, seed=123)
    label_model.train_model(L_train, class_balance=(0.15, 0.85), n_epochs=500, log_train_every=50)
    label_model.score((L_test, Y_test), metric=['accuracy','precision', 'recall', 'f1'])
    
    predictions = label_model.predict(L_test)
    prediction_probs = label_model.predict_proba(L_test)
    prediction_probabilities_windows.append([p[0] for p in prediction_probs])

```
Accuracy: 0.934
Precision: 0.765
Recall: 0.873
F1: 0.815
        y=1    y=2   
 l=1    234    72    
 l=2    34    1266   

Accuracy: 0.910
Precision: 0.932
Recall: 0.825
F1: 0.875
        y=1    y=2   
 l=1    165    12    
 l=2    35     309   

Accuracy: 0.979
Precision: 0.877
Recall: 0.971
F1: 0.922
        y=1    y=2   
 l=1    335    47    
 l=2    10    2275   

Accuracy: 0.898
Precision: 0.682
Recall: 0.827
F1: 0.747
        y=1    y=2   
 l=1    167    78    
 l=2    35     825   

Accuracy: 0.955
Precision: 0.881
Recall: 0.887
F1: 0.884
        y=1    y=2   
 l=1    244    33    
 l=2    31    1119   

Average F1: .849
```

In [None]:
for i, problist in enumerate(prediction_probabilities_windows):
    plt.hist(
        problist,
        log=True)
    plt.title('Probability histogram for fold {}'.format(i + 1))
    plt.xlabel('Probability')
    plt.ylabel('Count')
    plt.show()

# Part 4: Training on the entire dataset

In [26]:
# First, load noisy labels from the entire dataset
video_ids_all = sorted(list(frame_counts.keys()))

In [None]:
weak_labels_all_movies = []
for folder in labeling_function_folders:
    labels_for_function_all = {}
    for video_id in tqdm(video_ids):
        all_frames = IntervalList([
            (f+1, f+1, 0)
            for f in range(frame_counts[video_id])
        ])
        with open(os.path.join(folder, '{}.pkl'.format(video_id)), 'rb') as f:
            positive_boundaries, negative_boundaries = pickle.load(f)
            positive_frames = IntervalList([
                (f, f, 1)
                for f in positive_boundaries if f <= frame_counts[video_id]
            ])
            negative_frames = IntervalList([
                (f, f, 2)
                for f in negative_boundaries if f <= frame_counts[video_id]
            ])
            frames_w_labels = all_frames.set_union(
                positive_frames
            ).set_union(
                negative_frames
            ).coalesce(payload_merge_op = lambda p1, p2: max(p1, p2))
            
            labels_for_function_all[video_id] = [
                intrvl.payload
                for intrvl in frames_w_labels.get_intervals()
            ]
            
    weak_labels_all_movies.append(labels_for_function_all)

In [None]:
# Save weak labels
with open('../../data/shot_detection_weak_labels/all_labels.pkl', 'wb') as f:
    pickle.dump(weak_labels_all_movies, f)

In [28]:
# Or load weak labels
with open('../../data/shot_detection_weak_labels/all_labels.pkl', 'rb') as f:
    weak_labels_all_movies = pickle.load(f)

In [29]:
weak_labels_gt_only = [
    {
        video_id: [
            lf[video_id][f-1]
            for f in frames_per_video[video_id]
        ]
        for video_id in sorted(list(clips.get_allintervals().keys()))
    }
    for lf in weak_labels_all_movies
]

## 100 Movies

In [None]:
random.seed(0)

In [None]:
# randomly choose 100 movies to train on; do not choose any movies that we have GT for
vid_candidates = sorted(list(set(video_ids_all).difference(set(clips.get_allintervals().keys()))))

In [None]:
random.shuffle(vid_candidates)

In [None]:
train_split = sorted(vid_candidates[:100])

In [None]:
# Save train split
with open('../../data/shot_detection_weak_labels/train_split_100.pkl', 'wb') as f:
    pickle.dump(train_split, f)

In [49]:
# or load train split
with open('../../data/shot_detection_weak_labels/train_split_100.pkl', 'rb') as f:
    train_split = pickle.load(f)

### Frame-Based model

In [32]:
test_videos = sorted(list(clips.get_allintervals().keys()))

In [None]:
L_train_100_frames = csr_matrix([
    [
        label
        for video_id in train_split
        for label in lf[video_id]
    ]
    for lf in weak_labels_all_movies
]).transpose()

Y_test = np.array([
    label
    for video_id in test_videos
    for label in ground_truth[video_id]
])
L_test = csr_matrix([
    [
        label
        for video_id in test_videos
        for label in lf[video_id]
    ]
    for lf in weak_labels_gt_only
]).transpose()

In [None]:
MajorityLabelVoter(seed=123).score((L_test, Y_test), metric=['accuracy','precision', 'recall', 'f1'])

In [None]:
label_model_100_frames = LabelModel(k=2, seed=123)
label_model_100_frames.train_model(L_train_100_frames, Y_dev = Y_test, n_epochs=5000, log_train_every=50)
label_model_100_frames.score((L_test, Y_test), metric=['accuracy','precision', 'recall', 'f1'])

### Window-based model

In [None]:
# First, construct windows of 16 frames for each video
windows_train = VideoIntervalCollection({
    video_id: [
        (f, f + 16, video_id)
        for f in range(0, frame_counts[video_id] - 16, 8)
    ]
    for video_id in train_split
})

In [None]:
windows_test = VideoIntervalCollection({
    video_id: [
        (f, f + 16, video_id)
        for f in range(0, frame_counts[video_id] - 16, 8)
    ]
    for video_id in test_videos
})

In [None]:
# Label windows with the weak labels in our labeling functions
def label_window(per_frame_weak_labels):
    if 1 in per_frame_weak_labels:
        return 1
    if len([l for l in per_frame_weak_labels if l == 2]) >= len(per_frame_weak_labels) / 2:
        return 2
    return 0

windows_with_weak_labels_train = windows_train.map(
    lambda window: (
        window.start,
        window.end,
        [
            label_window([
                lf[window.payload][f-1]
                for f in range(window.start, window.end)
            ])
            for lf in weak_labels_all_movies
        ]
    )
)

In [None]:
windows_with_weak_labels_test = windows_test.map(
    lambda window: (
        window.start,
        window.end,
        [
            label_window([
                lf[window.payload][f-1]
                for f in range(window.start, window.end)
            ])
            for lf in weak_labels_all_movies
        ]
    )
)

In [None]:
# L_train_100_windows = csr_matrix([
#     intrvl.payload
#     for video_id in train_split
#     for intrvl in windows_with_weak_labels_train.get_intervallist(video_id).get_intervals()
# ])

Y_test_windows = np.array([
    intrvl.payload
    for video_id in test_videos
    for intrvl in windows_with_labels.get_intervallist(video_id).get_intervals()
])
L_test_windows = csr_matrix([
    intrvl.payload
    for video_id in test_videos
    for intrvl in windows_with_weak_labels_gt_only.get_intervallist(video_id).get_intervals()
])

In [None]:
MajorityLabelVoter(seed=123).score((L_test_windows, Y_test_windows), metric=['accuracy','precision', 'recall', 'f1'])

In [None]:
label_model_100_windows = LabelModel(k=2, seed=123)
label_model_100_windows.train_model(L_train_100_windows, class_balance=(0.15, 0.85), n_epochs=10000, log_train_every=50)
label_model_100_windows.score((L_test_windows, Y_test_windows), metric=['accuracy','precision', 'recall', 'f1'])

## All Movies

### Frame based

In [30]:
train_movies_all = sorted(list(set(video_ids_all).difference(set(clips.get_allintervals().keys()))))

In [33]:
L_train_everything = csr_matrix([
    [
        label
        for video_id in train_movies_all
        for label in lf[video_id]
    ]
    for lf in weak_labels_all_movies
]).transpose()

Y_test = np.array([
    label
    for video_id in test_videos
    for label in ground_truth[video_id]
])
L_test = csr_matrix([
    [
        label
        for video_id in test_videos
        for label in lf[video_id]
    ]
    for lf in weak_labels_gt_only
]).transpose()

In [34]:
label_model_everything = LabelModel(k=2, seed=123)
label_model_everything.train_model(L_train_everything, class_balance=(0.01, 0.99), n_epochs=5000, log_train_every=50)
label_model_everything.score((L_test, Y_test), metric=['accuracy','precision', 'recall', 'f1'])

Computing O...
Estimating \mu...
[50 epo]: TRAIN:[loss=0.055]
[100 epo]: TRAIN:[loss=0.052]
[150 epo]: TRAIN:[loss=0.052]
[200 epo]: TRAIN:[loss=0.052]
[250 epo]: TRAIN:[loss=0.052]
[300 epo]: TRAIN:[loss=0.052]
[350 epo]: TRAIN:[loss=0.052]
[400 epo]: TRAIN:[loss=0.052]
[450 epo]: TRAIN:[loss=0.052]
[500 epo]: TRAIN:[loss=0.052]
[550 epo]: TRAIN:[loss=0.052]
[600 epo]: TRAIN:[loss=0.052]
[650 epo]: TRAIN:[loss=0.052]
[700 epo]: TRAIN:[loss=0.052]
[750 epo]: TRAIN:[loss=0.052]
[800 epo]: TRAIN:[loss=0.052]
[850 epo]: TRAIN:[loss=0.052]
[900 epo]: TRAIN:[loss=0.052]
[950 epo]: TRAIN:[loss=0.052]
[1000 epo]: TRAIN:[loss=0.052]
[1050 epo]: TRAIN:[loss=0.052]
[1100 epo]: TRAIN:[loss=0.052]
[1150 epo]: TRAIN:[loss=0.052]
[1200 epo]: TRAIN:[loss=0.052]
[1250 epo]: TRAIN:[loss=0.052]
[1300 epo]: TRAIN:[loss=0.052]
[1350 epo]: TRAIN:[loss=0.052]
[1400 epo]: TRAIN:[loss=0.052]
[1450 epo]: TRAIN:[loss=0.052]
[1500 epo]: TRAIN:[loss=0.052]
[1550 epo]: TRAIN:[loss=0.052]
[1600 epo]: TRAIN:[loss=0.

KeyboardInterrupt: 

In [None]:
L_everything_frame = csr_matrix([
    [
        label
        for video_id in sorted(list(video_ids_all))
        for label in lf[video_id]
    ]
    for lf in weak_labels_all_movies
]).transpose()

In [None]:
len(weak_labels_all_movies[1][1])

In [None]:
frame_counts[1]

In [None]:
L_everything_frame.shape

In [None]:
frame_predictions_everything = label_model_everything.predict_proba(L_everything_frame)

In [None]:
video_frame_nums = [
    (video_id, f+1)
    for video_id in sorted(list(video_ids_all))
    for f in range(frame_counts[video_id])
]

In [None]:
frame_predictions_everything.shape

In [None]:
video_frame_nums[-10:]

In [None]:
len(video_frame_nums)

In [None]:
predictions_to_save = [
    (frame_info, prediction.tolist())
    for frame_info, prediction in zip(video_frame_nums, frame_predictions_everything)
]

In [None]:
predictions_to_save[:10]

In [None]:
preds_np = np.array(predictions_to_save)

In [None]:
preds_np.shape

In [None]:
# save predictions to disk
with open('../../data/shot_detection_weak_labels/noisy_labels_all_frame.npy', 'wb') as f:
    np.save(f, preds_np)

### Window based

In [52]:
# First, construct windows of 16 frames for each video
windows_train_all = VideoIntervalCollection({
    video_id: [
        (f, f + 16, video_id)
        for f in range(0, frame_counts[video_id] - 16, 8)
    ]
    for video_id in train_movies_all
})

In [53]:
# Label windows with the weak labels in our labeling functions
def label_window(per_frame_weak_labels):
    if 1 in per_frame_weak_labels:
        return 1
    if len([l for l in per_frame_weak_labels if l == 2]) >= len(per_frame_weak_labels) / 2:
        return 2
    return 0

windows_with_weak_labels_train_all = windows_train_all.map(
    lambda window: (
        window.start,
        window.end,
        [
            label_window([
                lf[window.payload][f-1]
                for f in range(window.start, window.end)
            ])
            for lf in weak_labels_all_movies
        ]
    )
)

In [54]:
windows_with_weak_labels_test_all = windows.map(
    lambda window: (
        window.start,
        window.end,
        [
            label_window([
                lf[window.payload][f-1]
                for f in range(window.start, window.end)
            ])
            for lf in weak_labels_all_movies
        ]
    )
)

In [55]:
windows_with_weak_labels_gt_only = windows_with_weak_labels_test_all.filter_against(
    clips, predicate=overlaps(), working_window=1
)

In [56]:
windows_with_weak_labels_all = windows_with_weak_labels_train_all.set_union(
    windows_with_weak_labels_test_all
)

In [44]:
Y_test_windows = np.array([
    intrvl.payload
    for video_id in test_videos
    for intrvl in windows_with_labels.get_intervallist(video_id).get_intervals()
])
L_test_windows = csr_matrix([
    intrvl.payload
    for video_id in test_videos
    for intrvl in windows_with_weak_labels_gt_only.get_intervallist(video_id).get_intervals()
])

In [57]:
L_train_windows_all = csr_matrix([
    intrvl.payload
    for video_id in train_split
    for intrvl in windows_with_weak_labels_train_all.get_intervallist(video_id).get_intervals()
])

In [58]:
label_model_everything_windows = LabelModel(k=2, seed=123)
label_model_everything_windows.train_model(L_train_windows_all, Y_dev = Y_test_windows,
                                           n_epochs=40000, log_train_every=50)

Computing O...
Estimating \mu...
[50 epo]: TRAIN:[loss=0.064]
[100 epo]: TRAIN:[loss=0.048]
[150 epo]: TRAIN:[loss=0.045]
[200 epo]: TRAIN:[loss=0.043]
[250 epo]: TRAIN:[loss=0.043]
[300 epo]: TRAIN:[loss=0.043]
[350 epo]: TRAIN:[loss=0.043]
[400 epo]: TRAIN:[loss=0.043]
[450 epo]: TRAIN:[loss=0.043]
[500 epo]: TRAIN:[loss=0.043]
[550 epo]: TRAIN:[loss=0.043]
[600 epo]: TRAIN:[loss=0.043]
[650 epo]: TRAIN:[loss=0.043]
[700 epo]: TRAIN:[loss=0.043]
[750 epo]: TRAIN:[loss=0.043]
[800 epo]: TRAIN:[loss=0.043]
[850 epo]: TRAIN:[loss=0.043]
[900 epo]: TRAIN:[loss=0.043]
[950 epo]: TRAIN:[loss=0.043]
[1000 epo]: TRAIN:[loss=0.043]
[1050 epo]: TRAIN:[loss=0.043]
[1100 epo]: TRAIN:[loss=0.043]
[1150 epo]: TRAIN:[loss=0.043]
[1200 epo]: TRAIN:[loss=0.043]
[1250 epo]: TRAIN:[loss=0.043]
[1300 epo]: TRAIN:[loss=0.043]
[1350 epo]: TRAIN:[loss=0.043]
[1400 epo]: TRAIN:[loss=0.043]
[1450 epo]: TRAIN:[loss=0.043]
[1500 epo]: TRAIN:[loss=0.043]
[1550 epo]: TRAIN:[loss=0.043]
[1600 epo]: TRAIN:[loss=0.

[13150 epo]: TRAIN:[loss=0.043]
[13200 epo]: TRAIN:[loss=0.043]
[13250 epo]: TRAIN:[loss=0.043]
[13300 epo]: TRAIN:[loss=0.043]
[13350 epo]: TRAIN:[loss=0.043]
[13400 epo]: TRAIN:[loss=0.043]
[13450 epo]: TRAIN:[loss=0.043]
[13500 epo]: TRAIN:[loss=0.043]
[13550 epo]: TRAIN:[loss=0.043]
[13600 epo]: TRAIN:[loss=0.043]
[13650 epo]: TRAIN:[loss=0.043]
[13700 epo]: TRAIN:[loss=0.043]
[13750 epo]: TRAIN:[loss=0.043]
[13800 epo]: TRAIN:[loss=0.043]
[13850 epo]: TRAIN:[loss=0.043]
[13900 epo]: TRAIN:[loss=0.043]
[13950 epo]: TRAIN:[loss=0.043]
[14000 epo]: TRAIN:[loss=0.043]
[14050 epo]: TRAIN:[loss=0.043]
[14100 epo]: TRAIN:[loss=0.043]
[14150 epo]: TRAIN:[loss=0.043]
[14200 epo]: TRAIN:[loss=0.043]
[14250 epo]: TRAIN:[loss=0.043]
[14300 epo]: TRAIN:[loss=0.043]
[14350 epo]: TRAIN:[loss=0.043]
[14400 epo]: TRAIN:[loss=0.043]
[14450 epo]: TRAIN:[loss=0.043]
[14500 epo]: TRAIN:[loss=0.043]
[14550 epo]: TRAIN:[loss=0.043]
[14600 epo]: TRAIN:[loss=0.043]
[14650 epo]: TRAIN:[loss=0.043]
[14700 e

[26100 epo]: TRAIN:[loss=0.043]
[26150 epo]: TRAIN:[loss=0.043]
[26200 epo]: TRAIN:[loss=0.043]
[26250 epo]: TRAIN:[loss=0.043]
[26300 epo]: TRAIN:[loss=0.043]
[26350 epo]: TRAIN:[loss=0.043]
[26400 epo]: TRAIN:[loss=0.043]
[26450 epo]: TRAIN:[loss=0.043]
[26500 epo]: TRAIN:[loss=0.043]
[26550 epo]: TRAIN:[loss=0.043]
[26600 epo]: TRAIN:[loss=0.043]
[26650 epo]: TRAIN:[loss=0.043]
[26700 epo]: TRAIN:[loss=0.043]
[26750 epo]: TRAIN:[loss=0.043]
[26800 epo]: TRAIN:[loss=0.043]
[26850 epo]: TRAIN:[loss=0.043]
[26900 epo]: TRAIN:[loss=0.043]
[26950 epo]: TRAIN:[loss=0.043]
[27000 epo]: TRAIN:[loss=0.043]
[27050 epo]: TRAIN:[loss=0.043]
[27100 epo]: TRAIN:[loss=0.043]
[27150 epo]: TRAIN:[loss=0.043]
[27200 epo]: TRAIN:[loss=0.043]
[27250 epo]: TRAIN:[loss=0.043]
[27300 epo]: TRAIN:[loss=0.043]
[27350 epo]: TRAIN:[loss=0.043]
[27400 epo]: TRAIN:[loss=0.043]
[27450 epo]: TRAIN:[loss=0.043]
[27500 epo]: TRAIN:[loss=0.043]
[27550 epo]: TRAIN:[loss=0.043]
[27600 epo]: TRAIN:[loss=0.043]
[27650 e

[39050 epo]: TRAIN:[loss=0.043]
[39100 epo]: TRAIN:[loss=0.043]
[39150 epo]: TRAIN:[loss=0.043]
[39200 epo]: TRAIN:[loss=0.043]
[39250 epo]: TRAIN:[loss=0.043]
[39300 epo]: TRAIN:[loss=0.043]
[39350 epo]: TRAIN:[loss=0.043]
[39400 epo]: TRAIN:[loss=0.043]
[39450 epo]: TRAIN:[loss=0.043]
[39500 epo]: TRAIN:[loss=0.043]
[39550 epo]: TRAIN:[loss=0.043]
[39600 epo]: TRAIN:[loss=0.043]
[39650 epo]: TRAIN:[loss=0.043]
[39700 epo]: TRAIN:[loss=0.043]
[39750 epo]: TRAIN:[loss=0.043]
[39800 epo]: TRAIN:[loss=0.043]
[39850 epo]: TRAIN:[loss=0.043]
[39900 epo]: TRAIN:[loss=0.043]
[39950 epo]: TRAIN:[loss=0.043]
[40000 epo]: TRAIN:[loss=0.043]
Finished Training


In [59]:
label_model_everything_windows.score((L_test_windows, Y_test_windows),
                                     metric=['accuracy','precision', 'recall', 'f1'])

Accuracy: 0.940
Precision: 0.819
Recall: 0.844
F1: 0.831
        y=1    y=2   
 l=1   1089    241   
 l=2    201   5795   


[0.9396669396669397, 0.818796992481203, 0.8441860465116279, 0.8312977099236641]

In [127]:
label_model_everything_windows.score((L_test_windows, Y_test_windows),
                                     metric=['accuracy','precision', 'recall', 'f1'])

Accuracy: 0.940
Precision: 0.819
Recall: 0.844
F1: 0.831
        y=1    y=2   
 l=1   1089    241   
 l=2    201   5795   


[0.9396669396669397, 0.818796992481203, 0.8441860465116279, 0.8312977099236641]

In [None]:
L_everything_windows = csr_matrix([
    intrvl.payload
    for video_id in sorted(list(video_ids_all))
    for intrvl in windows_with_weak_labels_all.get_intervallist(video_id).get_intervals()
])

In [None]:
window_predictions_everything = label_model_everything_windows.predict_proba(L_everything_windows)

In [None]:
window_nums = [
    (video_id, intrvl.start, intrvl.end)
    for video_id in sorted(list(video_ids_all))
    for intrvl in windows_with_weak_labels_all.get_intervallist(video_id).get_intervals()
]

In [None]:
predictions_to_save_windows = [
    (window_info, prediction)
    for window_info, prediction in zip(window_nums, window_predictions_everything)
]

In [None]:
preds_np_windows = np.array(predictions_to_save_windows)

In [None]:
# save predictions to disk
with open('../../data/shot_detection_weak_labels/noisy_labels_all_windows.npy', 'wb') as f:
    np.save(f, preds_np_windows)

In [60]:
with open('../../data/shot_detection_weak_labels/noisy_labels_all_windows.npy', 'rb') as f:
    preds_np_windows = np.load(f)

In [61]:
preds_np_windows.shape

(12350523, 2)

In [62]:
preds_np_windows[0]

array([(1, 0, 16), array([0.0032828, 0.9967172])], dtype=object)

In [63]:
preds_np_windows[-1]

array([(642, 154264, 154280), array([0.33286079, 0.66713921])],
      dtype=object)

In [65]:
sorted(list(windows_with_labels.get_allintervals().keys()))

[23,
 34,
 54,
 65,
 104,
 116,
 123,
 144,
 148,
 172,
 178,
 179,
 181,
 201,
 226,
 248,
 308,
 315,
 339,
 359,
 370,
 411,
 504,
 515,
 557,
 574,
 577,
 585]

In [72]:
covered_clips = windows_with_labels.coalesce()

In [86]:
Y_predicted = []
for pred in tqdm(preds_np_windows):
    video_id, start, end = pred[0]
    if video_id not in covered_clips.get_allintervals().keys():
        continue
    covered = False
    for clip in covered_clips.get_intervallist(video_id).get_intervals():
        if start >= clip.start and end <= clip.end:
            covered = True
            break
    if covered:
        Y_predicted.append(np.argmax(pred[1]) + 1)


  0%|          | 0/12350523 [00:00<?, ?it/s][A
  1%|          | 75824/12350523 [00:00<00:16, 758238.92it/s][A
  1%|▏         | 176620/12350523 [00:00<00:14, 819118.03it/s][A
  2%|▏         | 281123/12350523 [00:00<00:13, 875922.17it/s][A
  3%|▎         | 386126/12350523 [00:00<00:12, 921774.28it/s][A
  4%|▍         | 478550/12350523 [00:00<00:12, 922509.58it/s][A
  5%|▍         | 581854/12350523 [00:00<00:12, 953100.57it/s][A
  5%|▌         | 674539/12350523 [00:00<00:12, 945070.17it/s][A
  6%|▋         | 802463/12350523 [00:00<00:11, 1025428.58it/s][A
  8%|▊         | 931397/12350523 [00:00<00:10, 1092515.39it/s][A
  9%|▊         | 1053967/12350523 [00:01<00:10, 1129325.11it/s][A
 10%|▉         | 1194936/12350523 [00:01<00:09, 1200980.80it/s][A
 11%|█         | 1316494/12350523 [00:01<00:09, 1196412.78it/s][A
 12%|█▏        | 1459726/12350523 [00:01<00:08, 1258600.09it/s][A
 13%|█▎        | 1602699/12350523 [00:01<00:08, 1305476.87it/s][A
 14%|█▍        | 1747265/12350

In [87]:
len(Y_predicted)

7326

In [88]:
Y_predicted

[1,
 1,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 1,
 1,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 1,
 1,
 2,
 2,
 2,
 2,
 1,
 1,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 1,
 1,
 2,
 1,
 1,
 2,
 2,
 1,
 1,
 2,
 1,
 1,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 1,
 1,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 1,
 1,
 2,
 2,
 2,
 2,
 2,
 1,
 2,
 2,
 1,
 1,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 1,
 1,
 2,
 2,
 2,
 2,
 1,
 1,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 1,
 2,
 1,
 2,
 2,
 2,
 1,
 1,
 1,
 1,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 1,
 2,
 2,
 1,
 1,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 1,
 1,
 2,
 2,
 2,
 2,
 2,
 2,
 1,
 1,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,


In [80]:
Y_test_windows.shape

(7326,)

In [83]:
Y_test_windows[:15]

array([1, 1, 2, 2, 2, 2, 2, 2, 2, 1, 1, 2, 2, 2, 2])

In [89]:
tp = 0
tn = 0
fp = 0
fn = 0
for pred, gt in zip(Y_predicted, Y_test_windows):
    if pred == gt:
        if pred == 1:
            tp += 1
        else:
            tn += 1
    else:
        if pred == 1:
            fp += 1
        else:
            fn += 1
pre = tp / (tp + fp)
rec = tp / (tp + fn)
f1 = 2 * pre * rec / (pre + rec)
print(pre, rec, f1)

0.941351888667992 0.734108527131783 0.8249128919860627


In [68]:
clips.get_allintervals()

{65: [<Interval start:52470 end:53794 payload:1311898>],
 515: [<Interval start:19454 end:20064 payload:1311919>],
 577: [<Interval start:82661 end:84011 payload:1312046>],
 585: [<Interval start:39573 end:40914 payload:1311988>],
 34: [<Interval start:9325 end:10513 payload:1311847>],
 144: [<Interval start:74379 end:75003 payload:1311918>],
 504: [<Interval start:125357 end:126764 payload:1311921>],
 339: [<Interval start:14113 end:15187 payload:1312103>],
 148: [<Interval start:157178 end:158348 payload:1312001>],
 23: [<Interval start:79934 end:81255 payload:1311868>],
 411: [<Interval start:99219 end:99817 payload:1311812>],
 226: [<Interval start:8057 end:9179 payload:1312042>],
 123: [<Interval start:14455 end:21546 payload:61584>, <Interval start:86359 end:93503 payload:61666>, <Interval start:129517 end:136561 payload:61638>],
 359: [<Interval start:39211 end:40513 payload:1311784>],
 104: [<Interval start:78664 end:80019 payload:1312067>],
 370: [<Interval start:45535 end:530

# Tune the everything model

In [192]:
from metal.tuners.hyperband_tuner import HyperbandTuner
label_model_everything_windows_tuned = LabelModel(k=2, seed=123)
hb_tuner = HyperbandTuner(label_model_everything_windows_tuned, hyperband_epochs_budget=200,
                          seed=123, validation_metric="f1")

|           Hyperband Schedule          |
Table consists of tuples of (num configs, num_resources_per_config) which specify how many configs to run and for how many epochs. 
Each bracket starts with a list of random configurations which is successively halved according the schedule.
See the Hyperband paper (https://arxiv.org/pdf/1603.06560.pdf) for more details.
-----------------------------------------
Bracket 0: (9, 2) (3, 8) (1, 26)
Bracket 1: (3, 8) (1, 26)
Bracket 2: (3, 26)
-----------------------------------------


In [196]:
from metal.tuners.random_tuner import RandomSearchTuner
#label_model_everything_windows_tuned = LabelModel(k=2, seed=123)
random_tuner = RandomSearchTuner(LabelModel, seed=123, validation_metric='f1')

In [197]:
search_space = {
    'seed' : [123],
    'n_epochs': list(range(1000, 40000, 1000)),
    'lr': {'range': [1e-5, 1], 'scale': 'log'},
    'l2': {'range': [1e-5, 1], 'scale': 'log'},
    'log_train_every': [1000],
    'class_balance': [
        (i * .1, 1 - i * .1)
        for i in range(1, 10)
    ]
#     'Y_dev': [Y_test_windows]
}

In [198]:
best_random_model = random_tuner.search(search_space,
                                (L_test_windows, Y_test_windows),
                               train_args= [L_train_windows_all],
                               train_kwargs = {
#                                    'Y_dev': Y_test_windows
#                                    'class_balance': (0.2, 0.8)
                               },
                               init_kwargs={
                                   'k': 2
                               }, verbose=False)

PicklingError: Can't pickle <class 'metal.label_model.label_model.LabelModel'>: it's not the same object as metal.label_model.label_model.LabelModel

In [135]:
best_hb_model = hb_tuner.search(search_space,
                                (L_test_windows, Y_test_windows),
                               train_args= [L_train_windows_all],
                               init_kwargs={
                                   'k': 2
                               })

TypeError: forward() got an unexpected keyword argument 'k'