In [1]:
import os
import json
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import numpy as np
import trimesh
from datetime import datetime 

from pointnet2_utils import PointNetSetAbstractionMsg, PointNetSetAbstraction
from pointnet2_keypoint_regressor import get_model, get_model_msg

from torch.utils.data import random_split

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")


Using device: cuda


In [202]:
with open("guitar.json") as f:
    annotations = json.load(f)

model_id_to_keypoints = {}
for entry in annotations:
    model_id = entry['model_id']
    keypoints = [kp['xyz'] for kp in entry['keypoints']]
    keypoints = np.array(keypoints, dtype=np.float32)
    model_id_to_keypoints[model_id] = keypoints



In [3]:
# Using minimal PointNet (not full PointNet++)
class PointNetKeypointRegressor(nn.Module):
    def __init__(self, num_keypoints):
        super().__init__()
        self.sa1 = nn.Sequential(
            nn.Conv1d(3, 64, 1),
            nn.BatchNorm1d(64),
            nn.ReLU(),
            nn.Conv1d(64, 128, 1),
            nn.BatchNorm1d(128),
            nn.ReLU(),
            nn.Conv1d(128, 1024, 1),
            nn.BatchNorm1d(1024),
            nn.ReLU()
        )
        self.fc = nn.Sequential(
            nn.Linear(1024, 512),
            nn.ReLU(),
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Linear(256, num_keypoints * 3)
        )

    def forward(self, x):
        # x: [B, N, 3]
        x = x.permute(0, 2, 1)  # [B, 3, N]
        x = self.sa1(x)  # [B, 1024, N]
        x = torch.max(x, 2)[0]  # [B, 1024]
        x = self.fc(x)
        x = x.view(-1, NUM_KEYPOINTS, 3)
        return x




In [3]:
class ShapeNetKeypointDataset(Dataset):
    def __init__(self, mesh_dir, annotation_json):
        self.mesh_dir = mesh_dir
        self.samples = []

        #count_valid = 0
        #count_total = 0

        # Load annotations
        with open(annotation_json) as f:
            annotations = json.load(f)

        for entry in annotations:
            model_id = entry['model_id']
            #print(model_id)
            keypoints = np.array([kp['xyz'] for kp in entry['keypoints']], dtype=np.float32)
            #count_total += 1
            if keypoints.shape[0] != NUM_KEYPOINTS:
                continue  # Strictly filter only 6-keypoint meshes
            self.samples.append((model_id, keypoints))
            #count_valid += 1

        #print(f"Total meshes: {count_total}, Valid 6-keypoint meshes: {count_valid}")
        
    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        model_id, keypoints = self.samples[idx]

        #print("Inside __getitem__: keypoints.shape =", keypoints.shape) 

        mesh_path = os.path.join(self.mesh_dir, model_id + ".ply")

        try:
            mesh = trimesh.load(mesh_path, force='mesh')
            if mesh.is_empty or len(mesh.faces) == 0:
                raise ValueError("Empty mesh")

            points, _ = trimesh.sample.sample_surface(mesh, NUM_POINTS)
            if points.shape[0] < NUM_POINTS:
                pad_size = NUM_POINTS - points.shape[0]
                pad = np.repeat(points[0:1, :], pad_size, axis=0)
                points = np.vstack((points, pad))
        except Exception as e:
            print(f"Error loading mesh {model_id}: {e}")
            points = np.zeros((NUM_POINTS, 3), dtype=np.float32)
            keypoints = np.zeros((NUM_KEYPOINTS, 3), dtype=np.float32)

        # Normalize
        #print("Points Min:", np.min(points), "Max:", np.max(points))

        centroid = np.mean(points, axis=0)
        scale = np.max(np.linalg.norm(points - centroid, axis=1))
        points = (points - centroid) / scale
        keypoints = (keypoints - centroid) / scale


        return torch.from_numpy(points).float(), torch.from_numpy(keypoints).float()


In [5]:
device

device(type='cuda')

In [4]:
NUM_POINTS = 1024 # Minimum 512 for input, ~1000 for neighborhood density; 2^14 = 16384
NUM_KEYPOINTS = 9 # 2 meshes have 5 keypoints idk why
BATCH_SIZE = 128

In [10]:
# Training Loop

dataset = ShapeNetKeypointDataset(mesh_dir="Guitars/9_Augmented", annotation_json="guitar_9_augmented.json")  

total_size = len(dataset)
val_size = int(0.2 * total_size)
train_size = total_size - val_size

train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=8)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=8)

#dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=8)

print(f"Train dataset size: {len(train_dataset)}")
print(f"Test dataset size: {len(val_dataset)}")

train_losses = []
val_losses = []



#model = PointNetKeypointRegressor(NUM_KEYPOINTS).to(device)
model = get_model(num_keypoints=NUM_KEYPOINTS, normal_channel=False).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

start_epoch = 0

if False: # For resuming.
    os.makedirs("train_checkpoints", exist_ok=True)

    #resume_checkpoint_path = "saved_models/guitarnet_final.pth"
    resume_checkpoint_path = "train_checkpoints/guitarnet++_epoch_50.pth"
    if os.path.exists(resume_checkpoint_path):
        checkpoint = torch.load(resume_checkpoint_path)
        model.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        start_epoch = checkpoint['epoch'] + 1
        train_losses = checkpoint.get('train_losses', [])
        val_losses = checkpoint.get('val_losses', [])

        print(f"Resuming training from epoch {start_epoch}")


total_epoch = 50 # Additional epochs to train 
for epoch in range(total_epoch):
    model.train()
    total_train_loss = 0
    #num_batches = 0
    
    for batch_idx, (pts, kps) in enumerate(train_loader):
        try:
            pts, kps = pts.to(device), kps.to(device)
            pts = pts.permute(0, 2, 1) #Permute to fit input size
            
            # Forward pass
            preds, _ = model(pts)
            loss = F.mse_loss(preds, kps)
            #loss = nn.MSELoss() # Not sure why this one is better tbh
            
            # Back pass
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            # Accumulate loss
            total_train_loss += loss.item()
            #num_batches += 1
            
            #if epoch == 0 and batch_idx == 0:
            if batch_idx == 0:
                #print(f"First batch - Input shape: {pts.shape}, Target shape: {kps.shape}")
                #print(f"First batch loss: {loss.item():.6f}")
                print(f"Pred range: [{preds.min():.3f}, {preds.max():.3f}]")
                print(f"Target range: [{kps.min():.3f}, {kps.max():.3f}]")
            
        except Exception as e:
            print(f"Error in batch {batch_idx}: {e}")
            continue
    
    # Average train loss
    avg_train_loss = total_train_loss / len(train_loader)
    train_losses.append(avg_train_loss)


    model.eval()
    total_val_loss = 0
    with torch.no_grad():
        for pts, kps in val_loader:
            pts, kps = pts.to(device), kps.to(device)
            pts = pts.permute(0,2,1)
            preds, _ = model(pts)
            val_loss = F.mse_loss(preds, kps)
            total_val_loss += val_loss.item()

    avg_val_loss = total_val_loss/len(val_loader)
    val_losses.append(avg_val_loss)

    if True:
        if epoch + start_epoch + 1 % 5 == 0:
                for param_group in optimizer.param_groups:
                    param_group['lr'] *= 0.8
                print("Force-reduced LR to 80%")

    if False:
        if epoch + start_epoch + 1 == 20:
            for param_group in optimizer.param_groups:
                param_group['lr'] = 1e-4
            print("Reduced LR to 1e-4")
    if False: #For checkpointing
        if (epoch + 1) % 10 == 0:
            checkpoint_path = f"train_checkpoints/guitarnet++_lin_3_epoch_{start_epoch + epoch+1}.pth"
            torch.save({
                'epoch': start_epoch + epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'loss': avg_train_loss,
                'train_losses': train_losses,
                'val_losses': val_losses
            }, checkpoint_path)
            print(f"Saved checkpoint: {checkpoint_path}")
    
    print(f"Epoch {epoch+ start_epoch + 1}/{start_epoch + total_epoch}\nTrain Loss: {avg_train_loss:.6f}; Validation Loss: {avg_val_loss:.6f}")
if False: #For save when done
    torch.save({
        'model_state_dict': model.state_dict(),
        'epoch': epoch,
        'loss': avg_train_loss,
        'config': {'num_keypoints': NUM_KEYPOINTS, 'num_points': NUM_POINTS}
    }, "saved_models/guitarnet++_lin_3_final.pth")

print("Training completed!")

Train dataset size: 2348
Test dataset size: 587
Pred range: [-0.562, 0.643]
Target range: [-1.006, 1.010]
Epoch 1/50
Train Loss: 0.061864; Validation Loss: 0.078836
Pred range: [-1.589, 1.309]
Target range: [-1.019, 1.024]
Epoch 2/50
Train Loss: 0.009446; Validation Loss: 0.076760
Pred range: [-1.165, 1.176]
Target range: [-1.004, 1.019]
Epoch 3/50
Train Loss: 0.005098; Validation Loss: 0.078935
Pred range: [-1.122, 1.128]
Target range: [-1.008, 1.017]
Epoch 4/50
Train Loss: 0.004269; Validation Loss: 0.006506
Pred range: [-1.094, 1.233]
Target range: [-1.000, 1.026]
Epoch 5/50
Train Loss: 0.004929; Validation Loss: 0.003952
Pred range: [-1.178, 1.042]
Target range: [-1.002, 1.058]
Epoch 6/50
Train Loss: 0.003864; Validation Loss: 0.003656
Pred range: [-0.999, 0.992]
Target range: [-1.014, 1.022]
Epoch 7/50
Train Loss: 0.003378; Validation Loss: 0.003265
Pred range: [-1.072, 1.074]
Target range: [-1.008, 1.045]
Epoch 8/50
Train Loss: 0.003724; Validation Loss: 0.003114
Pred range: [-1.

In [None]:
"""Train dataset size: 2348
Test dataset size: 587
First batch - Input shape: torch.Size([256, 3, 1024]), Target shape: torch.Size([256, 9, 3])
First batch loss: 0.288146
Pred range: [-1.969, 1.959]
Target range: [-1.001, 1.021]
Epoch 1/50
Train Loss: 0.150820; Validation Loss: 0.103047
Epoch 2/50
Train Loss: 0.066115; Validation Loss: 0.124546
Epoch 3/50
Train Loss: 0.043645; Validation Loss: 0.078913
Epoch 4/50
Train Loss: 0.030031; Validation Loss: 0.087326
Epoch 5/50
Train Loss: 0.022437; Validation Loss: 0.085578
Epoch 6/50
Train Loss: 0.019087; Validation Loss: 0.046889
Epoch 7/50
Train Loss: 0.015242; Validation Loss: 0.013829
Epoch 8/50
Train Loss: 0.012592; Validation Loss: 0.006803
Epoch 9/50
Train Loss: 0.011893; Validation Loss: 0.005015
Epoch 10/50
Train Loss: 0.011156; Validation Loss: 0.009370"""

NameError: name 'num_keypoints' is not defined

In [None]:
torch.cuda.empty_cache()

In [11]:
torch.save({
    'model_state_dict': model.state_dict(),
    'epoch': epoch,
    'loss': [avg_train_loss, avg_val_loss],
    'config': {'num_keypoints': NUM_KEYPOINTS, 'num_points': NUM_POINTS}
}, "saved_models/guitarnet++_lin_4_final.pth")

In [8]:
model = get_model_msg(num_keypoints=NUM_KEYPOINTS, normal_channel=False)

TypeError: super(type, obj): obj must be an instance or subtype of type

In [6]:
def load_saved_model(model_path):
    """Load your saved model"""
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    # Create model with same architecture
    model = get_model_msg(num_keypoints=NUM_KEYPOINTS, normal_channel=False).to(device)
    
    # Load checkpoint
    checkpoint = torch.load(model_path, map_location=device)
    #model.load_state_dict(checkpoint['model_state_dict'])
    model.to(device)
    model.eval()
    
    #print(f"Model loaded! Trained for {checkpoint['epoch']+1} epochs, final loss: {checkpoint['loss'][0]:.6f}, {checkpoint['loss'][1]:.6f}")
    return model

In [7]:
model = load_saved_model("saved_models/guitarnet++_final.pth")

TypeError: super(type, obj): obj must be an instance or subtype of type

In [41]:
def predict_keypoints(model, ply_file_path):
    """Predict keypoints for a single cap .ply file"""
    device = next(model.parameters()).device
    
    try:
        # Load and process mesh (same as your training preprocessing)
        mesh = trimesh.load(ply_file_path, force='mesh')
        if mesh.is_empty or len(mesh.faces) == 0:
            raise ValueError("Empty mesh")
        
        # Sample points
        points, _ = trimesh.sample.sample_surface(mesh, NUM_POINTS)
        if points.shape[0] < NUM_POINTS:
            pad_size = NUM_POINTS - points.shape[0]
            pad = np.repeat(points[0:1, :], pad_size, axis=0)
            points = np.vstack((points, pad))
        
        # Normalize (exactly like training)
        centroid = np.mean(points, axis=0)
        scale = np.max(np.linalg.norm(points - centroid, axis=1))
        normalized_points = (points - centroid) / scale if scale > 0 else points - centroid
        
        # Convert to tensor and add batch dimension
        points_tensor = torch.from_numpy(normalized_points).float().unsqueeze(0).to(device)
        points_tensor = points_tensor.permute(0, 2, 1)
        
        # Predict
        with torch.no_grad():
            normalized_keypoints = model(points_tensor).cpu().numpy().squeeze(0)
        
        # Denormalize keypoints back to original scale
        original_keypoints = normalized_keypoints * scale + centroid if scale > 0 else normalized_keypoints + centroid
        
        return original_keypoints, normalized_keypoints
        
    except Exception as e:
        print(f"Error processing {ply_file_path}: {e}")
        return None, None

In [66]:
def predict_keypoints(model, ply_file_path):
    device = next(model.parameters()).device

   
    mesh = trimesh.load(ply_file_path, force='mesh')
    if mesh.is_empty or len(mesh.faces) == 0:
        raise ValueError("Empty mesh")

    points, _ = trimesh.sample.sample_surface(mesh, NUM_POINTS)
    if points.shape[0] < NUM_POINTS:
        pad_size = NUM_POINTS - points.shape[0]
        pad = np.repeat(points[0:1, :], pad_size, axis=0)
        points = np.vstack((points, pad))

    centroid = np.mean(points, axis=0)
    scale = np.max(np.linalg.norm(points - centroid, axis=1))
    normalized_points = (points - centroid) / scale if scale > 0 else points - centroid

    points_tensor = torch.from_numpy(normalized_points).float().unsqueeze(0).to(device)
    points_tensor = points_tensor.permute(0, 2, 1)

    with torch.no_grad():
        #normalized_keypoints = model(points_tensor).cpu().numpy().squeeze(0)
        output = model(points_tensor)
        normalized_keypoints = output[0].cpu().numpy().squeeze(0)
        

    original_keypoints = normalized_keypoints * scale + centroid if scale > 0 else normalized_keypoints + centroid

    return original_keypoints, normalized_keypoints

In [68]:
def predict_keypoints_no_normalization(model, ply_file_path):
    device = next(model.parameters()).device

    try:
        mesh = trimesh.load(ply_file_path, force='mesh')
        if mesh.is_empty or len(mesh.faces) == 0:
            raise ValueError("Empty mesh")

        points, _ = trimesh.sample.sample_surface(mesh, NUM_POINTS)
        if points.shape[0] < NUM_POINTS:
            pad_size = NUM_POINTS - points.shape[0]
            pad = np.repeat(points[0:1, :], pad_size, axis=0)
            points = np.vstack((points, pad))

        points_tensor = torch.from_numpy(points).float().unsqueeze(0).to(device)
        points_tensor = points_tensor.permute(0, 2, 1)

        with torch.no_grad():
            predicted_keypoints = model(points_tensor)[0].cpu().numpy().squeeze(0)

        return predicted_keypoints, predicted_keypoints  # same, no normalization

    except Exception as e:
        print(f"Error: {e}")
        return None, None


In [40]:
def predict_multiple(model, ply_folder):
    """Predict keypoints for all .ply files in a folder"""
    results = {}
    
    ply_files = [os.path.join(ply_folder, f) for f in os.listdir(ply_folder) if f.endswith('.ply')]
    print(f"Found {len(ply_files)} .ply files")
    
    for ply_file in ply_files:
        print(f"Processing {os.path.basename(ply_file)}...")
        keypoints, normalized_kp = predict_keypoints(model, ply_file)
        
        if keypoints is not None:
            results[os.path.basename(ply_file)] = {
                'original_keypoints': keypoints,
                'normalized_keypoints': normalized_kp
            }
            print(f"  Success! Predicted {len(keypoints)} keypoints")
        else:
            print(f"  Failed to process {ply_file}")
    
    return results

In [65]:
predict_keypoints_no_normalization(model, "Guitars/9_points/ef1c22bd3b74953689f0379846507dd3.ply")

Error: 'tuple' object has no attribute 'cpu'


(None, None)

In [71]:
model = load_saved_model("saved_models/guitarnet++_2_final.pth")

In [72]:
keypoints, norm_kp = predict_keypoints(model, "Guitars/9_points/ef1c22bd3b74953689f0379846507dd3.ply")
if keypoints is not None:
    #print("Predicted keypoints:")
    for i, kp in enumerate(keypoints):
        print(f"{kp[0]:.3f} {kp[1]:.3f} {kp[2]:.3f}")

-0.024 -0.234 0.012
0.020 -0.227 0.003
0.037 -0.202 0.002
0.035 -0.208 0.046
-0.018 -0.232 0.033
0.015 -0.222 0.023
-0.017 -0.235 -0.006
-0.015 -0.241 0.039
0.041 -0.233 -0.015


In [9]:
NUM_POINTS = 2048  # or match what you used in training

In [10]:
def predict_keypoints(model, ply_file_path, device='cuda' if torch.cuda.is_available() else 'cpu'):
    """
    Load a .ply mesh, sample and normalize points, and predict keypoints using the trained model.
    Returns: (N, 3) keypoint array in mesh coordinates
    """
    # Load and check mesh
    mesh = trimesh.load(ply_file_path, force='mesh')
    if mesh.is_empty or len(mesh.faces) == 0:
        raise ValueError("Empty or invalid mesh.")

    # Sample surface points
    points, _ = trimesh.sample.sample_surface(mesh, NUM_POINTS)
    if points.shape[0] < NUM_POINTS:
        pad_size = NUM_POINTS - points.shape[0]
        pad = np.repeat(points[0:1, :], pad_size, axis=0)
        points = np.vstack((points, pad))

    # Normalize (same as training)
    centroid = np.mean(points, axis=0)
    scale = np.max(np.linalg.norm(points - centroid, axis=1))
    normalized_points = (points - centroid) / scale

    # Convert to model input
    points_tensor = torch.from_numpy(normalized_points).float().unsqueeze(0)  # shape (1, N, 3)
    points_tensor = points_tensor.permute(0, 2, 1).to(device)  # (B, 3, N)

    # Predict
    model = model.to(device)
    model.eval()
    with torch.no_grad():
        keypoints_normalized, _ = model(points_tensor)  # (1, K, 3)

    keypoints_normalized = keypoints_normalized.squeeze(0).cpu().numpy()
    keypoints = keypoints_normalized * scale + centroid  # restore original coordinates

    return keypoints


In [13]:
model = load_saved_model("saved_models/guitarnet++_lin_4_final.pth")

TypeError: super(type, obj): obj must be an instance or subtype of type

In [None]:
predict_keypoints()