In [None]:
# installs

In [1]:
import gradio as gr
import numpy as np
from PIL import Image, ImageDraw, ImageFont
import json
import matplotlib.pyplot as plt
import io

# ============================================================================
# STATELESS HELPER CLASSES - IMPLEMENT THE CORE FUNCTIONS
# ============================================================================

class Calib:
    @staticmethod
    def load_K(json_data):
        data = json.loads(json_data) if isinstance(json_data, str) else json_data
        return np.array(data.get('K', np.eye(3)), dtype=np.float32), np.array(data.get('distCoeffs', [0]*5), dtype=np.float32)

class Model:
    @staticmethod
    def load_points(data):
        points = json.loads(data) if isinstance(data, str) else data
        points = np.array(points, dtype=np.float32)
        return points[:, :2] if points.shape[1] == 3 else points
    
    @staticmethod
    def load_points_3d(data):
        return np.array(json.loads(data) if isinstance(data, str) else data, dtype=np.float32)

class Clicks:
    @staticmethod
    def capture(points_list):
        return np.array(points_list, dtype=np.float32) if points_list else np.array([]).reshape(0, 2)
    
    @staticmethod
    def draw_numbered(image, points, radius=8):
        if isinstance(image, np.ndarray): image = Image.fromarray(image)
        img = image.copy()
        draw = ImageDraw.Draw(img)
        try: font = ImageFont.truetype("arial.ttf", 16)
        except: font = ImageFont.load_default()
        for i, (x, y) in enumerate(points):
            draw.ellipse([x-radius, y-radius, x+radius, y+radius], outline=(255, 0, 0), width=2, fill=(255, 255, 255))
            draw.text((x-5, y-8), str(i+1), fill=(0, 0, 0), font=font)
        return img

class Pose:
    @staticmethod
    def from_homography(K, model_2D, image_2D):
        """IMPLEMENT: Homography->Pose decomposition"""
        return np.eye(3, dtype=np.float32), np.zeros((3, 1), dtype=np.float32)  # Placeholder
    
    @staticmethod
    def from_opencv(K, distCoeffs, model_3D, image_2D):
        """IMPLEMENT: OpenCV solvePnP"""
        return np.eye(3, dtype=np.float32), np.zeros((3, 1), dtype=np.float32)  # Placeholder

class Viz:
    @staticmethod
    def overlay_axes(image, K, distCoeffs, R, t):
        """IMPLEMENT: Draw coordinate axes on image"""
        return Image.fromarray(image) if isinstance(image, np.ndarray) else image  # Placeholder
    
    @staticmethod
    def overlay_reprojection(image, model_3D, image_2D, K, distCoeffs, R, t):
        """IMPLEMENT: Show reprojection overlay"""
        if isinstance(image, np.ndarray): image = Image.fromarray(image)
        img = image.copy()
        draw = ImageDraw.Draw(img)
        for x, y in image_2D:
            draw.ellipse([x-5, y-5, x+5, y+5], outline=(255, 0, 0), width=2)
        return img
    
    @staticmethod
    def plot_camera_3d(R, t):
        """IMPLEMENT: 3D camera pose visualization"""
        fig = plt.figure(figsize=(6, 5))
        ax = fig.add_subplot(111, projection='3d')
        ax.scatter([0], [0], [0], color='black', s=100, label='World Origin')
        cam_pos = -R.T @ t
        ax.scatter(cam_pos[0], cam_pos[1], cam_pos[2], color='red', s=100, label='Camera')
        ax.set_xlabel('X'); ax.set_ylabel('Y'); ax.set_zlabel('Z'); ax.set_title('Camera Pose'); ax.legend()
        buf = io.BytesIO()
        plt.savefig(buf, format='png', dpi=100, bbox_inches='tight')
        buf.seek(0)
        img = Image.open(buf)
        plt.close(fig)
        return img




In [2]:
# ============================================================================
# SINGLE-CELL GRADIO DEMO
# ============================================================================

def create_pose_estimation_demo():
    # State
    state = {'points': [], 'model_2d': None, 'model_3d': None, 'K': None, 'dist': None, 'image': None}
    
    def on_click(evt: gr.SelectData, img):
        if img is None: return img, "No image", ""
        x, y = evt.index
        state['points'].append([int(x), int(y)])
        state['image'] = img
        updated = Clicks.draw_numbered(img, state['points'])
        return np.array(updated), f"Selected {len(state['points'])} points", json.dumps(state['points'], indent=2)
    
    def undo_last(img):
        if state['points']: state['points'].pop()
        if img is not None and state['points']:
            updated = Clicks.draw_numbered(img, state['points'])
            return np.array(updated), f"Selected {len(state['points'])} points", json.dumps(state['points'], indent=2)
        return img, "No points selected", "[]"
    
    def clear_points(img):
        state['points'].clear()
        return img, "Cleared", "[]"
    
    def load_intrinsics(text):
        try:
            state['K'], state['dist'] = Calib.load_K(text)
            return f"Loaded K: {state['K'].shape}, dist: {state['dist'].shape}"
        except Exception as e:
            return f"Error: {e}"
    
    def load_2d_points(text):
        try:
            state['model_2d'] = Model.load_points(text)
            return f"Loaded {len(state['model_2d'])} 2D points"
        except Exception as e:
            return f"Error: {e}"
    
    def load_3d_points(text):
        try:
            state['model_3d'] = Model.load_points_3d(text)
            return f"Loaded {len(state['model_3d'])} 3D points"
        except Exception as e:
            return f"Error: {e}"
    
    def estimate_poses():
        if not state['points'] or state['K'] is None: 
            return "No points/intrinsics", "No points/intrinsics", None
        
        image_pts = Clicks.capture(state['points'])
        homo_result = cv_result = "No model points"
        plot = None
        
        if state['model_2d'] is not None:
            if len(image_pts) == len(state['model_2d']):
                try:
                    R_h, t_h = Pose.from_homography(state['K'], state['model_2d'], image_pts)
                    homo_result = f"Homography:\nR=\n{R_h}\nt={t_h.flatten()}"
                except Exception as e:
                    homo_result = f"Homography failed: {e}"
            else:
                homo_result = f"Point mismatch: {len(image_pts)} vs {len(state['model_2d'])}"
        
        if state['model_3d'] is not None:
            if len(image_pts) == len(state['model_3d']):
                try:
                    R_cv, t_cv = Pose.from_opencv(state['K'], state['dist'], state['model_3d'], image_pts)
                    cv_result = f"OpenCV:\nR=\n{R_cv}\nt={t_cv.flatten()}"
                    plot = Viz.plot_camera_3d(R_cv, t_cv)
                except Exception as e:
                    cv_result = f"OpenCV failed: {e}"
            else:
                cv_result = f"Point mismatch: {len(image_pts)} vs {len(state['model_3d'])}"
        
        return homo_result, cv_result, plot
    
    def reprojection_overlay():
        if not all([state['image'], state['points'], state['model_3d'], state['K']]): 
            return state['image']
        try:
            image_pts = Clicks.capture(state['points'])
            R, t = Pose.from_opencv(state['K'], state['dist'], state['model_3d'], image_pts)
            overlay = Viz.overlay_reprojection(state['image'], state['model_3d'], image_pts, state['K'], state['dist'], R, t)
            return np.array(overlay)
        except:
            return state['image']
    
    def axes_overlay():
        if not all([state['image'], state['points'], state['model_3d'], state['K']]): 
            return state['image']
        try:
            image_pts = Clicks.capture(state['points'])
            R, t = Pose.from_opencv(state['K'], state['dist'], state['model_3d'], image_pts)
            overlay = Viz.overlay_axes(state['image'], state['K'], state['dist'], R, t)
            return np.array(overlay)
        except:
            return state['image']
    
    with gr.Blocks(title="Pose Estimation Demo") as demo:
        gr.Markdown("# Pose Estimation from 2D-3D Correspondences")
        
        with gr.Row():
            with gr.Column(scale=2):
                img = gr.Image(label="Upload Image (click to add points)", type="numpy", interactive=True)
                with gr.Row():
                    undo_btn = gr.Button("Undo", size="sm")
                    clear_btn = gr.Button("Clear", size="sm")
                status = gr.Textbox(label="Status", interactive=False)
                points_json = gr.Textbox(label="Selected Points", interactive=False, lines=3)
            
            with gr.Column(scale=1):
                gr.Markdown("### Camera Intrinsics")
                intrinsics_txt = gr.Textbox(label="JSON format", lines=3, 
                    placeholder='{"K": [[800,0,320],[0,800,240],[0,0,1]], "distCoeffs": [0,0,0,0,0]}')
                load_k_btn = gr.Button("Load Intrinsics")
                k_status = gr.Textbox(label="Status", interactive=False)
                
                gr.Markdown("### Model Points")
                model_2d_txt = gr.Textbox(label="2D Points (Homography)", lines=2, 
                    placeholder='[[0,0],[100,0],[100,100],[0,100]]')
                load_2d_btn = gr.Button("Load 2D")
                status_2d = gr.Textbox(label="Status", interactive=False)
                
                model_3d_txt = gr.Textbox(label="3D Points (OpenCV)", lines=2,
                    placeholder='[[0,0,0],[1,0,0],[1,1,0],[0,1,0]]')
                load_3d_btn = gr.Button("Load 3D")
                status_3d = gr.Textbox(label="Status", interactive=False)
        
        with gr.Row():
            estimate_btn = gr.Button("Estimate Pose", variant="primary", size="lg")
            overlay_btn = gr.Button("Reprojection Overlay", size="lg")
            axes_btn = gr.Button("Coordinate Axes", size="lg")
        
        with gr.Row():
            with gr.Column():
                homo_result = gr.Textbox(label="Homography to Pose", lines=6, interactive=False)
                cv_result = gr.Textbox(label="OpenCV solvePnP", lines=6, interactive=False)
            with gr.Column():
                plot_3d = gr.Image(label="3D Camera Pose", interactive=False)
        
        with gr.Row():
            overlay_img = gr.Image(label="Reprojection Overlay", interactive=False)
            axes_img = gr.Image(label="Coordinate Axes Overlay", interactive=False)
        
        # Event handlers
        img.select(on_click, inputs=[img], outputs=[img, status, points_json])
        undo_btn.click(lambda img: undo_last(img), inputs=[img], outputs=[img, status, points_json])
        clear_btn.click(lambda img: clear_points(img), inputs=[img], outputs=[img, status, points_json])
        load_k_btn.click(load_intrinsics, inputs=[intrinsics_txt], outputs=[k_status])
        load_2d_btn.click(load_2d_points, inputs=[model_2d_txt], outputs=[status_2d])
        load_3d_btn.click(load_3d_points, inputs=[model_3d_txt], outputs=[status_3d])
        estimate_btn.click(estimate_poses, outputs=[homo_result, cv_result, plot_3d])
        overlay_btn.click(reprojection_overlay, outputs=[overlay_img])
        axes_btn.click(axes_overlay, outputs=[axes_img])
    
    return demo

In [4]:
# Launch the demo
demo = create_pose_estimation_demo()
demo.launch(debug=False, share=True)

* Running on local URL:  http://127.0.0.1:7861
* Running on public URL: https://2ac7fc07f64aeab73c.gradio.live

This share link expires in 1 week. For free permanent hosting and GPU upgrades, run `gradio deploy` from the terminal in the working directory to deploy to Hugging Face Spaces (https://huggingface.co/spaces)


