In [25]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import json
from torch.utils.data import Dataset, DataLoader
from sklearn.neighbors import NearestNeighbors
import os
from typing import List, Tuple, Dict
import open3d as o3d

In [26]:
class KneeKeypointDataset(Dataset):
    """Dataset for loading knee point clouds and keypoint annotations."""
    
    def __init__(self, json_file: str, pointcloud_dir: str, max_points: int = 8192, 
                 surface_sampling_method: str = 'uniform', num_keypoints: int = 5):
        """
        Args:
            json_file: Path to JSON file with keypoint annotations
            pointcloud_dir: Directory containing STL mesh files
            max_points: Maximum number of points to sample from each mesh
            surface_sampling_method: 'uniform' or 'poisson' for surface sampling
            num_keypoints: Number of keypoints to use (useful for data cleaning with fewer keypoints)
        """
        with open(json_file, 'r') as f:
            self.annotations = json.load(f)
        
        self.pointcloud_dir = pointcloud_dir
        self.max_points = max_points
        self.sampling_method = surface_sampling_method
        self.num_keypoints = num_keypoints
        
        # Validate that annotations have the expected number of keypoints
        if len(self.annotations) > 0:
            expected_keypoints = len(self.annotations[0]['keypoints'])
            if expected_keypoints != num_keypoints:
                print(f"Warning: Expected {num_keypoints} keypoints but found {expected_keypoints} in annotation file")
        
        # Default keypoint names (you can modify these based on your specific needs)
        self.keypoint_names = ['front', 'left', 'right', 'thigh_center', 'shin_center'][:num_keypoints]
        
    def __len__(self):
        return len(self.annotations)
    
    def load_point_cloud(self, model_id: str) -> np.ndarray:
        """Load mesh from STL file and sample points from surface."""
        mesh_path = os.path.join(self.pointcloud_dir, f"{model_id}.stl")
        
        # Load mesh from STL file
        mesh = o3d.io.read_triangle_mesh(mesh_path)
        
        # Check if mesh is valid
        if len(mesh.vertices) == 0:
            raise ValueError(f"Failed to load mesh from {mesh_path}")
        
        # Sample points from mesh surface
        # Use more points than needed for better coverage
        num_sample_points = self.max_points * 2  # Sample more than we need
        
        # Method 1: Uniform sampling
        pcd = mesh.sample_points_uniformly(number_of_points=num_sample_points)
        
        # Alternative method 2: Poisson disk sampling (more even distribution)
        # pcd = mesh.sample_points_poisson_disk(number_of_points=num_sample_points)
        
        points = np.asarray(pcd.points)
        
        # If we didn't get enough points, use fewer
        if len(points) < self.max_points:
            print(f"Warning: Only sampled {len(points)} points from mesh {model_id}")
        
        return points
    
    def find_nearest_point_indices(self, points: np.ndarray, keypoint_coords: List[List[float]]) -> List[int]:
        """Find nearest point indices for each keypoint coordinate."""
        # Use KNN to find nearest points
        nbrs = NearestNeighbors(n_neighbors=1, algorithm='ball_tree').fit(points)
        
        keypoint_indices = []
        for coord in keypoint_coords:
            distances, indices = nbrs.kneighbors([coord])
            keypoint_indices.append(indices[0][0])
        
        return keypoint_indices
    
    def normalize_point_cloud(self, points: np.ndarray) -> np.ndarray:
        """Normalize point cloud to unit sphere."""
        # Center the point cloud
        centroid = np.mean(points, axis=0)
        points = points - centroid
        
        # Scale to unit sphere
        max_distance = np.max(np.linalg.norm(points, axis=1))
        points = points / max_distance
        
        return points
    
    def sample_points(self, points: np.ndarray, keypoint_indices: List[int]) -> Tuple[np.ndarray, List[int]]:
        """Sample points while preserving keypoint indices."""
        n_points = len(points)
        
        if n_points <= self.max_points:
            # Pad with zeros if needed
            padded_points = np.zeros((self.max_points, 3))
            padded_points[:n_points] = points
            return padded_points, keypoint_indices
        
        # Always include keypoint indices in sampling
        keypoint_set = set(keypoint_indices)
        non_keypoint_indices = [i for i in range(n_points) if i not in keypoint_set]
        
        # Sample remaining points
        n_additional = self.max_points - len(keypoint_indices)
        if n_additional > 0:
            sampled_indices = np.random.choice(
                non_keypoint_indices, 
                size=min(n_additional, len(non_keypoint_indices)), 
                replace=False
            )
            all_indices = list(keypoint_indices) + list(sampled_indices)
        else:
            all_indices = keypoint_indices
        
        # Create mapping from old to new indices
        old_to_new = {old_idx: new_idx for new_idx, old_idx in enumerate(all_indices)}
        new_keypoint_indices = [old_to_new[idx] for idx in keypoint_indices]
        
        return points[all_indices], new_keypoint_indices
    
    def __getitem__(self, idx: int) -> Dict:
        annotation = self.annotations[idx]
        model_id = annotation['model_id']
        
        # Load point cloud
        points = self.load_point_cloud(model_id)
        
        # Extract keypoint coordinates (only use the first num_keypoints)
        keypoint_coords = [kp['xyz'] for kp in annotation['keypoints'][:self.num_keypoints]]
        
        # Find nearest point indices
        keypoint_indices = self.find_nearest_point_indices(points, keypoint_coords)
        
        # Normalize point cloud
        points = self.normalize_point_cloud(points)
        
        # Sample points
        points, keypoint_indices = self.sample_points(points, keypoint_indices)
        
        # Create keypoint labels (one-hot encoded)
        keypoint_labels = np.zeros((self.num_keypoints, len(points)))
        for i, idx in enumerate(keypoint_indices):
            if idx < len(points):  # Safety check
                keypoint_labels[i, idx] = 1
        
        return {
            'points': torch.FloatTensor(points),
            'keypoint_labels': torch.FloatTensor(keypoint_labels),
            'model_id': model_id
        }


class PointNetFeatureExtractor(nn.Module):
    """PointNet feature extractor for point cloud processing."""
    
    def __init__(self, input_dim: int = 3, feature_dim: int = 1024):
        super().__init__()
        
        # Point-wise MLPs
        self.conv1 = nn.Conv1d(input_dim, 64, 1)
        self.conv2 = nn.Conv1d(64, 128, 1)
        self.conv3 = nn.Conv1d(128, feature_dim, 1)
        
        self.bn1 = nn.BatchNorm1d(64)
        self.bn2 = nn.BatchNorm1d(128)
        self.bn3 = nn.BatchNorm1d(feature_dim)
        
        self.dropout = nn.Dropout(0.3)
        
    def forward(self, x):
        # x shape: (batch, 3, num_points)
        batch_size, _, num_points = x.size()
        
        # Point-wise feature extraction
        x = F.relu(self.bn1(self.conv1(x)))
        x = F.relu(self.bn2(self.conv2(x)))
        x = F.relu(self.bn3(self.conv3(x)))
        
        # Global feature
        global_feature = torch.max(x, 2, keepdim=True)[0]
        
        # Concatenate global and local features
        global_feature = global_feature.repeat(1, 1, num_points)
        x = torch.cat([x, global_feature], dim=1)
        
        return x


class KneeKeypointModel(nn.Module):
    """Multi-task model for knee keypoint detection."""
    
    def __init__(self, num_keypoints: int = 5, feature_dim: int = 1024):
        super().__init__()
        
        self.num_keypoints = num_keypoints
        self.feature_extractor = PointNetFeatureExtractor(feature_dim=feature_dim)
        
        # Feature dimension after concatenation
        concat_dim = feature_dim * 2
        
        # Shared layers
        self.shared_conv1 = nn.Conv1d(concat_dim, 512, 1)
        self.shared_conv2 = nn.Conv1d(512, 256, 1)
        self.shared_bn1 = nn.BatchNorm1d(512)
        self.shared_bn2 = nn.BatchNorm1d(256)
        
        # Keypoint-specific heads
        self.keypoint_heads = nn.ModuleList([
            nn.Sequential(
                nn.Conv1d(256, 128, 1),
                nn.BatchNorm1d(128),
                nn.ReLU(),
                nn.Dropout(0.5),
                nn.Conv1d(128, 1, 1)
            ) for _ in range(num_keypoints)
        ])
        
        self.dropout = nn.Dropout(0.3)
        
    def forward(self, x):
        # x shape: (batch, num_points, 3)
        batch_size, num_points, _ = x.size()
        
        # Transpose for conv1d
        x = x.transpose(2, 1)  # (batch, 3, num_points)
        
        # Extract features
        features = self.feature_extractor(x)
        
        # Shared processing
        x = F.relu(self.shared_bn1(self.shared_conv1(features)))
        x = F.relu(self.shared_bn2(self.shared_conv2(x)))
        x = self.dropout(x)
        
        # Keypoint predictions
        keypoint_outputs = []
        for head in self.keypoint_heads:
            output = head(x)  # (batch, 1, num_points)
            output = output.squeeze(1)  # (batch, num_points)
            keypoint_outputs.append(output)
        
        return torch.stack(keypoint_outputs, dim=1)  # (batch, num_keypoints, num_points)


class FocalLoss(nn.Module):
    """Focal loss for handling class imbalance."""
    
    def __init__(self, alpha: float = 1.0, gamma: float = 2.0):
        super().__init__()
        self.alpha = alpha
        self.gamma = gamma
        
    def forward(self, inputs, targets):
        # inputs: (batch, num_keypoints, num_points)
        # targets: (batch, num_keypoints, num_points)
        
        # Apply softmax to get probabilities
        #probs = F.softmax(inputs, dim=-1) #Not used anywhere...
        
        # Compute focal loss
        ce_loss = F.cross_entropy(inputs.view(-1, inputs.size(-1)), 
                                 targets.argmax(dim=-1).view(-1), 
                                 reduction='none')
        
        # Get probabilities of true class
        pt = torch.exp(-ce_loss)
        
        # Compute focal loss
        focal_loss = self.alpha * (1 - pt) ** self.gamma * ce_loss
        
        return focal_loss.mean()




In [72]:
def train_model(model, 
                train_loader, 
                val_loader, 
                num_epochs: int = 100,
                learning_rate: float = 0.001, 
                device: str = 'cuda',
                weight_decay: float = 1e-4,
                focal_alpha: float = 1.0,
                focal_gamma: float = 2.0,
                save_path: str = 'best_knee_keypoint_model.pth',
                checkpoint_path: str = 'checkpoint.pth',
                load_path: str = 'best_knee_keypoint_model.pth',
                resume: bool = False):
    
    model = model.to(device)
    criterion = FocalLoss(alpha=focal_alpha, gamma=focal_gamma)
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.5)

    start_epoch = 0
    best_val_loss = float('inf')

    # Resume logic
    if resume and os.path.exists(load_path):
        print("Resuming Training From Checkpoint Saved at: ", load_path)
        checkpoint = torch.load(load_path, map_location=device)
        model.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
        start_epoch = checkpoint['epoch'] + 1
        best_val_loss = checkpoint['best_val_loss']
        print("Loading Checkpoint with best val loss: ", best_val_loss)
        print(f"[INFO] Resumed training from epoch {start_epoch}")

    for epoch in range(start_epoch, num_epochs):
        model.train()
        train_loss = 0.0

        for batch_idx, batch in enumerate(train_loader):
            points = batch['points'].to(device)
            labels = batch['keypoint_labels'].to(device)

            optimizer.zero_grad()
            outputs = model(points)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            train_loss += loss.item()

            if batch_idx % 10 == 0:
                print(f"Epoch [{epoch}/{num_epochs}] Batch [{batch_idx}] Loss: {loss.item():.4f}")

        avg_train_loss = train_loss / len(train_loader)

        # Validation
        model.eval()
        val_loss = 0.0
        with torch.no_grad():
            for batch in val_loader:
                points = batch['points'].to(device)
                labels = batch['keypoint_labels'].to(device)
                outputs = model(points)
                loss = criterion(outputs, labels)
                val_loss += loss.item()

        avg_val_loss = val_loss / len(val_loader)

        print(f"Epoch [{epoch}/{num_epochs}] Train Loss: {avg_train_loss:.4f} | Val Loss: {avg_val_loss:.4f}")

        # Save best model separately
        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
             
             
            checkpoint = {
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'scheduler_state_dict': scheduler.state_dict(),
                'best_val_loss': best_val_loss
            }
            torch.save(checkpoint, checkpoint_path) # save checkpoint
            print(f"[INFO] Checkpoint saved at epoch {epoch} with val loss {best_val_loss:.4f} at location: {checkpoint_path}")
          
            torch.save(model.state_dict(), save_path) # save model
            print(f"[INFO] Best model saved at epoch {epoch} with val loss {best_val_loss:.4f} at location: {save_path}")

        # Save checkpoint every 20th epoch
        if epoch + 1 % 20 == 0:
            print("Saving period checkpoint at epoch: ", epoch)
            checkpoint = {
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'scheduler_state_dict': scheduler.state_dict(),
                'best_val_loss': best_val_loss
            }
            torch.save(checkpoint, checkpoint_path[:-4] + f"-{epoch}.pth")

        scheduler.step()

    # Load best weights before returning
    model.load_state_dict(torch.load(save_path, map_location=device))
    print(f"[INFO] Training completed. Best model loaded from {save_path}")
    return model



def evaluate_model(model, test_loader, device: str = 'cuda'):
    """Evaluate the model and compute keypoint detection accuracy."""
    model.eval()
    model = model.to(device)
    
    total_distance_error = 0.0
    total_samples = 0
    
    with torch.no_grad():
        for batch in test_loader:
            points = batch['points'].to(device)
            labels = batch['keypoint_labels'].to(device)
            
            outputs = model(points)
            
            # Get predicted keypoint indices
            pred_indices = torch.argmax(outputs, dim=-1)  # (batch, num_keypoints)
            true_indices = torch.argmax(labels, dim=-1)   # (batch, num_keypoints)
            
            # Compute distance error
            batch_size = points.size(0)
            for i in range(batch_size):
                for j in range(model.num_keypoints):
                    pred_point = points[i, pred_indices[i, j]]
                    true_point = points[i, true_indices[i, j]]
                    distance = torch.norm(pred_point - true_point).item()
                    total_distance_error += distance
                    total_samples += 1
    
    avg_distance_error = total_distance_error / total_samples
    print(f'Average keypoint distance error: {avg_distance_error:.4f}')
    
    return avg_distance_error


In [73]:
JSON_FILE = 'knee_annotations/7-2-25/knee_points_4_5_flipped.json'
STL_DIR = 'scans_3/'
MODEL_SAVE_PATH = 'kp-selector-1.pth'

# Initialize dataset
dataset = KneeKeypointDataset(
    json_file=JSON_FILE,
    pointcloud_dir=STL_DIR,
    max_points=MAX_POINTS,
    surface_sampling_method=SURFACE_SAMPLING_METHOD,
    num_keypoints=NUM_KEYPOINTS
)

print(f"Dataset loaded with {len(dataset)} samples")
#print(f"Using {NUM_KEYPOINTS} keypoints: {dataset.keypoint_names}") #Defunct

Dataset loaded with 92 samples


In [74]:
# Split dataset
train_size = int(TRAIN_SPLIT * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])

# Create data loaders
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)

print(f"Training samples: {len(train_dataset)}, Validation samples: {len(val_dataset)}")

Training samples: 73, Validation samples: 19


In [75]:
# Initialize model
model = KneeKeypointModel(num_keypoints=NUM_KEYPOINTS)

# Count parameters
total_params = sum(p.numel() for p in model.parameters())
print(f"Model has {total_params:,} parameters")

# Train model
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Training on device: {device}")

Model has 1,391,618 parameters
Training on device: cuda


In [78]:
# Hyperparameters 
BATCH_SIZE = 16
MAX_POINTS = 8192
NUM_KEYPOINTS = 2  
LEARNING_RATE = 1e-3
NUM_EPOCHS = 200
TRAIN_SPLIT = 0.8
SURFACE_SAMPLING_METHOD = 'uniform'  # uniform/poisson

# Focal loss hyperparam
FOCAL_ALPHA = 1.0
FOCAL_GAMMA = 2.0
WEIGHT_DECAY = 1e-4

In [79]:
trained_model = train_model(
    model=model,
    train_loader=train_loader,
    val_loader=val_loader,
    num_epochs=NUM_EPOCHS,
    learning_rate=LEARNING_RATE,
    device=device,
    weight_decay=WEIGHT_DECAY,
    focal_alpha=FOCAL_ALPHA,
    focal_gamma=FOCAL_GAMMA,
    save_path='kp-selector-1.pth',
    checkpoint_path='Checkpoints/checkpoint-1.pth',
    load_path = 'Checkpoints/checkpoint-1.pth',
    resume=True 
)

# Evaluate model
avg_error = evaluate_model(trained_model, val_loader, device=device)
print(f"Final evaluation - Average keypoint distance error: {avg_error:.4f}")

Resuming Training From Checkpoint Saved at:  Checkpoints/checkpoint-1.pth
Loading Checkpoint with best val loss:  5.858627796173096
[INFO] Resumed training from epoch 34
Epoch [34/200] Batch [0] Loss: 5.0415
Epoch [34/200] Train Loss: 5.1113 | Val Loss: 6.0056
Epoch [35/200] Batch [0] Loss: 4.7832
Epoch [35/200] Train Loss: 4.7589 | Val Loss: 6.1259
Epoch [36/200] Batch [0] Loss: 4.5543
Epoch [36/200] Train Loss: 4.9222 | Val Loss: 6.4147
Epoch [37/200] Batch [0] Loss: 4.7836
Epoch [37/200] Train Loss: 4.5800 | Val Loss: 6.2031
Epoch [38/200] Batch [0] Loss: 4.5194
Epoch [38/200] Train Loss: 4.7765 | Val Loss: 6.2109
Epoch [39/200] Batch [0] Loss: 4.5774
Epoch [39/200] Train Loss: 4.5285 | Val Loss: 6.4493
Epoch [40/200] Batch [0] Loss: 4.9143
Epoch [40/200] Train Loss: 4.7642 | Val Loss: 5.9857
Epoch [41/200] Batch [0] Loss: 4.2307
Epoch [41/200] Train Loss: 4.4736 | Val Loss: 5.9712
Epoch [42/200] Batch [0] Loss: 4.6682
Epoch [42/200] Train Loss: 4.4881 | Val Loss: 6.2509
Epoch [43/2

In [23]:
# Save hyperparameters for reference
hyperparams = {
    'batch_size': BATCH_SIZE,
    'max_points': MAX_POINTS,
    'num_keypoints': NUM_KEYPOINTS,
    'learning_rate': LEARNING_RATE,
    'num_epochs': NUM_EPOCHS,
    'focal_alpha': FOCAL_ALPHA,
    'focal_gamma': FOCAL_GAMMA,
    'weight_decay': WEIGHT_DECAY,
    'surface_sampling_method': SURFACE_SAMPLING_METHOD
}

import json
with open('hyperparameters.json', 'w') as f:
    json.dump(hyperparams, f, indent=2)

print("Training completed! Hyperparameters saved to hyperparameters.json")

Training completed! Hyperparameters saved to hyperparameters.json


In [None]:
# Save current checkpoint every epoch
checkpoint_path = 'checkpoint.pth'

checkpoint = {
    'epoch': epoch,
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'scheduler_state_dict': scheduler.state_dict(),
    'best_val_loss': best_val_loss
}
torch.save(checkpoint, checkpoint_path)


In [None]:
if resume and os.path.exists(checkpoint_path):
    checkpoint = torch.load(checkpoint_path, map_location=device)
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
    start_epoch = checkpoint['epoch'] + 1
    best_val_loss = checkpoint['best_val_loss']
    print(f"Resumed training from epoch {start_epoch}")
else:
    start_epoch = 0
    best_val_loss = float('inf')
