In [38]:
import os, json, argparse, random, sys
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), "..")))
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), "..")) + '/lib/')

import numpy as np
import torch
from torch import nn
from torch.utils.data import random_split
from torch_geometric.loader import DataLoader
import matplotlib.pyplot as plt

os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
os.environ["OMP_NUM_THREADS"]      = "1"
os.environ["MKL_NUM_THREADS"]      = "1"

import matplotlib
matplotlib.use("Agg")

from lib.losses_metrics import masked_l1_loss, compute_metrics
from lib.net1 import GNN_Net   

Utility functions

In [39]:

def set_seed(seed=42):
    random.seed(seed); np.random.seed(seed)
    torch.manual_seed(seed); torch.cuda.manual_seed_all(seed)

def ensure_2d(t: torch.Tensor, feat_dim: int) -> torch.Tensor:
    if t.dim() == 2: return t
    if t.dim() == 1: return t.view(-1, feat_dim)
    raise RuntimeError(f"Unsupported t.dim() = {t.dim()}")


In [40]:
#Load training, validation, and test datasets from disk
def load_lists(train_pt, test_pt, map_json, val_ratio=0.1, seed=42):
    train_val = torch.load(train_pt)
    test_list = torch.load(test_pt)

    with open(map_json, "r", encoding="utf-8") as f:
        meta = json.load(f)
    num_classes = int(meta["num_classes"])


    n_total = len(train_val); n_val = max(1, int(round(n_total * val_ratio)))
    n_train = n_total - n_val
    gen = torch.Generator().manual_seed(seed)
    train_list, val_list = random_split(train_val, [n_train, n_val], generator=gen)
    return train_list, val_list, test_list, num_classes

In [41]:
#Create DataLoaders
def make_loaders(train_list, val_list, test_list, batch_size=32, num_workers=0):
    train_loader = DataLoader(train_list, batch_size=batch_size, shuffle=True,
                              num_workers=num_workers, pin_memory=(num_workers>0))
    val_loader   = DataLoader(val_list,   batch_size=batch_size, shuffle=False,
                              num_workers=num_workers, pin_memory=(num_workers>0))
    test_loader  = DataLoader(test_list,  batch_size=batch_size, shuffle=False,
                              num_workers=num_workers, pin_memory=(num_workers>0))
    return train_loader, val_loader, test_loader

In [42]:
#Compute inverse frequency class weights
def compute_class_weights(dataset, num_classes):
    counts = np.zeros(num_classes, dtype=np.int64)
    for d in dataset:
        counts[int(d.y)] += 1
    counts = np.maximum(counts, 1)
    inv = 1.0 / counts
    w = inv * (num_classes / inv.sum())  
    return torch.tensor(w, dtype=torch.float32)

In [43]:
#Compute mean/std
def compute_label_stats(dataset):
    p2, p6, dp = [], [], []
    for d in dataset:
        t = d.t.view(-1) if d.t.dim() > 1 else d.t
        v2, v6, vd = float(t[0]), float(t[1]), float(t[2])
        if np.isfinite(v2): p2.append(v2)
        if np.isfinite(v6): p6.append(v6)
        if np.isfinite(vd): dp.append(vd)
    def _ms(arr):
        if len(arr)==0: return 0.0, 1.0
        return float(np.mean(arr)), float(np.std(arr) + 1e-8)
    m2,s2 = _ms(p2); m6,s6 = _ms(p6); md,sd = _ms(dp)
    return {'pose2':(m2,s2), 'pose6':(m6,s6), 'depth':(md,sd)}

Multitask Losses

In [44]:
def multitask_losses(outputs, batch, lambdas, stats, class_weights=None):
    
    t = ensure_2d(batch.t, 4)

    # targets
    pose_t  = t[:, 0:2]         # (B,2)
    depth_t = t[:, 2:3]         # (B,1)
    y       = batch.y

    # normalize preds & targets
    device = pose_t.device
    pose_mean = torch.tensor([stats['pose2'][0], stats['pose6'][0]], device=device).view(1,2)
    pose_std  = torch.tensor([stats['pose2'][1], stats['pose6'][1]], device=device).view(1,2)
    depth_mean= torch.tensor([stats['depth'][0]], device=device).view(1,1)
    depth_std = torch.tensor([stats['depth'][1]], device=device).view(1,1)

    pose_pred_n  = (outputs['pose']  - pose_mean)  / pose_std
    pose_tgt_n   = (pose_t           - pose_mean)  / pose_std
    depth_pred_n = (outputs['depth'] - depth_mean) / depth_std
    depth_tgt_n  = (depth_t          - depth_mean) / depth_std

    pose_loss  = masked_l1_loss(pose_pred_n,  pose_tgt_n)
    depth_loss = masked_l1_loss(depth_pred_n, depth_tgt_n)

    if class_weights is not None:
        ce = nn.CrossEntropyLoss(weight=class_weights.to(y.device))
    else:
        ce = nn.CrossEntropyLoss()
    cls_loss = ce(outputs['logits'], y)

    total = lambdas['pose']*pose_loss + lambdas['depth']*depth_loss + lambdas['cls']*cls_loss
    return total, {'pose': pose_loss.item(), 'depth': depth_loss.item(), 'cls': cls_loss.item()}

In [45]:
# Disable gradient computation
@torch.no_grad()
def evaluate(model, loader, device, lambdas, stats):
  
    model.eval()
    n_batches = 0; loss_sum = 0.0
    mae_p2 = mae_p6 = mae_depth = acc_sum = 0.0
    n_pose = n_depth = n_cls = 0

    for batch in loader:
        batch = batch.to(device)
        out = model(batch)
        loss, _ = multitask_losses(out, batch, lambdas, stats, class_weights=None)
        loss_sum += float(loss); n_batches += 1

        m = compute_metrics(out, batch)
        mae_p2   += m['mae_pose2_sum'];  mae_p6 += m['mae_pose6_sum']; n_pose  += m['n_pose']
        mae_depth+= m['mae_depth_sum'];  n_depth+= m['n_depth']
        acc_sum  += m['acc_sum'];        n_cls  += m['n_cls']

    return {
        'loss': loss_sum / max(1, n_batches),
        'mae_pose2': mae_p2 / max(1, n_pose),
        'mae_pose6': mae_p6 / max(1, n_pose),
        'mae_depth': mae_depth / max(1, n_depth),
        'acc_cls':   acc_sum / max(1, n_cls),
    }

In [46]:
#Evaluate the model batch-by-batch for test curve plotting
@torch.no_grad()
def evaluate_batches(model, loader, device, lambdas, stats):
   
    model.eval()
    loss_list, mae_p2_list, mae_p6_list, mae_depth_list = [], [], [], []

    for batch in loader:
        batch = batch.to(device)
        out = model(batch)

        # batch loss
        loss, _ = multitask_losses(out, batch, lambdas, stats, class_weights=None)
        loss_list.append(float(loss))

        # batch MAE
        m = compute_metrics(out, batch)
        mae_p2_list.append(m['mae_pose2_sum'] / max(1, m['n_pose']))
        mae_p6_list.append(m['mae_pose6_sum'] / max(1, m['n_pose']))
        mae_depth_list.append(m['mae_depth_sum'] / max(1, m['n_depth']))

    return loss_list, mae_p2_list, mae_p6_list, mae_depth_list


train loop

In [None]:
def train(args):
    set_seed(args.seed)
    train_pt = os.path.join(args.train_dir, "Train_val_data_list.pt")
    test_pt  = os.path.join(args.test_dir,  "Test_data_list.pt")
    map_json = os.path.join(args.train_dir, "material_id_to_idx.json")

    train_list, val_list, test_list, num_classes = load_lists(
        train_pt, test_pt, map_json, val_ratio=args.val_ratio, seed=args.seed
    )

  
    stats = compute_label_stats(train_list)
    
    class_weights = compute_class_weights(train_list, num_classes) if args.class_balanced else None
    if class_weights is not None:
        print("[Class Weights]", class_weights.tolist())

    # loaders
    train_loader, val_loader, test_loader = make_loaders(
        train_list, val_list, test_list, batch_size=args.batch_size, num_workers=args.num_workers
    )

    device = torch.device('cuda' if torch.cuda.is_available() and not args.cpu else 'cpu')
    print(f"Using device: {device}")

    # GNN model
    model = GNN_Net(num_classes=num_classes).to(device)

    optim = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.wd)
    sched = torch.optim.lr_scheduler.CosineAnnealingLR(optim, T_max=args.epochs)

    lambdas = {'pose': args.lambda_pose, 'depth': args.lambda_depth, 'cls': args.lambda_cls}
    best_val = float('inf'); best_state = None

    os.makedirs(args.out_dir, exist_ok=True)
    train_losses, val_losses = [], []

    for epoch in range(1, args.epochs + 1):
        model.train()
        total_loss = 0.0
        for batch in train_loader:
            batch = batch.to(device)
            optim.zero_grad()
            out = model(batch)
            loss, _ = multitask_losses(out, batch, lambdas, stats, class_weights=class_weights)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
            optim.step()
            total_loss += float(loss)
        sched.step()

        avg_train = total_loss / max(1, len(train_loader))
        val_res   = evaluate(model, val_loader, device, lambdas, stats)
        train_losses.append(avg_train); val_losses.append(val_res['loss'])


        if val_res['loss'] < best_val:
            best_val = val_res['loss']
            best_state = {k: v.cpu() for k, v in model.state_dict().items()}
            torch.save(best_state, os.path.join(args.out_dir, "best_model.pt"))

    if best_state is not None:
        model.load_state_dict(best_state)

    # TEST
    test_res = evaluate(model, test_loader, device, lambdas, stats)
    print("==== TEST ====")
    print(f"test_loss {test_res['loss']:.4f} | test_acc {test_res['acc_cls']:.3f} | "
          f"test_mae(p2,p6,depth)=({test_res['mae_pose2']:.3f},{test_res['mae_pose6']:.3f},{test_res['mae_depth']:.3f})")


    loss_list, mae_p2_list, mae_p6_list, mae_dp_list = evaluate_batches(
        model, test_loader, device, lambdas, stats
    )
    xs = np.arange(1, len(loss_list) + 1)

    # Plot loss curve
    plt.figure(figsize=(8,6))
    plt.plot(train_losses, label="Train Loss")
    plt.plot(val_losses,   label="Val Loss")
    plt.xlabel("Epoch"); plt.ylabel("Loss")
    plt.title("Training & Validation Loss")
    plt.grid(True); plt.legend()
    out_png = os.path.join(args.out_dir, "loss_curve.png")
    plt.savefig(out_png, dpi=150); plt.close()
    print(f"[Saved] {out_png}")

In [48]:
# Helper: detect Notebook
def _running_in_notebook() -> bool:
    try:
        from IPython import get_ipython
        ip = get_ipython()
        return ip is not None and hasattr(ip, "kernel")
    except Exception:
        return False


Main Function

In [None]:
if __name__ == "__main__":
    p = argparse.ArgumentParser()

    # directories
    p.add_argument("--train_dir", type=str, default=r"..\result\train")
    p.add_argument("--test_dir",  type=str, default=r"..\result\test")
    p.add_argument("--out_dir",   type=str, default=r"..\result\train")
    p.add_argument("--epochs", type=int, default=80)
    p.add_argument("--batch_size", type=int, default=256)
    p.add_argument("--lr", type=float, default=3e-4)
    p.add_argument("--wd", type=float, default=1e-4)
    p.add_argument("--grad_clip", type=float, default=1.0)
    p.add_argument("--val_ratio", type=float, default=0.1)
    p.add_argument("--num_workers", type=int, default=0)  
    p.add_argument("--cpu", action="store_true")
    p.add_argument("--seed", type=int, default=42)

    # multi-task loss weights
    p.add_argument("--lambda_pose",  type=float, default=1.0)
    p.add_argument("--lambda_depth", type=float, default=0.7)
    p.add_argument("--lambda_cls",   type=float, default=1.0)

    # class imbalance option 
    p.add_argument("--class_balanced", action="store_true")

    if _running_in_notebook():
        args = p.parse_args(args=[])
        print("[Notebook]")
    else:
        args = p.parse_args()

    train(args)



[Notebook]
Using device: cpu
==== TEST ====
test_loss 0.6251 | test_acc 0.797 | test_mae(p2,p6,depth)=(1.570,0.1057,0.072) 
[Saved] ..\result\train\loss_curve.png
