In [None]:
!git clone https://github.com/a3d2ddk/Prenith-and-Drew-Estimate-Camera-Poses.git

In [192]:
import scipy
import numpy as np
import cv2 as cv
import gradio as gr
import numpy as np
import pytransform3d.camera as pc
import pytransform3d.transformations as pt
import matplotlib.pyplot as plt
import threading, json, os, io

from PIL import Image, ImageDraw, ImageFont

In [193]:
@staticmethod
def get_homography(X, W):
    u = W[0,:]
    v = W[1,:]

    x = X[0,:]
    y = X[1,:]

    I = 4
    
    A = np.zeros((2*I, 9))
    j = 0
    for i in range(0,I):
        A[j,:]  = [0, 0, 0, -u[i], -v[i], -1, y[i]*u[i], y[i]*v[i], y[i]]
        A[j+1,:]  = [u[i], v[i], 1, 0, 0, 0, -x[i]*u[i], -x[i]*v[i], -x[i]]
        j = j + 2

    U, D, V = np.linalg.svd(A)

    phi_h = V.T[:,-1]

    phi_hr = scipy.optimize.minimize(obj_funct,x0=phi_h, args=(x, y, u, v))

    phi = np.reshape(phi_hr.x, (3,3))
    
    return phi


@staticmethod
def obj_funct(phi, x, y, u, v):
    I = x.shape[0]
    sum_squares = 0.0

    for i in range(0, I):
        d = phi[6] * u[i] + phi[7] * v[i] + phi[8]
        if d == 0:  # Avoid division by zero
            continue

        n1 = phi[0] * u[i] + phi[1] * v[i] + phi[2]
        x_model = n1 / d

        n2 = phi[3] * u[i] + phi[4] * v[i] + phi[5]
        y_model = n2 / d

        squared_norm = (x[i] - x_model) ** 2 + (y[i] - y_model) ** 2
        sum_squares += squared_norm

    return sum_squares

@staticmethod
def get_pose_hom(lam, dist, X, W):
    X = X.reshape(-1, 1, 2)
    X = cv.undistortPoints(X, lam, dist)
    X = X.reshape(-1, 2).T

    hom = get_homography(X, W)
    lam_inv = np.linalg.inv(lam)
    hom_ext = np.dot(lam_inv, hom)

    # Use the full hom_ext for SVD
    U, L, V = np.linalg.svd(hom_ext)
    rotation = U @ V  # This will give you a 3x3 rotation matrix

    # Validate rotation matrix
    if np.linalg.det(rotation) < 0:
        rotation[:, 2] *= -1

    # Calculate scale and translation
    scale = np.sum(hom_ext[:, 0:2] / rotation[:, 0:2]) / 6
    translation = hom_ext[:, 2] / scale

    return rotation, translation


@staticmethod
def get_pose_cv(lam, dist, X, W):
    X = X.reshape(-1, 2)
    W = W.reshape(-1, 3)

    _, rvec, translation = cv.solvePnP(W, X, lam, dist)

    rotation, _ = cv.Rodrigues(rvec)
    
    return rotation, translation


@staticmethod
def img_with_axis(img, lam, rot, tvec, dist):
    rvec, _ = cv.Rodrigues(rot)
            
    W = 2 * np.array([
        [0, 1, 0, 0],
        [0, 0, 1, 0],
        [0, 0, 0, 1]
    ], dtype=np.float64)

    image_axes, jac = cv.projectPoints(W, rvec, tvec, lam, dist)

    #print(image_axes)

    image_axes = image_axes.squeeze().T

    x0, y0 = image_axes[:,0].astype(int)
    cv.circle(img, (x0, y0), 9, (0, 0, 0), -1)

    x1, y1 = image_axes[:,1].astype(int)
    img = cv.arrowedLine(img, (x0, y0), (x1, y1), (255, 0, 0), 5)

    x2, y2 = image_axes[:,2].astype(int)
    img = cv.arrowedLine(img, (x0, y0), (x2, y2), (0, 255, 0), 5)

    x3, y3 = image_axes[:,3].astype(int)
    img = cv.arrowedLine(img, (x0, y0), (x0 + (x0 - x3), y0 + (y0 - y3)), (0, 0, 255), 5)

    pil = Image.fromarray(img)

    return pil

@staticmethod
def get_camera_pose_plot(lam, img, rot, tvec):
    """Get camera pose visualization"""
    
    # Prepare the figure for plotting
    fig = plt.figure()
    ax = fig.add_subplot(111, projection='3d')

    nRows, nCols, _ = img.shape

    sensor_size = np.array([nCols, nRows])
        
    # Create the camera pose matrix
    R = rot
    Rt = np.block([R.T, -R.T @ tvec.reshape(3, 1)])
    Rt = np.vstack([Rt, [0, 0, 0, 1]])

    ax = pt.plot_transform(A2B=Rt, s=2)

    # Plot the camera
    pc.plot_camera(ax, cam2world=Rt, M=lam, sensor_size=sensor_size, virtual_image_distance=0.8)

    # Set limits and view angle
    ax.set_xlim(-10, 10)
    ax.set_ylim(-10, 10)
    ax.set_zlim(-30, 30)
    ax.view_init(30, 70)
    plt.grid()

    fig = plt.gcf()

    image = fig2img(fig)
    
    return image
    

@staticmethod
def read_json(file_path):
    with open(file_path, 'r') as file:
        data = json.load(file)  # Load the JSON data

    # Convert the specific keys to NumPy arrays
    lam = np.array(data['lambda'])  # Convert the 'mtx' list to a NumPy array
    dist = np.array(data['distortion'])  # Convert the 'dist' list to a NumPy array

    return lam, dist  # Return both arrays


def fig2img(fig):
    buf = io.BytesIO()
    fig.savefig(buf)
    buf.seek(0)
    img = Image.open(buf)
    return img

    
class Clicks:
    @staticmethod
    def capture(points_list):
        return np.array(points_list, dtype=np.float32) if points_list else np.array([]).reshape(0, 2)
    
    @staticmethod
    def draw_numbered(image, points, radius=8):
        if isinstance(image, np.ndarray): image = Image.fromarray(image)
        img = image.copy()
        draw = ImageDraw.Draw(img)
        try: font = ImageFont.truetype("arial.ttf", 16)
        except: font = ImageFont.load_default()
        for i, (x, y) in enumerate(points):
            draw.ellipse([x-radius, y-radius, x+radius, y+radius], outline=(255, 0, 0), width=2, fill=(255, 255, 255))
            draw.text((x-5, y-8), str(i+1), fill=(0, 0, 0), font=font)
        return img

In [194]:
def create_pose_estimation_demo():
    # State
    state = {'points': [], 'image': None}
    
    def on_click(evt: gr.SelectData, img):
        if img is None: 
            return img, "No image", ""
    
        x, y = evt.index
        state['points'].append([int(x), int(y)])
    
        # Set the original image if it's not already set
        if state['image'] is None:
            state['image'] = img.copy()  # Store a copy of the original image
    
        # Draw points on the original image
        updated = Clicks.draw_numbered(state['image'], state['points'])  
        return np.array(updated), f"Selected {len(state['points'])} points", json.dumps(state['points'])

    def undo_last(img):
        if state['points']: 
            state['points'].pop()  # Remove the last point
        if img is not None:
            updated = Clicks.draw_numbered(state['image'], state['points'])  # Redraw the image with the updated points
            return np.array(updated), f"Selected {len(state['points'])} points", json.dumps(state['points'])
        return img, "No points selected", "[]"

    def clear_points(img):
        state['points'].clear()  # Clear all points
        if img is not None:
            updated = Clicks.draw_numbered(state['image'], [])  # Redraw the image with no points
            return np.array(updated), "Cleared", "[]"
        return img, "No points selected", "[]"

    def estimate_poses():
        img = state['image']
        
        REF_POINTS = np.array([[0, 0, 0], [5, 0, 0], [5, 8, 0], [0, 8, 0]], dtype=np.float32).T
        img_points = np.array(state['points'], dtype=np.float32)

        lam, dist = read_json('calibration.json')

        rh, th = get_pose_hom(lam, dist, img_points, REF_POINTS)
        rc, tc = get_pose_cv(lam, dist, img_points, REF_POINTS)

        axes = img_with_axis(img, lam, rc, tc, dist)

        pose = get_camera_pose_plot(lam, img, rc, tc)

        tc = tc.T
        homo_str = "R: " + np.array_str(rh) + "\n\nT: " + np.array_str(th)
        cv_str = "R: " + np.array_str(rc) + "\n\nT: " + np.array_str(tc)

        #print(rh)
        #print(rc)
        return homo_str, cv_str, axes, pose
    
    with gr.Blocks(title="Pose Estimation Demo") as demo:
        gr.Markdown("# Pose Estimation from 2D-3D Correspondences")
        
        with gr.Row():
            with gr.Column(scale=2):
                img = gr.Image(label="Upload Image (click to add points)", type="numpy", interactive=True)
                with gr.Row():
                    undo_btn = gr.Button("Undo", size="sm")
                    clear_btn = gr.Button("Clear", size="sm")
                status = gr.Textbox(label="Status", interactive=False)
                points_json = gr.Textbox(label="Selected Points",value="[]", interactive=False)
        
        with gr.Row():
            estimate_btn = gr.Button("Process Points", variant="primary", size="lg")
        
        with gr.Row():
            with gr.Column():
                homo_result = gr.Textbox(label="Homography to Pose", lines=6, interactive=False)
                cv_result = gr.Textbox(label="OpenCV solvePnP", lines=6, interactive=False)
            with gr.Column():
                axes_img = gr.Image(label="Coordinate Axes Overlay", interactive=False)
                plot_3d = gr.Image(label="3D Camera Pose", interactive=False)          
        
        # Event handlers
        img.select(on_click, inputs=[img], outputs=[img, status, points_json])
        undo_btn.click(lambda img: undo_last(img), inputs=[img], outputs=[img, status, points_json])
        clear_btn.click(lambda img: clear_points(img), inputs=[img], outputs=[img, status, points_json])
        estimate_btn.click(estimate_poses, outputs=[homo_result, cv_result, axes_img, plot_3d])

    return demo

In [195]:
# Launch the demo
demo = create_pose_estimation_demo()
demo.launch(debug=False, share=True)

* Running on local URL:  http://127.0.0.1:7930
* Running on public URL: https://85203c927f3d62752b.gradio.live

This share link expires in 1 week. For free permanent hosting and GPU upgrades, run `gradio deploy` from the terminal in the working directory to deploy to Hugging Face Spaces (https://huggingface.co/spaces)


