In [2]:
import os
import json
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import numpy as np
import trimesh
from datetime import datetime 

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")


Using device: cuda


In [202]:
with open("guitar.json") as f:
    annotations = json.load(f)

model_id_to_keypoints = {}
for entry in annotations:
    model_id = entry['model_id']
    keypoints = [kp['xyz'] for kp in entry['keypoints']]
    keypoints = np.array(keypoints, dtype=np.float32)
    model_id_to_keypoints[model_id] = keypoints



In [4]:
# Using minimal PointNet (not full PointNet++)
class PointNetKeypointRegressor(nn.Module):
    def __init__(self, num_keypoints):
        super().__init__()
        self.sa1 = nn.Sequential(
            nn.Conv1d(3, 64, 1),
            nn.BatchNorm1d(64),
            nn.ReLU(),
            nn.Conv1d(64, 128, 1),
            nn.BatchNorm1d(128),
            nn.ReLU(),
            nn.Conv1d(128, 1024, 1),
            nn.BatchNorm1d(1024),
            nn.ReLU()
        )
        self.fc = nn.Sequential(
            nn.Linear(1024, 512),
            nn.ReLU(),
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Linear(256, num_keypoints * 3)
        )

    def forward(self, x):
        # x: [B, N, 3]
        x = x.permute(0, 2, 1)  # [B, 3, N]
        x = self.sa1(x)  # [B, 1024, N]
        x = torch.max(x, 2)[0]  # [B, 1024]
        x = self.fc(x)
        x = x.view(-1, NUM_KEYPOINTS, 3)
        return x




In [5]:
class ShapeNetKeypointDataset(Dataset):
    def __init__(self, mesh_dir, annotation_json):
        self.mesh_dir = mesh_dir
        self.samples = []

        #count_valid = 0
        #count_total = 0

        # Load annotations
        with open(annotation_json) as f:
            annotations = json.load(f)

        for entry in annotations:
            model_id = entry['model_id']
            #print(model_id)
            keypoints = np.array([kp['xyz'] for kp in entry['keypoints']], dtype=np.float32)
            #count_total += 1
            if keypoints.shape[0] != NUM_KEYPOINTS:
                continue  # Strictly filter only 6-keypoint meshes
            self.samples.append((model_id, keypoints))
            #count_valid += 1

        #print(f"Total meshes: {count_total}, Valid 6-keypoint meshes: {count_valid}")
        
    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        model_id, keypoints = self.samples[idx]

        #print("Inside __getitem__: keypoints.shape =", keypoints.shape) 

        mesh_path = os.path.join(self.mesh_dir, model_id + ".ply")

        try:
            mesh = trimesh.load(mesh_path, force='mesh')
            if mesh.is_empty or len(mesh.faces) == 0:
                raise ValueError("Empty mesh")

            points, _ = trimesh.sample.sample_surface(mesh, NUM_POINTS)
            if points.shape[0] < NUM_POINTS:
                pad_size = NUM_POINTS - points.shape[0]
                pad = np.repeat(points[0:1, :], pad_size, axis=0)
                points = np.vstack((points, pad))
        except Exception as e:
            print(f"Error loading mesh {model_id}: {e}")
            points = np.zeros((NUM_POINTS, 3), dtype=np.float32)
            keypoints = np.zeros((NUM_KEYPOINTS, 3), dtype=np.float32)

        # Normalize
        #print("Points Min:", np.min(points), "Max:", np.max(points))

        centroid = np.mean(points, axis=0)
        scale = np.max(np.linalg.norm(points - centroid, axis=1))
        points = (points - centroid) / scale
        keypoints = (keypoints - centroid) / scale


        return torch.from_numpy(points).float(), torch.from_numpy(keypoints).float()


In [None]:
NUM_POINTS = 4096 #Maybe do 512
NUM_KEYPOINTS = 9 # 2 meshes have 5 keypoints idk why
BATCH_SIZE = 128

In [20]:
# Training Loop
dataset = ShapeNetKeypointDataset(mesh_dir="Guitars", annotation_json="guitar.json")  
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=8)

print(f"Dataset size: {len(dataset)}")

model = PointNetKeypointRegressor(NUM_KEYPOINTS).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

os.makedirs("train_checkpoints", exist_ok=True)

resume_checkpoint_path = "saved_models/guitarnet_final.pth"
resume_checkpoint_path = "train_checkpoints/guitarnet_epoch_45.pth"
if os.path.exists(resume_checkpoint_path):
    checkpoint = torch.load(resume_checkpoint_path)
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    start_epoch = checkpoint['epoch'] + 1
    print(f"Resuming training from epoch {start_epoch}")


total_epoch = 50 # Additional epochs to train 
for epoch in range(total_epoch):
    model.train()
    total_loss = 0
    num_batches = 0
    
    for batch_idx, (pts, kps) in enumerate(dataloader):
        try:
            pts, kps = pts.to(device), kps.to(device)
            
            # Forward pass
            preds = model(pts)
            loss = F.mse_loss(preds, kps)
            
            # Back pass
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            # Accumulate loss
            total_loss += loss.item()
            num_batches += 1
            
            if epoch == 0 and batch_idx == 0:
                print(f"First batch - Input shape: {pts.shape}, Target shape: {kps.shape}")
                print(f"First batch loss: {loss.item():.6f}")
                print(f"Pred range: [{preds.min():.3f}, {preds.max():.3f}]")
                print(f"Target range: [{kps.min():.3f}, {kps.max():.3f}]")
            
        except Exception as e:
            print(f"Error in batch {batch_idx}: {e}")
            continue
    
    # Average loss
    avg_loss = total_loss / num_batches if num_batches > 0 else 0
    
    if (epoch + 1) % 10 == 0:
        checkpoint_path = f"train_checkpoints/guitarnet_epoch_{start_epoch + epoch+1}.pth"
        torch.save({
            'epoch': start_epoch + epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': avg_loss,
        }, checkpoint_path)
        print(f"Saved checkpoint: {checkpoint_path}")
    
    print(f"Epoch {epoch+ start_epoch + 1}/{start_epoch + total_epoch} - Loss: {avg_loss:.6f}")
    
torch.save({
    'model_state_dict': model.state_dict(),
    'epoch': epoch,
    'loss': avg_loss,
    'config': {'num_keypoints': NUM_KEYPOINTS, 'num_points': NUM_POINTS}
}, "saved_models/guitarnet_final_2.pth")

print("Training completed!")

Dataset size: 587
Resuming training from epoch 45
First batch - Input shape: torch.Size([128, 4096, 3]), Target shape: torch.Size([128, 9, 3])
First batch loss: 0.001150
Pred range: [-0.816, 1.070]
Target range: [-0.831, 1.008]
Epoch 46/95 - Loss: 0.000974
Epoch 47/95 - Loss: 0.000977
Epoch 48/95 - Loss: 0.000943
Epoch 49/95 - Loss: 0.000963
Epoch 50/95 - Loss: 0.000947
Epoch 51/95 - Loss: 0.000884
Epoch 52/95 - Loss: 0.000897
Epoch 53/95 - Loss: 0.000903
Epoch 54/95 - Loss: 0.000859
Saved checkpoint: train_checkpoints/guitarnet_epoch_55.pth
Epoch 55/95 - Loss: 0.000909
Epoch 56/95 - Loss: 0.000975
Epoch 57/95 - Loss: 0.000906
Epoch 58/95 - Loss: 0.000824
Epoch 59/95 - Loss: 0.000783
Epoch 60/95 - Loss: 0.000766
Epoch 61/95 - Loss: 0.000803
Epoch 62/95 - Loss: 0.000760
Epoch 63/95 - Loss: 0.000781
Epoch 64/95 - Loss: 0.000728
Saved checkpoint: train_checkpoints/guitarnet_epoch_65.pth
Epoch 65/95 - Loss: 0.000745
Epoch 66/95 - Loss: 0.000768
Epoch 67/95 - Loss: 0.001082
Epoch 68/95 - Lo

In [None]:
torch.save({
    'model_state_dict': model.state_dict(),
    'epoch': epoch,
    'loss': avg_loss,
    'config': {'num_keypoints': NUM_KEYPOINTS, 'num_points': NUM_POINTS}
}, "saved_models/guitarnet_final_2.pth")

In [174]:
save_model_complete(model, optimizer, epoch, loss, save_dir="saved_models/")

Model saved successfully!
Full checkpoint: saved_models/capnet_keypoint_model_20250620_125801.pth
Model only: saved_models/capnet_keypoint_model_20250620_125801_model_only.pth
Info file: saved_models/capnet_keypoint_model_20250620_125801_info.json


('saved_models/capnet_keypoint_model_20250620_125801.pth',
 'saved_models/capnet_keypoint_model_20250620_125801_model_only.pth',
 'saved_models/capnet_keypoint_model_20250620_125801_info.json')

In [15]:
def load_saved_model(model_path):
    """Load your saved model"""
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    # Create model with same architecture
    model = PointNetKeypointRegressor(NUM_KEYPOINTS)
    
    # Load checkpoint
    checkpoint = torch.load(model_path, map_location=device)
    model.load_state_dict(checkpoint['model_state_dict'])
    model.to(device)
    model.eval()
    
    print(f"Model loaded! Trained for {checkpoint['epoch']+1} epochs, final loss: {checkpoint['loss']:.6f}")
    return model

In [11]:
def predict_keypoints(model, ply_file_path):
    """Predict keypoints for a single cap .ply file"""
    device = next(model.parameters()).device
    
    try:
        # Load and process mesh (same as your training preprocessing)
        mesh = trimesh.load(ply_file_path, force='mesh')
        if mesh.is_empty or len(mesh.faces) == 0:
            raise ValueError("Empty mesh")
        
        # Sample points
        points, _ = trimesh.sample.sample_surface(mesh, NUM_POINTS)
        if points.shape[0] < NUM_POINTS:
            pad_size = NUM_POINTS - points.shape[0]
            pad = np.repeat(points[0:1, :], pad_size, axis=0)
            points = np.vstack((points, pad))
        
        # Normalize (exactly like training)
        centroid = np.mean(points, axis=0)
        scale = np.max(np.linalg.norm(points - centroid, axis=1))
        normalized_points = (points - centroid) / scale if scale > 0 else points - centroid
        
        # Convert to tensor and add batch dimension
        points_tensor = torch.from_numpy(normalized_points).float().unsqueeze(0).to(device)
        
        # Predict
        with torch.no_grad():
            normalized_keypoints = model(points_tensor).cpu().numpy().squeeze(0)
        
        # Denormalize keypoints back to original scale
        original_keypoints = normalized_keypoints * scale + centroid if scale > 0 else normalized_keypoints + centroid
        
        return original_keypoints, normalized_keypoints
        
    except Exception as e:
        print(f"Error processing {ply_file_path}: {e}")
        return None, None

In [12]:
def predict_multiple(model, ply_folder):
    """Predict keypoints for all .ply files in a folder"""
    results = {}
    
    ply_files = [os.path.join(ply_folder, f) for f in os.listdir(ply_folder) if f.endswith('.ply')]
    print(f"Found {len(ply_files)} .ply files")
    
    for ply_file in ply_files:
        print(f"Processing {os.path.basename(ply_file)}...")
        keypoints, normalized_kp = predict_cap_keypoints(model, ply_file)
        
        if keypoints is not None:
            results[os.path.basename(ply_file)] = {
                'original_keypoints': keypoints,
                'normalized_keypoints': normalized_kp
            }
            print(f"  Success! Predicted {len(keypoints)} keypoints")
        else:
            print(f"  Failed to process {ply_file}")
    
    return results

In [13]:
model = load_saved_model("saved_models/guitarnet_final.pth")

Model loaded! Trained for 32 epochs, final loss: 0.001158


In [14]:
keypoints, norm_kp = predict_keypoints(model, "a7dd.ply")
if keypoints is not None:
    print("Predicted keypoints:")
    for i, kp in enumerate(keypoints):
        print(f"  Keypoint {i+1}: [{kp[0]:.3f}, {kp[1]:.3f}, {kp[2]:.3f}]")

Predicted keypoints:
  Keypoint 1: [-0.005, 0.345, 0.030]
  Keypoint 2: [0.001, -0.181, 0.000]
  Keypoint 3: [0.099, -0.168, 0.018]
  Keypoint 4: [-0.108, -0.169, 0.023]
  Keypoint 5: [0.153, -0.557, 0.023]
  Keypoint 6: [-0.138, -0.552, 0.031]
  Keypoint 7: [-0.012, -0.456, -0.006]
  Keypoint 8: [0.081, -0.336, 0.017]
  Keypoint 9: [-0.079, -0.325, 0.018]
