In [None]:
from  IPython.display import clear_output
import matplotlib.pyplot as plt
import json
import random
import os
import cv2
import numpy as np
from sklearn.cluster import KMeans
import warnings
warnings.filterwarnings('ignore')

In [None]:
class generate_triplets:
    def __init__(self, path_json, path_videos, path_bar_trajectory, save_path):
        self.path_json = path_json
        self.path_videos = path_videos
        self.path_bar_trajectory = path_bar_trajectory
        self.save_path = save_path

    def visualize(self, **images): 
        n = len(images)
        plt.figure(figsize=(16, 5))
        for i, (name, image) in enumerate(images.items()):
            plt.subplot(1, n, i + 1)
            plt.xticks([])
            plt.yticks([])
            plt.title(' '.join(name.split('_')).title())
            plt.imshow(image)
        plt.show()

    def plot_sync_videos(self, path, cap1, cap2, name1 = "", name2 = "",save_video = False):
        width = int(cap1.get(cv2.CAP_PROP_FRAME_WIDTH))
        height = int(cap1.get(cv2.CAP_PROP_FRAME_HEIGHT))
        if save_video:
            fourcc = cv2.VideoWriter_fourcc(*'mp4v')
            fps = 30
            out = cv2.VideoWriter("./"+name1+"VS"+name2+".mp4", fourcc, fps, (width*2,height))

        for i,j in path:
            cap1.set(cv2.CAP_PROP_POS_FRAMES, i)
            ret1, frame1 = cap1.read()
            cap2.set(cv2.CAP_PROP_POS_FRAMES, j)
            ret2, frame2 = cap2.read()
            if not ret1:
                frame1 = np.zeros_like(frame2)
            if not ret2:
                frame2 = np.zeros_like(frame1)
            frame2 = cv2.resize(frame2,(width, height)) 
            frame = cv2.resize(frame1,(width, height)) 
            frame = np.hstack((frame1, frame2))
            clear_output(wait=True)
            plt.imshow(frame)
            plt.show()
            if save_video: out.write(frame)
        if save_video: out.release()

    def get_frame(self, cap, i):
        cap.set(cv2.CAP_PROP_POS_FRAMES, i)
        res, frame = cap.read()
        assert res
        img1 = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
        return img1

    def normalizar_posiciones_y(self, posiciones_y):
        pos_min = min(posiciones_y)
        pos_max = max(posiciones_y)
        amplitud_vertical = pos_max - pos_min
        if amplitud_vertical == 0:
            factor_escala = 1
        else:
            factor_escala = 1 / amplitud_vertical
        posiciones_y_normalizadas = np.array([[(pos - pos_min) * factor_escala] for pos in posiciones_y])
        return posiciones_y_normalizadas

    def get_trajectory(self, data, add_negative = False):
        vertical_trajectory = np.zeros(len(data))
        trayectoria_normalizada = []
        negativos = []
        for cnt, _ in enumerate(data):
            if len(data[cnt][0]) > 0:
                x1,y1,x2,y2,_ = data[cnt][0][0]
                center_x, center_y = int(x1 + (x2-x1)/2) , int(y1 + (y2 - y1)/2)
                vertical_trajectory[cnt] = center_y
            else:
                indices = np.nonzero(vertical_trajectory)[0]
                if len(indices) > 0:
                    j = indices[indices < cnt][-1]
                    value = vertical_trajectory[j]
                    vertical_trajectory[cnt] = value
                negativos.append(cnt)

        trayectoria_normalizada = self.normalizar_posiciones_y(vertical_trajectory)
        if add_negative:
            for x in negativos:
                trayectoria_normalizada[x] = -1
        return trayectoria_normalizada

    def k_means_on_trajectory(self, i_frame, trajectory, k=4, max = False):
        trajectory = np.array(trajectory)
        kmedias = KMeans(n_clusters=k, random_state=0)
        kmedias.fit(trajectory)
        etiquetas = kmedias.labels_
        positive_cluster = etiquetas[i_frame]
        indices_cluster_diferente = np.where(etiquetas != positive_cluster)[0]
        frame_negativo_index = np.random.choice(indices_cluster_diferente)
        if max:
            indices_cluster_diferente = np.where(etiquetas != positive_cluster)[0]
            positive_trajectory = trajectory[i_frame]
            distancias = np.linalg.norm(trajectory[indices_cluster_diferente] - positive_trajectory, axis=1)
            frame_negativo_index = indices_cluster_diferente[np.argmax(distancias)]
        return frame_negativo_index

    def dividir_y_seleccionar_valores(self,arr):
        n = len(arr)
        segmentos = np.linspace(0, n, num=5, dtype=int)
        rangos = [(segmentos[i], segmentos[i+1]) for i in range(4)]
        
        valores_seleccionados = []
        for rango in rangos:
            inicio, fin = rango
            valores_seleccionados.append(np.random.choice(arr[inicio:fin]))
        
        return valores_seleccionados

    def save_triplet(self, images, name1, name2, cnt):
        for key, image in images.items():
            save_filename = f"{name1}+{name2}+{key}_{cnt}.jpg"
            save_filepath = os.path.join(self.save_path+key+"/", save_filename)
            image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
            image_rgb = cv2.resize(image_rgb, (320, 320))
            cv2.imwrite(save_filepath, image_rgb)

    def get_triplets(self, plot_triplet = False, plot_video = False):
        files = [f for f in os.listdir(self.path_json)]
        files = sorted(files)
        for j, file in enumerate(files):
            name = file.split("-")[0]            
            path = os.path.join(self.path_json, file)
            with open(path) as f:
                data = json.load(f)
            best = data

            path_trajectory = os.path.join(self.path_bar_trajectory, name+".json")
            with open(path_trajectory) as f:
                data = json.load(f)
            trajectory_anchor = self.get_trajectory(data, add_negative = True)
            cnt = 0
            for video_data in best:
                another_name = video_data[0][0]
                path = video_data[0][1]
                cap1 = cv2.VideoCapture(self.path_videos+name+'.mp4')
                cap2 = cv2.VideoCapture(self.path_videos+another_name+'.mp4')  
                if plot_video:
                    self.plot_sync_videos(path, cap1, cap2)
                path_trajectory = os.path.join(self.path_bar_trajectory, another_name+".json")
                with open(path_trajectory) as f:
                    data = json.load(f)
                trajectory_positive = self.get_trajectory(data, add_negative = True)
                trajectory_ = self.get_trajectory(data)
                arreglo = np.arange(len(path)) 
                valores_seleccionados = self.dividir_y_seleccionar_valores(arreglo)
                for value in valores_seleccionados:
                    x, y = path[value]
                    if trajectory_anchor[x] == -1 or trajectory_positive[y] == -1:
                        indices_validos_A = [i for i, valor in enumerate(trajectory_anchor) if valor != -1]
                        indices_validos_B = [i for i, valor in enumerate(trajectory_positive) if valor != -1]
                        if len(indices_validos_A) and len(indices_validos_B):
                            for i, (a,b) in enumerate(path):
                                if a in indices_validos_A and b in indices_validos_B:
                                    x, y = path[i] 
                                    break
                        else:
                            continue
                    z = self.k_means_on_trajectory(y,trajectory_)
                    images = {'Anchor':self.get_frame(cap1,x),'Positive':self.get_frame(cap2,y),'Negative':self.get_frame(cap2,z)}
                    self.save_triplet(images, name, another_name, cnt)
                    if plot_triplet:
                        self.visualize(**images)
                    cnt+=1
                
            print("El video {} {} ha sido procesado".format(name,j))

In [None]:
path = "./alinged_videos_OHP_top/"
path_videos = "./videos/"
path_bar_trajectory = "./bar_trajectories_raw/"
save_path = "./ssl_triplets_dataset_ohp/"

g = generate_triplets(path, path_videos, path_bar_trajectory, save_path)
g.get_triplets()