In [None]:
"""
MVF-Net Improvement Strategies WITHOUT Pretrained Models
=========================================================

These are post-processing improvements that enhance MVF-Net output
without requiring additional trained models.

Time estimate: 1-2 days to implement
Results: Visibly improved quality with measurable metrics
"""

import numpy as np
import trimesh
import torch
import cv2, os
from scipy.ndimage import gaussian_filter
from scipy.spatial import cKDTree
from scipy.sparse import csr_matrix
import open3d as o3d
from PIL import Image
from preprocessing import crop_image, FaceDetector
from reconstruction import ShapeReconstructor, write_ply
import torchvision.transforms as transforms
from models.vgg_encoder import VggEncoder

In [13]:
# ============================================================================
# IMPROVEMENT 1: Laplacian Mesh Smoothing with Detail Preservation
# ============================================================================

def laplacian_smoothing_adaptive(vertices, faces, iterations=5, lambda_smooth=0.5):
    """
    Smooth the mesh while preserving high-frequency details
    This reduces MVF-Net's typical over-smoothing artifacts
    
    Args:
        vertices: (N, 3) vertex positions
        faces: (M, 3) face indices
        iterations: number of smoothing iterations
        lambda_smooth: smoothing strength (0=no smooth, 1=max smooth)
    
    Returns:
        smoothed_vertices: (N, 3) improved vertices
    """
    mesh = trimesh.Trimesh(vertices=vertices, faces=faces)
    
    # Compute vertex normals to detect high-curvature regions
    normals = mesh.vertex_normals
    neighbors = mesh.vertex_neighbors
    
    # Detect high-curvature areas (preserve these)
    curvature = np.zeros(len(vertices))
    for i, nbrs in enumerate(neighbors):
        if len(nbrs) > 0:
            # Normal variation indicates curvature
            normal_diff = np.linalg.norm(normals[nbrs] - normals[i], axis=1)
            curvature[i] = np.mean(normal_diff)
    
    # Normalize curvature to [0, 1]
    curvature = curvature / (curvature.max() + 1e-8)
    
    # Adaptive smoothing: smooth less in high-curvature areas
    smoothed = vertices.copy()
    for iteration in range(iterations):
        new_verts = smoothed.copy()
        for i, nbrs in enumerate(neighbors):
            if len(nbrs) > 0:
                # Weight by inverse curvature (smooth flat areas more)
                weight = lambda_smooth * (1.0 - curvature[i])
                neighbor_mean = np.mean(smoothed[nbrs], axis=0)
                new_verts[i] = (1 - weight) * smoothed[i] + weight * neighbor_mean
        smoothed = new_verts
    
    return smoothed

In [14]:
# ============================================================================
# IMPROVEMENT 2: Multi-View Consistency Enhancement
# ============================================================================

def enforce_multiview_consistency(vertices_list, faces, confidence_weights=None):
    """
    Given multiple MVF-Net outputs from slightly different views,
    combine them with consistency constraints
    
    Args:
        vertices_list: List of (N, 3) vertex arrays from different runs
        faces: (M, 3) face indices (same topology)
        confidence_weights: Optional weights for each view
    
    Returns:
        consistent_vertices: (N, 3) improved vertices
    """
    if confidence_weights is None:
        confidence_weights = np.ones(len(vertices_list)) / len(vertices_list)
    
    # Weighted average
    consistent = np.zeros_like(vertices_list[0])
    for verts, weight in zip(vertices_list, confidence_weights):
        consistent += weight * verts
    
    # Enforce mesh quality constraints
    mesh = trimesh.Trimesh(vertices=consistent, faces=faces)
    
    # Remove self-intersections (common MVF-Net artifact)
    if mesh.is_watertight and mesh.body_count == 1:
        # Resolve self-intersections
        consistent = remove_self_intersections(consistent, faces)
    
    return consistent


import numpy as np

def closest_point_on_triangles(P, Tris):
    """
    Compute closest points from points P (N,3)
    to triangles Tris (M,3,3).
    Returns:
        CP: (N,M,3) closest points
        D2: (N,M) squared distances
    """

    A = Tris[:, 0]  # (M,3)
    B = Tris[:, 1]
    C = Tris[:, 2]

    AB = B - A      # (M,3)
    AC = C - A
    AP = P[:, None, :] - A[None, :, :]  # (N, M, 3)

    # Dot products
    d1 = np.sum(AB * AP, axis=2)
    d2 = np.sum(AC * AP, axis=2)

    # Edges
    ABAB = np.sum(AB * AB, axis=1)
    ACAC = np.sum(AC * AC, axis=1)
    ABAC = np.sum(AB * AC, axis=1)

    denom = ABAB * ACAC - ABAC * ABAC
    denom = denom + 1e-12

    v = (ACAC * d1 - ABAC * d2) / denom
    w = (ABAB * d2 - ABAC * d1) / denom

    # Clamp barycentric coords to triangle
    v_clamped = np.clip(v, 0, 1)
    w_clamped = np.clip(w, 0, 1 - v_clamped)

    CP = A[None,:,:] + v_clamped[:,:,None] * AB[None,:,:] + w_clamped[:,:,None] * AC[None,:,:]
    D2 = np.sum((CP - P[:,None,:])**2, axis=2)

    return CP, D2


def remove_self_intersections(vertices, faces, iterations=3, threshold=0.005, chunk=300):
    """
    NumPy-only implementation matching trimesh.proximity behavior.
    No PyTorch3D / Open3D / trimesh RTree required.
    """

    vertices = vertices.copy()
    V = len(vertices)
    F = len(faces)

    # adjacency: vertex -> list of face ids
    vertex_to_faces = [[] for _ in range(V)]
    for f_idx, (a,b,c) in enumerate(faces):
        vertex_to_faces[a].append(f_idx)
        vertex_to_faces[b].append(f_idx)
        vertex_to_faces[c].append(f_idx)

    for iteration in range(iterations):
        print(f"[Iteration {iteration+1}/{iterations}] computing normals...")

        # face normals
        face_normals = np.cross(
            vertices[faces[:,1]] - vertices[faces[:,0]],
            vertices[faces[:,2]] - vertices[faces[:,0]]
        )
        # vertex normals = sum of face normals
        normals = np.zeros_like(vertices)
        for f_idx, (a,b,c) in enumerate(faces):
            n = face_normals[f_idx]
            normals[a] += n
            normals[b] += n
            normals[c] += n

        normals /= (np.linalg.norm(normals, axis=1, keepdims=True) + 1e-12)

        # Pre-extract triangles
        Tris = vertices[faces]   # (F,3,3)

        # per-vertex tracking
        closest_face = np.zeros(V, dtype=np.int32)
        closest_dist = np.full(V, np.inf)

        # print(f"[Iteration {iteration+1}] computing closest points (chunked)...")

        # Chunked closest-point computation
        for start in range(0, F, chunk):
            end = min(start + chunk, F)
            tris_chunk = Tris[start:end]  # (chunk, 3,3)

            CP, D2 = closest_point_on_triangles(vertices, tris_chunk)
            # D2: (V, chunk)

            local_min_idx = np.argmin(D2, axis=1)
            local_min_dist = D2[np.arange(V), local_min_idx]

            mask = local_min_dist < closest_dist
            closest_dist[mask] = local_min_dist[mask]
            closest_face[mask] = start + local_min_idx[mask]

        # print(f"[Iteration {iteration+1}] pushing intersecting vertices...")

        moved = 0
        for i in range(V):
            if closest_dist[i] < threshold:
                f = closest_face[i]
                if f not in vertex_to_faces[i]:
                    vertices[i] += threshold * normals[i]
                    moved += 1

        # print(f"  moved {moved} vertices.")

    return vertices

In [15]:
# ============================================================================
# IMPROVEMENT 3: Normal-Based Detail Enhancement
# ============================================================================

def enhance_normals_from_image(vertices, faces, image_rgb, camera_params):
    """
    Use input image gradients to enhance mesh normals
    This adds back detail that MVF-Net smoothed out
    
    Args:
        vertices: (N, 3) vertex positions
        faces: (M, 3) face indices
        image_rgb: (H, W, 3) input image
        camera_params: dict with intrinsics/extrinsics
    
    Returns:
        enhanced_vertices: (N, 3) vertices with enhanced details
    """
    mesh = trimesh.Trimesh(vertices=vertices, faces=faces)
    
    # Compute image gradients (proxy for surface detail)
    gray = cv2.cvtColor(image_rgb, cv2.COLOR_RGB2GRAY)
    grad_x = cv2.Sobel(gray, cv2.CV_64F, 1, 0, ksize=3)
    grad_y = cv2.Sobel(gray, cv2.CV_64F, 0, 1, ksize=3)
    gradient_magnitude = np.sqrt(grad_x**2 + grad_y**2)
    gradient_magnitude = gradient_magnitude / (gradient_magnitude.max() + 1e-8)
    
    # Project vertices to image
    # (Simplified - assumes orthographic projection)
    vertices_2d = vertices[:, :2]  # x, y coordinates
    
    # For each vertex, sample gradient at its projected location
    h, w = image_rgb.shape[:2]
    vertex_gradients = np.zeros(len(vertices))
    
    for i, v2d in enumerate(vertices_2d):
        # Convert to pixel coordinates
        px = int((v2d[0] + 1) * w / 2)
        py = int((v2d[1] + 1) * h / 2)
        
        if 0 <= px < w and 0 <= py < h:
            vertex_gradients[i] = gradient_magnitude[py, px]
    
    # Enhance vertex positions along normals based on gradient
    normals = mesh.vertex_normals
    enhancement_scale = 0.002  # Small displacement
    
    enhanced_vertices = vertices + enhancement_scale * vertex_gradients[:, np.newaxis] * normals
    
    return enhanced_vertices

In [16]:
def compute_laplacian_matrix(vertices, faces):
    """
    Compute Laplacian matrix manually (avoids trimesh caching issues)
    
    Returns sparse matrix L of shape (N, N) where N is number of vertices
    """
    N = len(vertices)
    
    # Build adjacency information
    edges = {}
    
    for face in faces:
        for i in range(3):
            v1 = face[i]
            v2 = face[(i+1) % 3]
            
            # Store edge (always smaller index first)
            edge = tuple(sorted([v1, v2]))
            if edge not in edges:
                edges[edge] = []
            edges[edge].append(face)
    
    # Build Laplacian matrix
    # L[i,j] = -1 if i and j are neighbors
    # L[i,i] = degree of vertex i
    
    rows = []
    cols = []
    data = []
    
    # Count neighbors for each vertex
    degree = [0] * N
    
    for (v1, v2) in edges.keys():
        # Off-diagonal entries
        rows.append(v1)
        cols.append(v2)
        data.append(-1.0)
        
        rows.append(v2)
        cols.append(v1)
        data.append(-1.0)
        
        # Update degree
        degree[v1] += 1
        degree[v2] += 1
    
    # Diagonal entries (degree)
    for i in range(N):
        rows.append(i)
        cols.append(i)
        data.append(float(degree[i]))
    
    # Create sparse matrix
    L = csr_matrix((data, (rows, cols)), shape=(N, N))
    
    return L


def verify_mesh_consistency(vertices_before, faces_before, 
                            vertices_after, faces_after,
                            step_name=""):
    """
    Debug function to verify mesh dimensions stay consistent
    
    Call this after each processing step to catch dimension changes
    """
    print(f"\n[{step_name}] Mesh consistency check:")
    print(f"  Vertices: {len(vertices_before)} → {len(vertices_after)}")
    print(f"  Faces: {len(faces_before)} → {len(faces_after)}")
    
    if len(vertices_before) != len(vertices_after):
        print(f"  ⚠️  WARNING: Vertex count changed!")
    
    if len(faces_before) != len(faces_after):
        print(f"  ⚠️  WARNING: Face count changed!")
    
    # Check for invalid face indices
    max_idx = faces_after.max()
    if max_idx >= len(vertices_after):
        print(f"  ⚠️  ERROR: Invalid faces (max index {max_idx} >= {len(vertices_after)} vertices)")
        return False
    
    print(f"  ✓ Mesh is consistent")
    return True

In [17]:
# ============================================================================
# IMPROVEMENT 4: Mesh Quality Metrics (For Evaluation)
# ============================================================================

def compute_mesh_quality_metrics(vertices, faces):
    """
    Compute metrics to quantify improvement
    
    Returns:
        dict with metrics like:
        - edge_uniformity: How uniform edge lengths are
        - triangle_quality: Aspect ratios of triangles
        - smoothness: Laplacian smoothness
        - volume: Mesh volume (should be stable)
    """

    assert vertices.shape[1] == 3, f"Vertices must be (N, 3), got {vertices.shape}"
    assert faces.shape[1] == 3, f"Faces must be (M, 3), got {faces.shape}"

    # Ensure faces reference valid vertices
    max_face_idx = faces.max()
    if max_face_idx >= len(vertices):
        print(f"  Warning: Invalid faces detected (max index {max_face_idx} >= {len(vertices)} vertices)")
        # Filter out invalid faces
        valid_mask = (faces[:, 0] < len(vertices)) & \
                     (faces[:, 1] < len(vertices)) & \
                     (faces[:, 2] < len(vertices))
        faces = faces[valid_mask]
        print(f"  Filtered to {len(faces)} valid faces")

    mesh = trimesh.Trimesh(vertices=vertices, faces=faces, process=False)
    
    metrics = {}
    
    # ========================================================================
    # 1. Edge length uniformity
    # ========================================================================
    try:
        edge_lengths = mesh.edges_unique_length
        if len(edge_lengths) > 0:
            mean_len = np.mean(edge_lengths)
            std_len = np.std(edge_lengths)
            if mean_len > 0:
                metrics['edge_uniformity'] = 1.0 - (std_len / mean_len)
            else:
                metrics['edge_uniformity'] = 0.0
        else:
            metrics['edge_uniformity'] = 0.0
    except Exception as e:
        print(f"  Warning: Could not compute edge uniformity: {e}")
        metrics['edge_uniformity'] = 0.0
    
    # ========================================================================
    # 2. Triangle quality (aspect ratio)
    # ========================================================================
    triangle_qualities = []
    for i, face in enumerate(faces):
        try:
            v0, v1, v2 = vertices[face]
            a = np.linalg.norm(v1 - v0)
            b = np.linalg.norm(v2 - v1)
            c = np.linalg.norm(v0 - v2)
            
            if a > 0 and b > 0 and c > 0:
                s = (a + b + c) / 2
                area_sq = s * (s-a) * (s-b) * (s-c)
                if area_sq > 0:
                    area = np.sqrt(area_sq)
                    # Quality metric: closer to 1 is better (equilateral triangle)
                    quality = 4 * np.sqrt(3) * area / (a**2 + b**2 + c**2)
                    triangle_qualities.append(quality)
        except Exception as e:
            # Skip degenerate triangles
            continue
    
    if len(triangle_qualities) > 0:
        metrics['triangle_quality'] = np.mean(triangle_qualities)
    else:
        metrics['triangle_quality'] = 0.0
    
    # ========================================================================
    # 3. Smoothness (Laplacian energy) - FIXED VERSION
    # ========================================================================
    try:
        # Compute Laplacian manually to avoid trimesh caching issues
        laplacian = compute_laplacian_matrix(vertices, faces)
        
        # Verify dimensions
        if laplacian.shape[0] == len(vertices) and laplacian.shape[1] == len(vertices):
            # Compute Laplacian energy: ||L * V||
            laplacian_coords = laplacian @ vertices  # (N, 3)
            laplacian_energy = np.linalg.norm(laplacian_coords)
            
            # Normalize by mesh size
            bbox_size = np.linalg.norm(vertices.max(axis=0) - vertices.min(axis=0))
            if bbox_size > 0:
                normalized_energy = laplacian_energy / (len(vertices) * bbox_size)
                metrics['laplacian_smoothness'] = 1.0 / (1.0 + normalized_energy)
            else:
                metrics['laplacian_smoothness'] = 0.0
        else:
            print(f"  Warning: Laplacian dimension mismatch: {laplacian.shape} vs {len(vertices)} vertices")
            metrics['laplacian_smoothness'] = 0.0
    except Exception as e:
        print(f"  Warning: Could not compute Laplacian smoothness: {e}")
        metrics['laplacian_smoothness'] = 0.0
    
    # ========================================================================
    # 4. Volume (should remain stable)
    # ========================================================================
    try:
        if mesh.is_watertight:
            metrics['volume'] = abs(mesh.volume)
        else:
            # For non-watertight meshes, compute unsigned volume
            metrics['volume'] = 0.0
    except Exception as e:
        print(f"  Warning: Could not compute volume: {e}")
        metrics['volume'] = 0.0
    
    # ========================================================================
    # 5. Self-intersection check
    # ========================================================================
    try:
        metrics['has_self_intersections'] = 0.0 if mesh.is_watertight else 1.0
    except:
        metrics['has_self_intersections'] = 1.0
    
    return metrics

In [18]:
# ============================================================================
# IMPROVEMENT 5: Bilateral Filter for Mesh (Edge-Preserving Smooth)
# ============================================================================

def bilateral_mesh_filtering(vertices, faces, iterations=3, sigma_spatial=0.1, sigma_range=0.1):
    """
    Apply bilateral filtering to mesh vertices
    Smooths while preserving edges/details (better than Laplacian)
    
    This is the KEY improvement - reduces noise while keeping features
    """
    mesh = trimesh.Trimesh(vertices=vertices, faces=faces)
    neighbors = mesh.vertex_neighbors
    
    filtered_vertices = vertices.copy()
    
    for iteration in range(iterations):
        new_vertices = filtered_vertices.copy()
        
        for i, nbrs in enumerate(neighbors):
            if len(nbrs) == 0:
                continue
            
            # Current vertex position
            v_i = filtered_vertices[i]
            
            # Neighbor positions
            v_neighbors = filtered_vertices[nbrs]
            
            # Spatial weights (based on distance)
            spatial_dists = np.linalg.norm(v_neighbors - v_i, axis=1)
            spatial_weights = np.exp(-(spatial_dists**2) / (2 * sigma_spatial**2))
            
            # Range weights (based on normal similarity)
            # normals_i = mesh.vertex_normals[i]
            # normals_neighbors = mesh.vertex_normals[nbrs]
            # normal_diffs = np.linalg.norm(normals_neighbors - normals_i, axis=1)
            # range_weights = np.exp(-(normal_diffs**2) / (2 * sigma_range**2))
            
            # Combined weights
            # weights = spatial_weights * range_weights
            weights = spatial_weights
            weights = weights / (weights.sum() + 1e-8)
            
            # Filtered position
            new_vertices[i] = np.sum(weights[:, np.newaxis] * v_neighbors, axis=0)
        
        filtered_vertices = new_vertices
    
    return filtered_vertices

In [19]:
# ============================================================================
# IMPROVEMENT 6: Complete Pipeline
# ============================================================================

def improve_mvfnet_output(
    vertices_mvf, 
    faces_mvf, 
    input_image=None,
    output_path='mesh_improv.ply',
    enable_smoothing=True,
    enable_bilateral=True,
    enable_normal_enhancement=False,  # Requires camera params
    smoothing_iterations=3,
    bilateral_iterations=5
):
    """
    Complete improvement pipeline for MVF-Net output
    
    Args:
        vertices_mvf: (N, 3) MVF-Net output vertices
        faces_mvf: (M, 3) face indices
        input_image: Optional (H, W, 3) RGB image for normal enhancement
        output_path: Where to save improved mesh
        
    Returns:
        improved_vertices: (N, 3) enhanced vertices
        metrics_before: Quality metrics before improvement
        metrics_after: Quality metrics after improvement
    """
    print("="*60)
    print("MVF-Net Mesh Improvement Pipeline")
    print("="*60)
    
    # Compute metrics before
    print("\n[1/5] Computing baseline metrics...")
    metrics_before = compute_mesh_quality_metrics(vertices_mvf, faces_mvf)
    print("  Baseline quality:")
    for key, val in metrics_before.items():
        print(f"    {key}: {val:.4f}")
    
    vertices_before = vertices_mvf.copy()
    
    # Step 1: Bilateral filtering (best quality improvement)
    if enable_bilateral:
        print("\n[2/5] Applying bilateral filtering...")
        improved_vertices = bilateral_mesh_filtering(
            vertices_before, 
            faces_mvf, 
            iterations=bilateral_iterations
        )
        verify_mesh_consistency(vertices_before, faces_mvf, 
                               improved_vertices, faces_mvf, 
                               "After Bilateral")
        print("  ✓ Bilateral filter applied")
    
    # Step 2: Adaptive Laplacian smoothing
    if enable_smoothing:
        print("\n[3/5] Applying adaptive smoothing...")
        vertices_before = improved_vertices.copy()
        improved_vertices = laplacian_smoothing_adaptive(
            improved_vertices, 
            faces_mvf, 
            iterations=smoothing_iterations
        )
        verify_mesh_consistency(vertices_before, faces_mvf,
                               improved_vertices, faces_mvf,
                               "After Smoothing")
        print("  ✓ Adaptive smoothing applied")
    
    # Step 3: Self-intersection removal
    print("\n[4/5] Removing self-intersections...")
    # vertices_before = improved_vertices.copy()
    improved_vertices = remove_self_intersections(improved_vertices, faces_mvf) # nhớ đổi lại thành improved_vertices
    is_consistent = verify_mesh_consistency(vertices_before, faces_mvf,
                                            improved_vertices, faces_mvf,
                                            "After Self-Intersection Removal")
    
    if not is_consistent:
        print("  ⚠️  WARNING: Mesh became inconsistent! Using pre-removal version.")
        improved_vertices = vertices_before
    else:
        print("  ✓ Self-intersections removed")
    
    print(f"\nFinal mesh: {len(improved_vertices)} vertices, {len(faces_mvf)} faces")
    
    # Step 4: Normal enhancement (optional, requires image)
    if enable_normal_enhancement and input_image is not None:
        print("\n[5/5] Enhancing from image gradients...")
        # Would need camera parameters - skip for now
        print("  ⚠ Skipped (requires camera parameters)")
    else:
        print("\n[5/5] Normal enhancement skipped")

    # Save improved mesh
    write_ply(
        filename=output_path,
        points=improved_vertices,  # (N, 3) numpy array
        mesh=faces_mvf,            # (M, 3) numpy array
        as_text=True               # ASCII format (easier to debug)
    )
    print(f"\n✓ Improved mesh saved: {output_path}")
    
    # Compute metrics after
    print("\nComputing improved metrics...")
    metrics_after = compute_mesh_quality_metrics(improved_vertices, faces_mvf)
    print("  Improved quality:")
    for key, val in metrics_after.items():
        print(f"    {key}: {val:.4f}")

    # Save improved mesh
    write_ply(
        filename=output_path,
        points=improved_vertices,  # (N, 3) numpy array
        mesh=faces_mvf,            # (M, 3) numpy array
        as_text=True               # ASCII format (easier to debug)
    )
    print(f"\n✓ Improved mesh saved: {output_path}")
    
    # Show improvements
    print("\n" + "="*60)
    print("IMPROVEMENTS:")
    print("="*60)
    for key in metrics_before.keys():
        if key in metrics_after:
            before = metrics_before[key]
            after = metrics_after[key]
            change = ((after - before) / (abs(before) + 1e-8)) * 100
            print(f"  {key}:")
            print(f"    Before: {before:.4f}")
            print(f"    After:  {after:.4f}")
            print(f"    Change: {change:+.1f}%")
    print("="*60)
    
    return improved_vertices, metrics_before, metrics_after

In [None]:
def load_model(ckpt="data/weights/net.pth"):
    model = VggEncoder()

    ckpt = torch.load('data/weights/net.pth', map_location=torch.device('cpu'))
    state = ckpt   # extract the real weights

    # Remove "module." prefix
    new_state = {}
    for k,v in state.items():
        new_state[k.replace("module.", "")] = v

    model.load_state_dict(new_state)

    model.eval()
    print("✓ MVF-Net model loaded")
    return model

def run_model_inference_simple(
        img1_path, img2_path, img3_path,
        output_path="mesh_original.ply",
        device="cpu"
        ):
    print("\n" + "="*60)
    print("Running MVF-Net Inference")
    print("="*60)
    
    # Load images
    img1 = Image.open(img1_path).convert('RGB')
    img2 = Image.open(img2_path).convert('RGB')
    img3 = Image.open(img3_path).convert('RGB')
    
    # Crop using author's function (detects face + crops)
    print("\nCropping images...")
    img1_cropped = crop_image(img1, res=224)
    img2_cropped = crop_image(img2, res=224)
    img3_cropped = crop_image(img3, res=224)
    print("✓ Images cropped")
    
    # Convert to tensors
    img1_tensor = transforms.functional.to_tensor(img1_cropped)
    img2_tensor = transforms.functional.to_tensor(img2_cropped)
    img3_tensor = transforms.functional.to_tensor(img3_cropped)
    
    # Stack: [batch, 9 channels, 224, 224]
    input_tensor = torch.cat([img1_tensor, img2_tensor, img3_tensor], 0).view(1, 9, 224, 224)
    
    # Run model
    model = load_model(ckpt='./data/weights/net.pth')
    print("\nRunning model forward pass...")
    with torch.no_grad():
        preds = model(input_tensor)
    
    # Convert to numpy
    preds_np = preds[0].cpu().numpy()  # Shape: (257,) or similar
    
    print(f"Model predictions shape: {preds_np.shape}")
    
    # Use author's preds_to_shape function
    print("\nReconstructing 3D mesh...")
    result = preds_to_shape(preds_np)
    
    # Parse author's output format
    # result = [face_shape, faces, kptA, kptB, kptC]
    vertices = result[0]  # (N, 3)
    faces = result[1]     # (M, 3)
    kptA = result[2]      # (68, 2) - keypoints for view A
    kptB = result[3]      # (68, 2) - keypoints for view B
    kptC = result[4]      # (68, 2) - keypoints for view C
    
    print(f"\n✓ Reconstruction complete:")
    print(f"  - Vertices: {vertices.shape}")
    print(f"  - Faces: {faces.shape}")
    print(f"  - Keypoints: {kptA.shape}")
    print("="*60)

    # Save original mesh
    write_ply(os.path.join('./result', output_path), result[0], result[1])
    
    return vertices, faces, [kptA, kptB, kptC]

In [21]:
def run_model_inference_with_improvement(
        img1_path, img2_path, img3_path,
        output_original='mvfnet_original.ply',
        output_improved='mvfnet_improv.ply',
        device='cpu'
    ):
    """
    Run MVF-Net + apply your post-processing improvements
    
    This combines:
    1. MVF-Net inference (baseline)
    2. Your bilateral filtering improvements
    """
    # Get baseline MVF-Net result
    vertices_orig, faces, _ = run_model_inference_simple(
        img1_path, img2_path, img3_path,
        output_path=output_original,
        device=device
    )
    
    # Apply your improvements (from previous artifact)
    print("\n" + "="*60)
    print("Applying Post-Processing Improvements")
    print("="*60)
    
    improved_verts, metrics_before, metrics_after = improve_mvfnet_output(
        vertices_orig,
        faces,
        output_path=output_improved,
        enable_smoothing=False,
        enable_bilateral=True,
        smoothing_iterations=3,
        bilateral_iterations=1
    )
    
    print("\n" + "="*60)
    print("RESULTS SUMMARY")
    print("="*60)
    print(f"✓ Original mesh: {output_original}")
    print(f"✓ Improved mesh: {output_improved}")
    print("\nQuality Improvements:")
    for key in metrics_before.keys():
        before = metrics_before[key]
        after = metrics_after[key]
        change = ((after - before) / (abs(before) + 1e-8)) * 100
        print(f"  {key}: {change:+.1f}%")
    print("="*60)

In [22]:
# ============================================================================
# Example Usage
# ============================================================================

if __name__ == "__main__":
    
    run_model_inference_with_improvement(
        img1_path='./data/imgs/front.jpg',
        img2_path='./data/imgs/left.jpg',
        img3_path='./data/imgs/right.jpg',
        output_original='mesh_original.ply',
        output_improved='mesh_improv.ply',
        device='cuda'
    )
    
    # For your presentation, you can show:
    # 1. Side-by-side visualization (before/after)
    # 2. Quantitative metrics improvement
    # 3. Processing time (should be <1 second per mesh)
    
    print("\n✓ DONE! You now have:")
    print("  - Original MVF-Net mesh")
    print("  - Improved mesh with better quality")
    print("  - Quantitative metrics showing improvement")
    print("  - Fast processing (no training needed!)")


# ============================================================================
# BONUS: Visualization Comparison
# ============================================================================

def visualize_before_after(mesh_before_path, mesh_after_path):
    """
    Visualize original and improved meshes side-by-side
    """
    mesh_before = o3d.io.read_triangle_mesh(mesh_before_path)
    mesh_after = o3d.io.read_triangle_mesh(mesh_after_path)
    
    # Color differently
    mesh_before.paint_uniform_color([0.7, 0.7, 0.7])  # Gray
    mesh_after.paint_uniform_color([0.3, 0.7, 0.9])   # Blue
    
    # Offset for side-by-side
    mesh_after.translate([0.3, 0, 0])
    
    o3d.visualization.draw_geometries(
        [mesh_before, mesh_after],
        window_name="Before (left) vs After (right)"
    )


Running MVF-Net Inference

Cropping images...
✓ Images cropped


  nn.init.normal(m.weight, 0.0, 0.0001)
  nn.init.constant(m.bias, 0)


✓ MVF-Net model loaded

Running model forward pass...
Model predictions shape: (249,)

Reconstructing 3D mesh...

✓ Reconstruction complete:
  - Vertices: (53215, 3)
  - Faces: (105840, 3)
  - Keypoints: (68, 2)

Applying Post-Processing Improvements
MVF-Net Mesh Improvement Pipeline

[1/5] Computing baseline metrics...
  Baseline quality:
    edge_uniformity: 0.5862
    triangle_quality: 0.7830
    laplacian_smoothness: 1.0000
    volume: 0.0000
    has_self_intersections: 1.0000

[2/5] Applying bilateral filtering...

[After Bilateral] Mesh consistency check:
  Vertices: 53215 → 53215
  Faces: 105840 → 105840
  ✓ Mesh is consistent
  ✓ Bilateral filter applied

[4/5] Removing self-intersections...
[Iteration 1/3] computing normals...
[Iteration 2/3] computing normals...
[Iteration 3/3] computing normals...

[After Self-Intersection Removal] Mesh consistency check:
  Vertices: 53215 → 53215
  Faces: 105840 → 105840
  ✓ Mesh is consistent
  ✓ Self-intersections removed

Final mesh: 532