In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import json
from torch.utils.data import Dataset, DataLoader
from sklearn.neighbors import NearestNeighbors
import os
from typing import List, Tuple, Dict
import open3d as o3d

Jupyter environment detected. Enabling Open3D WebVisualizer.
[Open3D INFO] WebRTC GUI backend enabled.
[Open3D INFO] WebRTCWindowSystem: HTTP handshake server disabled.


In [17]:
class KneeKeypointDataset(Dataset):
    """Dataset for loading knee point clouds and keypoint annotations."""
    
    def __init__(self, json_file: str, pointcloud_dir: str, max_points: int = 8192, 
                 surface_sampling_method: str = 'uniform'):
        """
        Args:
            json_file: Path to JSON file with keypoint annotations
            pointcloud_dir: Directory containing STL mesh files
            max_points: Maximum number of points to sample from each mesh
            surface_sampling_method: 'uniform' or 'poisson' for surface sampling
        """
        with open(json_file, 'r') as f:
            self.annotations = json.load(f)
        
        self.pointcloud_dir = pointcloud_dir
        self.max_points = max_points
        self.sampling_method = surface_sampling_method
        
        # Expected 5 keypoints: front, left, right, thigh_center, shin_center
        self.keypoint_names = ['front', 'left', 'right', 'thigh_center', 'shin_center']
        self.num_keypoints = len(self.keypoint_names)
        
    def __len__(self):
        return len(self.annotations)
    
    def load_point_cloud(self, model_id: str) -> np.ndarray:
        """Load mesh from STL file and sample points from surface."""
        mesh_path = os.path.join(self.pointcloud_dir, f"{model_id}.stl")
        
        # Load mesh from STL file
        mesh = o3d.io.read_triangle_mesh(mesh_path)
        
        # Check if mesh is valid
        if len(mesh.vertices) == 0:
            raise ValueError(f"Failed to load mesh from {mesh_path}")
        
        # Sample points from mesh surface
        # Use more points than needed for better coverage
        num_sample_points = self.max_points * 2  # Sample more than we need
        
        # Method 1: Uniform sampling
        pcd = mesh.sample_points_uniformly(number_of_points=num_sample_points)
        
        # Alternative method 2: Poisson disk sampling (more even distribution)
        # pcd = mesh.sample_points_poisson_disk(number_of_points=num_sample_points)
        
        points = np.asarray(pcd.points)
        
        # If we didn't get enough points, use fewer
        if len(points) < self.max_points:
            print(f"Warning: Only sampled {len(points)} points from mesh {model_id}")
        
        return points
    
    def find_nearest_point_indices(self, points: np.ndarray, keypoint_coords: List[List[float]]) -> List[int]:
        """Find nearest point indices for each keypoint coordinate."""
        # Use KNN to find nearest points
        nbrs = NearestNeighbors(n_neighbors=1, algorithm='ball_tree').fit(points)
        
        keypoint_indices = []
        for coord in keypoint_coords:
            distances, indices = nbrs.kneighbors([coord])
            keypoint_indices.append(indices[0][0])
        
        return keypoint_indices
    
    def normalize_point_cloud(self, points: np.ndarray) -> np.ndarray:
        """Normalize point cloud to unit sphere."""
        # Center the point cloud
        centroid = np.mean(points, axis=0)
        points = points - centroid
        
        # Scale to unit sphere
        max_distance = np.max(np.linalg.norm(points, axis=1))
        points = points / max_distance
        
        return points
    
    def sample_points(self, points: np.ndarray, keypoint_indices: List[int]) -> Tuple[np.ndarray, List[int]]:
        """Sample points while preserving keypoint indices."""
        n_points = len(points)
        
        if n_points <= self.max_points:
            # Pad with zeros if needed
            padded_points = np.zeros((self.max_points, 3))
            padded_points[:n_points] = points
            return padded_points, keypoint_indices
        
        # Always include keypoint indices in sampling
        keypoint_set = set(keypoint_indices)
        non_keypoint_indices = [i for i in range(n_points) if i not in keypoint_set]
        
        # Sample remaining points
        n_additional = self.max_points - len(keypoint_indices)
        if n_additional > 0:
            sampled_indices = np.random.choice(
                non_keypoint_indices, 
                size=min(n_additional, len(non_keypoint_indices)), 
                replace=False
            )
            all_indices = list(keypoint_indices) + list(sampled_indices)
        else:
            all_indices = keypoint_indices
        
        # Create mapping from old to new indices
        old_to_new = {old_idx: new_idx for new_idx, old_idx in enumerate(all_indices)}
        new_keypoint_indices = [old_to_new[idx] for idx in keypoint_indices]
        
        return points[all_indices], new_keypoint_indices
    
    def __getitem__(self, idx: int) -> Dict:
        annotation = self.annotations[idx]
        model_id = annotation['model_id']
        
        # Load point cloud
        points = self.load_point_cloud(model_id)
        
        # Extract keypoint coordinates
        keypoint_coords = [kp['xyz'] for kp in annotation['keypoints']]
        
        # Find nearest point indices
        keypoint_indices = self.find_nearest_point_indices(points, keypoint_coords)
        
        # Normalize point cloud
        points = self.normalize_point_cloud(points)
        
        # Sample points
        points, keypoint_indices = self.sample_points(points, keypoint_indices)
        
        # Create keypoint labels (one-hot encoded)
        keypoint_labels = np.zeros((self.num_keypoints, len(points)))
        for i, idx in enumerate(keypoint_indices):
            if idx < len(points):  # Safety check
                keypoint_labels[i, idx] = 1
        
        return {
            'points': torch.FloatTensor(points),
            'keypoint_labels': torch.FloatTensor(keypoint_labels),
            'model_id': model_id
        }


class PointNetFeatureExtractor(nn.Module):
    """PointNet feature extractor for point cloud processing."""
    
    def __init__(self, input_dim: int = 3, feature_dim: int = 1024):
        super().__init__()
        
        # Point-wise MLPs
        self.conv1 = nn.Conv1d(input_dim, 64, 1)
        self.conv2 = nn.Conv1d(64, 128, 1)
        self.conv3 = nn.Conv1d(128, feature_dim, 1)
        
        self.bn1 = nn.BatchNorm1d(64)
        self.bn2 = nn.BatchNorm1d(128)
        self.bn3 = nn.BatchNorm1d(feature_dim)
        
        self.dropout = nn.Dropout(0.3)
        
    def forward(self, x):
        # x shape: (batch, 3, num_points)
        batch_size, _, num_points = x.size()
        
        # Point-wise feature extraction
        x = F.relu(self.bn1(self.conv1(x)))
        x = F.relu(self.bn2(self.conv2(x)))
        x = F.relu(self.bn3(self.conv3(x)))
        
        # Global feature
        global_feature = torch.max(x, 2, keepdim=True)[0]
        
        # Concatenate global and local features
        global_feature = global_feature.repeat(1, 1, num_points)
        x = torch.cat([x, global_feature], dim=1)
        
        return x


class KneeKeypointModel(nn.Module):
    """Multi-task model for knee keypoint detection."""
    
    def __init__(self, num_keypoints: int = 5, feature_dim: int = 1024):
        super().__init__()
        
        self.num_keypoints = num_keypoints
        self.feature_extractor = PointNetFeatureExtractor(feature_dim=feature_dim)
        
        # Feature dimension after concatenation
        concat_dim = feature_dim * 2
        
        # Shared layers
        self.shared_conv1 = nn.Conv1d(concat_dim, 512, 1)
        self.shared_conv2 = nn.Conv1d(512, 256, 1)
        self.shared_bn1 = nn.BatchNorm1d(512)
        self.shared_bn2 = nn.BatchNorm1d(256)
        
        # Keypoint-specific heads
        self.keypoint_heads = nn.ModuleList([
            nn.Sequential(
                nn.Conv1d(256, 128, 1),
                nn.BatchNorm1d(128),
                nn.ReLU(),
                nn.Dropout(0.5),
                nn.Conv1d(128, 1, 1)
            ) for _ in range(num_keypoints)
        ])
        
        self.dropout = nn.Dropout(0.3)
        
    def forward(self, x):
        # x shape: (batch, num_points, 3)
        batch_size, num_points, _ = x.size()
        
        # Transpose for conv1d
        x = x.transpose(2, 1)  # (batch, 3, num_points)
        
        # Extract features
        features = self.feature_extractor(x)
        
        # Shared processing
        x = F.relu(self.shared_bn1(self.shared_conv1(features)))
        x = F.relu(self.shared_bn2(self.shared_conv2(x)))
        x = self.dropout(x)
        
        # Keypoint predictions
        keypoint_outputs = []
        for head in self.keypoint_heads:
            output = head(x)  # (batch, 1, num_points)
            output = output.squeeze(1)  # (batch, num_points)
            keypoint_outputs.append(output)
        
        return torch.stack(keypoint_outputs, dim=1)  # (batch, num_keypoints, num_points)


class FocalLoss(nn.Module):
    """Focal loss for handling class imbalance."""
    
    def __init__(self, alpha: float = 1.0, gamma: float = 2.0):
        super().__init__()
        self.alpha = alpha
        self.gamma = gamma
        
    def forward(self, inputs, targets):
        # inputs: (batch, num_keypoints, num_points)
        # targets: (batch, num_keypoints, num_points)
        
        # Apply softmax to get probabilities
        probs = F.softmax(inputs, dim=-1)
        
        # Compute focal loss
        ce_loss = F.cross_entropy(inputs.view(-1, inputs.size(-1)), 
                                 targets.argmax(dim=-1).view(-1), 
                                 reduction='none')
        
        # Get probabilities of true class
        pt = torch.exp(-ce_loss)
        
        # Compute focal loss
        focal_loss = self.alpha * (1 - pt) ** self.gamma * ce_loss
        
        return focal_loss.mean()




In [18]:
def train_model(model, train_loader, val_loader, num_epochs: int = 100, 
                learning_rate: float = 0.001, device: str = 'cuda'):
    """Training loop for the keypoint detection model."""
    
    model = model.to(device)
    
    # Loss function and optimizer
    criterion = FocalLoss(alpha=1.0, gamma=2.0)
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=1e-4)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.5)
    
    best_val_loss = float('inf')
    
    for epoch in range(num_epochs):
        # Training phase
        model.train()
        train_loss = 0.0
        
        for batch_idx, batch in enumerate(train_loader):
            points = batch['points'].to(device)
            labels = batch['keypoint_labels'].to(device)
            
            optimizer.zero_grad()
            
            # Forward pass
            outputs = model(points)
            loss = criterion(outputs, labels)
            
            # Backward pass
            loss.backward()
            optimizer.step()
            
            train_loss += loss.item()
            
            if batch_idx % 10 == 0:
                print(f'Epoch {epoch}, Batch {batch_idx}, Loss: {loss.item():.4f}')
        
        # Validation phase
        model.eval()
        val_loss = 0.0
        
        with torch.no_grad():
            for batch in val_loader:
                points = batch['points'].to(device)
                labels = batch['keypoint_labels'].to(device)
                
                outputs = model(points)
                loss = criterion(outputs, labels)
                val_loss += loss.item()
        
        avg_train_loss = train_loss / len(train_loader)
        avg_val_loss = val_loss / len(val_loader)
        
        print(f'Epoch {epoch}: Train Loss: {avg_train_loss:.4f}, Val Loss: {avg_val_loss:.4f}')
        
        # Save best model
        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            torch.save(model.state_dict(), 'best_knee_keypoint_model.pth')
        
        scheduler.step()
    
    return model


def evaluate_model(model, test_loader, device: str = 'cuda'):
    """Evaluate the model and compute keypoint detection accuracy."""
    model.eval()
    model = model.to(device)
    
    total_distance_error = 0.0
    total_samples = 0
    
    with torch.no_grad():
        for batch in test_loader:
            points = batch['points'].to(device)
            labels = batch['keypoint_labels'].to(device)
            
            outputs = model(points)
            
            # Get predicted keypoint indices
            pred_indices = torch.argmax(outputs, dim=-1)  # (batch, num_keypoints)
            true_indices = torch.argmax(labels, dim=-1)   # (batch, num_keypoints)
            
            # Compute distance error
            batch_size = points.size(0)
            for i in range(batch_size):
                for j in range(model.num_keypoints):
                    pred_point = points[i, pred_indices[i, j]]
                    true_point = points[i, true_indices[i, j]]
                    distance = torch.norm(pred_point - true_point).item()
                    total_distance_error += distance
                    total_samples += 1
    
    avg_distance_error = total_distance_error / total_samples
    print(f'Average keypoint distance error: {avg_distance_error:.4f}')
    
    return avg_distance_error


In [19]:
# Initialize dataset
dataset = KneeKeypointDataset(
    json_file='knee_annotations/7-2-25/knee_points_4_5_flipped.json',
    pointcloud_dir='scans_3/',
    max_points=8192
)

In [20]:
# Split dataset
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])

# Create data loaders
train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=8, shuffle=False)

In [22]:
# Initialize model
model = KneeKeypointModel(num_keypoints=5)

# Train model
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Using device: {device}")

Using device: cuda


In [23]:
trained_model = train_model(model, train_loader, val_loader, device=device)

# Evaluate model
evaluate_model(trained_model, val_loader, device=device)

Epoch 0, Batch 0, Loss: 9.0211
Epoch 0: Train Loss: 8.8080, Val Loss: 8.9963
Epoch 1, Batch 0, Loss: 8.7484
Epoch 1: Train Loss: 8.5513, Val Loss: 8.9119
Epoch 2, Batch 0, Loss: 8.3399
Epoch 2: Train Loss: 8.3537, Val Loss: 8.7145
Epoch 3, Batch 0, Loss: 8.0789
Epoch 3: Train Loss: 8.2572, Val Loss: 8.6199
Epoch 4, Batch 0, Loss: 8.0896
Epoch 4: Train Loss: 8.0814, Val Loss: 8.1946
Epoch 5, Batch 0, Loss: 7.5723
Epoch 5: Train Loss: 7.8103, Val Loss: 8.0154
Epoch 6, Batch 0, Loss: 7.5653
Epoch 6: Train Loss: 7.8511, Val Loss: 7.9865
Epoch 7, Batch 0, Loss: 7.4185
Epoch 7: Train Loss: 8.0430, Val Loss: 7.9969
Epoch 8, Batch 0, Loss: 7.3675
Epoch 8: Train Loss: 7.5014, Val Loss: 8.0364
Epoch 9, Batch 0, Loss: 7.4411
Epoch 9: Train Loss: 7.6627, Val Loss: 7.9548
Epoch 10, Batch 0, Loss: 7.0541
Epoch 10: Train Loss: 7.4648, Val Loss: 7.6892
Epoch 11, Batch 0, Loss: 6.9380
Epoch 11: Train Loss: 7.5327, Val Loss: 7.9900
Epoch 12, Batch 0, Loss: 7.3837
Epoch 12: Train Loss: 7.1681, Val Loss: 

KeyboardInterrupt: 