In [1]:
import os
import json
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import numpy as np
import trimesh
from datetime import datetime 

from pointnet2_utils import PointNetSetAbstractionMsg, PointNetSetAbstraction
from pointnet2_keypoint_regressor import get_model

from torch.utils.data import random_split

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")


Using device: cuda


In [202]:
with open("guitar.json") as f:
    annotations = json.load(f)

model_id_to_keypoints = {}
for entry in annotations:
    model_id = entry['model_id']
    keypoints = [kp['xyz'] for kp in entry['keypoints']]
    keypoints = np.array(keypoints, dtype=np.float32)
    model_id_to_keypoints[model_id] = keypoints



In [3]:
# Using minimal PointNet (not full PointNet++)
class PointNetKeypointRegressor(nn.Module):
    def __init__(self, num_keypoints):
        super().__init__()
        self.sa1 = nn.Sequential(
            nn.Conv1d(3, 64, 1),
            nn.BatchNorm1d(64),
            nn.ReLU(),
            nn.Conv1d(64, 128, 1),
            nn.BatchNorm1d(128),
            nn.ReLU(),
            nn.Conv1d(128, 1024, 1),
            nn.BatchNorm1d(1024),
            nn.ReLU()
        )
        self.fc = nn.Sequential(
            nn.Linear(1024, 512),
            nn.ReLU(),
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Linear(256, num_keypoints * 3)
        )

    def forward(self, x):
        # x: [B, N, 3]
        x = x.permute(0, 2, 1)  # [B, 3, N]
        x = self.sa1(x)  # [B, 1024, N]
        x = torch.max(x, 2)[0]  # [B, 1024]
        x = self.fc(x)
        x = x.view(-1, NUM_KEYPOINTS, 3)
        return x




In [4]:
class ShapeNetKeypointDataset(Dataset):
    def __init__(self, mesh_dir, annotation_json):
        self.mesh_dir = mesh_dir
        self.samples = []

        #count_valid = 0
        #count_total = 0

        # Load annotations
        with open(annotation_json) as f:
            annotations = json.load(f)

        for entry in annotations:
            model_id = entry['model_id']
            #print(model_id)
            keypoints = np.array([kp['xyz'] for kp in entry['keypoints']], dtype=np.float32)
            #count_total += 1
            if keypoints.shape[0] != NUM_KEYPOINTS:
                continue  # Strictly filter only 6-keypoint meshes
            self.samples.append((model_id, keypoints))
            #count_valid += 1

        #print(f"Total meshes: {count_total}, Valid 6-keypoint meshes: {count_valid}")
        
    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        model_id, keypoints = self.samples[idx]

        #print("Inside __getitem__: keypoints.shape =", keypoints.shape) 

        mesh_path = os.path.join(self.mesh_dir, model_id + ".ply")

        try:
            mesh = trimesh.load(mesh_path, force='mesh')
            if mesh.is_empty or len(mesh.faces) == 0:
                raise ValueError("Empty mesh")

            points, _ = trimesh.sample.sample_surface(mesh, NUM_POINTS)
            if points.shape[0] < NUM_POINTS:
                pad_size = NUM_POINTS - points.shape[0]
                pad = np.repeat(points[0:1, :], pad_size, axis=0)
                points = np.vstack((points, pad))
        except Exception as e:
            print(f"Error loading mesh {model_id}: {e}")
            points = np.zeros((NUM_POINTS, 3), dtype=np.float32)
            keypoints = np.zeros((NUM_KEYPOINTS, 3), dtype=np.float32)

        # Normalize
        #print("Points Min:", np.min(points), "Max:", np.max(points))

        centroid = np.mean(points, axis=0)
        scale = np.max(np.linalg.norm(points - centroid, axis=1))
        points = (points - centroid) / scale
        keypoints = (keypoints - centroid) / scale


        return torch.from_numpy(points).float(), torch.from_numpy(keypoints).float()


In [5]:
NUM_POINTS = 8192 #Maybe do 512
NUM_KEYPOINTS = 9 # 2 meshes have 5 keypoints idk why
BATCH_SIZE = 32

In [19]:
# Training Loop

dataset = ShapeNetKeypointDataset(mesh_dir="Guitars/9_points/", annotation_json="guitar_train.json")  

total_size = len(dataset)
val_size = int(0.2 * total_size)
train_size = total_size - val_size

train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=8)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=8)

#dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=8)

print(f"Dataset size: {len(dataset)}")



#model = PointNetKeypointRegressor(NUM_KEYPOINTS).to(device)
model = get_model(num_keypoints=NUM_KEYPOINTS, normal_channel=False).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

start_epoch = 0

if False: # For resuming.
    os.makedirs("train_checkpoints", exist_ok=True)

    #resume_checkpoint_path = "saved_models/guitarnet_final.pth"
    resume_checkpoint_path = "train_checkpoints/guitarnet_epoch_45.pth"
    if os.path.exists(resume_checkpoint_path):
        checkpoint = torch.load(resume_checkpoint_path)
        model.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        start_epoch = checkpoint['epoch'] + 1
        print(f"Resuming training from epoch {start_epoch}")


total_epoch = 50 # Additional epochs to train 
for epoch in range(total_epoch):
    model.train()
    total_train_loss = 0
    #num_batches = 0
    
    for batch_idx, (pts, kps) in enumerate(train_loader):
        try:
            pts, kps = pts.to(device), kps.to(device)
            pts = pts.permute(0, 2, 1) #Permute to fit input size
            
            # Forward pass
            preds = model(pts)
            loss = F.mse_loss(preds, kps)
            #loss = nn.MSELoss() # Not sure why this one is better tbh
            
            # Back pass
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            # Accumulate loss
            total_train_loss += loss.item()
            #num_batches += 1
            
            if epoch == 0 and batch_idx == 0:
                print(f"First batch - Input shape: {pts.shape}, Target shape: {kps.shape}")
                print(f"First batch loss: {loss.item():.6f}")
                print(f"Pred range: [{preds.min():.3f}, {preds.max():.3f}]")
                print(f"Target range: [{kps.min():.3f}, {kps.max():.3f}]")
            
        except Exception as e:
            print(f"Error in batch {batch_idx}: {e}")
            continue
    
    # Average train loss
    avg_train_loss = total_train_loss / len(train_loader)


    model.eval()
    total_val_loss = 0
    with torch.no_grad():
        for pts, kps in val_loader:
            pts, kps = pts.to(device), kps.to(device)
            pts = pts.permute(0,2,1)
            preds = model(pts)
            val_loss = F.mse_loss(preds, kps)
            total_val_loss += val_loss.item()

    avg_val_loss = total_val_loss/len(val_loader)
    
    if (epoch + 1) % 10 == 0:
        checkpoint_path = f"train_checkpoints/guitarnet++_epoch_{start_epoch + epoch+1}.pth"
        torch.save({
            'epoch': start_epoch + epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': avg_train_loss,
        }, checkpoint_path)
        print(f"Saved checkpoint: {checkpoint_path}")
    
    print(f"Epoch {epoch+ start_epoch + 1}/{start_epoch + total_epoch}\nTrain Loss: {avg_train_loss:.6f}; Validation Loss: {avg_val_loss:.6f}")
    
torch.save({
    'model_state_dict': model.state_dict(),
    'epoch': epoch,
    'loss': avg_train_loss,
    'config': {'num_keypoints': NUM_KEYPOINTS, 'num_points': NUM_POINTS}
}, "saved_models/guitarnet++_final.pth")

print("Training completed!")

Dataset size: 469
First batch - Input shape: torch.Size([32, 3, 8192]), Target shape: torch.Size([32, 9, 3])
First batch loss: 0.402841
Pred range: [-1.905, 1.477]
Target range: [-0.749, 1.001]
Epoch 1/50
Train Loss: 0.283798; Validation Loss: 0.035835
Epoch 2/50
Train Loss: 0.172636; Validation Loss: 0.042654
Epoch 3/50
Train Loss: 0.121219; Validation Loss: 0.051110
Epoch 4/50
Train Loss: 0.093375; Validation Loss: 0.024118
Epoch 5/50
Train Loss: 0.077583; Validation Loss: 0.013700
Epoch 6/50
Train Loss: 0.061444; Validation Loss: 0.010143
Epoch 7/50
Train Loss: 0.048456; Validation Loss: 0.008139
Epoch 8/50
Train Loss: 0.048267; Validation Loss: 0.012978
Epoch 9/50
Train Loss: 0.036501; Validation Loss: 0.008710
Saved checkpoint: train_checkpoints/guitarnet++_epoch_10.pth
Epoch 10/50
Train Loss: 0.030453; Validation Loss: 0.005229
Epoch 11/50
Train Loss: 0.022284; Validation Loss: 0.005536
Epoch 12/50
Train Loss: 0.020704; Validation Loss: 0.004808
Epoch 13/50
Train Loss: 0.018637; 

In [20]:
torch.cuda.empty_cache()

In [14]:
torch.save({
    'model_state_dict': model.state_dict(),
    'epoch': epoch,
    'loss': avg_loss,
    'config': {'num_keypoints': NUM_KEYPOINTS, 'num_points': NUM_POINTS}
}, "saved_models/guitarnet++_epoch_12.pth")

In [33]:
def load_saved_model(model_path):
    """Load your saved model"""
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    # Create model with same architecture
    model = get_model(num_keypoints=NUM_KEYPOINTS, normal_channel=False).to(device)
    
    # Load checkpoint
    checkpoint = torch.load(model_path, map_location=device)
    model.load_state_dict(checkpoint['model_state_dict'])
    model.to(device)
    model.eval()
    
    print(f"Model loaded! Trained for {checkpoint['epoch']+1} epochs, final loss: {checkpoint['loss']:.6f}")
    return model

In [None]:
def predict_keypoints(model, ply_file_path):
    """Predict keypoints for a single cap .ply file"""
    device = next(model.parameters()).device
    
    try:
        # Load and process mesh (same as your training preprocessing)
        mesh = trimesh.load(ply_file_path, force='mesh')
        if mesh.is_empty or len(mesh.faces) == 0:
            raise ValueError("Empty mesh")
        
        # Sample points
        points, _ = trimesh.sample.sample_surface(mesh, NUM_POINTS)
        if points.shape[0] < NUM_POINTS:
            pad_size = NUM_POINTS - points.shape[0]
            pad = np.repeat(points[0:1, :], pad_size, axis=0)
            points = np.vstack((points, pad))
        
        # Normalize (exactly like training)
        centroid = np.mean(points, axis=0)
        scale = np.max(np.linalg.norm(points - centroid, axis=1))
        normalized_points = (points - centroid) / scale if scale > 0 else points - centroid
        
        # Convert to tensor and add batch dimension
        points_tensor = torch.from_numpy(normalized_points).float().unsqueeze(0).to(device)
        points_tensor = points_tensor.permute(0, 2, 1)
        
        # Predict
        with torch.no_grad():
            normalized_keypoints = model(points_tensor).cpu().numpy().squeeze(0)
        
        # Denormalize keypoints back to original scale
        original_keypoints = normalized_keypoints * scale + centroid if scale > 0 else normalized_keypoints + centroid
        
        return original_keypoints, normalized_keypoints
        
    except Exception as e:
        print(f"Error processing {ply_file_path}: {e}")
        return None, None

In [38]:
def predict_keypoints(model, ply_file_path):
    device = next(model.parameters()).device

    try:
        mesh = trimesh.load(ply_file_path, force='mesh')
        if mesh.is_empty or len(mesh.faces) == 0:
            raise ValueError("Empty mesh")

        points, _ = trimesh.sample.sample_surface(mesh, NUM_POINTS)
        if points.shape[0] < NUM_POINTS:
            pad_size = NUM_POINTS - points.shape[0]
            pad = np.repeat(points[0:1, :], pad_size, axis=0)
            points = np.vstack((points, pad))

        centroid = np.mean(points, axis=0)
        scale = np.max(np.linalg.norm(points - centroid, axis=1))
        normalized_points = (points - centroid) / scale if scale > 0 else points - centroid

        points_tensor = torch.from_numpy(normalized_points).float().unsqueeze(0).to(device)
        points_tensor = points_tensor.permute(0, 2, 1)

        with torch.no_grad():
            normalized_keypoints = model(points_tensor).cpu().numpy().squeeze(0)

        original_keypoints = normalized_keypoints * scale + centroid if scale > 0 else normalized_keypoints + centroid

        return original_keypoints, normalized_keypoints

    except Exception as e:
        print(f"Error processing {ply_file_path}: {e}")
        return None, None


In [32]:
def predict_multiple(model, ply_folder):
    """Predict keypoints for all .ply files in a folder"""
    results = {}
    
    ply_files = [os.path.join(ply_folder, f) for f in os.listdir(ply_folder) if f.endswith('.ply')]
    print(f"Found {len(ply_files)} .ply files")
    
    for ply_file in ply_files:
        print(f"Processing {os.path.basename(ply_file)}...")
        keypoints, normalized_kp = predict_cap_keypoints(model, ply_file)
        
        if keypoints is not None:
            results[os.path.basename(ply_file)] = {
                'original_keypoints': keypoints,
                'normalized_keypoints': normalized_kp
            }
            print(f"  Success! Predicted {len(keypoints)} keypoints")
        else:
            print(f"  Failed to process {ply_file}")
    
    return results

In [39]:
model = load_saved_model("saved_models/guitarnet++_final.pth")

Model loaded! Trained for 50 epochs, final loss: 0.004393


In [43]:
keypoints, norm_kp = predict_keypoints(model, "Guitars/9_points/ef1c22bd3b74953689f0379846507dd3.ply")
if keypoints is not None:
    #print("Predicted keypoints:")
    for i, kp in enumerate(keypoints):
        print(f"{kp[0]:.3f} {kp[1]:.3f} {kp[2]:.3f}")

0.010 0.354 0.010
0.012 -0.211 -0.001
0.107 -0.164 0.013
-0.086 -0.198 0.010
0.125 -0.490 0.009
-0.095 -0.488 0.019
0.010 -0.514 0.009
0.099 -0.321 0.012
-0.078 -0.331 0.007
