# Train GCN for Heat Stake Recognition

In [9]:
import random
import torch
import torch.nn as nn
import torch.optim as optim
from torch_geometric.loader import DataLoader
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt

In [10]:
SEED = 60
EPOCHS = 200
BATCH_SIZE = 16
LR = 0.001
DROPOUT = 0.3
VAL_SPLIT = 0.2

random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

print(torch.cuda.get_arch_list())
print(torch.cuda.is_available())
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {DEVICE}")


['sm_50', 'sm_60', 'sm_61', 'sm_70', 'sm_75', 'sm_80', 'sm_86', 'sm_90']
True
Using device: cuda


# Create Dataset

In [None]:
import sys
import cadquery as cq
from pathlib import Path
import torch


# Resolve paths relative to this notebook's folder when possible
BASE_DIR = Path.cwd().parent
if str(BASE_DIR) not in sys.path:
    sys.path.insert(0, str(BASE_DIR))
    
from preprocessing.graphs import build_brep_graph, nx_to_PyG

DATA_DIR = BASE_DIR / "GCN" / "training_data"
HEATSTAKE_DIR = DATA_DIR / "allheatstakes"
OTHER_DIR = DATA_DIR / "allother"
DATASET_FILE = BASE_DIR / "GCN" / "training_ready_dataset.pt"

if True:
    dataset = []

    def iter_step_files(folder: Path):
        return [p for p in folder.rglob('*') if p.suffix.lower() in {'.stp', '.step'}]

    possible_heatstakes = iter_step_files(HEATSTAKE_DIR)
    possible_others = iter_step_files(OTHER_DIR)

    print(f"Found {len(possible_heatstakes)} heatstake STEP files and {len(possible_others)} other STEP files.")
    for heatstake_path in possible_heatstakes:
        solids = cq.importers.importStep(str(heatstake_path)).faces()
        G = build_brep_graph(solids)
        data = nx_to_PyG([G])
        data[0].y = torch.tensor([1], dtype=torch.long)  # class 1 = heatstake
        dataset.append(data[0])
    for other_path in possible_others:
        solids = cq.importers.importStep(str(other_path)).faces()    
        G = build_brep_graph(solids)
        data = nx_to_PyG([G])
        data[0].y = torch.tensor([0], dtype=torch.long)  # class 0 = other
        dataset.append(data[0])

    if len(dataset) == 0:
        print("No graphs were created. Ensure your folders contain .stp/.step files and your preprocessing functions are available.")
    else:
        torch.save(dataset, DATASET_FILE)
        print(f"Saved dataset with {len(dataset)} graphs to {DATASET_FILE}")


Found 150 heatstake STEP files and 158 other STEP files.
Saved dataset with 308 graphs to c:\Users\A01369877\Documents\GM\3d-part-localization\GCN\training_ready_dataset.pt


In [3]:
# Load dataset (expects a single .pt file saved as a list of PyG Data objects)
from pathlib import Path
import torch

BASE_DIR = Path.cwd().parent
DATASET_FILE = BASE_DIR / "GCN" / "training_ready_dataset.pt"

if DATASET_FILE.exists():
    dataset = torch.load(DATASET_FILE, weights_only=False)
    print(f"Loaded dataset with {len(dataset)} graphs from {DATASET_FILE}")
else:
    dataset = []
    print(f"Dataset file not found at {DATASET_FILE}. Add data or build dataset first.")

# Basic sanity check
if len(dataset) > 0:
    assert hasattr(dataset[0], 'x') and hasattr(dataset[0], 'edge_index') and hasattr(dataset[0], 'y'), \
        "Each Data must have x, edge_index, and y"

Loaded dataset with 308 graphs from c:\Users\A01369877\Documents\GM\3d-part-localization\GCN\training_ready_dataset.pt


In [4]:
# Split into train/val and create loaders
if len(dataset) > 0:
    labels = [int(d.y.item()) for d in dataset]
    train_idx, val_idx = train_test_split(
        list(range(len(dataset))),
        test_size=VAL_SPLIT,
        random_state=SEED,
        stratify=labels if len(set(labels)) > 1 else None,
    )
    train_dataset = [dataset[i] for i in train_idx]
    val_dataset = [dataset[i] for i in val_idx]

    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)
    print(f"Train graphs: {len(train_dataset)} | Val graphs: {len(val_dataset)}")
else:
    train_loader = None
    val_loader = None


Train graphs: 246 | Val graphs: 62


In [5]:
from GCN import GCN2

# Create model, criterion, optimizer
if train_loader is not None:
    in_channels = train_dataset[0].x.size(-1)
    #model = GCN(feature_dim_size=in_channels, num_classes=2, dropout=DROPOUT).to(DEVICE)
    model = GCN2(feature_dim_size=in_channels, num_classes=2, dropout=DROPOUT).to(DEVICE)
    criterion = nn.NLLLoss()
    optimizer = optim.Adam(model.parameters(), lr=LR)

    print(model)
else:
    model = None


GCN2(
  (convs): ModuleList(
    (0): GCNConv(3, 64)
    (1-2): 2 x GCNConv(64, 64)
  )
  (norms): ModuleList(
    (0-2): 3 x LayerNorm((64,), eps=1e-05, elementwise_affine=True)
  )
  (classifier): Sequential(
    (0): Linear(in_features=128, out_features=64, bias=True)
    (1): ReLU()
    (2): Dropout(p=0.3, inplace=False)
    (3): Linear(in_features=64, out_features=2, bias=True)
  )
)


In [6]:
# Train/eval helpers

def train_one_epoch(model, loader, optimizer, criterion):
    model.train()
    total_loss = 0.0
    correct = 0
    total = 0

    for batch in loader:
        # The provided GCN does not aggregate per-graph using the batch vector, so
        # we process each graph in the batch individually.
        data_list = batch.to_data_list()

        optimizer.zero_grad()
        batch_loss = 0.0
        batch_correct = 0
        batch_total = 0

        for data in data_list:
            data = data.to(DEVICE)
            out = model(adj=data.edge_index, features=data.x)  # shape [1, 2]
            loss = criterion(out, data.y.long())
            loss.backward()
            batch_loss += loss.item()

            preds = out.argmax(dim=1)
            batch_correct += int((preds == data.y).sum().item())
            batch_total += data.y.size(0)

        optimizer.step()

        total_loss += batch_loss
        correct += batch_correct
        total += batch_total

    avg_loss = total_loss / max(1, len(loader))
    acc = correct / max(1, total)
    return avg_loss, acc


def evaluate(model, loader, criterion):
    model.eval()
    total_loss = 0.0
    correct = 0
    total = 0
    with torch.no_grad():
        for batch in loader:
            data_list = batch.to_data_list()
            batch_loss = 0.0
            batch_correct = 0
            batch_total = 0
            for data in data_list:
                data = data.to(DEVICE)
                out = model(adj=data.edge_index, features=data.x)
                loss = criterion(out, data.y.long())
                batch_loss += loss.item()
                preds = out.argmax(dim=1)
                batch_correct += int((preds == data.y).sum().item())
                batch_total += data.y.size(0)
            total_loss += batch_loss
            correct += batch_correct
            total += batch_total
    avg_loss = total_loss / max(1, len(loader))
    acc = correct / max(1, total)
    return avg_loss, acc


In [7]:
# Training loop with plots
if train_loader is not None and model is not None:
    history = {"train_loss": [], "train_acc": [], "val_loss": [], "val_acc": []}
    for epoch in range(1, EPOCHS + 1):
        tr_loss, tr_acc = train_one_epoch(model, train_loader, optimizer, criterion)
        vl_loss, vl_acc = evaluate(model, val_loader, criterion)

        history["train_loss"].append(tr_loss)
        history["train_acc"].append(tr_acc)
        history["val_loss"].append(vl_loss)
        history["val_acc"].append(vl_acc)

        print(f"Epoch {epoch:03d} | Train Loss: {tr_loss:.4f} Acc: {tr_acc:.3f} | Val Loss: {vl_loss:.4f} Acc: {vl_acc:.3f}")

    # Plot loss and accuracy
    fig, axs = plt.subplots(1, 2, figsize=(12, 4))
    axs[0].plot(history["train_loss"], label="train")
    axs[0].plot(history["val_loss"], label="val")
    axs[0].set_title("Loss")
    axs[0].set_xlabel("Epoch")
    axs[0].set_ylabel("NLLLoss")
    axs[0].legend()

    axs[1].plot(history["train_acc"], label="train")
    axs[1].plot(history["val_acc"], label="val")
    axs[1].set_title("Accuracy")
    axs[1].set_xlabel("Epoch")
    axs[1].set_ylabel("Accuracy")
    axs[1].legend()

    plt.tight_layout()
    plt.show()
else:
    print("No dataset loaded. Build or load heatstake_dataset.pt first.")


Epoch 001 | Train Loss: 11.9014 Acc: 0.500 | Val Loss: 10.6723 Acc: 0.516
Epoch 002 | Train Loss: 10.9981 Acc: 0.520 | Val Loss: 10.6824 Acc: 0.516
Epoch 003 | Train Loss: 10.5924 Acc: 0.541 | Val Loss: 10.7756 Acc: 0.516
Epoch 004 | Train Loss: 10.7622 Acc: 0.520 | Val Loss: 10.6867 Acc: 0.516
Epoch 005 | Train Loss: 10.5378 Acc: 0.581 | Val Loss: 10.6342 Acc: 0.516
Epoch 006 | Train Loss: 10.5434 Acc: 0.553 | Val Loss: 10.6443 Acc: 0.516
Epoch 007 | Train Loss: 10.6237 Acc: 0.549 | Val Loss: 10.6476 Acc: 0.516
Epoch 008 | Train Loss: 10.6386 Acc: 0.516 | Val Loss: 10.6375 Acc: 0.516
Epoch 009 | Train Loss: 10.6017 Acc: 0.500 | Val Loss: 10.6684 Acc: 0.516
Epoch 010 | Train Loss: 10.4195 Acc: 0.577 | Val Loss: 10.6237 Acc: 0.516
Epoch 011 | Train Loss: 10.3560 Acc: 0.598 | Val Loss: 10.6285 Acc: 0.516
Epoch 012 | Train Loss: 10.3250 Acc: 0.602 | Val Loss: 10.4921 Acc: 0.516
Epoch 013 | Train Loss: 10.3109 Acc: 0.618 | Val Loss: 10.3939 Acc: 0.516
Epoch 014 | Train Loss: 10.2364 Acc: 0

KeyboardInterrupt: 

In [8]:
torch.save(model.state_dict(), BASE_DIR / "GCN" / "heatstake_classifier.pth")