In [6]:
import os
import json
import numpy as np
import trimesh
import shutil
from tqdm import tqdm

In [7]:
# Config
ORIGINAL_JSON = "knee_annotations/7-2-25/knee_points_4_5_flipped.json"
MESH_DIR = "scans_3"
OUTPUT_JSON = "knee_annotations/7-2-25/knee_points_4_5_aug.json"
OUTPUT_MESH_DIR = "scans_3_aug"
NUM_AUGMENTATIONS = 6  # How many times to augment each sample

In [None]:
# Create output directory for augmented meshes
#if not os.path.exists("OUTPUT_MESH_DIR"):
shutil.copytree(MESH_DIR, OUTPUT_MESH_DIR) #Copy unaugmented scans over (idk if this works if destination dir exists...)

#os.makedirs(OUTPUT_MESH_DIR, exist_ok=True)


# Load original dataset
with open(ORIGINAL_JSON, 'r') as f:
    original_data = json.load(f)

In [9]:
augmented_data = []

for sample in tqdm(original_data, desc="Augmenting dataset with meshes"):
    model_id = sample["model_id"]
    #mesh_path = os.path.join(MESH_DIR, model_id + ".ply")
    mesh_path = os.path.join(MESH_DIR, model_id + ".stl")

    if not os.path.exists(mesh_path):
        print(f"Mesh {mesh_path} not found, skipping.")
        continue

    mesh = trimesh.load(mesh_path, force='mesh')
    if mesh.is_empty or len(mesh.faces) == 0:
        print(f"Mesh {model_id} is empty, skipping.")
        continue

    vertices = mesh.vertices
    faces = mesh.faces

    keypoints = np.array([kp["xyz"] for kp in sample["keypoints"]])

    for aug_id in range(NUM_AUGMENTATIONS):
        # 1. Random Z rotation
        theta = np.random.uniform(0, 2*np.pi)
        cos_theta, sin_theta = np.cos(theta), np.sin(theta)
        Rz = np.array([[cos_theta, -sin_theta, 0],
                       [sin_theta,  cos_theta, 0],
                       [0, 0, 1]])
        vertices_aug = vertices @ Rz.T
        keypoints_aug = keypoints @ Rz.T

        # 2. Uniform scaling
        scale = np.random.uniform(0.9, 1.1)
        vertices_aug *= scale
        keypoints_aug *= scale

        # 3. XY translation
        shift_xy = np.random.uniform(-0.05, 0.05, size=(1, 2))
        vertices_aug[:, :2] += shift_xy
        keypoints_aug[:, :2] += shift_xy

        # Save augmented mesh
        augmented_mesh = trimesh.Trimesh(vertices=vertices_aug, faces=faces, process=False)
        #augmented_mesh_filename = f"{model_id}_aug_{aug_id}.ply"
        augmented_mesh_filename = f"{model_id}_aug_{aug_id}.stl"
        augmented_mesh.export(os.path.join(OUTPUT_MESH_DIR, augmented_mesh_filename))

        # Save augmented keypoints
        new_sample = {
            "model_id": f"{model_id}_aug_{aug_id}",
            "keypoints": [
                #{"semantic_id": kp["semantic_id"], "xyz": xyz.tolist()}
                {"xyz": xyz.tolist()}
                for kp, xyz in zip(sample["keypoints"], keypoints_aug)
            ]
        }
        augmented_data.append(new_sample)

# Merge original + augmented samples
full_dataset = original_data + augmented_data



Augmenting dataset with meshes: 100%|██████████| 92/92 [01:42<00:00,  1.11s/it]


In [10]:
# Save to new JSON file
with open(OUTPUT_JSON, "w") as f:
    json.dump(full_dataset, f, indent=2)

print(f"Augmented dataset with meshes saved to {OUTPUT_JSON}. Total samples: {len(full_dataset)}")


Augmented dataset with meshes saved to knee_annotations/7-2-25/knee_points_4_5_aug.json. Total samples: 644
