In [17]:
import os
import numpy as np
import open3d as o3d

def load_obj_points(file_path):
    points = []
    with open(file_path, 'r') as f:
        for line in f:
            if line.startswith('v '):  # vertex data
                coords = line.strip().split()[1:]  # skip the 'v' and get coordinates
                points.append([float(x) for x in coords])
    return np.array(points)

def create_point_cloud(points, color=[0.5, 0.5, 0.5]):
    pcd = o3d.geometry.PointCloud()
    pcd.points = o3d.utility.Vector3dVector(points)
    pcd.paint_uniform_color(color)
    return pcd

def visualize_point_cloud(points, window_name="Point Cloud"):
    pcd = create_point_cloud(points)
    viewer = o3d.visualization.Visualizer()
    viewer.create_window(window_name=window_name)
    viewer.add_geometry(pcd)
    viewer.get_render_option().point_size = 1.0
    viewer.get_render_option().background_color = np.array([255,255, 255])
    viewer.run()
    viewer.destroy_window()

def visualize_sample_pair(sample_idx=0, batch_idx=0):
    base_path = "log/PointAttN_cd_debug_pcn/all"
    input_file = f"batch{batch_idx}_sample{sample_idx}_input.obj"
    output_file = f"batch{batch_idx}_sample{sample_idx}_gt.obj"
    inter_file = f"batch{batch_idx}_sample{sample_idx}_output_inter.obj"
    
    # Load points
    input_points = load_obj_points(os.path.join(base_path, input_file))
    output_points = load_obj_points(os.path.join(base_path, output_file))
    inter_points = load_obj_points(os.path.join(base_path, inter_file))
    
    # Create point clouds with different colors
    input_pcd = create_point_cloud(input_points, [1, 0, 0])    # Red
    output_pcd = create_point_cloud(output_points, [0, 1, 0])  # Green
    inter_pcd = create_point_cloud(inter_points, [0, 0, 1])    # Blue
    
    # Visualize all point clouds in one window
    viewer = o3d.visualization.Visualizer()
    viewer.create_window(window_name=f"Sample {sample_idx} Visualization")
    #viewer.add_geometry(input_pcd)
    viewer.add_geometry(output_pcd)
    viewer.add_geometry(inter_pcd)
    
    # Customize visualization
    opt = viewer.get_render_option()
    opt.background_color = np.array([255, 255, 255])
    
    viewer.run()
    viewer.destroy_window()

def visualize_all_samples():
    base_path = "log/PointAttN_cd_debug_pcn"
    files = os.listdir(base_path)
    output_files = [f for f in files if f.endswith('output.obj')]
    
    # Extract unique batch and sample indices
    indices = set()
    for f in output_files:
        batch_idx = int(f[5:f.find('_sample')])
        sample_idx = int(f[f.find('sample')+6:f.find('_output')])
        indices.add((batch_idx, sample_idx))
    
    # Visualize each sample pair
    for batch_idx, sample_idx in sorted(indices):
        print(f"Visualizing batch {batch_idx}, sample {sample_idx}")
        visualize_sample_pair(sample_idx, batch_idx)

# Example usage:
visualize_sample_pair(2, 3)  # Visualize a single sample
#visualize_all_samples()      # Visualize all samples