# Hourglass Network
## For keypoints extraction

In [None]:
import os
import cv2
import numpy as np


### Path to data

In [None]:
IMAGES_PATH = "YCB-Video_data/keypoints/0010_gt_keypoints"
IMAGES = [f"{IMAGES_PATH}/{f}" for f in os.listdir(IMAGES_PATH)]
sample = cv2.imread(IMAGES[0])
DIM = sample.shape

KEYPOINTS_PATH = "YCB-Video_data/keypoints/0010_gt_keypoints2d.npy"
KEYPOINTS = np.load(KEYPOINTS_PATH)
# cv2.imshow("", sample)
# cv2.waitKey()
# cv2.destroyAllWindows()


In [None]:
def batchgen(images: list[str], keypoints: np.ndarray, dataset_size: int, batch_size: int) -> tuple[np.ndarray, np.ndarray]:
    for i in range(0, dataset_size, batch_size):
        x_batch = np.array([cv2.imread(img) for img in images[i:i+batch_size]])
        y_batch = keypoints[i:i+batch_size]
        yield x_batch, y_batch


In [None]:
from hourglass import create_hourglass_network


class HourglassNetwork:
    def __init__(self,
                 num_classes: int,
                 num_stacks: int,
                 num_filters: int,
                 inres: tuple[int, int],
                 outers: tuple[int, int]) -> None:
        self.inres = inres
        self.outres = outers
        self.model = create_hourglass_network(num_classes,
                                              num_stacks,
                                              num_filters,
                                              inres)

    def summary(self) -> None:
        self.model.summary()

    def fit(self, data_generator: int, dataset_size: int, batch_size: int, epochs: int) -> None:
        self.model.fit(data_generator,
                       steps_per_epoch=dataset_size//batch_size,
                       epochs=epochs)


In [None]:
num_classes = 25
num_stacks = 8
num_filters = 256

net = HourglassNetwork(num_classes, num_stacks,
                       num_filters, (480, 480), (25, 25))
# net.summary()
net.fit(data_generator=batchgen(IMAGES, KEYPOINTS, dataset_size=KEYPOINTS.shape[0], batch_size=8),
        dataset_size=len(IMAGES),
        batch_size=8,
        epochs=10)
