In [2]:
import random
import os
import json
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import numpy as np
import trimesh
from datetime import datetime 

from pointnet2_utils import PointNetSetAbstractionMsg, PointNetSetAbstraction
from pointnet2_keypoint_regressor import get_model#, get_model_msg

from torch.utils.data import random_split

import matplotlib.pyplot as plt 
from mpl_toolkits.mplot3d import Axes3D

In [3]:
class KeypointPredictor:
    def __init__(self, model_path, device='cuda'):
        """
        Initialize the keypoint predictor
        
        Args:
            model_path: path to saved model (.pth file)
            device: 'cuda' or 'cpu'
        """

        random.seed(0)
        np.random.seed(0)
        torch.manual_seed(0)
        torch.cuda.manual_seed(0)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False


        self.device = torch.device(device if torch.cuda.is_available() else 'cpu')
        
        # Load model checkpoint
        checkpoint = torch.load(model_path, map_location=self.device)
        self.config = checkpoint.get('config', {'num_keypoints': 9, 'num_points': 1024})
        
        # Initialize model
        self.model = get_model(
            num_keypoints=self.config['num_keypoints'], 
            normal_channel=False
        ).to(self.device)
        
        # Load trained weights
        self.model.load_state_dict(checkpoint['model_state_dict'])
        self.model.eval()
        
        print(f"Loaded model with {self.config['num_keypoints']} keypoints")
        print(f"Using device: {self.device}")
    
    def load_and_sample_mesh(self, mesh_path, num_points=None):
        """
        Load mesh and sample points from surface
        
        Args:
            mesh_path: path to mesh file (.ply, .obj, etc.)
            num_points: number of points to sample (default: from config)
        
        Returns:
            points: numpy array of shape (num_points, 3)
        """
        if num_points is None:
            num_points = self.config['num_points']
        
        try:
            # Load mesh
            mesh = trimesh.load(mesh_path, force='mesh')
            if mesh.is_empty or len(mesh.faces) == 0:
                raise ValueError("Empty mesh")
            
            # Sample points from surface
            points, _ = trimesh.sample.sample_surface(mesh, num_points)
            
            # Handle case where mesh has fewer faces than requested points
            if points.shape[0] < num_points:
                pad_size = num_points - points.shape[0]
                pad = np.repeat(points[0:1, :], pad_size, axis=0)
                points = np.vstack((points, pad))
            
            return points.astype(np.float32)
            
        except Exception as e:
            print(f"Error loading mesh {mesh_path}: {e}")
            return None
    
    def normalize_points(self, points):
        """
        Normalize points using same method as training
        
        Args:
            points: numpy array of shape (num_points, 3)
        
        Returns:
            normalized_points: numpy array of shape (num_points, 3)
            centroid: numpy array of shape (3,) - for denormalization
            scale: float - for denormalization
        """
        centroid = np.mean(points, axis=0)
        scale = np.max(np.linalg.norm(points - centroid, axis=1))
        
        # Avoid division by zero
        if scale == 0:
            scale = 1.0
        
        normalized_points = (points - centroid) / scale
        return normalized_points, centroid, scale
    
    def denormalize_keypoints(self, keypoints, centroid, scale):
        """
        Convert normalized keypoints back to original coordinate system
        
        Args:
            keypoints: numpy array of shape (num_keypoints, 3)
            centroid: numpy array of shape (3,)
            scale: float
        
        Returns:
            denormalized_keypoints: numpy array of shape (num_keypoints, 3)
        """
        return keypoints * scale + centroid
    
    def predict_keypoints(self, mesh_path, return_normalized=False):
        """
        Predict keypoints for a single mesh
        
        Args:
            mesh_path: path to mesh file
            return_normalized: if True, return keypoints in normalized coordinates
        
        Returns:
            keypoints: numpy array of shape (num_keypoints, 3)
            points: numpy array of shape (num_points, 3) - sampled points
            metadata: dict with normalization info
        """
        # Load and sample mesh
        points = self.load_and_sample_mesh(mesh_path)
        if points is None:
            return None, None, None
        
        # Normalize points
        normalized_points, centroid, scale = self.normalize_points(points)
        
        # Convert to tensor and add batch dimension
        points_tensor = torch.from_numpy(normalized_points).float()
        points_tensor = points_tensor.unsqueeze(0).permute(0, 2, 1).to(self.device)  # Shape: (1, 3, num_points)
        
        # Predict keypoints
        with torch.no_grad():
            predicted_keypoints, _ = self.model(points_tensor)
        
        # Convert back to numpy
        predicted_keypoints = predicted_keypoints.squeeze(0).cpu().numpy()  # Shape: (num_keypoints, 3)
        
        # Denormalize if requested
        if not return_normalized:
            predicted_keypoints = self.denormalize_keypoints(predicted_keypoints, centroid, scale)
            points_for_vis = points  # Original points
        else:
            points_for_vis = normalized_points
        
        metadata = {
            'centroid': centroid,
            'scale': scale,
            'mesh_path': mesh_path
        }
        
        return predicted_keypoints, points_for_vis, metadata
    
    def predict_batch(self, mesh_paths, return_normalized=False):
        """
        Predict keypoints for multiple meshes
        
        Args:
            mesh_paths: list of paths to mesh files
            return_normalized: if True, return keypoints in normalized coordinates
        
        Returns:
            results: list of (keypoints, points, metadata) tuples
        """
        results = []
        for mesh_path in mesh_paths:
            result = self.predict_keypoints(mesh_path, return_normalized)
            results.append(result)
        return results



def save_keypoints_to_file(keypoints, output_path, mesh_path=None):
    """
    Save keypoints to a text file
    
    Args:
        keypoints: numpy array of shape (num_keypoints, 3)
        output_path: path to save keypoints
        mesh_path: original mesh path (for reference)
    """
    with open(output_path, 'w') as f:
        if mesh_path:
            f.write(f"# Keypoints for mesh: {mesh_path}\n")
        f.write(f"# Format: x y z\n")
        f.write(f"# Number of keypoints: {len(keypoints)}\n")
        
        for i, kp in enumerate(keypoints):
            f.write(f"{kp[0]:.6f} {kp[1]:.6f} {kp[2]:.6f}\n")
    
    print(f"Keypoints saved to: {output_path}")
def visualize_keypoints_3d(points, keypoints, title="Predicted Keypoints", figsize=(12, 8)):
    """
    Visualize point cloud with predicted keypoints
    
    Args:
        points: numpy array of shape (num_points, 3)
        keypoints: numpy array of shape (num_keypoints, 3)
        title: string
        figsize: tuple for figure size
    """
    fig = plt.figure(figsize=figsize)
    ax = fig.add_subplot(111, projection='3d')
    
    # Plot point cloud
    ax.scatter(points[:, 0], points[:, 1], points[:, 2], 
               c='lightblue', alpha=0.6, s=1, label='Point Cloud')
    
    # Plot keypoints
    ax.scatter(keypoints[:, 0], keypoints[:, 1], keypoints[:, 2], 
               c='red', s=100, label='Predicted Keypoints', marker='o')
    
    # Add keypoint numbers
    for i, kp in enumerate(keypoints):
        ax.text(kp[0], kp[1], kp[2], f'  {i}', fontsize=10)
    
    ax.set_xlabel('X')
    ax.set_ylabel('Y')
    ax.set_zlabel('Z')
    ax.set_title(title)
    ax.legend()
    
    # Set equal aspect ratio
    max_range = np.array([points[:, 0].max()-points[:, 0].min(),
                         points[:, 1].max()-points[:, 1].min(),
                         points[:, 2].max()-points[:, 2].min()]).max() / 2.0
    mid_x = (points[:, 0].max()+points[:, 0].min()) * 0.5
    mid_y = (points[:, 1].max()+points[:, 1].min()) * 0.5
    mid_z = (points[:, 2].max()+points[:, 2].min()) * 0.5
    ax.set_xlim(mid_x - max_range, mid_x + max_range)
    ax.set_ylim(mid_y - max_range, mid_y + max_range)
    ax.set_zlim(mid_z - max_range, mid_z + max_range)

    ax.view_init(elev=-90, azim=90)
    
    plt.show()


def predict_points(model_path, input_ply, save = False):
    predictor = KeypointPredictor(model_path)

    mesh_path = input_ply 
    keypoints, points, metadata = predictor.predict_keypoints(mesh_path)

    if keypoints is not None:
        print(f"Predicted {len(keypoints)} keypoints for {mesh_path}")
        print("Keypoints coordinates:")
        for i, kp in enumerate(keypoints):
            print(f"{kp[0]:.3f} {kp[1]:.3f} {kp[2]:.3f}")

        
        # Visualize results
        visualize_keypoints_3d(points, keypoints, f"Keypoints for {os.path.basename(mesh_path)}")
        
        # Save keypoints to file
        if save:
            output_path = mesh_path.replace('.ply', '_keypoints.txt')
            save_keypoints_to_file(keypoints, output_path, mesh_path)
    else:
        print(f"Failed to process {mesh_path}")

    

predict_points("saved_models/kneenet++_4_5_final_2.pth", "scans_2/12252.stl")

RuntimeError: Error(s) in loading state_dict for get_model:
	size mismatch for sa1.mlp_convs.1.weight: copying a param with shape torch.Size([64, 64, 1, 1]) from checkpoint, the shape in current model is torch.Size([128, 64, 1, 1]).
	size mismatch for sa1.mlp_convs.1.bias: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for sa1.mlp_convs.2.weight: copying a param with shape torch.Size([128, 64, 1, 1]) from checkpoint, the shape in current model is torch.Size([256, 128, 1, 1]).
	size mismatch for sa1.mlp_convs.2.bias: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([256]).
	size mismatch for sa1.mlp_bns.1.weight: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for sa1.mlp_bns.1.bias: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for sa1.mlp_bns.1.running_mean: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for sa1.mlp_bns.1.running_var: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for sa1.mlp_bns.2.weight: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([256]).
	size mismatch for sa1.mlp_bns.2.bias: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([256]).
	size mismatch for sa1.mlp_bns.2.running_mean: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([256]).
	size mismatch for sa1.mlp_bns.2.running_var: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([256]).
	size mismatch for sa2.mlp_convs.0.weight: copying a param with shape torch.Size([128, 131, 1, 1]) from checkpoint, the shape in current model is torch.Size([256, 259, 1, 1]).
	size mismatch for sa2.mlp_convs.0.bias: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([256]).
	size mismatch for sa2.mlp_convs.1.weight: copying a param with shape torch.Size([128, 128, 1, 1]) from checkpoint, the shape in current model is torch.Size([512, 256, 1, 1]).
	size mismatch for sa2.mlp_convs.1.bias: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for sa2.mlp_convs.2.weight: copying a param with shape torch.Size([256, 128, 1, 1]) from checkpoint, the shape in current model is torch.Size([1024, 512, 1, 1]).
	size mismatch for sa2.mlp_convs.2.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([1024]).
	size mismatch for sa2.mlp_bns.0.weight: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([256]).
	size mismatch for sa2.mlp_bns.0.bias: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([256]).
	size mismatch for sa2.mlp_bns.0.running_mean: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([256]).
	size mismatch for sa2.mlp_bns.0.running_var: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([256]).
	size mismatch for sa2.mlp_bns.1.weight: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for sa2.mlp_bns.1.bias: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for sa2.mlp_bns.1.running_mean: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for sa2.mlp_bns.1.running_var: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for sa2.mlp_bns.2.weight: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([1024]).
	size mismatch for sa2.mlp_bns.2.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([1024]).
	size mismatch for sa2.mlp_bns.2.running_mean: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([1024]).
	size mismatch for sa2.mlp_bns.2.running_var: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([1024]).
	size mismatch for sa3.mlp_convs.0.weight: copying a param with shape torch.Size([256, 259, 1, 1]) from checkpoint, the shape in current model is torch.Size([512, 1027, 1, 1]).
	size mismatch for sa3.mlp_convs.0.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for sa3.mlp_convs.1.weight: copying a param with shape torch.Size([512, 256, 1, 1]) from checkpoint, the shape in current model is torch.Size([1024, 512, 1, 1]).
	size mismatch for sa3.mlp_convs.1.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([1024]).
	size mismatch for sa3.mlp_convs.2.weight: copying a param with shape torch.Size([1024, 512, 1, 1]) from checkpoint, the shape in current model is torch.Size([2048, 1024, 1, 1]).
	size mismatch for sa3.mlp_convs.2.bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([2048]).
	size mismatch for sa3.mlp_bns.0.weight: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for sa3.mlp_bns.0.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for sa3.mlp_bns.0.running_mean: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for sa3.mlp_bns.0.running_var: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for sa3.mlp_bns.1.weight: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([1024]).
	size mismatch for sa3.mlp_bns.1.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([1024]).
	size mismatch for sa3.mlp_bns.1.running_mean: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([1024]).
	size mismatch for sa3.mlp_bns.1.running_var: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([1024]).
	size mismatch for sa3.mlp_bns.2.weight: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([2048]).
	size mismatch for sa3.mlp_bns.2.bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([2048]).
	size mismatch for sa3.mlp_bns.2.running_mean: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([2048]).
	size mismatch for sa3.mlp_bns.2.running_var: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([2048]).
	size mismatch for fc1.weight: copying a param with shape torch.Size([512, 1024]) from checkpoint, the shape in current model is torch.Size([512, 2048]).