In [2]:

import os
from pathlib import Path
import json
import random
import math
import time

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, random_split

from pytorch3d.datasets import ShapeNetCore
from pytorch3d.datasets.utils import collate_batched_meshes
from pytorch3d.ops import GraphConv
from collections import defaultdict
from contextlib import nullcontext

from pytorch3d.datasets.utils import collate_batched_meshes

from pytorch3d.structures import Meshes
import time




# ---------- Configuration ----------
# Root path to your local ShapeNetCore.v2 directory.
# Example folder structure: SHAPENET_PATH/02958343/<model_id>/* etc.
SHAPENET_PATH = Path("../data/shapeNetCore").expanduser().resolve()
assert SHAPENET_PATH.exists(), f"ShapeNetCore folder not found: {SHAPENET_PATH}"
print("Dataset root:", SHAPENET_PATH)

# Choose 5 synsets by default (airplane, chair, lamp, mug, table)
# Feel free to edit this to the categories you actually want to use.
CATEGORIES = {
    "02808440": "bathtub",
    "02992529": "cellphone",
    "03046257": "clock",
    "03211117": "display",
    "03642806": "laptop",
}

NUM_CLASSES = len(CATEGORIES)

# Data / training params
BATCH_SIZE = 4
VAL_SPLIT = 0.2
NUM_WORKERS = 2
EPOCHS = 10
LR = 1e-3
SEED = 42

device = (torch.device("mps") if torch.backends.mps.is_available()
          else torch.device("cuda") if torch.cuda.is_available()
          else torch.device("cpu"))
print("Device:", device)
# Reproducibility
random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)


Dataset root: /Users/brageramberg/Desktop/3DCNN/data/shapeNetCore
Device: mps


## Dataset & Dataloaders

In [4]:
def mean_pool(x: torch.Tensor, num_verts_per_mesh: torch.Tensor) -> torch.Tensor:
    """
    x: (sum(V_i), F) vertex features packed across the batch
    num_verts_per_mesh: (B,) number of vertices per mesh
    returns: (B, F) mean pooled features per mesh
    """
    # Build batch index: for mesh i, repeat index i num_verts_per_mesh[i] times
    batch_index = torch.cat([
        torch.full((int(n),), i, device=x.device, dtype=torch.long)
        for i, n in enumerate(num_verts_per_mesh)
    ], dim=0)  # (sum(V_i),)

    B = int(num_verts_per_mesh.numel())
    Fdim = x.size(1)
    out = x.new_zeros((B, Fdim))
    out.index_add_(0, batch_index, x)  # sum over vertices for each mesh
    out = out / num_verts_per_mesh.view(-1, 1).to(x.dtype)
    return out

In [5]:
# Load ShapeNetCore with selected synsets.
# Setting load_textures=False avoids texture loading overhead/issues.
dataset = ShapeNetCore(
    data_dir=str(SHAPENET_PATH),                # <-- was root=
    synsets=list(CATEGORIES.keys()),            # <-- pass a list of synset ids
    version=2,
    load_textures=False
)

print(f"Total samples found: {len(dataset)}")

# Split into train/val
val_size = int(len(dataset) * VAL_SPLIT)
train_size = len(dataset) - val_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

print(f"Train: {len(train_dataset)}, Val: {len(val_dataset)}")

# Build synset->index mapping for labels
synset_list = list(CATEGORIES.keys())  # preserves your chosen order
synset_to_idx = {sid: i for i, sid in enumerate(synset_list)}
idx_to_synset = {i: sid for sid, i in synset_to_idx.items()}
idx_to_name = {i: CATEGORIES[sid] for i, sid in idx_to_synset.items()}


def normalize_verts(v):
    c = v.mean(0, keepdim=True)
    v = v - c
    s = v.norm(dim=1).max().clamp(min=1e-6)
    return v / s



def decimate_mesh(verts, faces, max_faces=1500):
    F = faces.shape[0]
    if F <= max_faces:
        return verts, faces
    idx = torch.randperm(F)[:max_faces]
    f_small = faces[idx]
    used, inv = torch.unique(f_small.reshape(-1), sorted=True, return_inverse=True)
    v_small = verts[used]
    f_small = inv.reshape(-1, 3)
    return v_small, f_small

def custom_collate_fn(batch):
    V, F = [], []
    for s in batch:
        v, f = decimate_mesh(s["verts"], s["faces"], max_faces=1500)
        V.append(normalize_verts(v))
        F.append(f)
    meshes = Meshes(verts=V, faces=F)
    labels = torch.tensor([synset_to_idx[s["synset_id"]] for s in batch], dtype=torch.long)
    return {"meshes": meshes, "labels": labels}

# Nuke old references so we don't accidentally use them
try:
    del train_loader
    del val_loader
except NameError:
    pass

# (Re)build loaders with num_workers=0 in notebooks/macOS
train_loader = DataLoader(
    train_dataset, batch_size=BATCH_SIZE, shuffle=True,
    num_workers=0, collate_fn=custom_collate_fn,
    pin_memory=(device.type == "cuda"),
)
val_loader = DataLoader(
    val_dataset, batch_size=BATCH_SIZE, shuffle=False,
    num_workers=0, collate_fn=custom_collate_fn,
    pin_memory=(device.type == "cuda"),
)

print("train_loader workers:", train_loader.num_workers)
print("val_loader workers:",   val_loader.num_workers)

Total samples found: 3891
Train: 3113, Val: 778
train_loader workers: 0
val_loader workers: 0




In [6]:
b = next(iter(val_loader))
print(type(b["meshes"]))         # should be <class 'pytorch3d.structures.meshes.Meshes'>
print(b["labels"].shape)         # torch.Size([batch_size])

<class 'pytorch3d.structures.meshes.Meshes'>
torch.Size([4])


## Model: Simple GraphCNN

In [7]:


class SimpleGraphCNN(nn.Module):
    """
    GraphConv classifier:
    - per-vertex input: (x, y, z)
    - 3 GraphConv layers
    - global mean pooling per mesh (pure PyTorch)
    - MLP head -> NUM_CLASSES
    """
    def __init__(self, num_classes: int):
        super().__init__()
        self.conv1 = GraphConv(3, 64)
        self.conv2 = GraphConv(64, 128)
        self.conv3 = GraphConv(128, 256)
        self.fc1 = nn.Linear(256, 128)
        self.fc2 = nn.Linear(128, num_classes)

    def forward(self, meshes):
        x = meshes.verts_packed()     # (sum(V_i), 3)
        edges = meshes.edges_packed() # (sum(E_i), 2) int64

        x = F.relu(self.conv1(x, edges))
        x = F.relu(self.conv2(x, edges))
        x = F.relu(self.conv3(x, edges))

        # Global mean pool per mesh w/o torch_scatter
        num_verts_per_mesh = meshes.num_verts_per_mesh()  # (B,)
        avg_features = mean_pool(x, num_verts_per_mesh)   # (B, 256)

        h = F.relu(self.fc1(avg_features))
        logits = self.fc2(h)
        return logits

model = SimpleGraphCNN(NUM_CLASSES).to(device)
print(model)


SimpleGraphCNN(
  (conv1): GraphConv(3 -> 64, directed=False)
  (conv2): GraphConv(64 -> 128, directed=False)
  (conv3): GraphConv(128 -> 256, directed=False)
  (fc1): Linear(in_features=256, out_features=128, bias=True)
  (fc2): Linear(in_features=128, out_features=5, bias=True)
)


## Training & Evaluation Loops

In [8]:
# --- Training setup (replace this whole block) ---
from contextlib import nullcontext

# Keep your existing 'device' from earlier; don't redefine it here
model.to(device)

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=3e-4, weight_decay=1e-4)

use_cuda = (device.type == "cuda")
scaler = torch.amp.GradScaler(device_type="cuda") if use_cuda else None  # CUDA only

PRINT_EVERY = 25  # batches

def run_epoch(loader, train: bool = True):
    model.train(train)
    running_loss, running_correct, n_samples = 0.0, 0, 0

    autocast = (torch.amp.autocast(device_type="cuda", dtype=torch.float16)
                if device.type == "cuda" else nullcontext())

    start = time.time()
    for bi, batch in enumerate(loader, 1):
        t0 = time.time()
        meshes = batch["meshes"].to(device)
        labels = batch["labels"].to(device)
        bs = labels.size(0)
        n_samples += bs

        with autocast:
            logits = model(meshes)
            logits = torch.nan_to_num(logits, nan=0.0, posinf=1e4, neginf=-1e4)
            loss = criterion(logits, labels)

        optimizer.zero_grad(set_to_none=True)
        if device.type == "cuda" and scaler is not None:
            scaler.scale(loss).backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            scaler.step(optimizer)
            scaler.update()
        else:
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()

        running_loss += loss.item() * bs
        running_correct += (logits.argmax(1) == labels).sum().item()

        if bi % PRINT_EVERY == 0 or bi == len(loader):
            dt = time.time() - start
            avg_bt = dt / bi
            eta = avg_bt * (len(loader) - bi)
            print(f"[{'train' if train else 'val':5}] "
                  f"batch {bi:4d}/{len(loader)} | "
                  f"avg {avg_bt:.2f}s/batch | ETA {eta/60:.1f}m")

    epoch_loss = running_loss / max(1, n_samples)
    epoch_acc  = running_correct / max(1, n_samples)
    return epoch_loss, epoch_acc

def train_loop(epochs=EPOCHS):
    best_val_acc = 0.0
    for ep in range(1, epochs + 1):
        t0 = time.time()
        train_loss, train_acc = run_epoch(train_loader, train=True)
        val_loss,   val_acc   = run_epoch(val_loader,   train=False)
        dt = time.time() - t0

        print(f"Epoch {ep:02d}/{epochs} | "
              f"train_loss={train_loss:.4f} acc={train_acc:.3f} | "
              f"val_loss={val_loss:.4f} acc={val_acc:.3f} | "
              f"{dt:.1f}s")

        if val_acc > best_val_acc:
            best_val_acc = val_acc
            torch.save({
                "model_state_dict": model.state_dict(),
                "synset_to_idx": synset_to_idx,
                "config": {"CATEGORIES": CATEGORIES, "NUM_CLASSES": NUM_CLASSES},
            }, "best_graphcnn.pt")
            print(f"  ✓ Saved checkpoint (val_acc={best_val_acc:.3f})")

## Inference / Evaluation on a Single Batch

In [7]:
model.eval()
batch = next(iter(val_loader))
meshes = batch["meshes"].to(device)   # now works (Meshes has .to)
labels = batch["labels"].to(device)

with torch.inference_mode():
    logits = model(meshes)
    preds = logits.argmax(dim=1)

print("Pred idx:", preds.tolist())
print("GT idx:  ", labels.tolist())
print("Pred synsets:", [idx_to_synset[i.item()] for i in preds])
print("GT synsets:  ", [idx_to_synset[i.item()] for i in labels])
print("Pred names:  ", [idx_to_name[i.item()]   for i in preds])
print("GT names:    ", [idx_to_name[i.item()]   for i in labels])

Pred idx: [3, 3, 3, 3]
GT idx:   [1, 3, 0, 1]
Pred synsets: ['03211117', '03211117', '03211117', '03211117']
GT synsets:   ['02992529', '03211117', '02808440', '02992529']
Pred names:   ['display', 'display', 'display', 'display']
GT names:     ['cellphone', 'display', 'bathtub', 'cellphone']


In [8]:
print("Model on:", next(model.parameters()).device)
b = next(iter(train_loader))
print("Meshes pre-move:", type(b["meshes"]))
meshes = b["meshes"].to(device)
labels = b["labels"].to(device)
print("Meshes moved to device OK; labels on", labels.device)

Model on: mps:0
Meshes pre-move: <class 'pytorch3d.structures.meshes.Meshes'>
Meshes moved to device OK; labels on mps:0


In [9]:
# --- SMOKE TEST ---

from torch.utils.data import Subset

SMOKE_N = 64  # try 64 first; bump later
train_subset = Subset(train_dataset, list(range(min(SMOKE_N, len(train_dataset)))))
val_subset   = Subset(val_dataset,   list(range(min(SMOKE_N//4, len(val_dataset)))))

train_loader = DataLoader(
    train_subset, batch_size=2, shuffle=True,
    num_workers=0, collate_fn=custom_collate_fn,
)
val_loader = DataLoader(
    val_subset, batch_size=2, shuffle=False,
    num_workers=0, collate_fn=custom_collate_fn,
)
print(len(train_loader), len(val_loader))

32 8


In [10]:
# Use full datasets? (set True when done debugging)
USE_FULL = True

# Per-class caps for the stratified subsets (tune as you like)
TAKE_PER_CLASS_TRAIN = 300   # try 200–500 depending on speed
TAKE_PER_CLASS_VAL   = 100   # make val bigger to stabilize accuracy

rng = random.Random(SEED)

def stratified_indices(ds, take_per_class):
    """Return a balanced list of indices for dataset ds based on 'synset_id'."""
    buckets = defaultdict(list)
    for i in range(len(ds)):
        s = ds[i]["synset_id"]
        buckets[s].append(i)
    sel = []
    for sid in CATEGORIES.keys():
        idxs = buckets.get(sid, [])
        rng.shuffle(idxs)
        sel.extend(idxs[:min(take_per_class, len(idxs))])
    rng.shuffle(sel)
    return sel, {sid: len([i for i in sel if ds[i]["synset_id"] == sid]) for sid in CATEGORIES.keys()}

if USE_FULL:
    train_dataset_ = train_dataset
    val_dataset_   = val_dataset
    train_counts   = {sid: None for sid in CATEGORIES.keys()}
    val_counts     = {sid: None for sid in CATEGORIES.keys()}
else:
    train_sel, train_counts = stratified_indices(train_dataset, TAKE_PER_CLASS_TRAIN)
    val_sel,   val_counts   = stratified_indices(val_dataset,   TAKE_PER_CLASS_VAL)
    train_dataset_ = Subset(train_dataset, train_sel)
    val_dataset_   = Subset(val_dataset,   val_sel)

print("Per-class train counts:", {CATEGORIES[sid]: train_counts[sid] for sid in CATEGORIES})
print("Per-class val counts:  ", {CATEGORIES[sid]: val_counts[sid]   for sid in CATEGORIES})

# Loader params (MPS/CPU-friendly)
BATCH_SIZE = 4
NUM_WORKERS = 0
PIN = (device.type == "cuda")

train_loader = DataLoader(
    train_dataset_, batch_size=BATCH_SIZE, shuffle=True,
    num_workers=NUM_WORKERS, collate_fn=custom_collate_fn, pin_memory=PIN,
)
val_loader = DataLoader(
    val_dataset_, batch_size=BATCH_SIZE, shuffle=False,
    num_workers=NUM_WORKERS, collate_fn=custom_collate_fn, pin_memory=PIN,
)

print(f"Train batches: {len(train_loader)} | Val batches: {len(val_loader)} | batch={BATCH_SIZE}")

Per-class train counts: {'bathtub': None, 'cellphone': None, 'clock': None, 'display': None, 'laptop': None}
Per-class val counts:   {'bathtub': None, 'cellphone': None, 'clock': None, 'display': None, 'laptop': None}
Train batches: 779 | Val batches: 195 | batch=4


In [11]:
train_loop(epochs=4)   # try 10–30 on MPS

[train] batch   25/779 | avg 0.39s/batch | ETA 4.9m
[train] batch   50/779 | avg 0.40s/batch | ETA 4.8m
[train] batch   75/779 | avg 0.42s/batch | ETA 5.0m
[train] batch  100/779 | avg 0.40s/batch | ETA 4.5m
[train] batch  125/779 | avg 0.40s/batch | ETA 4.4m
[train] batch  150/779 | avg 0.40s/batch | ETA 4.2m
[train] batch  175/779 | avg 0.40s/batch | ETA 4.1m
[train] batch  200/779 | avg 0.40s/batch | ETA 3.9m
[train] batch  225/779 | avg 0.42s/batch | ETA 3.9m
[train] batch  250/779 | avg 0.43s/batch | ETA 3.8m
[train] batch  275/779 | avg 0.43s/batch | ETA 3.6m
[train] batch  300/779 | avg 0.45s/batch | ETA 3.6m
[train] batch  325/779 | avg 0.45s/batch | ETA 3.4m
[train] batch  350/779 | avg 0.46s/batch | ETA 3.3m
[train] batch  375/779 | avg 0.47s/batch | ETA 3.1m
[train] batch  400/779 | avg 0.48s/batch | ETA 3.0m
[train] batch  425/779 | avg 0.49s/batch | ETA 2.9m
[train] batch  450/779 | avg 0.50s/batch | ETA 2.7m
[train] batch  475/779 | avg 0.51s/batch | ETA 2.6m
[train] batc

In [9]:
ckpt = torch.load("best_graphcnn.pt", map_location=device)
model.load_state_dict(ckpt["model_state_dict"])
print("Loaded model weights from checkpoint.")

# Full validation accuracy
model.eval()
correct = 0
total = 0
with torch.inference_mode():
    for batch in val_loader:
        meshes = batch["meshes"].to(device)
        labels = batch["labels"].to(device)
        logits = model(meshes)
        preds = logits.argmax(1)
        correct += (preds == labels).sum().item()
        total += labels.numel()
val_acc = correct / max(1, total)
print(f"Validation accuracy: {val_acc:.3f}")

Loaded model weights from checkpoint.
Validation accuracy: 0.733


In [10]:
# Confusion matrix

num_classes = NUM_CLASSES
cm = torch.zeros((num_classes, num_classes), dtype=torch.long)
model.eval()
with torch.inference_mode():
    for batch in val_loader:
        meshes = batch["meshes"].to(device)
        labels = batch["labels"]
        preds = model(meshes).argmax(1).cpu()
        for t, p in zip(labels, preds):
            cm[t, p] += 1

print("Confusion matrix rows=GT, cols=Pred:\n", cm)
print("Per-class accuracy:")
for i in range(num_classes):
    denom = cm[i].sum().item()
    acc_i = (cm[i, i].item() / denom) if denom > 0 else float("nan")
    print(f"  {i}: {idx_to_name[i]}  acc={acc_i:.3f}")

Confusion matrix rows=GT, cols=Pred:
 tensor([[145,   0,   0,   9,   7],
        [  0, 137,   4,  33,   0],
        [  8,  33,  45,  45,   1],
        [  6,   9,   3, 198,   7],
        [ 31,   0,   0,   8,  49]])
Per-class accuracy:
  0: bathtub  acc=0.901
  1: cellphone  acc=0.787
  2: clock  acc=0.341
  3: display  acc=0.888
  4: laptop  acc=0.557
