In [1]:
%matplotlib inline

import cv2
import functools as ft
import itertools as it
import matplotlib.pyplot as plt
import mxnet as mx
import numpy as np
import operator as op
import pickle
import random
import sys
import hdbscan
from datetime import datetime, timedelta
from sklearn.preprocessing import StandardScaler

def show_and_wait(frame, title='tesst'):
    cv2.imshow(title, frame)
    return chr(cv2.waitKey(0))

def read_and_wait(video, title='tesst'):
    result, frame = video.read()
    if result:
        return show_and_wait(frame, title)

def pairwise(iterable):
    "s -> (s0,s1), (s1,s2), (s2, s3), ..."
    a, b = it.tee(iterable)
    next(b, None)
    return zip(a, b)

def construct_data(file_path, extent, known_file_path=None, known_frame=-1, modify_frame=None):
    start_at = datetime.now()
    video = cv2.VideoCapture(file_path, cv2.CAP_MSMF)
    print(video.get(cv2.CAP_PROP_FRAME_COUNT), video.get(cv2.CAP_PROP_FPS))
    n = min(max(1, round(video.get(cv2.CAP_PROP_FRAME_COUNT) / 2000)), 30)
    top, left, bottom, right = extent
    def fn(frame):
        frame = frame[top:bottom, left:right]
        return modify_frame(frame) if modify_frame else frame
    def yield_frames():
        while True:
            result, frame = video.read()
            if not result:
                break
            yield frame
    frames = [fn(f) for i, f in enumerate(yield_frames()) if i % n == 0]
    video.release()

    if known_file_path:
        # Add a known positive frame.
        video = cv2.VideoCapture(known_file_path, cv2.CAP_MSMF)
        print(video.get(cv2.CAP_PROP_FRAME_COUNT), video.get(cv2.CAP_PROP_FPS))
        print(video.set(cv2.CAP_PROP_POS_FRAMES, video.get(cv2.CAP_PROP_FRAME_COUNT) if known_frame < 0 else known_frame))
        _, frame = video.read()
        frames.append(fn(frame))
        video.release()

    data = StandardScaler().fit_transform(np.stack([f.reshape(-1) for f in frames]))
    print(data.shape)
    print('total time:', (datetime.now() - start_at).total_seconds(), 'seconds')
    return frames, data

def print_clusters(frames, cluster_labels, has_known_cluster=False):
    print('cluster count:', len(np.unique(cluster_labels)))
    print(np.unique(cluster_labels))
    print(cluster_labels[:22])
    print(*enumerate(cluster_labels[:22]))
    print(*((k, len(list(v))) for k, v in it.groupby(np.sort(cluster_labels))))
    if has_known_cluster:
        if cluster_labels[-1] < 0:
            print('desired cluster unknown')
        else:
            print('desired cluster:', cluster_labels[-1])
    def fn():
        for n in np.unique(cluster_labels):
            g = (i for i, j in enumerate(cluster_labels) if j == n)
            g = it.islice(g, 5)
            stack = [frames[i] for i in g]
            yield np.hstack(stack)
    plt.imshow(np.vstack(list(fn())).take([2,1,0], axis=2))

def predict_cluster_labels(data, min_cluster_size=22):
    start_at = datetime.now()
    clusterer = hdbscan.HDBSCAN(min_cluster_size=min_cluster_size)
    cluster_labels = clusterer.fit_predict(data)
    print('total time:', (datetime.now() - start_at).total_seconds(), 'seconds')
    return cluster_labels

In [4]:
# This is a test for Champion clustering.
# This takes about 7 seconds to run.
file_path = r"C:\Users\cidzerda\Documents\GitHub\strevr-data\tesst\staycationyoutube1.mp4"
known_file_path = r"C:\Users\cidzerda\Documents\GitHub\strevr-data\champion\xednim.mp4"
def modify_frame(frame):
    return cv2.resize(frame, (0, 0), fx=.125, fy=.125, interpolation=cv2.INTER_AREA)
frames, data = construct_data(file_path, (184, 213, 259, 621), known_file_path=known_file_path, modify_frame=modify_frame)

5845.0 29.229851863110756
23522.0 29.97002997002997
True
(1950, 1377)
total time: 7.168676 seconds


In [5]:
last = (0, 0)
low, high = 1, len(data) // 2
while last != (low, high):
    last = (low, high)
    n = (low + high) // 2
    print(n)
    clusterer = hdbscan.HDBSCAN(min_cluster_size=n)
    cluster_labels = clusterer.fit_predict(data)
    print('cluster count:', len(np.unique(cluster_labels)))
    print(np.unique(cluster_labels))
    print(cluster_labels[:22])
    print(*enumerate(cluster_labels[:22]))
    print(*((k, len(list(v))) for k, v in it.groupby(np.sort(cluster_labels))))
    if len(np.unique(cluster_labels)) == 1:
        high = n - 1
    else:
        low = n

488
cluster count: 1
[-1]
[-1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1]
(0, -1) (1, -1) (2, -1) (3, -1) (4, -1) (5, -1) (6, -1) (7, -1) (8, -1) (9, -1) (10, -1) (11, -1) (12, -1) (13, -1) (14, -1) (15, -1) (16, -1) (17, -1) (18, -1) (19, -1) (20, -1) (21, -1)
(-1, 1950)
244
cluster count: 1
[-1]
[-1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1]
(0, -1) (1, -1) (2, -1) (3, -1) (4, -1) (5, -1) (6, -1) (7, -1) (8, -1) (9, -1) (10, -1) (11, -1) (12, -1) (13, -1) (14, -1) (15, -1) (16, -1) (17, -1) (18, -1) (19, -1) (20, -1) (21, -1)
(-1, 1950)
122
cluster count: 3
[-1  0  1]
[-1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1  0  0 -1 -1]
(0, -1) (1, -1) (2, -1) (3, -1) (4, -1) (5, -1) (6, -1) (7, -1) (8, -1) (9, -1) (10, -1) (11, -1) (12, -1) (13, -1) (14, -1) (15, -1) (16, -1) (17, -1) (18, 0) (19, 0) (20, -1) (21, -1)
(-1, 1345) (0, 481) (1, 124)
182
cluster count: 1
[-1]
[-1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1]
(0, -

In [9]:
n = 1
for i in range([i - 9 for i, k in enumerate(cluster_labels) if k == n][0], 9999):
    cv2.imshow('tesst', frames[i])
    if chr(cv2.waitKey(0)) == 'q':
        break
cv2.destroyAllWindows()

In [10]:
show_and_wait(np.vstack([np.hstack([frames[i] for i in it.islice((i for i, j in enumerate(cluster_labels) if j == n), 5)]) for n in np.unique(cluster_labels)]))
cv2.destroyAllWindows()