# Applying OpenSfM to find the depth for points in an image

OpenSfM is able to fully construct a point cloud for a collection of images that captures 3D information on the scene. We want to leverage this point cloud to create a 3D parametric representation. As such, we need to be able to go from detected wireframe features to 3D points in the point cloud.

The general idea will be to attempt to project out from a detected line endpoint/junction in an image to an approximate location in the depth map found by OpenSfM. We can attempt to "interpolate" depth by using the depth values of points that project close to the given point, and assuming they should be coplanar or just averaging the depths.

We should investigate how to access the resulting `merged.ply` point cloud data as well as the camera poses in order to accomplish the above.

## Follow the resources in SfM.ipynb for setup and running the OpenSfM pipeline

For this, using the OpenSfM `opensfm_run_all` executable should be good enough to generate your data.

In [None]:
import sys, os
import utils
import numpy as np
import yaml, json
import cv2
from plyfile import PlyData
import matplotlib.pyplot as plt

sys.path.append('../OpenSfM')
from opensfm import features, config
from sfm import utils as sfm_util

from scipy.spatial.transform import Rotation

In [None]:
# Calibration parameters created using the steps in BasicOpenCV.ipynb
# NOTE: Camera calibration parameters will not work with video frames!
intrinsic_mat = np.load(utils.data("numpy/intrinsic_mat.npy"))
distortion_mat = np.load(utils.data("numpy/distortion_mat.npy"))
# Average the two focal lengths to get a best guess focal length
f = (intrinsic_mat[0, 0] + intrinsic_mat[1, 1]) / (2.0 * 1920.0)
# k1 and k2 are first two parameters of the distortion matrix
k1 = distortion_mat[0, 0]
k2 = distortion_mat[0, 1]

In [None]:
print("{}\n{}\n{}".format(f, k1, k2))

In [None]:
# IMPORTANT: Set the project directory path here
project_dir = utils.data("door_closed/")


conf = config.load_config(os.path.join(project_dir, "config.yaml"))
depthmaps_dir = os.path.join(project_dir, "undistorted/depthmaps")
with open(os.path.join(project_dir, "reports/reconstruction.json")) as f:
    reconstruction_report = json.load(f)
image_dir = os.path.join(project_dir, "images")

# The merged.ply contains the depth information for points in the images (probably redundant with reconstruction.meshed.json)
numpy_merged_points = os.path.join(project_dir, "merged_points.npy")
try:
    points = np.load(numpy_merged_points)
except FileNotFoundError:
    merged_ply = PlyData.read(os.path.join(depthmaps_dir, "merged.ply"))
    element = merged_ply.elements[0]
    points = np.vstack((element.data['x'], element.data['y'], element.data['z'])).transpose()
    np.save(numpy_merged_points, points)
    print("Created numpy file for merged.ply points")


# reconstruction.meshed.json contain the rotations and translations of each camera, along with the mesh points in the image
with open(os.path.join(project_dir, "reconstruction.meshed.json")) as f:
    reconstruction_meshed = json.load(f)

In [None]:
print(reconstruction_report["not_reconstructed_images"])

other_list = []

for imname in os.listdir(image_dir):
    print("Processing {}...".format(imname))

    prunedname = imname + ".pruned.npz"
    try:
        pruned = np.load(os.path.join(depthmaps_dir, prunedname))
    except FileNotFoundError:
        print("Skipping {}: No depthmap found".format(imname))
        continue
    #im = cv2.imread(os.path.join(image_dir, imnamejpg))
    points = pruned["points"]
    if points.shape[0] == 0:
        print("Skipping {}: No points in pruned mesh".format(imname))
        continue
    other_list.append(imname)

In [None]:
from wireframe import Wireframe, WireframeGraph

# Make sure to put your pretrained model data in the data directory!
config_file = utils.data("wireframe.yaml")
model_file = utils.data("pretrained_lcnn.pth.tar")

w = Wireframe(config_file, model_file, "")

if not w.setup():
    print("An error occured trying to setup the wireframe: {}".format(w.error))

In [None]:
def get_K_dist(camera):
    camera_name = next(iter(camera.keys()))
    width = camera_params_dict[camera_name]["width"]
    height = camera_params_dict[camera_name]["height"]
    focal = camera_params_dict[camera_name]["focal"]
    k1 = camera_params_dict[camera_name]["k1"]
    k2 = camera_params_dict[camera_name]["k2"]

    K = np.array([[focal * max(width, height), 0, 0.5 * (width - 1)],
                  [0, focal * max(width, height), 0.5 * (height - 1)],
                  [0, 0, 1]])
    distortion = np.array([k1, k2, 0, 0, 0])
    return K, distortion

# Function to reproject a point in the point cloud to a point in the original image
def project_point(points, R, T, K, distortion):
    points, _ = cv2.projectPoints(points, R, T, K, distortion)
    return points

In [None]:
def is_point_near_line(p, l, dist=20.0):
    """
    Returns true if point is close to line and doesn't extend past the endpoints of the line.
    """
    start = l[0, ::-1]
    end = l[1, ::-1]
    perp_dir = (np.array([[0, -1], [1, 0]]) @ (end - start)) / np.linalg.norm(end - start)
    p = np.squeeze(p)
    d = np.dot(perp_dir, p - start)
    if np.abs(d) > dist:
        return False
    t = np.dot(end - start, p - start)
    if t < 0 or t > np.linalg.norm(end - start) ** 2:
        return False
    return True

In [None]:
def write_points_to_ply(filepath, points, c=np.array([255, 255, 255]), r=50.0):
    with open(filepath, "w") as f:
        print("ply", file=f)
        print("format ascii 1.0", file=f)
        print("element vertex {}".format(len(points)), file=f)
        print("property float x", file=f)
        print("property float y", file=f)
        print("property float z", file=f)
        print("property uchar red", file=f)
        print("property uchar green", file=f)
        print("property uchar blue", file=f)
        print("property float radius", file=f)
        print("end_header", file=f)
        for x, y, z in points:
            print("{} {} {} {} {} {} {}".format(x, y, z, c[0], c[1], c[2], r), file=f)


In [None]:
def process_image(imname, info, camera, debug=True):
    """
    Serves to run on the 3D information given by the reconstruction meshed json file.
    
    Arguments:
    imname -- string image name to process (original file found in images dir)
    info -- dictionary retrieved from reconstructed_meshed
    """
    print("Processing {}...".format(imname))
    impath = str(os.path.join(image_dir, imname))
    im = cv2.imread(impath)
    points = np.array(info["vertices"])
    rotation = np.array(info["rotation"])
    translation = np.array(info["translation"])
    K, distortion = get_K_dist(camera)
    if debug:
        print("Rotation:\n{}\nTranslation:\n{}".format(rotation,translation))
        print("Depth map consists of {} points".format(points.shape[0]))
    
    impoints = project_point(points, rotation, translation, K, distortion)
    if debug:
        for p in impoints:
            p = np.squeeze(p)
            if p[0] < 0.0 or p[0] > width:
                print("Width out of bounds: {}".format(p))
            elif p[1] < 0.0 or p[1] > height:
                print("Height out of bounds: {}".format(p))
            else:
                plt.scatter([p[0]], [p[1]])
        plt.imshow(im)
        plt.show()

    rec = w.parse(impath)
    nlines, nscores = rec.postprocess(threshold=0.9)
    if debug:
        print("Wireframe found {} lines with score passing threshold".format(nlines.shape[0]))
        graph = WireframeGraph(rec, threshold=0.9)
        graph.plot_graph(graph.g, im)
    
    for l_num, l in enumerate(nlines):
        print("Working on line num {}".format(l_num))
        close_points = []
        world_points = []
        for i, p in enumerate(impoints):
            p = np.squeeze(p)
            if is_point_near_line(p, l, dist=20.0):
                close_points.append(p)
                world_points.append(points[i])
        if debug:
            plt.plot([l[0, 1], l[1, 1]], [l[0, 0], l[1, 0]], c='r')
            for p in close_points:
                plt.scatter([p[0]], [p[1]])
            plt.imshow(im)
            plt.show()
        output_ply_file = imname + ".line_{}.ply".format(l_num)
        res_ply_dir = os.path.join(project_dir, "my_ply/{}_lines/".format(imname[:-3]))
        os.makedirs(res_ply_dir, exist_ok=True)
        write_points_to_ply(os.path.join(res_ply_dir, output_ply_file), world_points, c=np.array([255, 0, 0]))
    

In [None]:
camera = reconstruction_meshed[0]['cameras']
for r in reconstruction_meshed:
    for imname in r['shots'].keys():
        process_image(imname, r['shots'][imname], r['cameras'])