In [2]:
import os
import json
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import numpy as np
import trimesh
from datetime import datetime 

from pointnet2_utils import PointNetSetAbstractionMsg, PointNetSetAbstraction
from pointnet2_keypoint_regressor import get_model#, get_model_msg

from torch.utils.data import random_split

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")


Using device: cuda


In [7]:
NUM_POINTS = 1024 # Minimum 512 for input, ~1000 for neighborhood density; 2^14 = 16384
NUM_KEYPOINTS = 9 # 2 meshes have 5 keypoints idk why
BATCH_SIZE = 128

In [10]:
# Training Loop

dataset = ShapeNetKeypointDataset(mesh_dir="Guitars/9_Augmented", annotation_json="guitar_9_augmented.json")  

total_size = len(dataset)
val_size = int(0.2 * total_size)
train_size = total_size - val_size

train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=8)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=8)

#dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=8)

print(f"Train dataset size: {len(train_dataset)}")
print(f"Test dataset size: {len(val_dataset)}")

train_losses = []
val_losses = []



#model = PointNetKeypointRegressor(NUM_KEYPOINTS).to(device)
model = get_model(num_keypoints=NUM_KEYPOINTS, normal_channel=False).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

start_epoch = 0

if False: # For resuming.
    os.makedirs("train_checkpoints", exist_ok=True)

    #resume_checkpoint_path = "saved_models/guitarnet_final.pth"
    resume_checkpoint_path = "train_checkpoints/guitarnet++_epoch_50.pth"
    if os.path.exists(resume_checkpoint_path):
        checkpoint = torch.load(resume_checkpoint_path)
        model.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        start_epoch = checkpoint['epoch'] + 1
        train_losses = checkpoint.get('train_losses', [])
        val_losses = checkpoint.get('val_losses', [])

        print(f"Resuming training from epoch {start_epoch}")


total_epoch = 50 # Additional epochs to train 
for epoch in range(total_epoch):
    model.train()
    total_train_loss = 0
    #num_batches = 0
    
    for batch_idx, (pts, kps) in enumerate(train_loader):
        try:
            pts, kps = pts.to(device), kps.to(device)
            pts = pts.permute(0, 2, 1) #Permute to fit input size
            
            # Forward pass
            preds, _ = model(pts)
            loss = F.mse_loss(preds, kps)
            #loss = nn.MSELoss() # Not sure why this one is better tbh
            
            # Back pass
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            # Accumulate loss
            total_train_loss += loss.item()
            #num_batches += 1
            
            #if epoch == 0 and batch_idx == 0:
            if batch_idx == 0:
                #print(f"First batch - Input shape: {pts.shape}, Target shape: {kps.shape}")
                #print(f"First batch loss: {loss.item():.6f}")
                print(f"Pred range: [{preds.min():.3f}, {preds.max():.3f}]")
                print(f"Target range: [{kps.min():.3f}, {kps.max():.3f}]")
            
        except Exception as e:
            print(f"Error in batch {batch_idx}: {e}")
            continue
    
    # Average train loss
    avg_train_loss = total_train_loss / len(train_loader)
    train_losses.append(avg_train_loss)


    model.eval()
    total_val_loss = 0
    with torch.no_grad():
        for pts, kps in val_loader:
            pts, kps = pts.to(device), kps.to(device)
            pts = pts.permute(0,2,1)
            preds, _ = model(pts)
            val_loss = F.mse_loss(preds, kps)
            total_val_loss += val_loss.item()

    avg_val_loss = total_val_loss/len(val_loader)
    val_losses.append(avg_val_loss)

    if True:
        if epoch + start_epoch + 1 % 5 == 0:
                for param_group in optimizer.param_groups:
                    param_group['lr'] *= 0.8
                print("Force-reduced LR to 80%")

    if False:
        if epoch + start_epoch + 1 == 20:
            for param_group in optimizer.param_groups:
                param_group['lr'] = 1e-4
            print("Reduced LR to 1e-4")
    if False: #For checkpointing
        if (epoch + 1) % 10 == 0:
            checkpoint_path = f"train_checkpoints/guitarnet++_lin_3_epoch_{start_epoch + epoch+1}.pth"
            torch.save({
                'epoch': start_epoch + epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'loss': avg_train_loss,
                'train_losses': train_losses,
                'val_losses': val_losses
            }, checkpoint_path)
            print(f"Saved checkpoint: {checkpoint_path}")
    
    print(f"Epoch {epoch+ start_epoch + 1}/{start_epoch + total_epoch}\nTrain Loss: {avg_train_loss:.6f}; Validation Loss: {avg_val_loss:.6f}")
if False: #For save when done
    torch.save({
        'model_state_dict': model.state_dict(),
        'epoch': epoch,
        'loss': avg_train_loss,
        'config': {'num_keypoints': NUM_KEYPOINTS, 'num_points': NUM_POINTS}
    }, "saved_models/guitarnet++_lin_3_final.pth")

print("Training completed!")

Train dataset size: 2348
Test dataset size: 587
Pred range: [-0.562, 0.643]
Target range: [-1.006, 1.010]
Epoch 1/50
Train Loss: 0.061864; Validation Loss: 0.078836
Pred range: [-1.589, 1.309]
Target range: [-1.019, 1.024]
Epoch 2/50
Train Loss: 0.009446; Validation Loss: 0.076760
Pred range: [-1.165, 1.176]
Target range: [-1.004, 1.019]
Epoch 3/50
Train Loss: 0.005098; Validation Loss: 0.078935
Pred range: [-1.122, 1.128]
Target range: [-1.008, 1.017]
Epoch 4/50
Train Loss: 0.004269; Validation Loss: 0.006506
Pred range: [-1.094, 1.233]
Target range: [-1.000, 1.026]
Epoch 5/50
Train Loss: 0.004929; Validation Loss: 0.003952
Pred range: [-1.178, 1.042]
Target range: [-1.002, 1.058]
Epoch 6/50
Train Loss: 0.003864; Validation Loss: 0.003656
Pred range: [-0.999, 0.992]
Target range: [-1.014, 1.022]
Epoch 7/50
Train Loss: 0.003378; Validation Loss: 0.003265
Pred range: [-1.072, 1.074]
Target range: [-1.008, 1.045]
Epoch 8/50
Train Loss: 0.003724; Validation Loss: 0.003114
Pred range: [-1.

In [11]:
torch.save({
    'model_state_dict': model.state_dict(),
    'epoch': epoch,
    'loss': [avg_train_loss, avg_val_loss],
    'config': {'num_keypoints': NUM_KEYPOINTS, 'num_points': NUM_POINTS}
}, "saved_models/guitarnet++_lin_4_final.pth")