In [None]:
import cv2
import torch
import numpy as np
import matplotlib.pyplot as plt
from camera import CameraPoses

def create_point_cloud_file(points, colors, filename):
	colors = colors.reshape(-1,3)
	points = np.hstack([points.reshape(-1,3),colors])

	ply_header = '''ply
		format ascii 1.0
		element vertex %(vert_num)d
		property float x
		property float y
		property float z
		property uchar red
		property uchar green
		property uchar blue
		end_header
		'''
	with open(filename, 'w') as f:
		f.write(ply_header %dict(vert_num=len(points)))
		np.savetxt(f,points,'%f %f %f %d %d %d')

# Chargement du model
model_type = "DPT_Hybrid"
midas = torch.hub.load("intel-isl/MiDaS", model_type)

# Défini le périphérie pour l'excécution du model
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
midas.to(device)

# Chargement du preprocessing pour utiliser le model MiDaS
midas_transforms = torch.hub.load("intel-isl/MiDaS", "transforms")
transform = midas_transforms.dpt_transform

# Initialisation des variables (paramètrable)
intrinsic = np.array([
    [527.70703125,   0.        , 335.40994459],
    [  0.        , 525.81518555, 253.63149488],
    [  0.        ,   0.        ,   1.        ]]) # parametre d'une caméra d'un pc


def reconstruction(path_to_video, path_to_save_ply, intrinsic=intrinsic, skip_frames=50, log=False, fx=480):

    cap = cv2.VideoCapture(path_to_video) # Insérer l'url de la video à traiter (ex: './deer_vr_slow.mp4')
    ret, img = cap.read()
    start_pose = np.array([
        [   1.,    0.,    0.,    -img.shape[0]/2],
        [   0.,    1.,    0.,    -img.shape[1]/2],
        [   0.,    0.,    0.,                 fx]])
    cur_pose = start_pose

    hom_array = np.array([[0,0,1.0/90.0,0]])
    extrinsic = np.concatenate((cur_pose, hom_array), axis=0)

    first_frame = True
    prev_img = None
    cam = CameraPoses(intrinsic)

    frame = 0
    while cap.isOpened():
        ret, img = cap.read()
        if not ret:
            # La lecture est terminée, sortir de la boucle
            break

        frame += 1
        if log: print(frame) # Compter le nombre de frame actuel

        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        if not first_frame:
            q1, q2 = cam.get_matches(prev_img, img) # Détection des points communs avec l'image précédent
            if q1 is not None:
                if len(q1) > 20 and len(q2) > 20:
                    transf = cam.get_pose(q1, q2)   # Obtention de l'écart de position
                    cur_pose = cur_pose @ transf    # Mise à jour de la position
                    prev_img = img
            hom_array = np.array([[0,0,1/90,0]])
            extrinsic = np.concatenate((cur_pose,hom_array), axis=0)
            extrinsic = np.round(extrinsic, 8)
        else:
            prev_img = img

        # Skip pour ne pas trop chargarger de point cloud répété
        if frame % skip_frames != 1:
            continue

        img_trans = transform(img).to(device)

        # Prediction
        with torch.no_grad():
            prediction = midas(img_trans)
            prediction = torch.nn.functional.interpolate(
                prediction.unsqueeze(1),
                size=img.shape[:2],
                mode="bicubic",
                align_corners=False,
            ).squeeze()

        depth_image = prediction.cpu().numpy()

        depth_image = cv2.normalize(depth_image, None, 0, 1, norm_type=cv2.NORM_MINMAX, dtype=cv2.CV_32F)

        # Convertion en points cloud
        points_3D = cv2.reprojectImageTo3D(depth_image*200, extrinsic, handleMissingValues=False)
        points_3D = cv2.perspectiveTransform(points_3D.reshape(-1, 1, 3), extrinsic)
        points_3D = points_3D.reshape(img.shape)

        # Filtre
        filtre = depth_image > 0.1

        #Mask colors and points.
        points_cloud = points_3D[filtre]
        image_colors = img[filtre]

        if first_frame:
            points_cloud_global = points_cloud
            image_colors_global = image_colors
        else:
            points_cloud_global = np.concatenate((points_cloud_global, points_cloud), axis=0)
            image_colors_global = np.concatenate((image_colors_global, image_colors), axis=0)

        create_point_cloud_file(points_cloud_global, image_colors_global, path_to_save_ply)

        if log:
            fig, (ax1, ax2, ax3) = plt.subplots(1,3)
            ax1.imshow(img)
            ax2.imshow((depth_image*255).astype(np.uint8))
            depth_image[depth_image < 0.1] = 0
            ax3.imshow((depth_image*255).astype(np.uint8))
            plt.show()
            first_frame = False
            print(extrinsic)

            if cv2.waitKey(5) & 0xFF == 27:
                break