
# **Graph-Based Neural Network, Deep Learning Approach for Predicting Milling Process Times**

## Setup and Imports


In [None]:
!pip install -q torch-scatter -f https://data.pyg.org/whl/torch-2.6.0+cu124.html  #cu124
!pip install -q torch-sparse -f https://data.pyg.org/whl/torch-2.6.0+cu124.html  #cu124
!pip install -q torch-cluster -f https://data.pyg.org/whl/torch-2.6.0+cu124.html
!pip install -q pyg-lib -f https://data.pyg.org/whl/torch-2.6.0+cu124.html
!pip install -q torch-geometric
!pip install -q trimesh
!pip install -q fast_simplification
!pip install -q wandb
!pip install reportlab
!pip install optuna plotly
!pip install pymeshlab
!pip install optuna-dashboard

In [2]:
import torch
import torch.nn.functional as F
from torch_geometric.nn import global_mean_pool, FeaStConv
import torch.nn as nn
from torch_geometric.data import Data, Dataset, Batch
from torch_geometric.transforms import FaceToEdge
from torch_geometric.loader import DataLoader
from torch_geometric.utils import from_trimesh
import torch_geometric.transforms as T
import trimesh
import os
import random
import pandas as pd
from tqdm import tqdm
import time
import copy
from datetime import datetime
import numpy as np
import pickle
import seaborn as sns
import matplotlib.pyplot as plt
import matplotlib
from sklearn.metrics import r2_score, mean_absolute_error
from sklearn.model_selection import KFold
import csv
from matplotlib.colors import Normalize
import wandb
import json
from torch_geometric.loader.dataloader import Collater
from sklearn.ensemble import RandomForestRegressor
from sklearn.svm import SVR
from sklearn.tree import DecisionTreeRegressor
from sklearn.preprocessing import OneHotEncoder, StandardScaler
from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score
import pymeshlab

## Dataset definition

In [3]:
def compute_shape_features_milling(mesh_obj, theta_deg=45.0):
    """
    Shape features designed for 3-axis milling operations, robust to non-watertight meshes.
    - Evaluates steepness with respect to ±X, ±Y, ±Z and aggregates best/worst axis.
    - Adds bbox, aspect ratios, SA/V, sphericity (IQ), convexity, sharp edges, inertia radii, and relative COM.
    - Falls back to convex hull volume when the mesh volume is unreliable.
    """
    m = mesh_obj.copy()

    # Minimal cleaning for robustness
    try:
        m.remove_degenerate_faces()
        m.remove_duplicate_faces()
        m.remove_unreferenced_vertices()
        m.remove_infinite_values()
        m.rezero()
    except Exception:
        pass

    # --- Volume and area (fallback to convex hull if needed) ---
    used_hull = False
    try:
        volume = float(m.volume)
        if not np.isfinite(volume) or volume <= 0:
            raise ValueError
    except Exception:
        try:
            volume = float(m.convex_hull.volume)
            used_hull = True
        except Exception:
            volume = 0.0
            used_hull = True

    area = float(m.area) if np.isfinite(m.area) else 0.0

    # --- Ratios and shape indices ---
    sa_to_vol = (area / volume) if volume > 0 else 0.0
    # Isoperimetric Quotient (1 for a sphere, 0..1)
    isoperimetric_q = (36.0 * np.pi * (volume ** 2) / (area ** 3)) if (area > 0 and volume > 0) else 0.0
    # Your "compactness" was the inverse of IQ; we keep it explicitly
    compactness_inv = (area ** 3) / (36.0 * np.pi * (volume ** 2)) if (area > 0 and volume > 0) else 0.0

    # --- Bounding box and aspect ratios ---
    try:
        l, w, h = [float(x) for x in m.extents]  # box length, width, height
    except Exception:
        l, w, h = 0.0, 0.0, 0.0
    ext = np.array([l, w, h], dtype=float)
    max_e = float(np.max(ext)) if np.all(np.isfinite(ext)) else 0.0
    min_e = float(np.min(ext)) if np.all(np.isfinite(ext)) else 0.0
    med_e = float(np.median(ext)) if np.all(np.isfinite(ext)) else 0.0
    aspect_max_min = (max_e / (min_e + 1e-9)) if min_e > 0 else 0.0
    aspect_max_med = (max_e / (med_e + 1e-9)) if med_e > 0 else 0.0

    # --- Inertia and radii of gyration (more interpretable) ---
    try:
        I = np.array(m.principal_inertia_components, dtype=float)
        I.sort()  # ascending order
    except Exception:
        I = np.zeros(3, dtype=float)
    mass = volume  # unit density → mass ~ volume
    radii_g = np.sqrt(np.clip(I / (mass + 1e-12), 0.0, np.inf)) if mass > 0 else np.zeros(3, dtype=float)

    # --- Center of mass relative to bounding box (0..1) ---
    try:
        com = np.array(m.center_mass, dtype=float)
        bbox_min, bbox_max = m.bounds
        bbox_size = np.maximum(bbox_max - bbox_min, 1e-9)
        com_rel = (com - bbox_min) / bbox_size
    except Exception:
        com_rel = np.zeros(3, dtype=float)

    # --- Convexity: ratio between mesh volume and convex hull volume ---
    try:
        hull_vol = float(m.convex_hull.volume)
        convexity = (volume / hull_vol) if hull_vol > 0 else 1.0
    except Exception:
        convexity = 1.0

    # --- Face normals and areas (for steepness & sharp edges) ---
    try:
        normals = m.face_normals  # (F, 3)
        areas = m.area_faces      # (F,)
        total_area = float(np.sum(areas))
    except Exception:
        normals = np.zeros((0, 3), dtype=float)
        areas = np.zeros((0,), dtype=float)
        total_area = 0.0

    def steep_fraction(direction, theta_rad):
        if total_area <= 0 or normals.shape[0] == 0:
            return 0.0
        c = normals @ direction  # cosine of the angle with the direction
        mask = c < np.cos(theta_rad)  # angle > theta
        return float(np.sum(areas[mask])) / total_area

    theta = np.deg2rad(theta_deg)
    axes = np.eye(3, dtype=float)
    dirs = np.vstack([axes, -axes])  # ±X, ±Y, ±Z
    steep_fracs = np.array([steep_fraction(d, theta) for d in dirs]) if total_area > 0 else np.zeros(6, dtype=float)
    steep_best = float(np.min(steep_fracs)) if steep_fracs.size else 0.0
    steep_worst = float(np.max(steep_fracs)) if steep_fracs.size else 0.0

    # Keep the equivalent of "overhang" vs +Z for compatibility with the old model
    steep_posZ = steep_fracs[2] if steep_fracs.size >= 3 else 0.0

    # --- Sharp edges: fraction of dihedral angles below a threshold (e.g., 30°) ---
    try:
        dihed = m.face_adjacency_angles  # radians
        sharp_fraction = float(np.mean(dihed < np.deg2rad(30.0))) if dihed is not None and len(dihed) > 0 else 0.0
    except Exception:
        sharp_fraction = 0.0

    # Pack everything into a flat dictionary (no nested arrays)
    return {
        # base
        'volume': float(volume),
        'surface_area': float(area),
        'sa_to_vol': float(sa_to_vol),
        'isoperimetric_q': float(isoperimetric_q),
        'compactness_inv': float(compactness_inv),

        # bbox
        'bbox_len': float(l),
        'bbox_wid': float(w),
        'bbox_hei': float(h),
        'aspect_max_min': float(aspect_max_min),
        'aspect_max_med': float(aspect_max_med),

        # inertia and radii
        'principal_inertia_0': float(I[0]),
        'principal_inertia_1': float(I[1]),
        'principal_inertia_2': float(I[2]),
        'radius_gyr_0': float(radii_g[0]),
        'radius_gyr_1': float(radii_g[1]),
        'radius_gyr_2': float(radii_g[2]),

        # relative COM (0..1)
        'center_mass_x_rel': float(com_rel[0]),
        'center_mass_y_rel': float(com_rel[1]),
        'center_mass_z_rel': float(com_rel[2]),

        # convexity
        'convexity': float(convexity),

        # milling-oriented steepness
        'steep_frac_best_axis': float(steep_best),
        'steep_frac_worst_axis': float(steep_worst),
        'steep_frac_posZ': float(steep_posZ),  # compatibility with original "overhang"

        # sharp edges
        'sharp_edge_fraction': float(sharp_fraction),

        # diagnostics
        'used_convex_hull_volume': bool(used_hull),
    }


def normalize_only(mesh):
    """
    Normalize a trimesh mesh:
    - Center the model at the origin
    - Scale vertices so that the maximum radius is 1
    Args:
        mesh (trimesh.Trimesh)
    Returns:
        (mesh_normalized, scale_factor)
    """
    m = mesh.copy()

    # Center vertices at the origin
    centroid = m.vertices.mean(axis=0)
    m.vertices -= centroid

    # Scale to fit within the unit sphere
    max_dist = np.max(np.linalg.norm(m.vertices, axis=1))
    if max_dist > 0:
        m.vertices /= max_dist

    return m, max_dist


def z_score_norm_train(labels_df, target_cols=None, exclude_cols=None):
    """
    Z-score normalize all numeric columns in the DataFrame,
    excluding target or specified columns.
    """
    df = labels_df.copy()
    stats = {}

    if exclude_cols is None:
        exclude_cols = []
    if target_cols is None:
        target_cols = []

    for col in df.columns:
        if pd.api.types.is_numeric_dtype(df[col]) and col not in exclude_cols + target_cols:
            stats[col] = {}
            mean_val = df[col].mean()
            std_val = df[col].std()
            std_val = std_val if std_val > 0 else 1.0
            df[col] = (df[col] - mean_val) / std_val
            stats[col]['mean'] = mean_val
            stats[col]['std'] = std_val

    return df, stats


#train_df, stats = z_score_norm_train(csv_train, target_cols=["tempo_minuti"])
# test_df = z_score_norm_test(csv_test, stats)

def z_score_norm_test(test_data, train_stats):
    """
    Normalize test data using the training statistics.
    """
    df = test_data.copy()
    for col, st in train_stats.items():
        if col in df.columns:
            mean_val, std_val = st['mean'], st['std']
            df[col] = (df[col] - mean_val) / std_val
    return df

In [4]:
class STLDataset(Dataset):
    def __init__(self, root, printcfg_labels_path, raw_dir=None, raw_simplified_dir=None,
                 processed_dir=None, transform=None, pre_transform=None, pre_filter=None,
                 raw_extension='.stl', dataset_norm_stats=None,
                 exclude_features=None, exclude_pieces=None):

        self.exclude_features = exclude_features or []
        self.exclude_pieces = set(str(p) for p in (exclude_pieces or []))  # list of pieces to exclude

        # Load CSV
        self.print_cfg_labels = pd.read_csv(printcfg_labels_path, sep=";")

        # Encode categorical columns
        if 'pezzo_id' in self.print_cfg_labels.columns:
            self.print_cfg_labels['pezzo_id'] = self.print_cfg_labels['pezzo_id'].astype(str)

        # 🔹 Remove excluded pieces
        if len(self.exclude_pieces) > 0:
            self.print_cfg_labels = self.print_cfg_labels[~self.print_cfg_labels['pezzo_id'].isin(self.exclude_pieces)]

        # 🔹 Embedding for process ID
        if 'id operazione' in self.print_cfg_labels.columns:
            original_cats = self.print_cfg_labels['id operazione'].astype('category')
            self.proc_id_to_name = dict(enumerate(original_cats.cat.categories))
            self.print_cfg_labels['id operazione'] = original_cats.cat.codes
            self.n_processi = len(self.proc_id_to_name)

        # 🔹 Normalization of tabular features (excluding time and process)
        exclude_cols = ["time", "id operazione"]
        if dataset_norm_stats is None:
            self.print_cfg_labels, self.dataset_norm_stats = z_score_norm_train(
                self.print_cfg_labels,
                target_cols=["time"],
                exclude_cols=exclude_cols
            )
            self.training_dataset = True
        else:
            self.print_cfg_labels = z_score_norm_test(
                self.print_cfg_labels,
                dataset_norm_stats,
                exclude_cols=exclude_cols
            )
            self.dataset_norm_stats = dataset_norm_stats
            self.training_dataset = False

        # 🔹 List of all available mesh features
        self.all_features = [
            'volume', 'surface_area', 'sa_to_vol', 'isoperimetric_q', 'compactness_inv',
            'bbox_len', 'bbox_wid', 'bbox_hei', 'aspect_max_min', 'aspect_max_med',
            'principal_inertia_0', 'principal_inertia_1', 'principal_inertia_2',
            'radius_gyr_0', 'radius_gyr_1', 'radius_gyr_2',
            'center_mass_x_rel', 'center_mass_y_rel', 'center_mass_z_rel',
            'convexity', 'steep_frac_best_axis', 'steep_frac_worst_axis', 'steep_frac_posZ',
            'sharp_edge_fraction', 'used_convex_hull_volume', 'scaling_factor',
            'mrv', 'removal_ratio', 'sa_to_vol_finale'
        ]

        self.active_features = [f for f in self.all_features if f not in self.exclude_features]

        # Directories
        self._custom_raw_dir = raw_dir
        self._custom_raw_simplified_dir = raw_simplified_dir
        self._custom_processed_dir = processed_dir
        self.raw_extension = raw_extension
        super().__init__(root, transform, pre_transform, pre_filter)

    @property
    def raw_dir(self):
        return self._custom_raw_dir or os.path.join(self.root, 'raw')

    @property
    def raw_file_names(self):
        filenames = self.print_cfg_labels['pezzo_id'].unique()
        stl_filenames = [f"{fn}.stl" if not fn.endswith('.stl') else fn for fn in filenames]
        return sorted(stl_filenames)

    @property
    def processed_file_names(self):
        processed_files_path = os.path.join(self.processed_dir, "processed_files.txt")
        if os.path.exists(processed_files_path):
            with open(processed_files_path, 'r') as f:
                processed_files = [line.strip() for line in f.readlines()]
            return [f'{os.path.splitext(filename)[0]}.pt' for filename in processed_files]
        else:
            return [f'{os.path.splitext(filename)[0]}.pt' for filename in self.raw_file_names]

    @property
    def raw_simplified_dir(self):
        return self._custom_raw_simplified_dir or self.raw_dir

    @property
    def processed_dir(self):
        return self._custom_processed_dir or os.path.join(self.root, 'processed')

    def process(self):
        print('Processing...')
        processed_files = []
        data_list = []
        all_features_dict = {feat: [] for feat in self.active_features}

        for stl_path in tqdm(self.raw_file_names, desc="Processing STL files"):
            try:
                # Original STL (for normalization and volumes)
                raw_path = os.path.join(self.raw_dir, stl_path)
                trimesh_raw = trimesh.load_mesh(raw_path)
                result = normalize_only(trimesh_raw)
                if result is None or result[0] is None:
                    print(f"Skip {stl_path}: normalize_only returned None")
                    continue
                trimesh_raw_norm, scaling_factor = result

                mesh_feat = compute_shape_features_milling(trimesh_raw_norm)
                mesh_feat['scaling_factor'] = float(scaling_factor or 0.0)

                # 🔹 Raw volume from CSV
                pezzo_id = os.path.splitext(os.path.basename(stl_path))[0]
                row = self.print_cfg_labels[self.print_cfg_labels['pezzo_id'] == pezzo_id].iloc[0]
                vol_grezzo = float(row['volume grezzo [mm3]'])

                # 🔹 Compute new features
                vol_finale = mesh_feat.get("volume", 0.0)
                sa_finale = mesh_feat.get("surface_area", 0.0)
                mrv = max(0.0, vol_grezzo - vol_finale)
                removal_ratio = mrv / (vol_grezzo + 1e-9)
                sa_to_vol_finale = sa_finale / (vol_finale + 1e-9)

                mesh_feat["mrv"] = mrv
                mesh_feat["removal_ratio"] = removal_ratio
                mesh_feat["sa_to_vol_finale"] = sa_to_vol_finale

                # Save for global normalization
                for feat in self.active_features:
                    all_features_dict[feat].append(mesh_feat.get(feat, 0.0))

                # 🔹 Load simplified mesh
                simp_path = os.path.join(self.raw_simplified_dir, stl_path)
                trimesh_simplified = trimesh.load_mesh(simp_path)
                trimesh_simplified.remove_degenerate_faces()
                trimesh_simplified.remove_duplicate_faces()
                trimesh_simplified.remove_unreferenced_vertices()

                data = from_trimesh(trimesh_simplified)
                if data is None or data.x is None:
                    verts = trimesh_simplified.vertices
                    faces = trimesh_simplified.faces
                    edges = np.vstack([faces[:, [0, 1]], faces[:, [1, 2]], faces[:, [2, 0]]])
                    edges = np.unique(np.sort(edges, axis=1), axis=0)
                    edge_index = torch.tensor(edges.T, dtype=torch.long)
                    x = torch.tensor(verts, dtype=torch.float)
                    from torch_geometric.data import Data
                    data = Data(x=x, edge_index=edge_index, face=torch.tensor(faces.T, dtype=torch.long))

                data.pos = data.x.clone()
                data.face_normals = torch.tensor(trimesh_simplified.face_normals, dtype=torch.float)
                data.mesh_feat = mesh_feat

                filename = os.path.basename(stl_path)
                data.filename = os.path.splitext(filename)[0]
                data_list.append(data)
                processed_files.append(filename)
            except Exception as e:
                with open("processing_errors.log", 'a') as f:
                    f.write(f"Error processing {stl_path}: {str(e)}\n")
                continue

        # 🔹 Compute normalization statistics
        feature_stats = {}
        for feature in self.active_features:
            values = np.array(all_features_dict[feature])
            feature_stats[feature] = {'mean': values.mean(), 'std': values.std() if values.std() > 0 else 1.0}

        if self.training_dataset:
            self.dataset_norm_stats.update(feature_stats)

        # 🔹 Normalize features
        for idx, data in enumerate(data_list):
            for feature in self.active_features:
                data.mesh_feat[feature] = (
                    data.mesh_feat.get(feature, 0.0) - self.dataset_norm_stats[feature]['mean']
                ) / self.dataset_norm_stats[feature]['std']

            mesh_feat_values = [data.mesh_feat[f] for f in self.active_features]
            data.mesh_feat = torch.tensor(mesh_feat_values, dtype=torch.float32)
            torch.save(data, os.path.join(self.processed_dir, f'{data.filename}.pt'))

    def len(self):
        return len(self.print_cfg_labels)

    def get(self, idx):
        row = self.print_cfg_labels.iloc[idx]
        filename = row['pezzo_id']

        # 🔹 Raw volume
        volume_grezzo = torch.tensor([float(row['volume grezzo [mm3]'])], dtype=torch.float32)

        # 🔹 Process as integer index (embedding)
        proc_id = torch.tensor(int(row['id operazione']), dtype=torch.long)

        # 🔹 Target (log time + z-score per process)
        t_log = np.log1p(float(row["time"]))
        labels = torch.tensor([t_log], dtype=torch.float32)

        # 🔹 Load corresponding graph
        data_path = os.path.join(self.processed_dir, f"{os.path.splitext(filename)[0]}.pt")
        data = torch.load(data_path, weights_only=False)

        return data, volume_grezzo, proc_id, labels

## Simplification of STL

In [None]:
base_path = "/content/gdrive/MyDrive/GNN_Def"

root_dir = base_path
raw_dir = os.path.join(base_path, "raw")
simplified_dir = os.path.join(base_path, "raw_simplified2")

target_faces = 1024               # Desired number of faces in the simplified mesh

# Create output directory if it doesn't exist
os.makedirs(simplified_dir, exist_ok=True)

# Simplification loop
for filename in os.listdir(raw_dir):
    if not filename.lower().endswith(".stl"):
        continue  # Skip non-STL files

    input_path = os.path.join(raw_dir, filename)
    output_path = os.path.join(simplified_dir, filename)

    try:
        # --- Load the mesh ---
        ms = pymeshlab.MeshSet()
        ms.load_new_mesh(input_path)

        # --- Simplify using quadratic edge collapse ---
        ms.meshing_decimation_quadric_edge_collapse(
            targetfacenum=target_faces,
            preservenormal=True,
            preservetopology=True,
            qualitythr=0.3
        )

        # --- Save the simplified mesh ---
        ms.save_current_mesh(output_path)
        print(f"Simplified: {filename} → {target_faces} faces (saved to {output_path})")

    except Exception as e:
        print(f"Error simplifying {filename}: {e}")

print("\n Simplification complete!")

## Path and Dataset Creation

In [12]:
# ===============================================================================================
# Files and directory || change if you have different path || here google Colab + google Drive
# ===============================================================================================

#from google.colab import drive
#drive.mount('/content/gdrive', force_remount=True)

base_path = "/GNN_Milling"

root_dir = base_path
raw_dir = os.path.join(base_path, "raw")
raw_simplified_dir = os.path.join(base_path, "raw_simplified")
processed_dir = os.path.join(base_path, "processed")
labels_path = os.path.join(base_path, "dataset.csv")
save_dir = os.path.join(base_path, "results")

In [6]:
orig_dataset = STLDataset(
    root=base_path,
    raw_dir=raw_dir,
    raw_simplified_dir=raw_simplified_dir,
    processed_dir=processed_dir,
    printcfg_labels_path=labels_path,
    exclude_features=['volume', 'bbox_len', 'bbox_wid', 'bbox_hei', 'aspect_max_min', 'aspect_max_med',
            'principal_inertia_0', 'principal_inertia_1', 'principal_inertia_2',
            'center_mass_x_rel', 'center_mass_y_rel', 'center_mass_z_rel',
            'steep_frac_best_axis', 'steep_frac_worst_axis', 'steep_frac_posZ',
            'sharp_edge_fraction', 'used_convex_hull_volume', 'scaling_factor'],
    exclude_pieces= ['4851D5091']
)

## Neural Network

In [7]:
class FeaStNet_TimePred(nn.Module):
    def __init__(self, in_channels=3, out_channels=1, heads=11,
                 hidden_dim=128, dropout=0.2,
                 n_processi=15, proc_emb_dim=8):
        super(FeaStNet_TimePred, self).__init__()

        # --- GNN ---
        self.fc0 = nn.Linear(in_channels, 16)
        self.conv1 = FeaStConv(16, 32, heads=heads)
        self.conv2 = FeaStConv(32, 64, heads=heads)
        self.conv3 = FeaStConv(64, 128, heads=heads)

        # --- Process embedding ---
        self.proc_emb = nn.Embedding(n_processi, proc_emb_dim)

        # --- Tabular encoder (mesh + volume + process) ---
        self.mesh_feat_encoder = nn.Sequential(
            nn.Linear(N_MESH_FEAT + 1 + proc_emb_dim, hidden_dim),
            nn.LayerNorm(hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.1)
        )

        # --- Fusion + regression head ---
        self.norm_concat = nn.LayerNorm(128 + hidden_dim)
        self.fc1 = nn.Linear(128 + hidden_dim, 256)
        self.dropout = nn.Dropout(dropout)
        self.fc2 = nn.Linear(256, out_channels)

    def forward(self, pos, edge_index, batch, mesh_feat, volume_grezzo, proc_id):
        # --- GNN embedding ---
        x = F.elu(self.fc0(pos))
        x = F.elu(self.conv1(x, edge_index))
        x = F.elu(self.conv2(x, edge_index))
        x = F.elu(self.conv3(x, edge_index))
        x = global_mean_pool(x, batch)  # (B, 128)

        B = x.size(0)

        # --- Fix tensor shapes ---
        if mesh_feat.dim() == 1:
            mesh_feat = mesh_feat.view(B, -1)
        elif mesh_feat.size(0) != B:
            mesh_feat = mesh_feat.view(B, -1)

        if volume_grezzo.dim() == 1:
            volume_grezzo = volume_grezzo.unsqueeze(1)

        # --- Process embedding ---
        proc_vec = self.proc_emb(proc_id)  # (B, proc_emb_dim)

        # --- Concatenate tabular features ---
        extra_inputs = torch.cat([mesh_feat, volume_grezzo, proc_vec], dim=-1)
        feat_encoded = self.mesh_feat_encoder(extra_inputs)

        # --- Fusion with GNN output ---
        out = torch.cat([x, feat_encoded], dim=-1)
        out = self.norm_concat(out)
        out = F.elu(self.fc1(out))
        out = self.dropout(out)
        out = self.fc2(out)

        return out

## LOPO

In [8]:
# ==============================
# GLOBAL PARAMETERS
# ==============================
N_MESH_FEAT = len(orig_dataset.active_features)
RESULTS_DIR = save_dir
os.makedirs(RESULTS_DIR, exist_ok=True)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")

# Seed for reproducibility
def set_seed(seed=42):
    os.environ["PYTHONHASHSEED"] = str(seed)
    random.seed(seed); np.random.seed(seed)
    torch.manual_seed(seed); torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
set_seed(42)

# Fixed generator for DataLoader
g = torch.Generator()
g.manual_seed(42)

# ==============================
# METRICS
# ==============================
def compute_metrics(preds, labels):
    preds = preds.detach().cpu().numpy()
    labels = labels.detach().cpu().numpy()
    mse = np.mean((preds - labels) ** 2)
    rmse = np.sqrt(mse)
    mae = mean_absolute_error(labels, preds)
    mape = np.mean(np.abs((labels - preds) / (labels + 1e-8))) * 100
    wape = (np.sum(np.abs(labels - preds)) / (np.sum(np.abs(labels)) + 1e-8)) * 100
    r2 = r2_score(labels, preds)
    return mse, rmse, mae, mape, wape, r2

# ==============================
# LABEL NORMALIZATION BY PROCESS
# ==============================
def build_proc_stats(dataset):
    """Compute mean/std of log labels per process ONLY on training set"""
    proc_stats = {}
    for _, _, proc_id, labels in dataset:
        val = labels.item()  # log
        pid = int(proc_id.item())
        if pid not in proc_stats:
            proc_stats[pid] = []
        proc_stats[pid].append(val)
    for pid in proc_stats:
        arr = np.array(proc_stats[pid])
        proc_stats[pid] = {
            "mean": arr.mean(),
            "std": arr.std() if arr.std() > 1e-8 else 1.0
        }
    return proc_stats

def normalize_labels(labels, proc_id, proc_stats):
    mean = torch.tensor([proc_stats[int(pid.item())]["mean"] for pid in proc_id],
                        device=labels.device, dtype=torch.float32).view(-1,1)
    std  = torch.tensor([proc_stats[int(pid.item())]["std"] for pid in proc_id],
                        device=labels.device, dtype=torch.float32).view(-1,1)
    return (labels - mean) / std

def denormalize_preds(preds, proc_id, proc_stats):
    mean = torch.tensor([proc_stats[int(pid)]["mean"] for pid in proc_id],
                        device=preds.device, dtype=torch.float32).view(-1,1)
    std  = torch.tensor([proc_stats[int(pid)]["std"] for pid in proc_id],
                        device=preds.device, dtype=torch.float32).view(-1,1)
    return preds * std + mean

# ==============================
# AUXILIARY FUNCTIONS
# ==============================
def _unpack_batch(batch):
    data, volume_grezzo, proc_id, labels = batch
    if getattr(data, 'pos', None) is None and getattr(data, 'x', None) is not None:
        data.pos = data.x
    return data, volume_grezzo, proc_id, labels

# ==============================
# TRAIN + TEST
# ==============================
def train(model, loader, optimizer, criterion, proc_stats):
    model.train()
    total_loss = 0.0
    for batch in loader:
        data, volume_grezzo, proc_id, labels = _unpack_batch(batch)
        data, volume_grezzo, proc_id, labels = (
            data.to(device), volume_grezzo.to(device),
            proc_id.to(device), labels.to(device).view(-1, 1)  # log
        )
        labels_norm = normalize_labels(labels, proc_id, proc_stats)

        optimizer.zero_grad()
        out = model(data.pos, data.edge_index, data.batch,
                    data.mesh_feat, volume_grezzo, proc_id)
        loss = criterion(out, labels_norm)
        torch.nn.utils.clip_grad_norm_(model.parameters(), 5.0)
        loss.backward(); optimizer.step()
        total_loss += loss.item()
    return total_loss / len(loader)

def test(model, loader, criterion, proc_stats):
    model.eval()
    all_preds, all_labels, all_procs = [], [], []
    with torch.no_grad():
        for batch in loader:
            data, volume_grezzo, proc_id, labels = _unpack_batch(batch)
            data, volume_grezzo, proc_id, labels = (
                data.to(device), volume_grezzo.to(device),
                proc_id.to(device), labels.to(device).view(-1, 1)  # log
            )
            preds_norm = model(data.pos, data.edge_index, data.batch,
                               data.mesh_feat, volume_grezzo, proc_id)

            preds_log = denormalize_preds(preds_norm, proc_id.cpu(), proc_stats)

            all_preds.append(preds_log.cpu())
            all_labels.append(labels.cpu())
            all_procs.append(proc_id.cpu())

    preds_log = torch.cat(all_preds, dim=0).numpy()
    labels_log = torch.cat(all_labels, dim=0).numpy()
    procs = torch.cat(all_procs, dim=0).numpy()

    # 🔹 from log → real scale
    preds = np.expm1(preds_log)
    labels = np.expm1(labels_log)

    mse = np.mean((preds - labels) ** 2)
    rmse = np.sqrt(mse)
    mae = mean_absolute_error(labels, preds)
    mape = np.mean(np.abs((labels - preds) / (labels + 1e-8))) * 100
    wape = (np.sum(np.abs(labels - preds)) / (np.sum(np.abs(labels)) + 1e-8)) * 100
    r2 = r2_score(labels, preds)
    return mse, rmse, mae, mape, wape, r2, preds, labels, procs

def custom_collate(batch):
    datas, vols, procs, labels = zip(*batch)
    batch_data = Collater(dataset=None, follow_batch=[])(list(datas))
    vols   = torch.cat(vols, dim=0).float()
    procs  = torch.cat(procs, dim=0).long()
    labels = torch.cat(labels, dim=0).float().view(-1, 1)  # log
    return batch_data, vols, procs, labels

# ==============================
# TRAIN + TEST ON ONE PIECE (SWA) + PLOT
# ==============================

swa_start_perc = 0.8
swa_end_perc = 1.0
epochs = 200

swa_start = int(swa_start_perc * epochs)
swa_end = int(swa_end_perc * epochs)

def run_one_piece(pezzo_escluso, orig_dataset, results_all, n_runs=3, swa_start=swa_start, swa_end=swa_end):
    dataset_holdout, dataset_train = [], []
    for i in range(len(orig_dataset)):
        data, volume_grezzo, proc_id, labels = orig_dataset[i]
        if data.filename == pezzo_escluso:
            dataset_holdout.append((data, volume_grezzo, proc_id, labels))
        else:
            dataset_train.append((data, volume_grezzo, proc_id, labels))

    # Compute mean/std per process ONLY on train set
    proc_stats = build_proc_stats(dataset_train)

    train_loader = DataLoader(dataset_train, batch_size=32, shuffle=True,
                              collate_fn=custom_collate, drop_last=False, generator=g)
    holdout_loader = DataLoader(dataset_holdout, batch_size=1, shuffle=False,
                                collate_fn=custom_collate, generator=g)

    config = {"hidden_dim": 128, "dropout": 0.5, "lr": 0.0005}
    all_runs_metrics = []
    rows_csv = []
    all_histories = []

    for run in range(n_runs):
        print(f"▶ Run {run+1}, piece {pezzo_escluso}, cfg={config}")

        model = FeaStNet_TimePred(
            in_channels=3, out_channels=1,
            hidden_dim=config["hidden_dim"], dropout=config["dropout"], heads=8,
            n_processi=orig_dataset.n_processi,
            proc_emb_dim=8
        ).to(device)

        optimizer = torch.optim.AdamW(model.parameters(), lr=config["lr"], weight_decay=1e-4)
        criterion = torch.nn.MSELoss()

        history_mae = []
        swa_state, swa_n = None, 0
        for epoch in range(1,201):
            train(model, train_loader, optimizer, criterion, proc_stats)
            # compute MAE on holdout at each epoch
            _, _, mae, _, _, _, _, _, _ = test(model, holdout_loader, criterion, proc_stats)
            history_mae.append(mae)

            if swa_start <= epoch <= swa_end:
                state_dict = {k: v.cpu().clone() for k, v in model.state_dict().items()}
                if swa_state is None:
                    swa_state = state_dict
                else:
                    for k in swa_state:
                        swa_state[k] += state_dict[k]
                swa_n += 1

        if swa_state is not None and swa_n > 0:
            for k in swa_state:
                swa_state[k] /= swa_n
            model.load_state_dict(swa_state, strict=True)

        mse, rmse, mae, mape, wape, r2, preds, labels, procs = test(
            model, holdout_loader, criterion, proc_stats
        )
        all_histories.append(history_mae)

        run_result = {"run": run+1, "piece": pezzo_escluso, "mse": mse, "rmse": rmse,
                      "mae": mae, "mape": mape, "wape": wape, "r2": r2}
        all_runs_metrics.append(run_result)

        # 🔹 single predictions
        for i in range(len(preds)):
            rows_csv.append({
                "piece": pezzo_escluso, "run": run+1,
                "process": orig_dataset.proc_id_to_name[int(procs[i])],
                "real": f"{labels[i][0]:.2f}",
                "predicted": f"{preds[i][0]:.2f}",
                "error": f"{(preds[i][0]-labels[i][0]):+.2f}",
                "error_%": f"{((preds[i][0]-labels[i][0])/labels[i][0]*100 if labels[i][0]!=0 else 0):+.2f}%",
            })

        # 🔹 summary row
        rows_csv.append({
            "piece": pezzo_escluso,
            "run": run+1,
            "process": "----",
            "real": "----",
            "predicted": "----",
            "error": f"MAE={mae:.2f}",
            "error_%": f"MAPE={mape:.2f}%",
            "WAPE": f"{wape:.2f}%"
        })

    # 🔹 save detailed CSV
    csv_path = os.path.join(RESULTS_DIR, f"predictions_{pezzo_escluso}_{timestamp}.csv")
    pd.DataFrame(rows_csv).to_csv(csv_path, index=False)
    print(f"📑 Predictions CSV saved: {csv_path}")
    results_all.extend(all_runs_metrics)

    # 🔹 save MAE trend plot
    plt.figure(figsize=(8,6))
    for run_idx, history in enumerate(all_histories):
        plt.plot(history, label=f"Run {run_idx+1}")
    plt.title(f"MAE Trend - Piece {pezzo_escluso}")
    plt.xlabel("Epoch")
    plt.ylabel("MAE (holdout)")
    plt.legend()
    plot_path = os.path.join(RESULTS_DIR, f"mae_plot_{pezzo_escluso}_{timestamp}.png")
    plt.savefig(plot_path)
    plt.close()
    print(f"📊 MAE plot saved: {plot_path}")

# ==============================
# LOOP OVER ALL PIECES
# ==============================
def run_all_pieces(orig_dataset):
    pezzi_unici = list({orig_dataset[i][0].filename for i in range(len(orig_dataset))})
    results_all = []
    for pezzo in pezzi_unici:
        print(f"\n Leave-One-Piece-Out: excluded {pezzo}")
        run_one_piece(pezzo, orig_dataset, results_all, n_runs=5)

    global_csv = os.path.join(RESULTS_DIR, f"global_metrics_{timestamp}.csv")
    pd.DataFrame(results_all)[["piece", "run", "mae", "mape", "wape", "r2"]].to_csv(global_csv, index=False)
    print(f"Global CSV saved at {global_csv}")


def run_selected_pieces(orig_dataset, selected_pieces):
    results_all = []
    for piece in selected_pieces:
        print(f"\n Testing selected piece: {piece}")
        run_one_piece(piece, orig_dataset, results_all, n_runs=2)  # can increase to 5 if desired
    # save only selected pieces results
    timestamp_sel = datetime.now().strftime("%Y%m%d_%H%M%S")
    csv_path = os.path.join(RESULTS_DIR, f"selected_metrics_{timestamp_sel}.csv")
    pd.DataFrame(results_all)[["piece", "run", "mae", "mape", "wape", "r2"]].to_csv(csv_path, index=False)
    print(f" Selected metrics saved at {csv_path}")

In [None]:
run_all_pieces(orig_dataset)

## Classical Evaluation

In [None]:
# ==============================
# CSV Export for Classical Models
# ==============================

def _to_float(x):
    """Safely convert a tensor or scalar to a float."""
    try:
        return float(x.item())
    except AttributeError:
        return float(x)

@torch.no_grad()
def export_mesh_features(
    dataset,
    output_path,
    labels_in_log1p=True,
    sort_output=True,
    print_every=10
):
    """
    Export mesh-level averaged features with process names and target times.
    Each row corresponds to one dataset sample.
    """
    rows = []
    feat_dim = None
    n = len(dataset)

    for idx in range(n):
        data, volume, proc_id, label = dataset[idx]

        # --- Mesh features ---
        mesh = getattr(data, "mesh_feat", None)
        if mesh is None:
            raise ValueError("data.mesh_feat not found in dataset element.")

        # Ensure tensor shape
        if mesh.dim() == 2:
            mesh_mean = mesh.mean(dim=0).cpu().numpy().ravel()
        elif mesh.dim() == 1:
            mesh_mean = mesh.cpu().numpy().ravel()
        else:
            raise ValueError(f"Unexpected mesh_feat shape: {tuple(mesh.shape)}")

        # Ensure consistent feature dimension
        if feat_dim is None:
            feat_dim = mesh_mean.shape[0]
        elif mesh_mean.shape[0] != feat_dim:
            raise ValueError(
                f"Inconsistent mesh_feat dimension: {feat_dim} vs {mesh_mean.shape[0]} (idx={idx})"
            )

        # --- Tabular data ---
        vol = _to_float(volume)
        pid = int(_to_float(proc_id))
        process_name = dataset.proc_id_to_name.get(pid, f"proc_{pid}")

        if labels_in_log1p:
            time_real = float(np.expm1(_to_float(label)))  # convert from log1p scale
        else:
            time_real = _to_float(label)

        piece_name = str(getattr(data, "filename", f"piece_{idx}"))

        # --- Build row ---
        feat_dict = {f"meshfeat_{j}": mesh_mean[j] for j in range(feat_dim)}
        feat_dict.update({
            "piece": piece_name,
            "process": process_name,
            "volume": vol,
            "time_real": time_real,
        })
        rows.append(feat_dict)

        # --- Progress print ---
        if print_every and ((idx + 1) % print_every == 0 or (idx + 1) == n):
            print(f"Processed {idx+1}/{n} samples", flush=True)

    # --- Build DataFrame ---
    df = pd.DataFrame(rows)

    # --- Optional sorting ---
    if sort_output:
        df = df.sort_values(["piece", "process"], kind="mergesort").reset_index(drop=True)

    # --- Save CSV ---
    os.makedirs(os.path.dirname(output_path) or ".", exist_ok=True)
    df.to_csv(output_path, index=False)

    # --- Summary ---
    n_pieces = df["piece"].nunique()
    n_procs = df["process"].nunique()
    print(f"\n✅ Exported {len(df)} rows | {n_pieces} pieces | {n_procs} processes | {feat_dim} mesh features")
    print(f"📂 Saved to: {output_path}")


train_dataset = orig_dataset
output_path = os.path.join(base_path, "Classical_Results/csvClassical.csv")
export_mesh_features(train_dataset, output_path, labels_in_log1p=True, sort_output=True)

In [None]:
# ==============================
# CONFIG
# ==============================
INPUT_CSV = os.path.join(base_path, "Classical_Results/csvClassical.csv")      # <-- your exported file
RESULTS_DIR = os.path.join(base_path, "Classical_Results")
os.makedirs(RESULTS_DIR, exist_ok=True)
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")

# ==============================
# LOAD DATA
# ==============================
df = pd.read_csv(INPUT_CSV)
mesh_cols = [c for c in df.columns if c.startswith("meshfeat_")]
target_col = "time_real"
piece_col = "piece"
proc_col = "process"    # ✅ now using process *name*
num_cols = mesh_cols + ["volume"]

print(f"📂 Loaded dataset: {df.shape[0]} rows, {len(mesh_cols)} mesh features, {df[piece_col].nunique()} unique pieces")

# ==============================
# METRICS
# ==============================
def _safe_mape(y, yhat):
    return np.mean(np.abs((yhat - y) / (y + 1e-8))) * 100

def _wape(y, yhat):
    return np.sum(np.abs(yhat - y)) / (np.sum(np.abs(y)) + 1e-8) * 100

# ==============================
# LOOP: Leave-One-Piece-Out
# ==============================
results, all_preds = [], []
unique_pieces = df[piece_col].unique()
print(f"\n🔧 Found {len(unique_pieces)} unique pieces → LOPO setup\n")

for held_out in unique_pieces:
    print(f"🚀 Leave-One-Piece-Out: excluding '{held_out}'")

    # Split train/test
    train_df = df[df[piece_col] != held_out]
    test_df  = df[df[piece_col] == held_out]

    # --- Feature scaling ---
    scaler = StandardScaler().fit(train_df[num_cols])
    X_train = scaler.transform(train_df[num_cols])
    X_test  = scaler.transform(test_df[num_cols])

    # --- One-hot encode process names ---
    ohe = OneHotEncoder(handle_unknown="ignore", sparse_output=False).fit(train_df[[proc_col]])
    X_train_proc = ohe.transform(train_df[[proc_col]])
    X_test_proc  = ohe.transform(test_df[[proc_col]])

    # --- Combine numeric + categorical ---
    X_train = np.concatenate([X_train, X_train_proc], axis=1)
    X_test  = np.concatenate([X_test, X_test_proc], axis=1)

    y_train = train_df[target_col].values
    y_test  = test_df[target_col].values

    models = {
        "RandomForest": RandomForestRegressor(n_estimators=500, random_state=42, n_jobs=-1),
        "SVM": SVR(C=10, gamma="scale"),
        "DecisionTree": DecisionTreeRegressor(max_depth=None, random_state=42)
    }

    for name, model in models.items():
        model.fit(X_train, y_train)
        preds = model.predict(X_test)

        mse = mean_squared_error(y_test, preds)
        rmse = np.sqrt(mse)
        mae = mean_absolute_error(y_test, preds)
        mape = _safe_mape(y_test, preds)
        wape = _wape(y_test, preds)
        r2 = r2_score(y_test, preds)

        results.append({
            "piece": held_out,
            "model": name,
            "MSE": mse, "RMSE": rmse, "MAE": mae,
            "MAPE%": mape, "WAPE%": wape, "R2": r2
        })

        df_pred = pd.DataFrame({
            "piece": held_out,
            "model": name,
            "process": test_df[proc_col].values,   # ✅ process name
            "real": y_test,
            "pred": preds
        })
        df_pred["error"] = df_pred["pred"] - df_pred["real"]
        df_pred["error_%"] = np.where(df_pred["real"] != 0,
                                      df_pred["error"] / df_pred["real"] * 100,
                                      np.nan)
        all_preds.append(df_pred)

        print(f"   ▶ {name:12s} | RMSE={rmse:8.2f} | MAE={mae:8.2f} | R²={r2:6.3f}")

# ==============================
# SAVE RESULTS
# ==============================
df_metrics = pd.DataFrame(results)
df_preds = pd.concat(all_preds, ignore_index=True)

metrics_csv = os.path.join(RESULTS_DIR, f"classical_metrics_{timestamp}.csv")
preds_csv = os.path.join(RESULTS_DIR, f"classical_predictions_{timestamp}.csv")
df_metrics.to_csv(metrics_csv, index=False)
df_preds.to_csv(preds_csv, index=False)

print(f"\n✅ Metrics saved to: {metrics_csv}")
print(f"📑 Detailed predictions saved to: {preds_csv}")

# ==============================
# SUMMARY
# ==============================
summary = (
    df_metrics.groupby("model")[["RMSE", "MAE", "MAPE%", "WAPE%", "R2"]]
    .mean()
    .sort_values("R2", ascending=False)
)
print("\n=== 📊 Summary per model (average over all pieces) ===")
print(summary.round(3))

## Features Selection

In [None]:
# =====================================
# CORRELATION ANALYSIS: Global Mesh Features ↔ Time
# (Clustered correlation map without 'time' in heatmap)
# =====================================

import os
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from matplotlib.backends.backend_pdf import PdfPages

base_path = "/Users/lorenzo/Desktop/GNN_Milling"

raw_dir = os.path.join(base_path, "raw")
raw_simplified_dir = os.path.join(base_path, "raw_simplified")
processed_dir = os.path.join(base_path, "processed_FeatureSelection")
labels_path = os.path.join(base_path, "dataset.csv")
pdf_path = os.path.join(base_path, "GNNcorrelations_global.pdf")

# --- Load dataset ---
orig_dataset2 = STLDataset(
    root=base_path,
    raw_dir=raw_dir,
    raw_simplified_dir=raw_simplified_dir,
    processed_dir=processed_dir,
    printcfg_labels_path=labels_path,
    exclude_features=[],
    exclude_pieces=['4851D5091']
)

# --- Mesh features ---
mesh_features = [
    'volume', 'surface_area', 'sa_to_vol', 'isoperimetric_q', 'compactness_inv',
    'bbox_len', 'bbox_wid', 'bbox_hei', 'aspect_max_min', 'aspect_max_med',
    'principal_inertia_0', 'principal_inertia_1', 'principal_inertia_2',
    'radius_gyr_0', 'radius_gyr_1', 'radius_gyr_2',
    'center_mass_x_rel', 'center_mass_y_rel', 'center_mass_z_rel',
    'convexity', 'steep_frac_best_axis', 'steep_frac_worst_axis',
    'steep_frac_posZ', 'sharp_edge_fraction'
]

# --- Build global DataFrame ---
records = []
for i in range(len(orig_dataset2)):
    data, volume_grezzo, proc_id, labels = orig_dataset2.get(i)
    row = {"process_id": float(proc_id.item()), "time": float(labels.item())}
    for j, feat in enumerate(mesh_features):
        row[feat] = float(data.mesh_feat[j].item())
    records.append(row)

df_full = pd.DataFrame(records)

# --- Compute correlation matrix ---
corr_full = df_full[mesh_features + ["time"]].corr()

# --- Helper: classify correlation strength ---
def classify_corr(value):
    abs_val = abs(value)
    if abs_val < 0.3:
        return "Weak"
    elif abs_val < 0.6:
        return "Moderate"
    else:
        return "Strong"

# --- Extract correlations with time only ---
corr_time = corr_full["time"].drop("time").sort_values(key=lambda x: abs(x), ascending=False)
corr_table = pd.DataFrame({
    "Feature": corr_time.index,
    "Correlation with time": corr_time.values,
    "Strength": [classify_corr(v) for v in corr_time.values]
})

# --- Correlation matrix without 'time' for visualization ---
corr_no_time = corr_full.drop(index="time", columns="time")

# --- Save all results to PDF ---
with PdfPages(pdf_path) as pdf:
    # 1️⃣ Clustered heatmap of feature correlations (without 'time')
    sns.set(font_scale=0.8)
    cluster = sns.clustermap(
        corr_no_time,
        cmap="coolwarm_r",
        center=0,
        annot=True,
        fmt=".2f",
        figsize=(12, 10),
        linewidths=0.3,
        cbar_pos=(0.02, 0.8, 0.03, 0.18),
        dendrogram_ratio=(.2, .1)
    )
    pdf.savefig(cluster.fig)
    plt.close()

    # 2️⃣ Bar plot of correlation with time
    plt.figure(figsize=(10, 6))
    sns.barplot(
        x=corr_time.values,
        y=corr_time.index,
        hue=[classify_corr(v) for v in corr_time.values],
        dodge=False,
        palette="Blues_r"
    )
    plt.axvline(0, color="black", linewidth=1)
    plt.title("Feature ↔ Time Correlation (All Processes Combined)")
    plt.xlabel("Correlation coefficient")
    plt.ylabel("Feature")
    plt.legend(title="Strength")
    plt.tight_layout()
    pdf.savefig(); plt.close()

    # 3️⃣ Correlation table (values and classification)
    fig, ax = plt.subplots(figsize=(8, len(corr_table) * 0.3))
    ax.axis('off')
    table = ax.table(
        cellText=corr_table.values,
        colLabels=corr_table.columns,
        loc='center'
    )
    table.auto_set_font_size(False)
    table.set_fontsize(8)
    table.scale(1, 1.2)
    plt.title("Correlation Table (All Processes Combined)")
    pdf.savefig(); plt.close()

print(f"✅ Clustered correlation PDF saved at: {pdf_path}")