In [None]:
import sys
!{sys.executable} -m pip install -U deepposekit

In [None]:
import cv2
import h5py
import matplotlib.pyplot as plt
from deepposekit.io import VideoReader, DataGenerator, initialize_dataset
from deepposekit.annotate import KMeansSampler
import tqdm
import glob
import pandas as pd

from os.path import expanduser

try:
    import google.colab
    IN_COLAB = True
except:
    IN_COLAB = False

HOME = expanduser("~") if not IN_COLAB else '.'
HOME = HOME +'/Documents/RamanLab'

In [None]:
videos = sorted(glob.glob(HOME + '/Data/*.mp4'))
videos

In [None]:
reader = VideoReader(HOME + '/Data/crop.mp4', gray=True)
frame = reader[0] # read a frame
reader.close()
frame.shape

In [None]:
plt.figure(figsize=(5,5))
plt.imshow(frame[0,...,0])
plt.show()

In [None]:
reader = VideoReader(HOME + '/Data/crop.mp4', batch_size=100, gray=True)

randomly_sampled_frames = []
for idx in tqdm.tqdm(range(len(reader)-1)):
    batch = reader[idx]
    random_sample = batch[np.random.choice(batch.shape[0], 10, replace=False)]
    randomly_sampled_frames.append(random_sample)
reader.close()

randomly_sampled_frames = np.concatenate(randomly_sampled_frames)
randomly_sampled_frames.shape

In [None]:
kmeans = KMeansSampler(n_clusters=10, max_iter=1000, n_init=10, batch_size=100, verbose=True)
kmeans.fit(randomly_sampled_frames)

In [None]:
kmeans.plot_centers(n_rows=2)
plt.show()

In [None]:
kmeans_sampled_frames, kmeans_cluster_labels = kmeans.sample_data(randomly_sampled_frames, n_samples_per_label=10)
kmeans_sampled_frames.shape

In [None]:
skeleton = pd.read_csv(HOME + '/deepposekit-data/datasets/fly/skeleton.csv')
skeleton

In [None]:
initialize_dataset(
    images=kmeans_sampled_frames,
    datapath=HOME + '/deepposekit-data/datasets/locust/example_annotation_set.h5',
    skeleton=HOME + '/deepposekit-data/datasets/locust/skeleton.csv',
    # overwrite=True # This overwrites the existing datapath
)

In [None]:
data_generator = DataGenerator(HOME + '/deepposekit-data/datasets/locust/example_annotation_set.h5', mode="full")

image, keypoints = data_generator[0]

plt.figure(figsize=(5,5))
image = image[0] if image.shape[-1] is 3 else image[0, ..., 0]
cmap = None if image.shape[-1] is 3 else 'gray'
plt.imshow(image, cmap=cmap, interpolation='none')
for idx, jdx in enumerate(data_generator.graph):
    if jdx > -1:
        plt.plot(
            [keypoints[0, idx, 0], keypoints[0, jdx, 0]],
            [keypoints[0, idx, 1], keypoints[0, jdx, 1]],
            'r-'
        )
plt.scatter(keypoints[0, :, 0], keypoints[0, :, 1], c=np.arange(data_generator.keypoints_shape[0]), s=50, cmap=plt.cm.hsv, zorder=3)

plt.show()