In [None]:
from preprocessing.slice_generator import generate_z_slices

slices = generate_z_slices(
    mesh_path="data/meshes/car_0001.stl",
    num_slices=32,
    res=720,
    save_path="outputs/slices/car_0001.npy"
)


In [None]:
# SPLIT BETWEEN TRAIN | TEST | VAL
import os

def count_ids(split, dir_path="data/subset_dir"):
    path = os.path.join(dir_path, f"{split}_design_ids.txt")
    if not os.path.exists(path):
        print(f"[{split.upper()}] File not found: {path}")
        return 0
    with open(path) as f:
        ids = [line.strip() for line in f if line.strip()]
    print(f"[{split.upper()}] Count: {len(ids)}")
    return len(ids)

if __name__ == "__main__":
    total = 0
    for split in ['train', 'val', 'test']:
        total += count_ids(split)
    print(f"[TOTAL] Combined Cars: {total}")


In [None]:
import os
import numpy as np

# Change this path to your actual slice directory
SLICE_DIR = "outputs/slices/"

# Track max Mi (points per slice) for each car
car_max_points = {}

for file in os.listdir(SLICE_DIR):
    if file.endswith(".npy"):
        car_id = file.replace(".npy", "")
        path = os.path.join(SLICE_DIR, file)

        # Load the slice file
        slices = np.load(path, allow_pickle=True)
        
        # For each slice, get number of points
        max_mi = max(slice.shape[0] for slice in slices)
        car_max_points[car_id] = max_mi
        print(f"{car_id}: max points in a slice = {max_mi}")

# Find global max Mi across all cars
global_max = max(car_max_points.values())
print("\n✅ Global max number of points in any slice (padding cap) =", global_max)


In [None]:
import os
import numpy as np
import matplotlib.pyplot as plt

# Adjust this path if needed
PADDED_DIR = "outputs/pad_masked_slices/"

# Pick a test file
test_file = sorted(os.listdir(PADDED_DIR))[0]  # Or manually: 'DrivAer_F_D_WM_WW_0001_axis-x.npz'
path = os.path.join(PADDED_DIR, test_file)

# Load .npz file
data = np.load(path)
slices = data["slices"]         # (80, 6500, 2)
point_mask = data["point_mask"] # (80, 6500)
slice_mask = data["slice_mask"] # (80,)

print(f"✅ File: {test_file}")
print("slices shape:", slices.shape)
print("point_mask shape:", point_mask.shape)
print("slice_mask shape:", slice_mask.shape)

# -------------------------------
# Determine global axis limits
# -------------------------------
all_x = slices[:, :, 0][point_mask == 1]
all_y = slices[:, :, 1][point_mask == 1]
xmin, xmax = all_x.min(), all_x.max()
ymin, ymax = all_y.min(), all_y.max()

# Add padding (optional for visual clarity)
pad_x = 0.02 * (xmax - xmin)
pad_y = 0.02 * (ymax - ymin)
xmin, xmax = xmin - pad_x, xmax + pad_x
ymin, ymax = ymin - pad_y, ymax + pad_y

# -------------------------------
# Visualize a few valid slices
# -------------------------------
n_cols = 5
fig, axes = plt.subplots(1, n_cols, figsize=(15, 3))
valid_indices = np.where(slice_mask == 1)[0][:n_cols]

for i, idx in enumerate(valid_indices):
    points = slices[idx]
    mask = point_mask[idx].astype(bool)
    axes[i].scatter(points[mask, 0], points[mask, 1], s=1, c='k')
    axes[i].set_title(f"Slice {idx}")
    axes[i].set_xlim(xmin, xmax)
    axes[i].set_ylim(ymin, ymax)
    axes[i].set_aspect("equal", adjustable="box")
    axes[i].axis("off")

plt.tight_layout()
plt.show()
