In [4]:
import trimesh
import numpy as np
import glob
import os
import json

In [3]:
def compute_global_min_max(file_list):
    """
    Compute the global minimum and maximum vertex values across all shapes.
    """
    global_min = np.inf
    global_max = -np.inf
    global_avg = 0

    count = 0
    for file in file_list:
        mesh = trimesh.load(file)
        vertices = mesh.vertices

        # Update global min and max
        global_min = min(global_min, vertices.min())
        global_max = max(global_max, vertices.max())
        global_avg += vertices.mean()

        count += 1
        if count % 50 == 0:
            print(f"Processed {count} shapes.")

    global_avg /= len(file_list)
    
    return global_min, global_max, global_avg

def scale_mesh_to_uniform_range(mesh, global_min, global_max, target_min=-0.90, target_max=0.90):
    """
    Scale a mesh such that the vertex coordinates are mapped to a global range [-0.95, 0.95],
    while preserving the shape uniformly across all dimensions.
    """
    # Compute the global range and the target range
    global_range = global_max - global_min
    target_range = target_max - target_min

    # Compute the scaling factor based on the largest dimension range
    scaling_factor = target_range / global_range

    # Scale the vertices uniformly
    vertices = mesh.vertices
    scaled_vertices = (vertices - global_min) * scaling_factor + target_min
    print(f"Scaled vertices to range [{scaled_vertices.min()}, {scaled_vertices.max()}]")

    # Update the mesh vertices
    mesh.vertices = scaled_vertices

    return mesh


In [10]:
# Input files
input_files = glob.glob("[path to mesh folder: mesh_minimal_obj]")  
output_folder = "[path to mesh folder: mesh_minimal_scaled_obj_files]" 

target_min = -0.90
target_max = 0.90

# Create output folder if it doesn't exist
if not os.path.exists(output_folder):
    os.makedirs(output_folder)

In [None]:
# Compute global min and max
global_min, global_max, global_avg = compute_global_min_max(input_files)
print(f"Global min: {global_min}, Global max: {global_max}, Global avg: {global_avg}")

scaling_factor = (target_max - target_min) / (global_max - global_min)
scale_info = {
    "global_min": float(global_min),
    "global_max": float(global_max),
    "target_min": float(target_min),
    "target_max": float(target_max),
    "scaling_factor": float(scaling_factor),
}
scale_info_path = os.path.join(output_folder, "scale_info.json")
with open(scale_info_path, "w") as f:
    json.dump(scale_info, f, indent=2)
print(f"Scaling factor: {scaling_factor}")
print(f"Saved scale info to: {scale_info_path}")


In [None]:
# Process and save scaled meshes
count = 0
for input_file in input_files:
    mesh = trimesh.load(input_file)
    num_vertices = len(mesh.vertices)
    num_faces = len(mesh.faces)
    print(f"{os.path.basename(input_file)}: vertices {num_vertices}, faces {num_faces}")
    scaled_mesh = scale_mesh_to_uniform_range(mesh, global_min, global_max, target_min, target_max)

    # Save the scaled mesh
    output_file = f"{output_folder}{input_file.split('/')[-1]}"
    print(f"Saving scaled mesh to: {output_file}")
    scaled_mesh.export(output_file)
    count += 1
    if count % 50 == 0:
        print(f"Processed {count} meshes.")

print("All meshes scaled and saved successfully.")

In [None]:
# Compute global min and max
input_files_scaled = glob.glob("[path to mesh folder: mesh_minimal_scaled_obj_files]")  
global_min, global_max, global_avg = compute_global_min_max(input_files_scaled)
print(f"Global min: {global_min}, Global max: {global_max}, Global avg: {global_avg}")

In [13]:
# Check watertightness for scaled meshes
scaled_dir = '[path to mesh folder: mesh_minimal_scaled_obj_files]'
mesh_paths = sorted(glob.glob(os.path.join(scaled_dir, '*.obj')))

non_watertight = []
for path in mesh_paths:
    mesh = trimesh.load(path, force='mesh')
    if not mesh.is_watertight:
        non_watertight.append(path)

print(f"Watertight meshes: {len(mesh_paths) - len(non_watertight)}/{len(mesh_paths)}")
if non_watertight:
    print('Non-watertight samples:')
    for path in non_watertight[:10]:
        print(path)



Watertight meshes: 234/234


In [None]:
import pandas as pd
import numpy as np
import torch

# Read the CSV file
csv_path = "[path to labels file]"
df = pd.read_csv(
    csv_path,
    names=["patient_id", "disease_diagnosis", "kl_grade", "gender", "age"],
    skiprows=1,
)

def normalize_patient_id(x):
    if pd.isna(x):
        return None
    if isinstance(x, (int, float)) and float(x).is_integer():
        return str(int(x))
    s = str(x).strip()
    try:
        f = float(s)
        if f.is_integer():
            return str(int(f))
    except ValueError:
        pass
    return s

def to_num(x):
    if pd.isna(x):
        return np.nan
    if isinstance(x, str):
        s = x.strip().lower()
        if s in ("m", "male"):
            return 0.0
        if s in ("f", "female"):
            return 1.0
    return float(x)

# Create labels dictionary
labels = {}
for _, row in df.iterrows():
    patient_id = normalize_patient_id(row["patient_id"])
    if patient_id is None:
        continue
    label = np.array(
        [
            to_num(row["disease_diagnosis"]),  # 0/1
            to_num(row["kl_grade"]),
            to_num(row["gender"]),
            to_num(row["age"]),
        ],
        dtype=np.float32,
    )
    labels[patient_id] = label

print(f"Created labels for {len(labels)} patients")
print("Sample labels:")
for i, (key, value) in enumerate(labels.items()):
    if i < 5:
        print(f"{key}: {value}")

# Save as .pt file
output_path = "[path to labels file]"
torch.save(labels, output_path)
print(f"\nLabels saved to: {output_path}")

# Verify the saved file
loaded_labels = torch.load(output_path)
print(f"\nVerification - loaded {len(loaded_labels)} labels")
print("Sample loaded labels:")
for i, (key, value) in enumerate(loaded_labels.items()):
    if i < 3:
        print(f"{key}: {value}")



In [None]:
test_labels = torch.load(output_path)
print(f"\nVerification - loaded {len(test_labels)} labels")
print("Sample loaded labels:")
for i, (key, value) in enumerate(test_labels.items()):
    if i > 100:
        print(f"{key}: {value}")

In [None]:
import os
import torch

mesh_dir = "[path to mesh folder: mesh_minimal]"  # <-- update if different
labels_path = "[path to mesh folder: mesh_minimal]_scaled_obj_files/labels.pt"

# load labels
labels = torch.load(labels_path)
label_ids = set(labels.keys())

# collect mesh ids
mesh_ids = set()
for name in os.listdir(mesh_dir):
    if name.startswith("."):
        continue
    stem = os.path.splitext(name)[0]  # remove extension if any
    # handle names like 9478504_femur or 9478504_femur.obj
    mesh_id = stem.split("_femur")[0]
    if mesh_id:
        mesh_ids.add(mesh_id)

missing_in_labels = sorted(mesh_ids - label_ids)
extra_in_labels = sorted(label_ids - mesh_ids)

print(f"Mesh files found: {len(mesh_ids)}")
print(f"Labels found: {len(label_ids)}")
print(f"Missing in labels: {len(missing_in_labels)}")
print(f"Extra in labels: {len(extra_in_labels)}")

if missing_in_labels:
    print("Sample missing in labels:", missing_in_labels[:10])
if extra_in_labels:
    print("Sample extra in labels:", extra_in_labels[:10])



In [None]:
import os
import torch

base_dir = "[path to mesh folder: mesh_minimal]"  # update if different
diseased_dir = os.path.join(base_dir, "diseased_tagged")
healthy_dir = os.path.join(base_dir, "healthy_tagged")
labels_path = "[path to mesh folder: mesh_minimal]_scaled_obj_files/labels.pt"

def extract_tagged_ids(folder, suffix):
    ids = set()
    for name in os.listdir(folder):
        if name.startswith("."):
            continue
        stem = os.path.splitext(name)[0]
        if not stem.endswith(suffix):
            continue
        stem = stem[:-len(suffix)]  # remove _less / _more
        # handle names like 9478504_femur
        if "_femur" in stem:
            stem = stem.split("_femur")[0]
        if stem:
            ids.add(stem)
    return ids

less_ids = extract_tagged_ids(diseased_dir, "_less")
more_ids = extract_tagged_ids(healthy_dir, "_more")
tagged_ids = less_ids | more_ids

print(f"Tagged (_less) ids: {len(less_ids)}")
print(f"Tagged (_more) ids: {len(more_ids)}")
print(f"Total tagged ids: {len(tagged_ids)}")

labels = torch.load(labels_path)
before = len(labels)

# remove tagged ids
labels = {k: v for k, v in labels.items() if k not in tagged_ids}
after = len(labels)

torch.save(labels, labels_path)

print(f"Labels before: {before}")
print(f"Labels after: {after}")
print(f"Removed from labels: {before - after}")

# tests
still_present = [pid for pid in tagged_ids if pid in labels]
assert not still_present, f"Tagged IDs still present in labels: {still_present[:10]}"
print("Test OK: no tagged patient IDs remain in labels.pt")



In [None]:
import os
import math
import torch
import numpy as np

base_dir = "[path to mesh folder: mesh_minimal]"  # update if needed
diseased_dir = os.path.join(base_dir, "diseased_tagged")
healthy_dir = os.path.join(base_dir, "healthy_tagged")
labels_path = "[path to labels file]"

def extract_tagged_ids(folder, suffix):
    ids = set()
    for name in os.listdir(folder):
        if name.startswith("."):
            continue
        stem = os.path.splitext(name)[0]
        if not stem.endswith(suffix):
            continue
        stem = stem[:-len(suffix)]  # remove _less / _more
        if "_femur" in stem:
            stem = stem.split("_femur")[0]
        if stem:
            ids.add(stem)
    return ids

# build tagged set
less_ids = extract_tagged_ids(diseased_dir, "_less")
more_ids = extract_tagged_ids(healthy_dir, "_more")
tagged_ids = less_ids | more_ids

print(f"Tagged (_less) ids: {len(less_ids)}")
print(f"Tagged (_more) ids: {len(more_ids)}")
print(f"Total tagged ids: {len(tagged_ids)}")

labels = torch.load(labels_path)
label_ids = set(labels.keys())

# explicit checks BEFORE removal
intersect_before = label_ids & tagged_ids
missing_in_labels = tagged_ids - label_ids
print(f"Tagged ids present in labels (before): {len(intersect_before)}")
print(f"Tagged ids missing in labels (before): {len(missing_in_labels)}")

# remove tagged ids
labels = {k: v for k, v in labels.items() if k not in tagged_ids}
torch.save(labels, labels_path)

# explicit checks AFTER removal
label_ids_after = set(labels.keys())
intersect_after = label_ids_after & tagged_ids
print(f"Tagged ids present in labels (after): {len(intersect_after)}")

# tests
assert len(intersect_after) == 0, f"Tagged IDs still in labels: {list(intersect_after)[:10]}"
assert len(intersect_before) == (len(label_ids) - len(label_ids_after)), "Removed count mismatch"
print("Tests OK: tagged IDs removed and counts consistent.")

# count healthy/diseased in labels (disease_diagnosis is first element)
healthy = 0
diseased = 0
unknown = 0
for v in labels.values():
    # v can be numpy array or list/torch tensor
    if isinstance(v, torch.Tensor):
        val = float(v[0].item())
    else:
        val = float(v[0])
    if math.isnan(val):
        unknown += 1
    elif val == 0:
        healthy += 1
    elif val == 1:
        diseased += 1
    else:
        unknown += 1

print(f"Labels total: {len(labels)}")
print(f"Healthy (0): {healthy}")
print(f"Diseased (1): {diseased}")
print(f"Unknown/other: {unknown}")



In [None]:
import os
import torch
import math

base_dir = "[path to mesh folder: mesh_minimal]"  # update if needed
diseased_dir = os.path.join(base_dir, "diseased_tagged")
healthy_dir = os.path.join(base_dir, "healthy_tagged")
labels_path = "[path to mesh folder: mesh_minimal]_scaled_obj_files/labels.pt"

def extract_ids_from_folder(folder):
    ids = set()
    for name in os.listdir(folder):
        if name.startswith("."):
            continue
        stem = os.path.splitext(name)[0]
        if stem.endswith("_less"):
            stem = stem[:-5]
        if stem.endswith("_more"):
            stem = stem[:-5]
        if "_femur" in stem:
            stem = stem.split("_femur")[0]
        if stem:
            ids.add(stem)
    return ids

healthy_folder_ids = extract_ids_from_folder(healthy_dir)
diseased_folder_ids = extract_ids_from_folder(diseased_dir)

labels = torch.load(labels_path)

healthy_ids = set()
diseased_ids = set()
unknown_ids = set()

for k, v in labels.items():
    val = float(v[0]) if not isinstance(v, torch.Tensor) else float(v[0].item())
    if math.isnan(val):
        unknown_ids.add(k)
    elif val == 0:
        healthy_ids.add(k)
    elif val == 1:
        diseased_ids.add(k)
    else:
        unknown_ids.add(k)

# Checks
missing_healthy = healthy_ids - healthy_folder_ids
missing_diseased = diseased_ids - diseased_folder_ids
wrong_healthy_side = healthy_ids & diseased_folder_ids
wrong_diseased_side = diseased_ids & healthy_folder_ids

print(f"Labels total: {len(labels)}")
print(f"Healthy labels: {len(healthy_ids)}")
print(f"Diseased labels: {len(diseased_ids)}")
print(f"Unknown labels: {len(unknown_ids)}")

print(f"Healthy IDs missing in healthy_tagged: {len(missing_healthy)}")
print(f"Diseased IDs missing in diseased_tagged: {len(missing_diseased)}")
print(f"Healthy IDs found in diseased_tagged (wrong side): {len(wrong_healthy_side)}")
print(f"Diseased IDs found in healthy_tagged (wrong side): {len(wrong_diseased_side)}")

if missing_healthy:
    print("Sample missing healthy:", sorted(list(missing_healthy))[:10])
if missing_diseased:
    print("Sample missing diseased:", sorted(list(missing_diseased))[:10])
if wrong_healthy_side:
    print("Sample wrong-side healthy:", sorted(list(wrong_healthy_side))[:10])
if wrong_diseased_side:
    print("Sample wrong-side diseased:", sorted(list(wrong_diseased_side))[:10])

# strict tests (uncomment if you want hard failure)
# assert not missing_healthy, "Some healthy label IDs are not in healthy_tagged"
# assert not missing_diseased, "Some diseased label IDs are not in diseased_tagged"
# assert not wrong_healthy_side, "Some healthy IDs appear in diseased_tagged"
# assert not wrong_diseased_side, "Some diseased IDs appear in healthy_tagged"

