In [7]:
import json
import numpy as np
import open3d as o3d
from pathlib import Path
from torch.utils.data import DataLoader, Dataset

### Convert GC json to projected json

In [23]:
import json
import numpy as np
import open3d as o3d
from pathlib import Path

In [2]:
# === CONFIG ===
annotation_file = "knee_annotations/7-2-25/knee_points_4_5_flipped.json"
mesh_dir = Path("scans_3")  # adjust to your path
output_file = "knee_annotations/7-8-25/knee_projected_4_5.json"
num_points = 50000

In [None]:
# === LOAD annotation ===
with open(annotation_file) as f:
    data = json.load(f)

output_list = []

for entry in data:
    model_id = entry["model_id"]
    keypoints = entry["keypoints"]

    print(f"Processing {model_id}...")

    # === Load mesh ===
    mesh_file = mesh_dir / f"{model_id}.stl"
    mesh = o3d.io.read_triangle_mesh(str(mesh_file))
    if not mesh.has_triangles():
        raise ValueError(f"Mesh not found or invalid: {mesh_file}")

    # === Sample surface ===
    pcd = mesh.sample_points_uniformly(number_of_points=num_points)
    pcd_tree = o3d.geometry.KDTreeFlann(pcd)

    # === Find indices ===
    kp_indices = []
    for kp in keypoints:
        xyz = np.array(kp["xyz"])
        [_, idx, _] = pcd_tree.search_knn_vector_3d(xyz, 1)
        kp_indices.append({"index": int(idx[0])})

    output_list.append({
        "model_id": model_id,
        "keypoints": kp_indices
    })


In [5]:
# === Save ===
with open(output_file, "w") as f:
    json.dump(output_list, f, indent=2)

print(f"Saved to {output_file}")

Saved to knee_annotations/7-8-25/knee_projected_4_5.json


### Define New Dataset Class

In [2]:
import os
import json
import numpy as np
import torch
from torch.utils.data import Dataset,DataLoader
import trimesh
from sklearn.neighbors import KDTree  # Faster than Open3D for small clouds

NUM_POINTS = 50000  # Same as before
NUM_KEYPOINTS = 2   # For your 2-point determiner

class ShapeNetKeypointDatasetPointSelection(Dataset):
    def __init__(self, mesh_dir, annotation_json):
        self.mesh_dir = mesh_dir
        self.samples = []

        with open(annotation_json) as f:
            annotations = json.load(f)

        count_valid = 0
        count_total = 0

        for entry in annotations:
            model_id = entry['model_id']
            keypoints = np.array([kp['xyz'] for kp in entry['keypoints']], dtype=np.float32)
            count_total += 1
            if keypoints.shape[0] != NUM_KEYPOINTS:
                continue  # Filter to correct count
            self.samples.append((model_id, keypoints))
            count_valid += 1

        print(f"Total meshes: {count_total}, Valid: {count_valid}")

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        model_id, keypoints = self.samples[idx]

        mesh_path = os.path.join(self.mesh_dir, model_id + ".stl")

        try:
            mesh = trimesh.load(mesh_path, force='mesh')
            if mesh.is_empty or len(mesh.faces) == 0:
                raise ValueError("Empty mesh")

            points, _ = trimesh.sample.sample_surface(mesh, NUM_POINTS)
            if points.shape[0] < NUM_POINTS:
                pad_size = NUM_POINTS - points.shape[0]
                pad = np.repeat(points[0:1, :], pad_size, axis=0)
                points = np.vstack((points, pad))

            # Normalize points and GT keypoints together
            centroid = np.mean(points, axis=0)
            scale = np.max(np.linalg.norm(points - centroid, axis=1))
            points = (points - centroid) / scale
            keypoints = (keypoints - centroid) / scale

            # Build KD-tree to find nearest point indices
            tree = KDTree(points)
            gt_indices = []
            for kp in keypoints:
                dist, idxs = tree.query([kp], k=1)
                gt_indices.append(int(idxs[0][0]))

            gt_indices = np.array(gt_indices, dtype=np.int64)

        except Exception as e:
            print(f"Error loading mesh {model_id}: {e}")
            points = np.zeros((NUM_POINTS, 3), dtype=np.float32)
            gt_indices = np.zeros(NUM_KEYPOINTS, dtype=np.int64)

        return torch.from_numpy(points).float(), torch.from_numpy(gt_indices).long()


In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from pointnet2_utils import PointNetSetAbstraction

class PointSelectorModel(nn.Module):
    def __init__(self, num_keypoints, normal_channel=True):
        super(PointSelectorModel, self).__init__()
        in_channel = 6 if normal_channel else 3
        self.normal_channel = normal_channel
        self.num_keypoints = num_keypoints

        self.sa1 = PointNetSetAbstraction(
            npoint=512, radius=0.2, nsample=32,
            in_channel=in_channel,
            mlp=[64, 128, 256],
            group_all=False
        )
        self.sa2 = PointNetSetAbstraction(
            npoint=128, radius=0.4, nsample=64,
            in_channel=256 + 3,
            mlp=[256, 512, 1024],
            group_all=False
        )
        self.sa3 = PointNetSetAbstraction(
            npoint=None, radius=None, nsample=None,
            in_channel=1024 + 3,
            mlp=[512, 1024, 2048],
            group_all=True
        )

        # Classifier head: per point scores
        # For point selection, you want to keep per-point features
        self.fp1 = nn.Conv1d(256, 128, 1)
        self.bn1 = nn.BatchNorm1d(128)

        self.fp2 = nn.Conv1d(128, num_keypoints, 1)  # output logits per point

    def forward(self, xyz):
        B, _, N = xyz.shape

        if self.normal_channel:
            norm = xyz[:, 3:, :]
            xyz = xyz[:, :3, :]
        else:
            norm = None

        # PointNet++ backbone
        l1_xyz, l1_points = self.sa1(xyz, norm)  # (B, 3, 512), (B, C, 512)
        l2_xyz, l2_points = self.sa2(l1_xyz, l1_points)  # (B, 3, 128), (B, C, 128)
        # Feature Propagation: interpolate back to original points
        # For point selection, you might want to do upsampling or skip connections
        # Here's a simple skip:
        points = l1_points  # (B, C, N)
        x = F.relu(self.bn1(self.fp1(points)))
        logits = self.fp2(x)  # (B, num_keypoints, N)

        return logits  # shape: [B, num_keypoints, N]


In [4]:
class PointSelectionLoss(nn.Module):
    def __init__(self):
        super(PointSelectionLoss, self).__init__()
        self.ce = nn.CrossEntropyLoss()

    def forward(self, logits, gt_indices):
        """
        logits: (B, K, N)
        gt_indices: (B, K)
        """
        total_loss = 0.0
        B, K, N = logits.shape

        for k in range(K):
            pred_k = logits[:, k, :]  # (B, N)
            target_k = gt_indices[:, k]  # (B,)
            total_loss += self.ce(pred_k, target_k)

        total_loss /= K  # average over keypoints
        return total_loss


### T-Loop

In [5]:
NUM_POINTS = 8192 # Minimum 512 for input, ~1000 for neighborhood density; 2^14 = 16384
NUM_KEYPOINTS = 2 
BATCH_SIZE = 64 # Bigger is better (always)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

total_epoch = 10


Using device: cuda


In [8]:
import torch.nn.functional as F

# === Replace dataset ===
dataset = ShapeNetKeypointDatasetPointSelection(
    mesh_dir="scans_3",
    annotation_json="7-8-25/knee_points_4_5_flipped.json"
)

# === Keep same split ===
total_size = len(dataset)
val_size = int(0.2 * total_size)
train_size = total_size - val_size

train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=8)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=8)

# === New model ===
model = PointSelectorModel(num_keypoints=NUM_KEYPOINTS, normal_channel=False).to(device)

# === Optimizer same ===
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-4)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=5)

# === New loss ===
criterion = PointSelectionLoss()

start_epoch = 0
avg_val_loss = 1

# === Resume checkpoint logic stays the same ===
# (make sure your checkpoint saves model_state_dict, optimizer_state_dict, etc.)

for epoch in range(total_epoch):
    model.train()
    total_train_loss = 0

    for batch_idx, (pts, gt_indices) in enumerate(train_loader):
        pts, gt_indices = pts.to(device), gt_indices.to(device)
        pts = pts.permute(0, 2, 1)  # [B, 3, N]

        logits = model(pts)  # logits: (B, K, N)

        loss = criterion(logits, gt_indices)  # CE loss

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_train_loss += loss.item()

        if batch_idx == 0:
            print(f"Logits shape: {logits.shape}")
            print(f"GT indices shape: {gt_indices.shape}")

    avg_train_loss = total_train_loss / len(train_loader)
    train_losses.append(avg_train_loss)

    # === Validation ===
    model.eval()
    total_val_loss = 0
    with torch.no_grad():
        for pts, gt_indices in val_loader:
            pts, gt_indices = pts.to(device), gt_indices.to(device)
            pts = pts.permute(0, 2, 1)

            logits = model(pts)
            val_loss = criterion(logits, gt_indices)
            total_val_loss += val_loss.item()

    avg_val_loss_new = total_val_loss / len(val_loader)
    val_losses.append(avg_val_loss_new)

    scheduler.step(avg_val_loss)

    # === Save checkpoints ===
    if (epoch + start_epoch + 1) % 10 == 0:
        checkpoint_path = f"New_Architecture/train_checkpoints/2_points_v2_epoch_{start_epoch + epoch + 1}.pth"
        torch.save({
            'epoch': start_epoch + epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'train_losses': train_losses,
            'val_losses': val_losses
        }, checkpoint_path)
        print(f"Saved checkpoint: {checkpoint_path}")

    if epoch == 0:
        avg_val_loss = avg_val_loss_new

    print(f"Epoch {epoch + start_epoch + 1}/{start_epoch + total_epoch} | "
          f"Train Loss: {avg_train_loss:.6f} | Val Loss: {(avg_val_loss + avg_val_loss_new) / 2:.6f}")

    avg_val_loss = avg_val_loss_new

print("Training completed!")


Total meshes: 92, Valid: 92


NameError: name 'random_split' is not defined