In [5]:
import math
import os
from dataclasses import dataclass
from typing import List, Tuple

import torch
import torch.nn as nn
import torch.nn.functional as F


# ----------------------------
# 1) We build a closed quad mesh by subdividing each cube face into an n×n grid,
#    welding shared edges/vertices, then projecting to a sphere.
# ----------------------------

@dataclass
class Cage:
    V: torch.Tensor                 # (N,3) float32 vertices
    F: torch.Tensor                 # (M,4) int64 quad faces (indices into V)
    adjacency: List[List[int]]      # per-vertex neighbor list

def _face_grid(n: int) -> Tuple[torch.Tensor, torch.Tensor]:
#    Build a single (n+1)x(n+1) grid of vertices in (s,t) ∈ [-1,1]^2, and quad faces.
#    Returns: uv (Nv,2) and qf (n*n,4) indices into uv

    lin = torch.linspace(-1.0, 1.0, n + 1)
    ss, tt = torch.meshgrid(lin, lin, indexing="ij")
    uv = torch.stack([ss.reshape(-1), tt.reshape(-1)], dim=1)  # (Nv,2)

    def vid(i, j):  # i,j in [0..n]
        return i * (n + 1) + j

    faces = []
    for i in range(n):
        for j in range(n):
            v00 = vid(i, j)
            v10 = vid(i + 1, j)
            v11 = vid(i + 1, j + 1)
            v01 = vid(i, j + 1)
            faces.append([v00, v10, v11, v01])  # consistent winding on the grid
    qf = torch.tensor(faces, dtype=torch.long)
    return uv, qf

def build_cube_sphere_cage(n: int = 6, device: str = "cpu") -> Cage:
#      - 6 faces of a cube, each subdivided into n×n quads
#      - welded shared vertices
#      - projected to unit sphere (cube-sphere)

    uv, qf_local = _face_grid(n)  # same for all faces
    uv = uv.to(device)
    qf_local = qf_local.to(device)

    # For each cube face, map (s,t) to xyz on that face
    def face_xyz(face_id: int, s: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
        one = torch.ones_like(s)
        if face_id == 0:   # +X
            return torch.stack([one, s, t], dim=1)
        if face_id == 1:   # -X
            return torch.stack([-one, s, -t], dim=1)
        if face_id == 2:   # +Y
            return torch.stack([s, one, t], dim=1)
        if face_id == 3:   # -Y
            return torch.stack([s, -one, -t], dim=1)
        if face_id == 4:   # +Z
            return torch.stack([s, t, one], dim=1)
        if face_id == 5:   # -Z
            return torch.stack([-s, t, -one], dim=1)
        raise ValueError("bad face_id")

    # Weld vertices using quantized keys in cube-space (before sphere projection).
    vert_map = {}
    verts = []
    faces = []

    def key_from_xyz(xyz: torch.Tensor) -> Tuple[int, int, int]:
        # quantize to avoid float hash issues; values are on {-1..1} grid anyway
        q = torch.round(xyz * 10_000).to(torch.int64).tolist()
        return (q[0], q[1], q[2])

    for f in range(6):
        xyz_face = face_xyz(f, uv[:, 0], uv[:, 1])  # (Nv,3)
        local_to_global = torch.empty((xyz_face.shape[0],), dtype=torch.long, device=device)

        for i in range(xyz_face.shape[0]):
            k = key_from_xyz(xyz_face[i])
            if k not in vert_map:
                vert_map[k] = len(verts)
                verts.append(xyz_face[i].detach().cpu())  # store on CPU for list
            local_to_global[i] = vert_map[k]

        # Add faces with global indices
        fg = local_to_global[qf_local]  # (n*n,4)
        faces.append(fg.detach().cpu())

    V = torch.stack(verts, dim=0).to(torch.float32).to(device)  # cube vertices
    Fq = torch.cat(faces, dim=0).to(torch.long).to(device)

    # Project cube to sphere (simple normalization).
    V = V / (V.norm(dim=1, keepdim=True).clamp_min(1e-12))

    # Build adjacency from faces
    adj = [[] for _ in range(V.shape[0])]
    for quad in Fq.tolist():
        a, b, c, d = quad
        edges = [(a, b), (b, c), (c, d), (d, a)]
        for u, v in edges:
            if v not in adj[u]:
                adj[u].append(v)
            if u not in adj[v]:
                adj[v].append(u)

    return Cage(V=V, F=Fq, adjacency=adj)


# ----------------------------
# 2) A tiny "cage generator" network
#    Input: low-dim design parameters (a,b,c,bump_amp)
#    Output: per-vertex positions of the SubD cage
# ----------------------------

class CageNet(nn.Module):
    def __init__(self, num_verts: int, cond_dim: int = 4, hidden: int = 256):
        super().__init__()
        self.num_verts = num_verts

        # A simple MLP that outputs Nx3 offsets; topology fixed
        self.mlp = nn.Sequential(
            nn.Linear(cond_dim, hidden),
            nn.SiLU(),
            nn.Linear(hidden, hidden),
            nn.SiLU(),
            nn.Linear(hidden, num_verts * 3),
        )

        # Small init helps stability
        for m in self.mlp.modules():
            if isinstance(m, nn.Linear):
                nn.init.xavier_uniform_(m.weight, gain=0.5)
                nn.init.zeros_(m.bias)

    def forward(self, base_V: torch.Tensor, cond: torch.Tensor) -> torch.Tensor:
#        base_V = (N,3); cond = (B,cond_dim); return = (B,N,3)

        B = cond.shape[0]
        offs = self.mlp(cond).view(B, self.num_verts, 3)
        return base_V.unsqueeze(0) + offs


# ----------------------------
# 3) Synthetic target shapes
#    Teach the network: given parameters, output a cage that fits an ellipsoid
#    with an optional bump (for visual interest).
# ----------------------------
def target_shape(base_V: torch.Tensor, cond: torch.Tensor) -> torch.Tensor:
#    base_V = (N,3) on sphere; cond = (B,4) with [a,b,c,bump_amp]; returns (B,N,3)

    B = cond.shape[0]
    V = base_V.unsqueeze(0).expand(B, -1, -1)  # (B,N,3)

    a = cond[:, 0].view(B, 1, 1)
    b = cond[:, 1].view(B, 1, 1)
    c = cond[:, 2].view(B, 1, 1)
    bump = cond[:, 3].view(B, 1, 1)

    # Ellipsoid scaling
    Ve = torch.cat([V[..., 0:1] * a, V[..., 1:2] * b, V[..., 2:3] * c], dim=-1)

    # Smooth bump near +Z
    z = V[..., 2:3]
    bump_field = torch.exp(-25.0 * (1.0 - z).clamp_min(0.0) ** 2)
    Ve = Ve + bump * bump_field * V  # outward along radial direction

    return Ve



# ----------------------------
# 4) Regularizers for a "CAD-ish" cage
#    - edge length regularization
#    - Laplacian smoothness (keeps cage fair)
# ----------------------------

def edge_list_from_adjacency(adj: List[List[int]]) -> torch.Tensor:
    edges = set()
    for i, nbrs in enumerate(adj):
        for j in nbrs:
            a, b = (i, j) if i < j else (j, i)
            edges.add((a, b))
    return torch.tensor(sorted(list(edges)), dtype=torch.long)

def edge_length_loss(V: torch.Tensor, edges: torch.Tensor) -> torch.Tensor:
#    V: (B,N,3), edges: (E,2)

    a = V[:, edges[:, 0], :]
    b = V[:, edges[:, 1], :]
    L = (a - b).norm(dim=-1)
    return ((L - L.mean(dim=1, keepdim=True)) ** 2).mean()

def laplacian_smoothness(V: torch.Tensor, adj: List[List[int]]) -> torch.Tensor:
#    simple uniform Laplacian: ||v_i - avg(neigh)||^2;    V = (B,N,3)

    B, N, _ = V.shape
    out = 0.0
    for i, nbrs in enumerate(adj):
        if not nbrs:
            continue
        nbr = torch.tensor(nbrs, dtype=torch.long, device=V.device)
        avg = V[:, nbr, :].mean(dim=1)
        out = out + (V[:, i, :] - avg).pow(2).sum(dim=-1).mean()
    return out / N


# ----------------------------
# 5) OBJ export (quads)
# ----------------------------

def write_obj_quads(path: str, V: torch.Tensor, Fq: torch.Tensor):
#    V = (N,3) on CPU; Fq = (M,4) on CPU (0-based)

    with open(path, "w", encoding="utf-8") as f:
        for v in V.tolist():
            f.write(f"v {v[0]:.8f} {v[1]:.8f} {v[2]:.8f}\n")
        for q in Fq.tolist():
            # OBJ is 1-based
            a, b, c, d = [idx + 1 for idx in q]
            f.write(f"f {a} {b} {c} {d}\n")


# ----------------------------
# 6) Demo training + inference
# ----------------------------

def main():
    device = "cuda" if torch.cuda.is_available() else "cpu"
    torch.manual_seed(0)

    cage = build_cube_sphere_cage(n=6, device=device)
    base_V = cage.V  # (N,3)
    edges = edge_list_from_adjacency(cage.adjacency).to(device)

    net = CageNet(num_verts=base_V.shape[0], cond_dim=4, hidden=256).to(device)
    opt = torch.optim.Adam(net.parameters(), lr=2e-3)

    # Train on random ellipsoid parameters
    for step in range(1000):
        B = 32
        # a,b,c in [0.6,1.4], bump in [0,0.15]
        cond = torch.empty((B, 4), device=device)
        cond[:, 0:3] = 0.6 + 0.8 * torch.rand((B, 3), device=device)
        cond[:, 3:4] = 0.0 + 0.15 * torch.rand((B, 1), device=device)

        V_pred = net(base_V, cond)               # (B,N,3)
        V_tgt = target_shape(base_V, cond)       # (B,N,3)

        # Fit loss (vertex-wise in this toy; in real systems you'd use samples/SDF/Chamfer/etc.)
        fit = F.mse_loss(V_pred, V_tgt)

        # Regularize for fairness / “nice cage”
        reg_e = edge_length_loss(V_pred, edges)
        reg_l = laplacian_smoothness(V_pred, cage.adjacency)

        loss = fit + 0.1 * reg_e + 0.05 * reg_l

        opt.zero_grad(set_to_none=True)
        loss.backward()
        opt.step()

        if step % 200 == 0:
            print(f"step {step:4d}  loss={loss.item():.6f}  fit={fit.item():.6f}")

    # Inference: choose a specific "design"
    cond = torch.tensor([[1.2, 0.9, 0.7, 0.12]], device=device)  # (1,4)
    with torch.no_grad():
        V_out = net(base_V, cond)[0].detach().cpu()
        F_out = cage.F.detach().cpu()

    out_path = "predicted_subd_cage.obj"
    write_obj_quads(out_path, V_out, F_out)
    print(f"\nWrote SubD control cage OBJ (quads): {os.path.abspath(out_path)}")
    print("Load in Blender/Rhino/etc. and apply Catmull–Clark subdivision to see the smooth limit surface.")

if __name__ == "__main__":
    main()


step    0  loss=0.019316  fit=0.019056


step  200  loss=0.000391  fit=0.000047


step  400  loss=0.000419  fit=0.000045


step  600  loss=0.000354  fit=0.000039


step  800  loss=0.000393  fit=0.000030



Wrote SubD control cage OBJ (quads): /home/user/predicted_subd_cage.obj
Load in Blender/Rhino/etc. and apply Catmull–Clark subdivision to see the smooth limit surface.
