### After embeddings are created /project/part_embeddings_struct.py

In [1]:
import torch
from assembly_graph_pyg import AssemblyGraphDataset

part_embeddings = torch.load("part_embeddings.pt", map_location="cpu")

ds = AssemblyGraphDataset(
    "assembly_filter_out/assembly_graphs_v1.jsonl",
    part_embeddings=part_embeddings
)

print(ds[0])


  return torch._C._show_config()
  from .autonotebook import tqdm as notebook_tqdm


Data(x=[15, 8], edge_index=[2, 28], edge_attr=[28, 1], part_ids=[15], assembly_path='/media/swapnil/3f73cc1a-8f9d-4c19-87af-99b3512ff5b2/MK_S/Automate/assemblies/asm_00000/01984c6ecc641bbea1b098c1_7cf8f9bc95fd3a123d51b6d4_d8e061ba74a1304ea959b3f8_default.json', assembly_id='01984c6ecc641bbea1b098c1_7cf8f9bc95fd3a123d51b6d4_d8e061ba74a1304ea959b3f8_default')


### plug embeddings into the AssemblyGraphDataset

In [2]:
import torch
from assembly_graph_pyg import AssemblyGraphDataset

part_embeddings = torch.load("part_embeddings.pt", map_location="cpu")

ds = AssemblyGraphDataset(
    "assembly_filter_out/assembly_graphs_v1.jsonl",
    part_embeddings=part_embeddings
)

d0 = ds[0]
print(d0)
print("x dim:", d0.x.shape[1])


Data(x=[15, 8], edge_index=[2, 28], edge_attr=[28, 1], part_ids=[15], assembly_path='/media/swapnil/3f73cc1a-8f9d-4c19-87af-99b3512ff5b2/MK_S/Automate/assemblies/asm_00000/01984c6ecc641bbea1b098c1_7cf8f9bc95fd3a123d51b6d4_d8e061ba74a1304ea959b3f8_default.json', assembly_id='01984c6ecc641bbea1b098c1_7cf8f9bc95fd3a123d51b6d4_d8e061ba74a1304ea959b3f8_default')
x dim: 8


### do all assemblies have embeddings coverage?

If any part id is missing in embeddings, dataset loading will throw an error.
Run this quick coverage test on the first 200 assemblies:

In [3]:
import json
from pathlib import Path

GRAPHS = Path("assembly_filter_out/assembly_graphs_v1.jsonl")

missing_parts = 0
checked = 0

with GRAPHS.open() as f:
    for i, line in enumerate(f):
        if i >= 200:
            break
        rec = json.loads(line)
        assembly_path = rec["assembly_path"]

        assembly = json.load(open(assembly_path))
        for p in assembly.get("parts", []):
            pid = p.get("id")
            if isinstance(pid, str) and pid not in part_embeddings:
                missing_parts += 1

        checked += 1

print("Checked assemblies:", checked)
print("Missing part embeddings (count):", missing_parts)


Checked assemblies: 200
Missing part embeddings (count): 0


### forward pass again (now meaningful features)

Run the same forward-pass cell, just adjust in_dim=8 automatically:

In [4]:
import torch
import torch.nn.functional as F
from torch_geometric.loader import DataLoader
from torch_geometric.nn import GCNConv, global_mean_pool

loader = DataLoader(ds, batch_size=8, shuffle=True)

class TinyGNN(torch.nn.Module):
    def __init__(self, in_dim, hidden_dim=64, out_dim=32):
        super().__init__()
        self.conv1 = GCNConv(in_dim, hidden_dim)
        self.conv2 = GCNConv(hidden_dim, hidden_dim)
        self.lin = torch.nn.Linear(hidden_dim, out_dim)

    def forward(self, data):
        x, edge_index, batch = data.x, data.edge_index, data.batch
        x = F.relu(self.conv1(x, edge_index))
        x = F.relu(self.conv2(x, edge_index))
        x = global_mean_pool(x, batch)
        return self.lin(x)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = TinyGNN(in_dim=ds[0].x.shape[1]).to(device)

batch = next(iter(loader)).to(device)
out = model(batch)

print("Batch graphs:", batch.num_graphs)
print("Batch nodes:", batch.num_nodes)
print("Output shape:", out.shape)
print("Example output row:", out[0].detach().cpu()[:10])


Batch graphs: 8
Batch nodes: 20
Output shape: torch.Size([8, 32])
Example output row: tensor([ 14.8924, -10.6971, -12.8809, -10.0885,  11.0672,   2.8783,   8.8975,
         -8.1643,   4.9358,  -8.4923])


### Experiment 1: Do assembly embeddings correlate with assembly complexity?

We’ll compute, for a sample of assemblies:

num_parts

num_edges (undirected and directed)

mean and standard deviation of node features

graph embedding norm (from the model output)

simple correlations

In [5]:
import torch
import torch.nn.functional as F
from torch_geometric.loader import DataLoader
from torch_geometric.nn import GCNConv, global_mean_pool
import pandas as pd
import numpy as np

# -------- Settings --------
N_GRAPHS = 1000
BATCH_SIZE = 32

# -------- DataLoader subset --------
subset = [ds[i] for i in range(min(N_GRAPHS, len(ds)))]
loader = DataLoader(subset, batch_size=BATCH_SIZE, shuffle=False)

# -------- Model (same as before) --------
class TinyGNN(torch.nn.Module):
    def __init__(self, in_dim, hidden_dim=64, out_dim=32):
        super().__init__()
        self.conv1 = GCNConv(in_dim, hidden_dim)
        self.conv2 = GCNConv(hidden_dim, hidden_dim)
        self.lin = torch.nn.Linear(hidden_dim, out_dim)

    def forward(self, data):
        x, edge_index, batch = data.x, data.edge_index, data.batch
        x = F.relu(self.conv1(x, edge_index))
        x = F.relu(self.conv2(x, edge_index))
        x = global_mean_pool(x, batch)
        return self.lin(x)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = TinyGNN(in_dim=ds[0].x.shape[1]).to(device)
model.eval()

rows = []
with torch.no_grad():
    idx_base = 0
    for batch in loader:
        batch = batch.to(device)
        out = model(batch)  # [B, 32]

        # Split batch graphs
        # batch.ptr gives node pointer per graph; batch.batch gives graph assignment per node
        num_graphs = batch.num_graphs

        # For each graph in the batch, compute simple stats
        for g_idx in range(num_graphs):
            # mask nodes belonging to this graph
            node_mask = (batch.batch == g_idx)
            xg = batch.x[node_mask]  # [Ng, 8]

            # edges: edge_index is directed (we doubled edges earlier)
            # approximate per-graph directed edges count by filtering edges whose source node is in this graph
            src = batch.edge_index[0]
            edge_mask = node_mask[src]
            directed_edges = int(edge_mask.sum().item())
            undirected_edges = directed_edges // 2

            emb = out[g_idx]
            rows.append({
                "graph_idx": idx_base + g_idx,
                "num_nodes": int(xg.shape[0]),
                "num_edges_undirected": undirected_edges,
                "num_edges_directed": directed_edges,
                "x_mean": float(xg.mean().item()),
                "x_std": float(xg.std(unbiased=False).item()),
                "emb_norm": float(emb.norm().item()),
                "emb_mean": float(emb.mean().item()),
                "emb_std": float(emb.std(unbiased=False).item()),
            })

        idx_base += num_graphs

df = pd.DataFrame(rows)
print(df.head())

print("\n=== Correlations (Pearson) ===")
cols = ["num_nodes", "num_edges_undirected", "x_mean", "x_std", "emb_norm", "emb_mean", "emb_std"]
corr = df[cols].corr(numeric_only=True)
print(corr)

print("\nQuick summaries:")
print(df[cols].describe().T[["mean","std","min","max"]])

   graph_idx  num_nodes  num_edges_undirected  num_edges_directed      x_mean  \
0          0         15                    14                  28  579.688232   
1          1          2                     1                   2   12.266267   
2          2          9                     8                  16   44.318012   
3          3          1                     0                   0  563.009460   
4          4          4                     3                   6   88.136536   

         x_std     emb_norm   emb_mean     emb_std  
0  1723.321899  1026.231812  37.043854  177.591507  
1    14.322392    13.313927   0.094850    2.351680  
2   138.155579    66.811562   3.557670   11.262161  
3  1081.587769   931.203430  -5.437477  164.525238  
4   141.435913   132.843216   1.175073   23.454168  

=== Correlations (Pearson) ===
                      num_nodes  num_edges_undirected    x_mean     x_std  \
num_nodes              1.000000              0.994369  0.000566  0.002286   
num_edges

### Experiment: feature normalization (one cell)
Notebook cell: compute normalization stats + wrap dataset

In [7]:
import torch
from torch_geometric.loader import DataLoader

# Collect feature statistics over a sample (or full dataset)
N_STATS = 5000  # set 10920 for full usable set
subset = [ds[i] for i in range(min(N_STATS, len(ds)))]

# Stack all node features (may be big; this is manageable for 5000)
X_all = torch.cat([d.x for d in subset], dim=0).float()

# Log transform heavy-tailed dims (all dims here are heavy-tailed in practice)
X_log = torch.log1p(torch.clamp(X_all, min=0))  # safe for non-negative stats

mu = X_log.mean(dim=0)
sigma = X_log.std(dim=0, unbiased=False) + 1e-8

print("mu:", mu)
print("sigma:", sigma)

def normalize_x(x):
    x = x.float()
    x = torch.log1p(torch.clamp(x, min=0))
    return (x - mu) / sigma

# Build a normalized view of ds without rewriting files
class NormalizedWrapper(torch.utils.data.Dataset):
    def __init__(self, base_ds):
        self.base = base_ds
    def __len__(self):
        return len(self.base)
    def __getitem__(self, idx):
        d = self.base[idx]
        d.x = normalize_x(d.x)
        return d

ds_norm = NormalizedWrapper(ds)

print(ds_norm[0].x[:3])  # sanity

mu: tensor([3.0167, 4.4207, 1.6734, 2.4733, 0.2953, 2.7081, 3.6028, 5.3299])
sigma: tensor([1.1093e+00, 1.2978e+00, 2.3341e-01, 9.2471e-01, 1.8325e-01, 1.0000e-08,
        2.4772e+00, 2.5690e+00])
tensor([[ 0.3466,  0.5821,  1.2096,  0.5095, -0.5409,  0.0000,  0.8889,  0.8992],
        [ 0.5357,  0.8005,  1.4631,  0.6178, -0.6819,  0.0000,  0.6156,  0.6658],
        [ 0.5357,  0.8005,  1.4631,  0.6178, -0.6819,  0.0000,  0.5667,  0.6663]])


In [6]:
import torch
from copy import deepcopy

# ---------------------------------------
# 1) Compute normalization stats (mu/sigma)
# ---------------------------------------
N_STATS = 10920  # use full usable set; change if you want a smaller sample
subset = [ds[i] for i in range(min(N_STATS, len(ds)))]

# Stack all node features
X_all = torch.cat([d.x for d in subset], dim=0).float()

# Log transform (features are non-negative stats; clamp for safety)
X_log = torch.log1p(torch.clamp(X_all, min=0))

mu = X_log.mean(dim=0)
sigma = X_log.std(dim=0, unbiased=False) + 1e-8

print("mu:", mu)
print("sigma:", sigma)

# Save stats for reuse (recommended)
torch.save({"mu": mu, "sigma": sigma}, "x_norm_stats.pt")
print("✅ Saved normalization stats to x_norm_stats.pt")

# -------------------------------
# 2) Normalization function
# -------------------------------
def normalize_x(x, mu, sigma):
    x = x.float()
    x = torch.log1p(torch.clamp(x, min=0))
    return (x - mu) / sigma

# ---------------------------------------
# 3) Safe dataset wrapper (no in-place mutation)
# ---------------------------------------
class NormalizedWrapper(torch.utils.data.Dataset):
    def __init__(self, base_ds, mu, sigma):
        self.base = base_ds
        self.mu = mu
        self.sigma = sigma

    def __len__(self):
        return len(self.base)

    def __getitem__(self, idx):
        d = deepcopy(self.base[idx])  # prevents modifying cached objects
        d.x = normalize_x(d.x, self.mu, self.sigma)
        return d

ds_norm = NormalizedWrapper(ds, mu, sigma)

# Sanity check
print("Before (first node):", ds[0].x[0])
print("After  (first node):", ds_norm[0].x[0])
print("After stats (mean≈0, std≈1) on sample:",
      ds_norm[0].x.mean().item(), ds_norm[0].x.std(unbiased=False).item())

mu: tensor([3.0935, 4.4779, 1.6569, 2.5016, 0.2800, 2.7080, 3.4804,    nan])
sigma: tensor([1.1754e+00, 1.3934e+00, 2.6066e-01, 9.9911e-01, 1.9047e-01, 1.0000e-08,
        2.5162e+00,        nan])
✅ Saved normalization stats to x_norm_stats.pt
Before (first node): tensor([2.9000e+01, 1.7600e+02, 6.0690e+00, 1.8000e+01, 2.1675e-01, 1.4000e+01,
        3.3086e+02, 2.0788e+03])
After  (first node): tensor([ 0.2617,  0.5011,  1.1464,  0.4433, -0.4402, 47.6837,  0.9237,     nan])
After stats (mean≈0, std≈1) on sample: nan nan
