In [None]:
import torch
import torch.nn as nn
import numpy as np
import open3d as o3d
from autoencoder import PointCloudAE  # Import your model definition
import os
from sklearn.metrics.pairwise import cosine_similarity

def shape_encoder(train_data_path, ply_query_path, model_path, latent_size, point_size, train_ply_text):
    use_GPU = torch.cuda.is_available()
    top_k = 1  # Top K similar shapes to retrieve
    device = torch.device("cuda:0" if use_GPU else "cpu")
    net = PointCloudAE(point_size, latent_size)
    checkpoint = torch.load(model_path, map_location=device)
    net.load_state_dict(checkpoint['model_state_dict'])

    net.eval()
    net = net.to(device)

    def sample_point_cloud(ply_path, num_points=10000):
        mesh = o3d.io.read_triangle_mesh(ply_path)
        mesh.compute_vertex_normals()
        pcd = mesh.sample_points_uniformly(number_of_points=num_points)
        return np.asarray(pcd.points)

    train_pcs = np.load(train_data_path)  # (N, 10000, 3)
    train_pcs_tensor = torch.tensor(train_pcs, dtype=torch.float32).permute(0, 2, 1).to(device)

    with torch.no_grad():
        train_latents = net.encoder(train_pcs_tensor).cpu().numpy()  # (N, latent_size)

    query_pc = sample_point_cloud(ply_query_path)  # (10000, 3)
    query_tensor = torch.tensor(query_pc, dtype=torch.float32).unsqueeze(0).permute(0, 2, 1).to(device)

    with torch.no_grad():
        query_latent = net.encoder(query_tensor).cpu().numpy()  # (1, latent_size)

    similarities = cosine_similarity(query_latent, train_latents)[0]  # shape: (N,)
    top_k_indices = similarities.argsort()[::-1][:top_k]
    with open(train_ply_text, "r") as f:
        all_ply_files = [line.strip() for line in f.readlines()]
    for rank, idx in enumerate(top_k_indices, start=1):
        return ply_query_path, all_ply_files[idx]
    

In [None]:
def run_shape_encoder_on_folder(deformed_root, train_data_path, model_path, latent_size, point_size, train_ply_text, output_txt_path):
    result_lines = []

    for root, _, files in os.walk(deformed_root):
        for file_name in files:
            if file_name.endswith('.ply'):
                query_path = os.path.join(root, file_name)
                try:
                    query, match = shape_encoder(
                        train_data_path, query_path,
                        model_path, latent_size,
                        point_size, train_ply_text
                    )
                    result_lines.append(f"{query} {match}")
                except Exception as e:
                    print(f"Failed to process {query_path}: {e}")

    with open(output_txt_path, "w") as f:
        for line in result_lines:
            f.write(line + "\n")

# Example usage:
run_shape_encoder_on_folder(
    deformed_root="split_data_new/test/deformed", # the directory containing test osteophytic femurs
    train_data_path="femur_train_val.npy", # .npy file path
    model_path="output1/checkpoint_epoch_1500.pt", # checkpoint of the HFPR_autoencoder
    latent_size=128,
    point_size=5000,
    train_ply_text="train_val_ply_files.txt", # txt file containing the osteophytic femur paths 
    output_txt_path="output_results_train_val.txt" # txt file containing simialar mesh correponding to the test osteophytic femurs
)
