In [1]:
%matplotlib inline

import cv2
import functools as ft
import itertools as it
import matplotlib.pyplot as plt
import mxnet as mx
import numpy as np
import operator as op
import pickle
import random
import sys
import hdbscan
from datetime import datetime, timedelta
from sklearn.preprocessing import StandardScaler

def show_and_wait(frame, title='tesst', wait_time=0):
    cv2.imshow(title, frame)
    key = cv2.waitKey(wait_time)
    return None if key < 0 else chr(key)

def read_and_wait(video, title='tesst', wait_time=0):
    result, frame = video.read()
    if result:
        return show_and_wait(frame, title, wait_time)

def pairwise(iterable):
    "s -> (s0,s1), (s1,s2), (s2, s3), ..."
    a, b = it.tee(iterable)
    next(b, None)
    return zip(a, b)

def construct_data(file_path, extent, known_file_path=None, known_frame=-1, modify_frame=None):
    start_at = datetime.now()
    video = cv2.VideoCapture(file_path, cv2.CAP_MSMF)
    print(video.get(cv2.CAP_PROP_FRAME_COUNT), video.get(cv2.CAP_PROP_FPS))
    n = min(max(1, round(video.get(cv2.CAP_PROP_FRAME_COUNT) / 2000)), 30)
    top, left, bottom, right = extent
    def fn(frame):
        frame = frame[top:bottom, left:right]
        return modify_frame(frame) if modify_frame else frame
    def yield_frames():
        while True:
            result, frame = video.read()
            if not result:
                break
            yield frame
    frames = [fn(f) for i, f in enumerate(yield_frames()) if i % n == 0]
    video.release()

    if known_file_path:
        # Add a known positive frame.
        video = cv2.VideoCapture(known_file_path, cv2.CAP_MSMF)
        print(video.get(cv2.CAP_PROP_FRAME_COUNT), video.get(cv2.CAP_PROP_FPS))
        print(video.set(cv2.CAP_PROP_POS_FRAMES, video.get(cv2.CAP_PROP_FRAME_COUNT) if known_frame < 0 else known_frame))
        _, frame = video.read()
        frames.append(fn(frame))
        video.release()

    data = StandardScaler().fit_transform(np.stack([f.reshape(-1) for f in frames]))
    print(data.shape)
    print('total time:', (datetime.now() - start_at).total_seconds(), 'seconds')
    return frames, data

def print_clusters(frames, cluster_labels, has_known_cluster=False):
    print('cluster count:', len(np.unique(cluster_labels)))
    print(np.unique(cluster_labels))
    print(cluster_labels[:22])
    print(*enumerate(cluster_labels[:22]))
    print(*((k, len(list(v))) for k, v in it.groupby(np.sort(cluster_labels))))
    if has_known_cluster:
        if cluster_labels[-1] < 0:
            print('desired cluster unknown')
        else:
            print('desired cluster:', cluster_labels[-1])
    def fn():
        for n in np.unique(cluster_labels):
            g = (i for i, j in enumerate(cluster_labels) if j == n)
            g = it.islice(g, 5)
            stack = [frames[i] for i in g]
            yield np.hstack(stack)
    plt.imshow(np.vstack(list(fn())).take([2,1,0], axis=2))

def predict_cluster_labels(data, min_cluster_size=22):
    start_at = datetime.now()
    clusterer = hdbscan.HDBSCAN(min_cluster_size=min_cluster_size)
    cluster_labels = clusterer.fit_predict(data)
    print('total time:', (datetime.now() - start_at).total_seconds(), 'seconds')
    return cluster_labels

In [11]:
video_file_path = r"C:\Users\cidzerda\Documents\GitHub\strevr-data\unsupervised\unsupervised.mp4"
video = cv2.VideoCapture(video_file_path)
print(video.get(cv2.CAP_PROP_FRAME_COUNT))
video.release()
cv2.destroyAllWindows()

196367.0


In [12]:
pickle_file_path = r'C:\Users\cidzerda\Documents\GitHub\strevr-data\unsupervised.pickle'
label_file_path = r'C:\Users\cidzerda\Documents\GitHub\strevr-data\unsupervised\unsupervised.txt'
with open(pickle_file_path, 'rb') as fin:
    frames = pickle.load(fin)
    clusters = pickle.load(fin)
with open(label_file_path, 'rt') as fin:
    labels = eval(fin.read())
print(len(frames), frames[0].shape, len(clusters), clusters.shape, len(labels))
print(labels)

6547 (30, 35, 3) 6547 (6547,) 292
[(0, 'z'), (114, 'a'), (116, 'z'), (155, 'a'), (220, 'z'), (221, 'a'), (294, 'z'), (304, 'a'), (308, 'z'), (310, 'a'), (334, 'z'), (336, 'a'), (360, 'z'), (361, 'a'), (472, 'z'), (475, 'a'), (520, 'z'), (522, 'a'), (546, 'z'), (547, 'a'), (643, 'z'), (714, 'a'), (716, 'z'), (755, 'a'), (759, 'z'), (783, 'a'), (822, 'z'), (824, 'a'), (826, 'z'), (828, 'a'), (841, 'z'), (842, 'a'), (877, 'z'), (879, 'a'), (903, 'z'), (906, 'a'), (957, 'z'), (959, 'a'), (960, 'z'), (962, 'a'), (964, 'c'), (995, 'z'), (1004, 'c'), (1010, 'z'), (1013, 'c'), (1018, 'z'), (1020, 'c'), (1074, 'z'), (1080, 'c'), (1090, 'z'), (1093, 'c'), (1165, 'z'), (1169, 'c'), (1182, 'z'), (1183, 'c'), (1210, 'z'), (1213, 'c'), (1220, 'z'), (1222, 'c'), (1225, 'z'), (1229, 'c'), (1483, 'z'), (1624, 'b'), (1669, 'z'), (1670, 'b'), (1695, 'z'), (1696, 'b'), (1708, 'z'), (1710, 'b'), (1723, 'z'), (1724, 'b'), (1764, 'z'), (1769, 'b'), (1771, 'z'), (1774, 'b'), (1823, 'z'), (1836, 'b'), (1841, '

In [31]:
print(clusters[116])
show_and_wait(frames[116])
cv2.destroyAllWindows()

19


In [None]:
g = pairwise(labels)
for a, b in g:
    a, label = a
    b, _ = b
    for i in range(a, b):
        show_and_wait(frames[i], wait_time=25)
    print(label)
    if chr(cv2.waitKey(0)) == 'q':
        cv2.destroyAllWindows()
        break