In [3]:
# %% [markdown]
# # CryoEM-MotorMetaNet: A GNN-Stacked Ensemble for 3D Bacterial Motor Detection in Cryo-ET
#
# ## Overall Goal
# Develop a state-of-the-art 3D deep learning pipeline, delivered as a single, comprehensive Jupyter Notebook (.ipynb file), to automatically detect and localize bacterial motors within 3D cryo-electron tomography (cryo-ET) reconstructions. The training data consists of folders of 2D image slices (JPG format). The core detection will be performed by one or more base 3D CNN models, and their outputs will be intelligently combined and refined using a Graph Neural Network (GNN) acting as a stacking meta-model. The system must address challenges of extremely small motor size, low SNR, and the need for precise 3D localization. Optimization will be driven by the F-beta score (beta=2.0) using Optuna, with MONAI for data augmentation and other pipeline components. A critical contingency is handling potentially 2D slice-based test data by adapting the entire 3D-trained pipeline.
#
# ## 1. Context and Motivation
# Bacterial flagellar motors are crucial nanomachines. Cryo-ET offers 3D visualization, but manual analysis is a bottleneck. This project aims to create an automated system leveraging 3D CNNs for robust feature extraction and a GNN meta-model for advanced ensembling of base model predictions. This sophisticated stacking approach is expected to enhance detection accuracy for these challenging, small targets. The entire solution is encapsulated in this notebook and designed to be adaptable for different test data formats.
#
# ## 2. Input Data Detailed Description
# *   **Training/Validation Data:**
#     *   Format: Each 3D tomogram is a sequence of 2D JPG slices.
#     *   Structure: Slices for one tomogram are in a folder (e.g., `tomo_003acc/slice_0000.jpg`, `.../slice_NNNN.jpg`).
#     *   3D Volume: Slices are ordered along the Z-axis, typically 8-bit grayscale.
# *   **Test Data Contingency:**
#     *   While training uses 3D tomograms, the test set might be individual 2D slices. The solution includes adaptation for this.
# *   **Data Characteristics:** Low Signal-to-Noise Ratio (SNR), low contrast, extremely small target size, variable appearance, potential anisotropy.
# *   **Dataset Size (Illustrative - User to define):**
#     *   Number of tomograms (train/val/test): e.g., 50 train, 10 val, 10 test.
#     *   Slices/tomogram: e.g., 100-500.
#     *   X-Y dimensions: e.g., 512x512 or 1024x1024.
#     *   Class imbalance: Motors are sparse.
# *   **Ground Truth (for 3D Tomograms - Training/Validation):**
#     *   Format: CSV file.
#     *   Assumed CSV Structure: `tomo_id,center_x,center_y,center_z` (plus potentially a `row_id`). Coordinates are voxel indices.
#     *   Target Generation: Since motors are small, ground truth for base CNNs will be small 3D spherical/cubical masks centered at these (X,Y,Z) coordinates. Bounding boxes are not assumed to be directly provided in the GT CSV but will be generated for output.
#
# ## 3. Desired Output Format
# *   **For 3D Tomogram Input:**
#     *   Output: List of 3D detections per tomogram.
#     *   Each detection: 3D Coordinates (center: X,Y,Z), 3D Bounding Box (X_min,Y_min,Z_min,X_max,Y_max,Z_max), Confidence Score.
# *   **For 2D Slice Input (Adapted):**
#     *   Output: List of 2D detections per slice.
#     *   Each detection: 2D Coordinates (center on slice: X,Y), 2D Bounding Box (on slice: X_min,Y_min,X_max,Y_max), Confidence Score.
#
# ## 4. Proposed Methodology: Stacked Ensemble with GNN Meta-Model
#
# ### 4.1. Base Model(s): 3D Convolutional Neural Networks (3D CNNs)
# *   **Purpose:** Extract features and generate initial candidate detections or probability maps from 3D tomogram patches.
# *   **Number:** N=1 base model (MONAI DynUNet) will be implemented for simplicity in this notebook, but the framework can be extended.
# *   **Training:** Trained on 3D tomogram patches to produce 3D probability maps for motor presence.
# *   **Output:** A 3D probability map indicating motor likelihood at each voxel.
#
# ### 4.2. GNN as a Stacking Meta-Model
# *   **Purpose:** Combine predictions/features from base model(s) to make a refined prediction, leveraging relational information between candidate detections.
# *   **Graph Construction:**
#     1.  **Node Definition:** Candidate motor regions identified from base model output (e.g., peaks in probability maps after Non-Maximum Suppression - NMS).
#     2.  **Node Features:** Derived from base model(s) output (e.g., probability scores, local intensity statistics).
#     3.  **Edge Definition:** Connect nodes based on spatial proximity (e.g., Euclidean distance).
# *   **GNN Architecture:** Graph Attention Network (GAT) is suitable due to its ability to weigh neighbor importance.
# *   **GNN Task:** Node Classification (motor vs. non-motor).
# *   **Training the GNN:** Two-stage: Train base model(s) first, then use their outputs on a validation set to create a dataset for training the GNN.
#
# ### 4.3. Overall Pipeline Flow
# 1.  Input 3D tomogram (or patch).
# 2.  Base 3D CNN Model processes input, generating a probability map.
# 3.  Construct a graph: nodes are candidate detections from the base model, features from base model output.
# 4.  GNN Meta-Model processes the graph, outputting refined detections.
#
# ## Notebook Structure:
# *   Part 0: Setup and Imports
# *   Part 1: Configuration and Global Parameters
# *   Part 2: Data Loading and Preprocessing (MONAI)
# *   Part 3: Base 3D CNN Model (MONAI DynUNet)
# *   Part 4: GNN Meta-Model (PyTorch Geometric - GAT)
# *   Part 5: Training Orchestration (Base Model + GNN)
# *   Part 6: Evaluation Metrics and Utilities
# *   Part 7: Inference Pipeline (3D Tomograms and 2D Slices)
# *   Part 8: Hyperparameter Optimization with Optuna
# *   Part 9: Usage Instructions
# *   Part 10: Discussion, Challenges, and Future Work

# %% [markdown]
# ## Part 0: Setup and Imports
# Ensure you have a Python environment with CUDA-enabled PyTorch.

# %%
# !pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118 # Or your CUDA version
# !pip install monai[all]==1.3.0 # Using a specific version for stability
# !pip install optuna==3.5.0
# !pip install torch_geometric pyg_lib torch_scatter torch_sparse torch_cluster torch_spline_conv -f https://data.pyg.org/whl/torch-$(python -c 'import torch; print(torch.__version__)')+cu118.html # Adjust cuXXX
# !pip install pandas scikit-image scikit-learn matplotlib opencv-python joblib tqdm

# %%
import os
import glob
import shutil
import time
import json
import random
from pathlib import Path
from collections import defaultdict
from tqdm.notebook import tqdm

import numpy as np
import pandas as pd
import cv2 # OpenCV for image loading if needed, though MONAI's LoadImage is preferred

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

import monai
from monai.config import print_config
from monai.data import (
    Dataset as MonaiDataset,
    DataLoader as MonaiDataLoader,
    ImageReader,
    PersistentDataset,
    CacheDataset,
    list_data_collate,
    pad_list_data_collate
)
from monai.transforms import (
    Compose, LoadImaged, EnsureChannelFirstd, ScaleIntensityRanged, NormalizeIntensityd,
    Spacingd, RandSpatialCropSamplesd, RandFlipd, RandRotate90d, RandAffined,
    Rand3DElasticd, RandGaussianNoised, RandAdjustContrastd, ToTensord,
    Activationsd, AsDiscreted, KeepLargestConnectedComponentd, LabelToContourD,
    CropForegroundd
)
from monai.networks.nets import DynUNet, UNet
from monai.losses import DiceCELoss, DiceLoss, FocalLoss
from monai.metrics import DiceMetric
from monai.inferers import sliding_window_inference, SimpleInferer
from monai.utils import set_determinism, first

# PyTorch Geometric
try:
    import torch_geometric
    import torch_geometric.nn as pyg_nn
    from torch_geometric.data import Data as PyGData
    from torch_geometric.loader import DataLoader as PyGDataLoader
    PYG_AVAILABLE = True
except ImportError:
    print("PyTorch Geometric not found. GNN functionality will be disabled.")
    PYG_AVAILABLE = False

import optuna
from optuna.pruners import MedianPruner
from optuna.trial import TrialState

import matplotlib.pyplot as plt
from skimage.measure import label, regionprops
from scipy.ndimage import center_of_mass, binary_dilation, binary_erosion
from scipy.spatial.distance import cdist

# Print MONAI config
print_config()

# %% [markdown]
# ## Part 1: Configuration and Global Parameters

# %%
# %% [markdown]
# ## Part 1: Configuration and Global Parameters

# %%
CONFIG = {
    # General Paths and Data Settings
    "project_name": "CryoEM_MotorMetaNet",
    "dataset_root": "./Downloads/byu-locating-bacterial-flagellar-motors-2025/train", # Path to root of dataset
    "gt_csv_path": "./Downloads/byu-locating-bacterial-flagellar-motors-2025/train_labels.csv", # Path to ground truth CSV
    "output_dir": "./output_cryoEM_MotorMetaNet", # Where to save models, logs, HPO results

    # Data characteristics (user may need to adjust)
    "num_slices_per_tomo_avg": 200, # Approximate, for patch size considerations
    "xy_dimensions_avg": (512, 512), # Approximate

    # Training Parameters - Base Model
    "base_model_name": "DynUNet",
    "patch_size_base": (96, 96, 96), # XYZ patch size for training base model
    "base_model_target_radius": 3, # Radius of sphere for GT mask generation (voxels)
    "base_model_train_epochs": 2, # Low for demo; increase for real training (e.g., 100-300)
    "base_model_val_interval": 1, # Low for demo
    "base_model_lr": 1e-4,
    "base_model_batch_size": 1, # Limited by GPU VRAM with large patches
    "num_samples_per_volume": 4, # For RandSpatialCropSamplesd
    "num_base_models": 1, # For this demo, N=1. Can be extended.

    # Training Parameters - GNN Meta-Model
    "gnn_model_name": "GAT",
    "gnn_node_candidate_threshold": 0.3, # Probability threshold to consider a peak a candidate node
    "gnn_node_NMS_footprint": (5,5,5), # Footprint for NMS for candidate generation
    "gnn_edge_max_distance": 20, # Max distance (voxels) to connect nodes in graph
    "gnn_train_epochs": 2, # Low for demo; increase (e.g., 50-100)
    "gnn_lr": 1e-3,
    "gnn_batch_size": 1, # Number of graphs per batch
    "gnn_hidden_channels": 64,
    "gnn_num_layers": 3,
    "gnn_gat_heads": 4,

    # Inference Parameters
    "inference_roi_size_base": (96, 96, 96), # Sliding window ROI for base model
    "inference_sw_batch_size_base": 2, # Sliding window batch size
    "inference_overlap_base": 0.5,
    "output_bbox_size_3d": (11, 11, 11), # Odd numbers for centered box
    "output_bbox_size_2d": (11, 11),

    # HPO Parameters
    "optuna_n_trials": 3, # Low for demo; increase significantly (e.g., 50-100+)
    "optuna_objective_metric": "f2_score", # F-beta with beta=2

    # System
    "device": torch.device("cuda" if torch.cuda.is_available() else "cpu"),
    # MODIFICATION: Set num_workers to 0 for debugging DataLoader issues
    "num_workers": 0, # Was 2, Changed to 0 for debugging
    "random_seed": 42,

    # For 2D slice inference adaptation
    "pseudo_3d_depth_for_2d_slice": 32,
}

# Create output directory
Path(CONFIG["output_dir"]).mkdir(parents=True, exist_ok=True)
Path(CONFIG["dataset_root"]).mkdir(parents=True, exist_ok=True) # For simulated data

# Set seed for reproducibility
set_determinism(CONFIG["random_seed"])
np.random.seed(CONFIG["random_seed"])
random.seed(CONFIG["random_seed"])

print(f"Using device: {CONFIG['device']}")
if not PYG_AVAILABLE and CONFIG["gnn_model_name"]:
    print("WARNING: PyTorch Geometric not available, GNN part will not function correctly.")

# %% [markdown]
# ## Utility: Simulate Dataset Creation
# This function will create dummy JPG slice folders and a ground truth CSV file.
# **Replace this with your actual data loading logic if you have the dataset.**

# %%
def simulate_dataset(root_dir, gt_csv_path, num_tomos, slices_per_tomo_range, dims, num_motors_range):
    # Ensure root_dir and gt_csv_path are Path objects for consistency
    root_dir = Path(root_dir)
    gt_csv_path = Path(gt_csv_path)

    # Check if the specific GT file exists and the number of tomo folders matches
    # This check should be robust to whether root_dir was passed with a trailing slash or not
    tomo_folders_exist = list(root_dir.glob("tomo_*"))

    if gt_csv_path.exists() and len(tomo_folders_exist) == num_tomos :
        # Further check: do these tomo_folders actually contain JPGs? (Simplified check here)
        # This check is basic; a more robust one would verify content or a marker file.
        first_tomo_path_check = root_dir / tomo_folders_exist[0].name if tomo_folders_exist else None
        if first_tomo_path_check and list(first_tomo_path_check.glob("*.jpg")):

            print(f"Simulated dataset appears to exist at {root_dir} with matching GT {gt_csv_path}. Skipping creation.")
            try:
                gt_df = pd.read_csv(gt_csv_path)
                tomo_ids = sorted(list(gt_df['tomo_id'].unique()))
                if len(tomo_ids) == num_tomos: # Ensure CSV also reflects the correct number of tomos
                    return tomo_ids
                else:
                    print(f"Warning: GT CSV tomo_ids count ({len(tomo_ids)}) doesn't match expected num_tomos ({num_tomos}). Recreating.")
            except Exception as e:
                print(f"Error loading existing GT CSV or validating: {e}. Will recreate.")
        else:
            print(f"Tomogram folders found, but content check failed or first tomogram is empty. Recreating dataset.")


    print(f"Simulating dataset at {root_dir}...")
    # If we are recreating, clean up existing simulated data in that specific location first
    if root_dir.exists():
        # Be careful with recursive deletion. Only delete if it's the intended simulation dir.
        # For safety, only delete tomo_* folders and the gt_csv file.
        for item in root_dir.glob("tomo_*"):
            if item.is_dir():
                shutil.rmtree(item)
            elif item.is_file(): # Should not happen for tomo_*
                item.unlink()
        if gt_csv_path.exists():
            gt_csv_path.unlink()
        print(f"Cleaned up existing simulated data in {root_dir} for recreation.")


    root_dir.mkdir(parents=True, exist_ok=True) # Ensure root exists
    all_gt_data = []
    tomo_ids = []

    for i in range(num_tomos):
        tomo_id = f"tomo_{i:03d}"
        tomo_ids.append(tomo_id)
        tomo_path = root_dir / tomo_id
        tomo_path.mkdir(parents=True, exist_ok=True)

        num_slices = random.randint(slices_per_tomo_range[0], slices_per_tomo_range[1])
        height, width = dims

        for s_idx in range(num_slices):
            slice_img = np.random.randint(0, 256, size=(height, width), dtype=np.uint8)
            if s_idx > num_slices // 3 and s_idx < 2 * num_slices // 3 :
                 slice_img[height//4:3*height//4, width//4:3*width//4] = np.clip(slice_img[height//4:3*height//4, width//4:3*width//4] + random.randint(-30,30),0,255)
            cv2.imwrite(str(tomo_path / f"slice_{s_idx:04d}.jpg"), slice_img)

        num_motors = random.randint(num_motors_range[0], num_motors_range[1])
        for m_idx in range(num_motors):
            center_z = random.randint(CONFIG["base_model_target_radius"] + 1, num_slices - CONFIG["base_model_target_radius"] -1 )
            center_y = random.randint(CONFIG["base_model_target_radius"] + 1, height - CONFIG["base_model_target_radius"]-1)
            center_x = random.randint(CONFIG["base_model_target_radius"] + 1, width - CONFIG["base_model_target_radius"]-1)
            all_gt_data.append({
                "row_id": f"{tomo_id}_motor_{m_idx:03d}",
                "tomo_id": tomo_id,
                "center_x": center_x,
                "center_y": center_y,
                "center_z": center_z
            })

    gt_df = pd.DataFrame(all_gt_data)
    gt_df.to_csv(gt_csv_path, index=False)
    print(f"Simulated dataset created. GT at {gt_csv_path}")
    return sorted(tomo_ids)

# MODIFICATION: Use user's provided paths from the traceback for simulation
# This ensures the simulation matches where the error likely occurred.
# The traceback indicates paths like:
# GT: .\Downloads\byu-locating-bacterial-flagellar-motors-2025\train_labels.csv
# Dataset root (implied by GT path): .\Downloads\byu-locating-bacterial-flagellar-motors-2025
# These are relative paths; their absolute location depends on the CWD.
# For consistency, I will use these relative paths as strings.
USER_DATASET_ROOT = r".\Downloads\byu-locating-bacterial-flagellar-motors-2025\train"
USER_GT_CSV_PATH = r".\Downloads\byu-locating-bacterial-flagellar-motors-2025\train_labels.csv"

# Update CONFIG to use these paths for simulation and dataset loading
CONFIG["dataset_root"] = USER_DATASET_ROOT
CONFIG["gt_csv_path"] = USER_GT_CSV_PATH

SIM_NUM_TOMOS = 5
SIM_SLICES_PER_TOMO = (CONFIG["patch_size_base"][2] + 20, CONFIG["patch_size_base"][2] + 50)
SIM_DIMS = (CONFIG["patch_size_base"][0] + 20, CONFIG["patch_size_base"][1] + 20)
SIM_NUM_MOTORS = (1, 5)

ALL_TOMO_IDS = simulate_dataset(
    CONFIG["dataset_root"],
    CONFIG["gt_csv_path"],
    SIM_NUM_TOMOS,
    SIM_SLICES_PER_TOMO,
    SIM_DIMS,
    SIM_NUM_MOTORS
)

if len(ALL_TOMO_IDS) >= 5:
    TRAIN_TOMO_IDS = ALL_TOMO_IDS[:3]
    VAL_TOMO_IDS = ALL_TOMO_IDS[3:4]
    TEST_TOMO_IDS = ALL_TOMO_IDS[4:5]
elif len(ALL_TOMO_IDS) >=3:
    TRAIN_TOMO_IDS = ALL_TOMO_IDS[:1]
    VAL_TOMO_IDS = ALL_TOMO_IDS[1:2]
    TEST_TOMO_IDS = ALL_TOMO_IDS[2:3]
else:
    TRAIN_TOMO_IDS = ALL_TOMO_IDS
    VAL_TOMO_IDS = ALL_TOMO_IDS
    TEST_TOMO_IDS = ALL_TOMO_IDS

print(f"Train Tomo IDs: {TRAIN_TOMO_IDS}")
print(f"Val Tomo IDs: {VAL_TOMO_IDS}")
print(f"Test Tomo IDs: {TEST_TOMO_IDS}")

# %% [markdown]
# ## Part 2: Data Loading and Preprocessing (MONAI)
#
# ### 2.1. Custom ImageReader for JPG Sequences

# %%
# MODIFICATION: Add traceback import
import traceback

import traceback
from pathlib import Path # Ensure Path is imported
import glob
import os
import cv2 # Ensure cv2 is imported
import numpy as np # Ensure numpy is imported

# Required for type hints in verify_suffix, ensure these are available
from typing import Union, Sequence 
# from os import PathLike # More specific: Union[str, Path] is often sufficient

class JPGSequenceReader(ImageReader):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self.kwargs = kwargs

    def read(self, data: Union[str, Path, Sequence[Union[str, Path]]], **kwargs): # Adjusted type hint for data
        # print(f"[JPGSequenceReader] Reading from: {data}") # Debug print
        if isinstance(data, (list, tuple)):
             img_files = sorted([str(f) for f in data])
        elif Path(data).is_dir():
            img_files = sorted(glob.glob(os.path.join(str(data), "*.jpg"))) # Ensure data is str for os.path.join
        else:
            # Check if 'data' is a single file path string that was expected to be a directory
            if isinstance(data, (str, Path)) and not Path(data).exists():
                 raise FileNotFoundError(f"Input path {data} does not exist.")
            raise ValueError(f"JPGSequenceReader expects a directory path or list of JPG files, got {type(data)}: {data}")

        if not img_files:
            # print(f"[JPGSequenceReader] No JPG files found in {data}") # Debug print
            raise FileNotFoundError(f"No JPG files found in {data}")
        # print(f"[JPGSequenceReader] Found {len(img_files)} files in {data}") # Debug print

        slices = []
        expected_shape = None
        for idx, img_file in enumerate(img_files):
            # print(f"[JPGSequenceReader] Reading slice: {img_file}") # Debug print
            slice_img = cv2.imread(img_file, cv2.IMREAD_GRAYSCALE)
            if slice_img is None:
                # print(f"[JPGSequenceReader] ERROR: Could not read image {img_file}") # Debug print
                raise IOError(f"Could not read image {img_file}")
            
            if expected_shape is None:
                expected_shape = slice_img.shape
            elif slice_img.shape != expected_shape:
                # print(f"[JPGSequenceReader] ERROR: Slice {img_file} shape {slice_img.shape} differs from expected {expected_shape}") # Debug print
                raise ValueError(f"Inconsistent slice dimensions in tomogram {data}. Expected {expected_shape}, got {slice_img.shape} for {img_file}")
            slices.append(slice_img)

        if not slices: # Should be caught by 'if not img_files' earlier, but as a safeguard
            raise RuntimeError(f"No slices were loaded for {data}, though image files might have been listed.")

        volume = np.stack(slices, axis=-1) # Stack along Z: HxWxD
        # print(f"[JPGSequenceReader] Stacked volume shape: {volume.shape}") # Debug print
        
        # Construct basic affine: Assuming voxel size 1,1,1 and origin 0,0,0 for simplicity
        # MONAI expects affine in RAS+ coordinate system (Right, Anterior, Superior)
        # If slices are HxWxD (Height, Width, Depth along Z)
        # Image array is typically (X, Y, Z, ...)
        # Let's assume data is (width, height, depth) for affine purposes.
        # Affine: diagonal contains pixel/voxel spacing, last column is origin.
        # For JPG, spacing is usually not encoded, assume 1.0.
        # Default affine (identity for voxel space)
        affine = np.eye(4) 
        # If your slices represent a physical volume, and you know the pixel spacing and slice thickness:
        # pixel_spacing_x, pixel_spacing_y, slice_thickness_z = 1.0, 1.0, 1.0 # Example values
        # affine = np.array([
        #     [pixel_spacing_x, 0, 0, 0],
        #     [0, pixel_spacing_y, 0, 0],
        #     [0, 0, slice_thickness_z, 0],
        #     [0, 0, 0, 1]
        # ])
        return volume, {"original_affine": affine, "spatial_shape": np.array(volume.shape)}

    def get_data(self, img_obj): # img_obj is what read() returned
        # ImageReader.get_data expects to return (array, metadata_dict)
        return img_obj 

    # <<< BEGIN MODIFICATION >>>
    def verify_suffix(self, filename: Union[str, Path, Sequence[Union[str, Path]]]) -> bool:
        """
        Verify that the filename is a directory (assumed to contain JPGs),
        or a .jpg/.jpeg file (if single file passed, though `read` handles directories primarily),
        or a sequence of .jpg/.jpeg files.
        """
        if isinstance(filename, (list, tuple)):
            if not filename:  # Empty sequence
                return False
            return all(Path(f).suffix.lower() in (".jpg", ".jpeg") for f in filename)
        
        # Handle single PathLike object (str or Path)
        p_fn = Path(filename)
        if p_fn.is_dir():
            # The reader is designed to read JPG sequences from a directory.
            return True 
        if p_fn.is_file():
            return p_fn.suffix.lower() in (".jpg", ".jpeg")
        
        # If filename is a string but not an existing file or directory path,
        # it's ambiguous. For example, it could be a file pattern.
        # Since `read` primarily expects an existing directory,
        # we are stricter here. If it's not a recognized file/dir, return False.
        return False
    # <<< END MODIFICATION >>>


# %% [markdown]
# ### 2.2. Ground Truth Processing and Dataset Definition

# %%
class CryoETDataset(MonaiDataset):
    def __init__(self, data_root, tomo_ids, gt_csv_path, transforms, target_radius, is_train=True):
        self.data_root = Path(data_root)
        self.tomo_ids = tomo_ids
        # print(f"[CryoETDataset] Initializing with gt_csv_path: {gt_csv_path}") # Debug
        try:
            self.gt_df = pd.read_csv(gt_csv_path)
            # print(f"[CryoETDataset] Successfully loaded GT CSV. Columns: {self.gt_df.columns.tolist()}") #Debug
        except FileNotFoundError:
            print(f"[CryoETDataset] ERROR: Ground truth CSV file not found at {gt_csv_path}")
            raise
        except Exception as e:
            print(f"[CryoETDataset] ERROR: Failed to load or parse GT CSV {gt_csv_path}: {e}")
            raise

        self.transforms = transforms
        self.target_radius = target_radius
        self.is_train = is_train

        self.data_files = []
        for tomo_id in self.tomo_ids:
            tomo_path = self.data_root / tomo_id
            if tomo_path.is_dir() and list(tomo_path.glob("*.jpg")): # Basic check
                 self.data_files.append({"image": str(tomo_path), "id": tomo_id})
            else:
                print(f"Warning: Tomogram directory {tomo_path} not found or empty for tomo_id '{tomo_id}'. Skipping.")
        
        if not self.data_files and self.tomo_ids: # If tomo_ids were provided but no valid files found
            raise ValueError(f"No valid data files found for the provided tomo_ids. Check dataset_root ('{self.data_root}') and tomo_id subfolders.")
        # print(f"[CryoETDataset] Initialized with {len(self.data_files)} data files.") # Debug


    def _create_target_mask(self, shape_3d, motor_coords):
        # shape_3d is (D, H, W)
        mask = np.zeros(shape_3d, dtype=np.float32)
        # print(f"[_create_target_mask] Mask shape: {shape_3d}, Num motor_coords: {len(motor_coords)}") # Debug

        for i, mc in enumerate(motor_coords):
            # print(f"[_create_target_mask] Processing motor coord {i}: {mc}") # Debug
            if len(mc) != 3:
                print(f"Warning: Motor coord {mc} is not of length 3. Skipping.")
                continue
            try:
                # Ensure they are integers for indexing
                cz, cy, cx = int(round(mc[0])), int(round(mc[1])), int(round(mc[2])) # Z, Y, X
            except (ValueError, TypeError) as e:
                print(f"Warning: Could not convert motor coord {mc} to int: {e}. Skipping.")
                continue

            if not (0 <= cz < shape_3d[0] and \
                    0 <= cy < shape_3d[1] and \
                    0 <= cx < shape_3d[2]):
                print(f"Warning: Motor coord ({cx},{cy},{cz}) out of bounds for shape {shape_3d[::-1]} (orig shape_3d D,H,W: {shape_3d}). Skipping.")
                continue

            z_min, z_max = max(0, cz - self.target_radius), min(shape_3d[0], cz + self.target_radius + 1)
            y_min, y_max = max(0, cy - self.target_radius), min(shape_3d[1], cy + self.target_radius + 1)
            x_min, x_max = max(0, cx - self.target_radius), min(shape_3d[2], cx + self.target_radius + 1)
            mask[z_min:z_max, y_min:y_max, x_min:x_max] = 1.0
        return mask

    def __len__(self):
        return len(self.data_files)

    def __getitem__(self, index):
        # MODIFICATION: Wrap major parts of __getitem__ in a try-except for better error reporting
        data_item = {} # Initialize to ensure it exists in except block
        try:
            data_item = self.data_files[index].copy()
            # print(f"[__getitem__] Processing index {index}, tomo_id: {data_item.get('id', 'Unknown')}") # Debug

            img_arr, meta = JPGSequenceReader().read(data_item["image"]) # HxWxD
            img_arr = img_arr.transpose(2,0,1) # Transpose to DxHxW
            data_item["image"] = img_arr
            data_item["image_meta_dict"] = meta
            # print(f"[__getitem__] Loaded image shape (D,H,W): {img_arr.shape}") # Debug

            tomo_id_current = data_item["id"]
            tomo_gt = self.gt_df[self.gt_df["tomo_id"] == tomo_id_current]
            motor_coords_zyx = []
            if not tomo_gt.empty:
                # Ensure columns exist before trying to access
                required_cols = ["center_z", "center_y", "center_x"]
                if not all(col in tomo_gt.columns for col in required_cols):
                    raise KeyError(f"Ground truth CSV for tomo_id {tomo_id_current} is missing one of required columns: {required_cols}. Found: {tomo_gt.columns.tolist()}")
                
                motor_coords_zyx = list(zip(tomo_gt["center_z"].values,
                                            tomo_gt["center_y"].values,
                                            tomo_gt["center_x"].values))
            # print(f"[__getitem__] Found {len(motor_coords_zyx)} motors for {tomo_id_current}") # Debug
            
            target_mask = self._create_target_mask(img_arr.shape, motor_coords_zyx)
            data_item["label"] = target_mask
            # print(f"[__getitem__] Created target_mask shape: {target_mask.shape}, sum: {np.sum(target_mask)}") # Debug

            # Apply transforms
            # print(f"[__getitem__] Applying transforms for {tomo_id_current}...") # Debug
            processed_data = self.transforms(data_item)
            # if isinstance(processed_data, list): # RandSpatialCropSamplesd case
            #     print(f"[__getitem__] Transforms returned list of {len(processed_data)} samples.") # Debug
            #     for i, p_data in enumerate(processed_data):
            #          print(f"  Sample {i} image shape: {p_data['image'].shape}, label shape: {p_data['label'].shape}") # Debug
            # else: # Single item case
            #     print(f"[__getitem__] Transforms returned single item. Image shape: {processed_data['image'].shape}, label shape: {processed_data['label'].shape}") # Debug

            return processed_data

        except Exception as e:
            # This will now catch errors from JPG reading, GT processing, mask creation, or transforms
            print(f"ERROR in CryoETDataset.__getitem__ for index {index}, data_item id '{data_item.get('id', 'Unknown')}': {e}")
            traceback.print_exc() # Print full traceback
            # To make DataLoader continue and not hang, can return None or a dummy item,
            # but this often hides problems. For debugging, re-raising is better with num_workers=0.
            raise e


# %% [markdown]
# ### 2.3. MONAI Transforms

# %%
# For training base model (includes patch sampling and augmentation)
train_transforms_base = Compose([
    EnsureChannelFirstd(keys=["image", "label"], channel_dim="no_channel"),
    ScaleIntensityRanged(keys=["image"], a_min=0.0, a_max=255.0, b_min=0.0, b_max=1.0, clip=True),
    RandSpatialCropSamplesd(
        keys=["image", "label"],
        roi_size=CONFIG["patch_size_base"],
        num_samples=CONFIG["num_samples_per_volume"],
        random_center=True,
        random_size=False,
    ),
    RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=0),
    RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=1),
    RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=2),
    RandRotate90d(keys=["image", "label"], prob=0.5, max_k=3, spatial_axes=(0,1)), # ZY plane (axis=2)
    RandRotate90d(keys=["image", "label"], prob=0.5, max_k=3, spatial_axes=(0,2)), # ZX plane (axis=1)
    RandRotate90d(keys=["image", "label"], prob=0.5, max_k=3, spatial_axes=(1,2)), # YX plane (axis=0) - MONAI uses (H,W) for spatial_axes=(0,1) on 3D

    RandGaussianNoised(keys=["image"], prob=0.1, mean=0.0, std=0.01),
    RandAdjustContrastd(keys=["image"], prob=0.1, gamma=(0.5, 1.5)),
    ToTensord(keys=["image", "label"]),
])

# For validation base model (processes full volumes or large ROIs, no augmentation usually)
val_transforms_base = Compose([
    EnsureChannelFirstd(keys=["image", "label"], channel_dim="no_channel"),
    ScaleIntensityRanged(keys=["image"], a_min=0.0, a_max=255.0, b_min=0.0, b_max=1.0, clip=True),
    ToTensord(keys=["image", "label"]),
])


# %% [markdown]
# ### 2.4. DataLoaders

# %%
if TRAIN_TOMO_IDS:
    print(f"Attempting to create train_ds_base with {len(TRAIN_TOMO_IDS)} tomograms: {TRAIN_TOMO_IDS}")
    try:
        train_ds_base = CryoETDataset(
            data_root=CONFIG["dataset_root"],
            tomo_ids=TRAIN_TOMO_IDS,
            gt_csv_path=CONFIG["gt_csv_path"],
            transforms=train_transforms_base,
            target_radius=CONFIG["base_model_target_radius"],
            is_train=True
        )
        if len(train_ds_base) == 0: # Check if dataset ended up empty
             print("Warning: train_ds_base is empty after initialization. No data will be loaded.")
             train_loader_base = None
        else:
            train_loader_base = MonaiDataLoader(
                train_ds_base,
                batch_size=CONFIG["base_model_batch_size"],
                shuffle=True,
                num_workers=CONFIG["num_workers"], # Will be 0 due to earlier change
                collate_fn=list_data_collate,
                pin_memory=torch.cuda.is_available()
            )
            print(f"Train dataset size: {len(train_ds_base)}")
            if len(train_ds_base) > 0 and train_loader_base is not None: # Added check for loader not None
                print("Attempting to get first item from train_loader_base...")
                # This is where the original error occurred
                first_batch = first(train_loader_base)
                if first_batch: # Check if first_batch is not None
                    print(f"Shape of first training batch 'image': {first_batch['image'].shape}")
                    print(f"Shape of first training batch 'label': {first_batch['label'].shape}")
                else:
                    print("first(train_loader_base) returned None. DataLoader might be empty or failed silently on first item.")
            elif train_loader_base is None:
                 print("train_loader_base is None, skipping first item check.")

    except Exception as e:
        print(f"ERROR: Failed to create train_ds_base or train_loader_base: {e}")
        traceback.print_exc()
        train_loader_base = None # Ensure it's None on failure
else:
    train_loader_base = None
    print("No training data specified (TRAIN_TOMO_IDS is empty).")


if VAL_TOMO_IDS:
    print(f"Attempting to create val_ds_base with {len(VAL_TOMO_IDS)} tomograms: {VAL_TOMO_IDS}")
    try:
        val_ds_base = CryoETDataset(
            data_root=CONFIG["dataset_root"],
            tomo_ids=VAL_TOMO_IDS,
            gt_csv_path=CONFIG["gt_csv_path"],
            transforms=val_transforms_base,
            target_radius=CONFIG["base_model_target_radius"],
            is_train=False
        )
        if len(val_ds_base) == 0:
            print("Warning: val_ds_base is empty after initialization.")
            val_loader_base = None
        else:
            val_loader_base = MonaiDataLoader(
                val_ds_base, 
                batch_size=1, # Typically 1 for validation of full volumes
                shuffle=False, 
                num_workers=CONFIG["num_workers"], # Will be 0
                pin_memory=torch.cuda.is_available()
                # No list_data_collate needed if not using RandSpatialCropSamplesd for val
            )
            print(f"Validation dataset size: {len(val_ds_base)}")
            if len(val_ds_base) > 0 and val_loader_base is not None:
                print("Attempting to get first item from val_loader_base...")
                first_val_item = first(val_loader_base)
                if first_val_item:
                    print(f"Shape of first validation item 'image': {first_val_item['image'].shape}")
                    print(f"Shape of first validation item 'label': {first_val_item['label'].shape}")
                else:
                    print("first(val_loader_base) returned None.")
            elif val_loader_base is None:
                 print("val_loader_base is None, skipping first item check.")


    except Exception as e:
        print(f"ERROR: Failed to create val_ds_base or val_loader_base: {e}")
        traceback.print_exc()
        val_loader_base = None
else:
    val_loader_base = None
    print("No validation data specified (VAL_TOMO_IDS is empty).")

# %% [markdown]
# ## Part 3: Base 3D CNN Model (MONAI DynUNet)
# Using DynUNet from MONAI, which is highly configurable.

# %%
def get_base_model(config):
    # Пример: 5 уровней блоков (включая начальный и самый глубокий), 4 этапа downsampling
    num_blocks = 5 # Это будет длина kernel_size, strides, filters

    # Filters для каждого из 5 блоков
    # Вы можете сделать это настраиваемым через config, если хотите
    # Например: start_filters = config.get("dynunet_start_filters", 32)
    # num_blocks = config.get("dynunet_num_blocks", 5)
    # dynunet_filters = [start_filters * (2**i) for i in range(num_blocks)]
    dynunet_filters = [32, 64, 128, 256, 512] 
    if len(dynunet_filters) != num_blocks:
        # Если вы хотите гибко менять num_blocks, генерируйте фильтры соответственно
        print(f"Warning: Length of predefined dynunet_filters ({len(dynunet_filters)}) "
              f"does not match num_blocks ({num_blocks}). Adjusting filters.")
        start_filters = dynunet_filters[0] if dynunet_filters else 32
        dynunet_filters = [start_filters * (2**i) for i in range(num_blocks)]


    # Strides: первый для начального блока, остальные для downsampling
    # Первый stride обычно (1,1,1) или (1,2,2) если хотите уменьшить XY сразу.
    # Остальные (2,2,2) для уменьшения вдвое на каждом этапе downsampling.
    # Общая длина списка strides будет num_blocks.
    initial_stride = config.get("dynunet_initial_stride", (1,1,1)) # Позволяет настроить начальный stride
    downsample_stride = config.get("dynunet_downsample_stride", (2,2,2))
    
    actual_strides = [initial_stride] + [downsample_stride] * (num_blocks - 1)

    # Kernel sizes: по одному для каждого блока. Длина также num_blocks.
    # Обычно (3,3,3) для всех блоков.
    kernel_dim = config.get("dynunet_kernel_dim", 3)
    unet_kernel_sizes = [(kernel_dim, kernel_dim, kernel_dim)] * num_blocks 

    # Upsample kernel size (на самом деле это upsample strides для TransposedConv)
    # Их должно быть num_blocks - 1 (для количества операций upsampling)
    # Значения должны соответствовать шагам downsampling (кроме первого initial_stride)
    # т.е. если downsampling был (2,2,2), то и upsampling должен быть (2,2,2)
    upsample_strides_values = [downsample_stride] * (num_blocks - 1)


    # --- Проверки длин согласно требованиям DynUNet ---
    if not (len(unet_kernel_sizes) == len(actual_strides) == len(dynunet_filters)):
        # Эта ошибка не должна возникать, если логика выше верна
        raise ValueError(
            f"CRITICAL INTERNAL ERROR: Lengths of kernel_size ({len(unet_kernel_sizes)}), "
            f"strides ({len(actual_strides)}), and filters ({len(dynunet_filters)}) "
            f"must be equal. Current num_blocks={num_blocks}."
        )
    
    # Проверка длины upsample_kernel_size
    # len(strides) - 1 == len(upsample_kernel_size)
    # len(actual_strides) - 1 == len(upsample_strides_values)
    expected_upsample_len = len(actual_strides) - 1 
    if len(upsample_strides_values) != expected_upsample_len:
         # Эта ошибка не должна возникать, если логика выше верна
        raise ValueError(
            f"CRITICAL INTERNAL ERROR: Length of upsample_kernel_size ({len(upsample_strides_values)}) "
            f"must be {expected_upsample_len} (which is len(strides) - 1)."
        )

    # Условие "no less than 3" для kernel_size и strides (фактически для num_blocks)
    if num_blocks < 3:
        raise ValueError(f"DynUNet requires at least 3 levels (num_blocks >= 3). Current num_blocks={num_blocks}.")

    model = DynUNet(
        spatial_dims=3,
        in_channels=1, # Grayscale input
        out_channels=1, # Output probability map (pre-sigmoid)
        kernel_size=unet_kernel_sizes,
        strides=actual_strides,
        upsample_kernel_size=upsample_strides_values, # Передаем шаги для upsampling
        filters=dynunet_filters,
        norm_name="instance",
        act_name=("leakyrelu", {"negative_slope": 0.01}),
        deep_supervision=False, # Установите True, если хотите использовать глубокий надзор
        # deep_supr_num= num_blocks - 2, # Обычно для deep_supervision (если num_blocks >=3)
        res_block=config.get("dynunet_res_block", True), # Сделать ResBlock настраиваемым
    )
    return model

# %% [markdown]
# ### Base Model Training Loop

# %%
def train_base_model(model, train_loader, val_loader, config, trial=None): # trial for Optuna
    model.to(config["device"])
    
    # Loss function (DiceCE is good for segmentation, Focal for sparse targets)
    # loss_function = DiceCELoss(to_onehot_y=False, sigmoid=True, squared_pred=True, lambda_dice=0.5, lambda_ce=0.5)
    loss_function = FocalLoss(to_onehot_y=False, gamma=2.0, reduction="mean") # Sigmoid applied internally usually
                                                                            # if not, apply sigmoid to model output.
                                                                            # DynUNet output is raw logits.
    # For FocalLoss, ensure model output is logits, then apply sigmoid before loss or use `include_background=False`
    # and ensure labels are {0,1}. MONAI's FocalLoss expects input logits.

    optimizer = torch.optim.AdamW(model.parameters(), lr=config["base_model_lr"], weight_decay=1e-5)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=5, factor=0.5)

    best_metric = -1
    best_metric_epoch = -1
    train_losses = []
    val_metrics_list = [] # Store F-beta or Dice scores from validation

    # Inferer for validation on full volumes
    val_inferer = sliding_window_inference
    roi_size_val = config["inference_roi_size_base"] # Use same as training patch or larger if GPU allows
    sw_batch_size_val = config["inference_sw_batch_size_base"]

    post_pred_val = Compose([Activationsd(keys="pred", sigmoid=True), AsDiscreted(keys="pred", threshold=0.5)])
    post_label_val = Compose([AsDiscreted(keys="label", threshold=0.5)]) # GT is already 0 or 1

    dice_metric = DiceMetric(include_background=True, reduction="mean", get_not_nans=False)

    print(f"Starting base model training for {config['base_model_train_epochs']} epochs.")

    for epoch in range(config["base_model_train_epochs"]):
        model.train()
        epoch_loss = 0
        step = 0
        # train_loader might be None if no train_ids
        if not train_loader:
            print("Skipping training epoch as train_loader is None.")
            break # Exit training loop if no data

        for batch_data in tqdm(train_loader, desc=f"Epoch {epoch+1}/{config['base_model_train_epochs']} Training"):
            step += 1
            inputs, labels = batch_data["image"].to(config["device"]), batch_data["label"].to(config["device"])
            
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = loss_function(outputs, labels)
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()
        
        epoch_loss /= step
        train_losses.append(epoch_loss)
        print(f"Epoch {epoch+1} average training loss: {epoch_loss:.4f}")

        if (epoch + 1) % config["base_model_val_interval"] == 0:
            model.eval()
            if not val_loader:
                print("Skipping validation as val_loader is None.")
                # If Optuna trial, report intermediate value based on training loss
                if trial:
                    trial.report(epoch_loss, epoch) # Report training loss if no validation
                    if trial.should_prune():
                        raise optuna.exceptions.TrialPruned()
                continue # Continue to next epoch

            with torch.no_grad():
                for val_data in tqdm(val_loader, desc=f"Epoch {epoch+1} Validation"):
                    val_inputs = val_data["image"].to(config["device"])
                    val_labels = val_data["label"].to(config["device"]) # Full label mask

                    # Sliding window inference for full volume
                    val_outputs = val_inferer(val_inputs, roi_size_val, sw_batch_size_val, model, overlap=CONFIG["inference_overlap_base"], mode="gaussian", progress=False)
                    
                    # Apply post-processing (sigmoid, threshold)
                    processed_val_outputs = [post_pred_val({"pred": val_out})["pred"] for val_out in val_outputs]
                    processed_val_labels = [post_label_val({"label": val_lab})["label"] for val_lab in val_labels]

                    # Compute Dice metric (or F-beta later)
                    dice_metric(y_pred=processed_val_outputs, y=processed_val_labels)

                metric_val = dice_metric.aggregate().item()
                dice_metric.reset()
                val_metrics_list.append(metric_val) # Store Dice for now
                scheduler.step(metric_val) # Or use train_loss for scheduler if val is too noisy / infrequent

                print(f"Epoch {epoch+1} validation Dice: {metric_val:.4f}")

                if metric_val > best_metric:
                    best_metric = metric_val
                    best_metric_epoch = epoch + 1
                    # Save best model (can be more sophisticated, e.g. based on F-beta)
                    model_save_path = Path(config["output_dir"]) / f"{config['base_model_name']}_best_epoch{best_metric_epoch}_dice{best_metric:.4f}.pth"
                    torch.save({
                        'epoch': epoch + 1,
                        'model_state_dict': model.state_dict(),
                        'optimizer_state_dict': optimizer.state_dict(),
                        'loss': epoch_loss,
                        'metric': best_metric,
                        'config': config # Save config used for this model
                    }, model_save_path)
                    print(f"Saved new best model to {model_save_path}")

                # Optuna pruning
                if trial:
                    trial.report(metric_val, epoch) # Report validation metric to Optuna
                    if trial.should_prune():
                        torch.cuda.empty_cache()
                        raise optuna.exceptions.TrialPruned()
    
    print(f"Finished base model training. Best validation metric: {best_metric:.4f} at epoch {best_metric_epoch}")
    
    # Load best model for subsequent steps if not already loaded
    if best_metric_epoch != -1:
        best_model_path = Path(config["output_dir"]) / f"{config['base_model_name']}_best_epoch{best_metric_epoch}_dice{best_metric:.4f}.pth"
        if best_model_path.exists():
            checkpoint = torch.load(best_model_path, map_location=config["device"])
            model.load_state_dict(checkpoint['model_state_dict'])
            print(f"Loaded best model from {best_model_path} for GNN data generation.")
        else:
            print(f"Warning: Best model path {best_model_path} not found. Using last epoch model.")

    return model, {"train_losses": train_losses, "val_metrics": val_metrics_list, "best_val_metric": best_metric, "best_epoch": best_metric_epoch}


# %% [markdown]
# ## Part 4: GNN Meta-Model (PyTorch Geometric - GAT)
#
# ### 4.1. Graph Construction Utilities

# %%
def get_candidate_nodes_from_prob_map(prob_map_tensor, prob_threshold, nms_footprint_voxels, max_candidates=500):
    """
    Extracts candidate motor locations (nodes) from a probability map.
    Args:
        prob_map_tensor (torch.Tensor): Single channel probability map (D, H, W) on CPU or GPU.
        prob_threshold (float): Minimum probability to consider a voxel.
        nms_footprint_voxels (tuple): Size of footprint for non-maximum suppression (Z,Y,X).
        max_candidates (int): Max number of candidates to return.
    Returns:
        np.ndarray: Array of candidate coordinates (N, 3) in (z, y, x) order.
        np.ndarray: Array of corresponding probabilities (N,).
    """
    if prob_map_tensor.ndim == 4: # (C,D,H,W)
        if prob_map_tensor.shape[0] != 1:
            raise ValueError("Prob map tensor should be single channel for candidate extraction.")
        prob_map_tensor = prob_map_tensor.squeeze(0)
    
    prob_map_np = prob_map_tensor.cpu().numpy()

    # Threshold
    binary_map = prob_map_np > prob_threshold
    if np.sum(binary_map) == 0:
        return np.empty((0,3), dtype=int), np.empty((0,), dtype=float)

    # Non-Maximum Suppression (simplified local max filter)
    # A more robust NMS would use skimage.feature.peak_local_max or similar
    # For simplicity, this is a basic NMS:
    from scipy.ndimage import maximum_filter
    local_max = maximum_filter(prob_map_np, footprint=np.ones(nms_footprint_voxels)) == prob_map_np
    
    # Combine binary map (threshold) and local maxima
    candidate_mask = binary_map & local_max
    
    coords_zyx = np.argwhere(candidate_mask) # Get (z,y,x) coordinates of candidates
    
    if coords_zyx.shape[0] == 0:
        return np.empty((0,3), dtype=int), np.empty((0,), dtype=float)
        
    candidate_probs = prob_map_np[coords_zyx[:,0], coords_zyx[:,1], coords_zyx[:,2]]

    # Sort by probability and take top N
    if len(candidate_probs) > max_candidates:
        sorted_indices = np.argsort(candidate_probs)[::-1][:max_candidates]
        coords_zyx = coords_zyx[sorted_indices]
        candidate_probs = candidate_probs[sorted_indices]
        
    return coords_zyx, candidate_probs


def extract_node_features(candidate_coords_zyx, base_model_outputs_list, patch_radius=2):
    """
    Extracts features for GNN nodes from base model probability maps.
    Args:
        candidate_coords_zyx (np.ndarray): (N, 3) array of (z,y,x) node coordinates.
        base_model_outputs_list (list of torch.Tensor): List of probability maps [(D,H,W), ...] from base models.
                                                      Assumed to be on the same device.
        patch_radius (int): Radius around candidate center to extract local stats (e.g., mean prob).
    Returns:
        torch.Tensor: Node features (N, num_features) on the same device as base_model_outputs.
    """
    if candidate_coords_zyx.shape[0] == 0:
        return torch.empty((0, len(base_model_outputs_list)), device=base_model_outputs_list[0].device if base_model_outputs_list else CONFIG["device"])

    num_nodes = candidate_coords_zyx.shape[0]
    num_base_models = len(base_model_outputs_list)
    # Feature: probability at center from each base model
    # Feature: mean probability in a small patch around center from each base model
    # num_features_per_model = 2
    num_features_per_model = 1 # Just prob at center for simplicity now
    
    node_features = torch.zeros((num_nodes, num_base_models * num_features_per_model),
                                device=base_model_outputs_list[0].device)

    for i, (z, y, x) in enumerate(candidate_coords_zyx):
        feat_idx = 0
        for model_idx, prob_map in enumerate(base_model_outputs_list):
            # Ensure prob_map is (D,H,W)
            if prob_map.ndim == 4 and prob_map.shape[0] == 1:
                prob_map = prob_map.squeeze(0)
            
            # Boundary checks for safety
            d, h, w = prob_map.shape
            z, y, x = int(z), int(y), int(x)
            if not (0 <= z < d and 0 <= y < h and 0 <= x < w):
                # Node is out of bounds for this prob_map (should not happen if maps cover same space)
                # Fill with zeros or a specific value
                node_features[i, feat_idx : feat_idx + num_features_per_model] = 0
                feat_idx += num_features_per_model
                continue

            # Prob at center
            node_features[i, feat_idx] = prob_map[z, y, x]
            feat_idx += 1
            
            # Mean prob in patch (optional, adds complexity)
            # z_min, z_max = max(0, z - patch_radius), min(d, z + patch_radius + 1)
            # y_min, y_max = max(0, y - patch_radius), min(h, y + patch_radius + 1)
            # x_min, x_max = max(0, x - patch_radius), min(w, x + patch_radius + 1)
            # local_patch = prob_map[z_min:z_max, y_min:y_max, x_min:x_max]
            # node_features[i, feat_idx] = local_patch.mean() if local_patch.numel() > 0 else 0
            # feat_idx +=1

    return node_features


def create_graph_edges(candidate_coords_zyx, max_dist):
    """
    Creates graph edges based on spatial proximity.
    Args:
        candidate_coords_zyx (np.ndarray): (N, 3) node coordinates (z,y,x).
        max_dist (float): Maximum Euclidean distance to connect two nodes.
    Returns:
        torch.Tensor: Edge index (2, num_edges) in PyG format.
    """
    if candidate_coords_zyx.shape[0] < 2:
        return torch.empty((2,0), dtype=torch.long)

    dist_matrix = cdist(candidate_coords_zyx, candidate_coords_zyx)
    adj_matrix = (dist_matrix > 0) & (dist_matrix <= max_dist) # Exclude self-loops for now, ensure >0
    
    edge_index_np = np.array(np.where(adj_matrix))
    return torch.from_numpy(edge_index_np).long()


def label_graph_nodes(candidate_coords_zyx, gt_motor_coords_zyx, matching_radius):
    """
    Assigns ground truth labels (motor/non-motor) to GNN nodes.
    Args:
        candidate_coords_zyx (np.ndarray): (N, 3) node coordinates (z,y,x).
        gt_motor_coords_zyx (np.ndarray): (M, 3) ground truth motor coordinates (z,y,x).
        matching_radius (float): Max distance for a candidate to be matched to a GT motor.
    Returns:
        torch.Tensor: Node labels (N,), 1 for motor, 0 for non-motor.
    """
    num_nodes = candidate_coords_zyx.shape[0]
    node_labels = torch.zeros(num_nodes, dtype=torch.float) # Use float for BCEWithLogitsLoss

    if num_nodes == 0 or gt_motor_coords_zyx.shape[0] == 0:
        return node_labels

    dist_matrix = cdist(candidate_coords_zyx, gt_motor_coords_zyx)
    
    # For each GT motor, find the closest candidate node within matching_radius
    # This is a simple assignment. More complex matching (e.g. Hungarian) could be used.
    # Here, any node close enough to ANY GT motor is labeled positive.
    min_dist_to_gt = dist_matrix.min(axis=1)
    matched_nodes_indices = np.where(min_dist_to_gt <= matching_radius)[0]
    
    node_labels[matched_nodes_indices] = 1.0
    
    return node_labels

# %% [markdown]
# ### 4.2. GNN Model Definition (e.g., GAT)

# %%
if PYG_AVAILABLE:
    class GATNet(torch.nn.Module):
        def __init__(self, in_channels, hidden_channels, out_channels, num_layers, heads=4, dropout=0.2):
            super().__init__()
            self.dropout_p = dropout
            self.convs = torch.nn.ModuleList()
            self.batch_norms = torch.nn.ModuleList() # Optional: Batch norm for GNNs

            # Input layer
            self.convs.append(pyg_nn.GATConv(in_channels, hidden_channels, heads=heads, dropout=dropout))
            self.batch_norms.append(pyg_nn.BatchNorm(hidden_channels * heads))

            # Hidden layers
            for _ in range(num_layers - 2):
                self.convs.append(pyg_nn.GATConv(hidden_channels * heads, hidden_channels, heads=heads, dropout=dropout))
                self.batch_norms.append(pyg_nn.BatchNorm(hidden_channels * heads))
            
            # Output layer
            self.convs.append(pyg_nn.GATConv(hidden_channels * heads, out_channels, heads=1, concat=False, dropout=dropout))
            # No batch norm for output typically, as it's logits

        def forward(self, x, edge_index):
            for i, conv in enumerate(self.convs[:-1]):
                x = conv(x, edge_index)
                x = self.batch_norms[i](x)
                x = F.elu(x) # ELU or LeakyReLU are common with GAT
                x = F.dropout(x, p=self.dropout_p, training=self.training)
            
            x = self.convs[-1](x, edge_index) # Output logits for classification
            return x # Output shape (num_nodes, out_channels)
else:
    class GATNet(torch.nn.Module): # Placeholder if PyG not available
        def __init__(self, *args, **kwargs):
            super().__init__()
            self.dummy_layer = torch.nn.Linear(1,1)
            if not PYG_AVAILABLE:
              print("WARNING: GATNet is a placeholder as PyTorch Geometric is not installed.")
        def forward(self, x, edge_index):
            if not PYG_AVAILABLE:
                # Return dummy output of correct expected shape (num_nodes, out_channels=1)
                # This won't train meaningfully but allows pipeline to run.
                return torch.zeros((x.shape[0], 1), device=x.device)
            return self.dummy_layer(torch.mean(x, dim=1, keepdim=True))


# %% [markdown]
# ### 4.3. Generating Training Data for GNN
# This involves running the trained base model(s) on a GNN training split of tomograms.

# %%
def generate_gnn_training_data(base_model, tomo_ids_for_gnn, data_root, gt_csv_path, gnn_config, base_model_config, device):
    print(f"Generating GNN training data for {len(tomo_ids_for_gnn)} tomograms...")
    base_model.eval()
    base_model.to(device)
    
    gnn_data_list = []
    
    # Use validation transforms for loading full volumes for GNN data generation
    # No random cropping here.
    gnn_data_gen_transforms = Compose([
        EnsureChannelFirstd(keys=["image", "label"], channel_dim="no_channel"),
        ScaleIntensityRanged(keys=["image"], a_min=0.0, a_max=255.0, b_min=0.0, b_max=1.0, clip=True),
        ToTensord(keys=["image", "label"]),
    ])

    # Dataset for GNN data generation (full volumes)
    gnn_raw_ds = CryoETDataset(
        data_root=data_root,
        tomo_ids=tomo_ids_for_gnn,
        gt_csv_path=gt_csv_path,
        transforms=gnn_data_gen_transforms,
        target_radius=base_model_config["base_model_target_radius"], # Not strictly needed for GNN data gen, but part of dataset
        is_train=False
    )
    gnn_raw_loader = MonaiDataLoader(gnn_raw_ds, batch_size=1, shuffle=False, num_workers=base_model_config["num_workers"])

    full_gt_df = pd.read_csv(gt_csv_path)

    with torch.no_grad():
        for batch_data in tqdm(gnn_raw_loader, desc="Processing Tomos for GNN Data"):
            inputs = batch_data["image"].to(device)
            tomo_id = batch_data["id"][0] # DataLoader with batch_size=1 gives list

            # 1. Get base model probability map for the full tomogram
            # Sliding window inference for full volume
            prob_map = sliding_window_inference(
                inputs,
                roi_size=base_model_config["inference_roi_size_base"],
                sw_batch_size=base_model_config["inference_sw_batch_size_base"],
                predictor=base_model,
                overlap=base_model_config["inference_overlap_base"],
                mode="gaussian",
                progress=False # Reduce verbosity
            )
            prob_map = torch.sigmoid(prob_map) # Apply sigmoid to get probabilities
            # prob_map is (1, C, D, H, W), C=1. Squeeze to (D,H,W) for CPU ops.
            prob_map_squeezed = prob_map.squeeze(0).squeeze(0) # Now (D,H,W)

            # 2. Get candidate nodes
            candidate_coords_zyx, candidate_probs = get_candidate_nodes_from_prob_map(
                prob_map_squeezed,
                gnn_config["gnn_node_candidate_threshold"],
                gnn_config["gnn_node_NMS_footprint"]
            )
            if candidate_coords_zyx.shape[0] == 0:
                print(f"No candidate nodes found for tomogram {tomo_id}. Skipping.")
                continue

            # 3. Extract node features
            # For now, assuming one base model. If multiple, base_model_outputs_list would contain their maps.
            node_features = extract_node_features(candidate_coords_zyx, [prob_map_squeezed])
            
            # 4. Create graph edges
            edge_index = create_graph_edges(candidate_coords_zyx, gnn_config["gnn_edge_max_distance"])

            # 5. Label graph nodes
            tomo_gt_df = full_gt_df[full_gt_df["tomo_id"] == tomo_id]
            if not tomo_gt_df.empty:
                gt_motor_coords_zyx_np = tomo_gt_df[["center_z", "center_y", "center_x"]].values.astype(int)
            else:
                gt_motor_coords_zyx_np = np.empty((0,3), dtype=int)
            
            # Matching radius for GNN node labeling (can be a hyperparameter)
            # Should be related to base_model_target_radius.
            gnn_node_labeling_radius = base_model_config["base_model_target_radius"] + 2 
            
            node_labels = label_graph_nodes(candidate_coords_zyx, gt_motor_coords_zyx_np, gnn_node_labeling_radius)

            # Create PyG Data object
            if PYG_AVAILABLE:
                graph_data = PyGData(
                    x=node_features.to(device), # Keep features on device if GNN training on GPU
                    edge_index=edge_index.to(device),
                    y=node_labels.unsqueeze(1).to(device), # Target shape (num_nodes, 1)
                    pos=torch.from_numpy(candidate_coords_zyx).float().to(device) # Store positions for potential visualization or use
                )
                graph_data.tomo_id = tomo_id # Store for reference
                gnn_data_list.append(graph_data)
            else: # Fallback if PyG not available (store as dict for potential later conversion)
                 gnn_data_list.append({
                     "x": node_features, "edge_index": edge_index, "y": node_labels.unsqueeze(1),
                     "pos": torch.from_numpy(candidate_coords_zyx).float(), "tomo_id": tomo_id
                 })


    print(f"Generated {len(gnn_data_list)} graphs for GNN training.")
    if not PYG_AVAILABLE and gnn_data_list:
        print("WARNING: PyG not available. GNN data stored as dicts, actual GNN training will fail or be dummied.")
    return gnn_data_list


# %% [markdown]
# ### 4.4. GNN Training Loop

# %%
def train_gnn_model(gnn_model, gnn_train_loader, gnn_val_loader, gnn_config, device, trial=None):
    if not PYG_AVAILABLE:
        print("Skipping GNN training as PyTorch Geometric is not available.")
        # Return a dummy model and results if in Optuna trial to avoid crash
        # The dummy GATNet will handle forward pass by returning zeros.
        # Create a dummy optimizer and loss for it to "run"
        dummy_optimizer = torch.optim.Adam(gnn_model.parameters(), lr=1e-3) # Won't do anything useful
        return gnn_model, {"train_losses": [0], "val_f2_scores": [0], "best_val_f2": 0}


    gnn_model.to(device)
    optimizer = torch.optim.AdamW(gnn_model.parameters(), lr=gnn_config["gnn_lr"], weight_decay=1e-5)
    # Loss for node classification (binary: motor vs non-motor)
    # BCEWithLogitsLoss is suitable as GNN outputs logits.
    # Handle class imbalance for GNN nodes if necessary (many candidates are non-motors)
    # Can use pos_weight in BCEWithLogitsLoss. Calculate from training data.
    # For now, simple BCE.
    criterion = torch.nn.BCEWithLogitsLoss()

    best_val_f2 = -1
    train_losses = []
    val_f2_scores = [] # We'll use F-beta for GNN validation too

    print(f"Starting GNN training for {gnn_config['gnn_train_epochs']} epochs.")

    for epoch in range(gnn_config["gnn_train_epochs"]):
        gnn_model.train()
        epoch_loss = 0
        num_batches = 0

        if not gnn_train_loader:
            print("Skipping GNN training epoch as gnn_train_loader is None.")
            break

        for graph_batch in tqdm(gnn_train_loader, desc=f"GNN Epoch {epoch+1} Training"):
            # graph_batch is a Batch object from PyG if PyGDataLoader is used
            # It automatically handles batching of graphs.
            graph_batch = graph_batch.to(device)
            
            optimizer.zero_grad()
            # Ensure graph_batch.x and graph_batch.edge_index are correctly formatted
            if not hasattr(graph_batch, 'x') or not hasattr(graph_batch, 'edge_index') or not hasattr(graph_batch, 'y'):
                 print(f"Skipping batch due to missing attributes: {graph_batch}")
                 continue
            if graph_batch.x is None or graph_batch.edge_index is None or graph_batch.y is None:
                 print(f"Skipping batch due to None attributes: x:{graph_batch.x is None}, edge_index:{graph_batch.edge_index is None}, y:{graph_batch.y is None}")
                 continue
            if graph_batch.num_nodes == 0 or graph_batch.num_edges == 0 and graph_batch.num_nodes > 1 : # Allow single node graph with no edges
                # print(f"Skipping batch with {graph_batch.num_nodes} nodes and {graph_batch.num_edges} edges.")
                # Instead of skipping, let GNN handle it; it might predict based on node features alone.
                # Or, ensure graph generation avoids empty graphs if they cause issues.
                pass

            try:
                out_logits = gnn_model(graph_batch.x, graph_batch.edge_index)
                loss = criterion(out_logits, graph_batch.y) # y should be (num_nodes_in_batch, 1) float
                loss.backward()
                optimizer.step()
                epoch_loss += loss.item()
                num_batches += 1
            except RuntimeError as e:
                print(f"Runtime error during GNN training: {e}")
                print(f"Graph batch details: Nodes: {graph_batch.num_nodes}, Edges: {graph_batch.num_edges}")
                print(f"Node features shape: {graph_batch.x.shape if graph_batch.x is not None else 'None'}")
                print(f"Edge index shape: {graph_batch.edge_index.shape if graph_batch.edge_index is not None else 'None'}")
                # Consider skipping problematic batch or re-raising
                continue # Skip this batch


        if num_batches > 0:
            epoch_loss /= num_batches
            train_losses.append(epoch_loss)
            print(f"GNN Epoch {epoch+1} average training loss: {epoch_loss:.4f}")
        else:
            train_losses.append(0) # Or handle as error
            print(f"GNN Epoch {epoch+1} - No batches processed in training.")


        # GNN Validation (on a GNN validation set, could be same as base model val set)
        if gnn_val_loader:
            gnn_model.eval()
            all_preds = []
            all_true_labels = []
            with torch.no_grad():
                for graph_batch_val in tqdm(gnn_val_loader, desc=f"GNN Epoch {epoch+1} Validation"):
                    graph_batch_val = graph_batch_val.to(device)
                    if not hasattr(graph_batch_val, 'x') or not hasattr(graph_batch_val, 'edge_index') or not hasattr(graph_batch_val, 'y'):
                        continue
                    if graph_batch_val.x is None or graph_batch_val.edge_index is None or graph_batch_val.y is None:
                        continue
                    
                    out_logits_val = gnn_model(graph_batch_val.x, graph_batch_val.edge_index)
                    preds_probs_val = torch.sigmoid(out_logits_val).cpu().numpy().ravel()
                    true_labels_val = graph_batch_val.y.cpu().numpy().ravel()
                    
                    all_preds.extend(preds_probs_val)
                    all_true_labels.extend(true_labels_val)
            
            if all_true_labels: # If any validation data processed
                # For F-beta, we need TP, FP, FN. Threshold predictions.
                # A common threshold is 0.5, but this could also be optimized.
                threshold = 0.5
                preds_binary = (np.array(all_preds) >= threshold).astype(int)
                true_binary = (np.array(all_true_labels) >= 0.5).astype(int) # GT labels are 0 or 1

                f2_score_val, _, _ = calculate_fbeta_precision_recall(preds_binary, true_binary, beta=2.0)
                val_f2_scores.append(f2_score_val)
                print(f"GNN Epoch {epoch+1} validation F2-score: {f2_score_val:.4f}")

                if f2_score_val > best_val_f2:
                    best_val_f2 = f2_score_val
                    # Save best GNN model
                    gnn_save_path = Path(CONFIG["output_dir"]) / f"{CONFIG['gnn_model_name']}_best_f2_{best_val_f2:.4f}.pth"
                    torch.save(gnn_model.state_dict(), gnn_save_path)
                    print(f"Saved new best GNN model to {gnn_save_path}")
                
                if trial:
                    trial.report(f2_score_val, epoch)
                    if trial.should_prune():
                        torch.cuda.empty_cache()
                        raise optuna.exceptions.TrialPruned()
            else:
                print("GNN Epoch {epoch+1} - No validation data processed or no true labels.")
                if trial: # Report train loss if no val metric
                    trial.report(epoch_loss if num_batches > 0 else 0.0 , epoch) # Or some default low val score
                    if trial.should_prune():
                         torch.cuda.empty_cache()
                         raise optuna.exceptions.TrialPruned()


    print(f"Finished GNN training. Best validation F2-score: {best_val_f2:.4f}")
    # Load best GNN model
    if best_val_f2 != -1:
        best_gnn_path_str = f"{CONFIG['gnn_model_name']}_best_f2_{best_val_f2:.4f}.pth"
        best_gnn_path = Path(CONFIG["output_dir"]) / best_gnn_path_str
        if best_gnn_path.exists():
            gnn_model.load_state_dict(torch.load(best_gnn_path, map_location=device))
            print(f"Loaded best GNN model from {best_gnn_path}")
        else:
            print(f"Warning: Best GNN model path {best_gnn_path} not found. Using last epoch GNN model.")


    return gnn_model, {"train_losses": train_losses, "val_f2_scores": val_f2_scores, "best_val_f2": best_val_f2}

# %% [markdown]
# ## Part 5: Training Orchestration (Base Model + GNN)
# This function will manage the two-stage training process.

# %%
def run_full_training_pipeline(config, trial=None):
    """
    Manages the entire training:
    1. Train base model(s).
    2. Generate GNN training data using trained base model(s).
    3. Train GNN meta-model.
    Returns the trained GNN model and the last trained base model.
    """
    # --- 1. Train Base Model ---
    print("--- Stage 1: Training Base Model ---")
    base_model = get_base_model(config).to(config["device"])
    
    # For Optuna, if base model HPs are tuned, they should be passed via config.
    # For now, use fixed HPs from global CONFIG, or let Optuna override them in `objective`
    
    # Ensure data loaders are available
    if not train_loader_base:
        print("Cannot train base model: train_loader_base is None.")
        # If in Optuna trial, this is a failure or needs specific handling
        if trial: raise optuna.exceptions.TrialPruned("No training data for base model")
        return None, None, None # Or some indicator of failure

    trained_base_model, base_model_history = train_base_model(
        base_model, train_loader_base, val_loader_base, config, trial=trial # Pass trial for pruning base model
    )
    # Note: if base model training prunes, this function will exit via exception.

    # --- 2. Generate GNN Training Data ---
    print("--- Stage 2: Generating GNN Training Data ---")
    # Use validation tomograms for GNN training/validation to avoid leakage from base model's training set.
    # Split VAL_TOMO_IDS further if needed: e.g., 70% for GNN train, 30% for GNN val.
    # For simplicity here, use all VAL_TOMO_IDS for GNN data generation.
    # This data will then be split into GNN_train_graphs and GNN_val_graphs.
    if not VAL_TOMO_IDS:
        print("Cannot generate GNN data: VAL_TOMO_IDS is empty.")
        if trial: raise optuna.exceptions.TrialPruned("No validation data for GNN generation")
        return trained_base_model, None, None

    gnn_graphs_all = generate_gnn_training_data(
        trained_base_model, VAL_TOMO_IDS, # Use validation set of tomograms
        config["dataset_root"], config["gt_csv_path"], config, config, config["device"]
    )

    if not gnn_graphs_all:
        print("No graphs generated for GNN training. Cannot train GNN.")
        if trial: raise optuna.exceptions.TrialPruned("No GNN graphs generated")
        return trained_base_model, None, None # Return trained base model, but no GNN

    # Split GNN graphs into train/val for GNN training
    random.shuffle(gnn_graphs_all) # Shuffle before splitting
    gnn_train_split_idx = int(0.8 * len(gnn_graphs_all)) # 80/20 split
    gnn_train_graphs = gnn_graphs_all[:gnn_train_split_idx]
    gnn_val_graphs = gnn_graphs_all[gnn_train_split_idx:]

    if not gnn_train_graphs:
        print("Not enough GNN graphs for training split. GNN training skipped.")
        if trial: trial.report(0.0, config["base_model_train_epochs"]) # Report 0 for final GNN metric
        # No Optuna prune here, as base model trained. The pipeline just didn't get to GNN.
        return trained_base_model, None, None


    if PYG_AVAILABLE:
        gnn_train_loader = PyGDataLoader(gnn_train_graphs, batch_size=config["gnn_batch_size"], shuffle=True)
        gnn_val_loader = PyGDataLoader(gnn_val_graphs, batch_size=config["gnn_batch_size"], shuffle=False) if gnn_val_graphs else None
    else: # Fallback if PyG not available
        # Cannot create PyGDataLoaders. Training will be dummied.
        print("PyG not available. GNN DataLoaders are placeholders.")
        gnn_train_loader = gnn_train_graphs # Pass list of dicts
        gnn_val_loader = gnn_val_graphs if gnn_val_graphs else None


    # --- 3. Train GNN Meta-Model ---
    print("--- Stage 3: Training GNN Meta-Model ---")
    # Determine GNN input channels from features.
    # Example: if node_features are (N, num_base_models * 1), then in_channels is num_base_models.
    # From extract_node_features: num_features_per_model = 1. So in_channels = num_base_models.
    # If PYG_AVAILABLE and gnn_train_graphs:
    #    gnn_in_channels = gnn_train_graphs[0].x.shape[1] if gnn_train_graphs[0].x is not None else config["num_base_models"]
    # elif gnn_train_graphs: # list of dicts
    #    gnn_in_channels = gnn_train_graphs[0]['x'].shape[1] if gnn_train_graphs[0]['x'] is not None else config["num_base_models"]
    # else: # Default if no data
    #    gnn_in_channels = config["num_base_models"]
    # Safer way, check first element of loader if it exists
    if gnn_train_loader and PYG_AVAILABLE:
        first_gnn_batch_for_shape = first(gnn_train_loader)
        if first_gnn_batch_for_shape and hasattr(first_gnn_batch_for_shape, 'x') and first_gnn_batch_for_shape.x is not None:
            gnn_in_channels = first_gnn_batch_for_shape.x.shape[1]
        else: # Fallback if batch is weird or empty
            gnn_in_channels = config["num_base_models"] # Default
    elif gnn_train_graphs and not PYG_AVAILABLE: # list of dicts
        gnn_in_channels = gnn_train_graphs[0]['x'].shape[1] if gnn_train_graphs[0]['x'] is not None else config["num_base_models"]
    else:
        gnn_in_channels = config["num_base_models"]


    gnn_model = GATNet(
        in_channels=gnn_in_channels,
        hidden_channels=config["gnn_hidden_channels"],
        out_channels=1, # Binary classification (motor/non-motor)
        num_layers=config["gnn_num_layers"],
        heads=config["gnn_gat_heads"]
    ).to(config["device"])

    trained_gnn_model, gnn_model_history = train_gnn_model(
        gnn_model, gnn_train_loader, gnn_val_loader, config, config["device"], trial=trial # Pass trial for GNN pruning
    )

    # The metric for Optuna should be from evaluating the *entire pipeline* on a final hold-out validation set.
    # Here, gnn_model_history["best_val_f2"] is on GNN's own validation split.
    # This is what Optuna will use if called from objective function.
    
    # For returning from a standalone run:
    pipeline_results = {
        "base_model_history": base_model_history,
        "gnn_model_history": gnn_model_history,
        "final_gnn_val_metric": gnn_model_history.get("best_val_f2", 0) if gnn_model_history else 0
    }
    
    return trained_base_model, trained_gnn_model, pipeline_results


# %% [markdown]
# ## Part 6: Evaluation Metrics and Utilities
#
# ### F-beta Score and Detection Matching

# %%
def calculate_fbeta_precision_recall(predictions_binary, ground_truth_binary, beta, epsilon=1e-7):
    """
    Calculates F-beta score, precision, and recall.
    Args:
        predictions_binary (np.ndarray or torch.Tensor): Binary predictions (0 or 1).
        ground_truth_binary (np.ndarray or torch.Tensor): Binary ground truth (0 or 1).
        beta (float): The beta value for F-beta score.
        epsilon (float): Small value to prevent division by zero.
    Returns:
        tuple: (fbeta_score, precision, recall)
    """
    if isinstance(predictions_binary, torch.Tensor):
        predictions_binary = predictions_binary.cpu().numpy()
    if isinstance(ground_truth_binary, torch.Tensor):
        ground_truth_binary = ground_truth_binary.cpu().numpy()

    TP = np.sum((predictions_binary == 1) & (ground_truth_binary == 1))
    FP = np.sum((predictions_binary == 1) & (ground_truth_binary == 0))
    FN = np.sum((predictions_binary == 0) & (ground_truth_binary == 1))

    precision = TP / (TP + FP + epsilon)
    recall = TP / (TP + FN + epsilon)

    fbeta_denominator = (beta**2 * precision) + recall
    if fbeta_denominator > epsilon:
        fbeta_score = (1 + beta**2) * (precision * recall) / fbeta_denominator
    else:
        fbeta_score = 0.0
        
    return fbeta_score, precision, recall

def match_detections_3d(pred_centers_scores, gt_centers, matching_dist_threshold):
    """
    Matches predicted 3D detections to ground truth 3D centers.
    Args:
        pred_centers_scores (list of tuples): [( (x,y,z), score ), ...]. Assumes x,y,z order.
        gt_centers (np.ndarray): Array of GT (x,y,z) centers, shape (M, 3).
        matching_dist_threshold (float): Maximum distance for a match.
    Returns:
        tuple: (TP, FP, FN, matched_pairs_info)
        matched_pairs_info: list of (pred_idx, gt_idx, dist, score)
    """
    if not pred_centers_scores or gt_centers.shape[0] == 0:
        FP = len(pred_centers_scores)
        FN = gt_centers.shape[0]
        return 0, FP, FN, []

    pred_centers_np = np.array([p[0] for p in pred_centers_scores])
    pred_scores_np = np.array([p[1] for p in pred_centers_scores])

    num_preds = pred_centers_np.shape[0]
    num_gts = gt_centers.shape[0]

    if num_preds == 0:
        return 0, 0, num_gts, []

    dist_matrix = cdist(pred_centers_np, gt_centers) # (num_preds, num_gts)

    # Greedy matching based on highest score first (or could use Hungarian algorithm for optimal assignment)
    # For simplicity, iterate through predictions sorted by score (descending)
    
    # Create a list of (original_pred_idx, score) to sort
    sorted_pred_indices = np.argsort(pred_scores_np)[::-1] # Sort by score, descending

    gt_matched_flags = np.zeros(num_gts, dtype=bool)
    pred_is_tp = np.zeros(num_preds, dtype=bool)
    matched_pairs_info = []

    for pred_idx_sorted in sorted_pred_indices:
        # Find GTs within threshold for this prediction
        dists_to_gts_for_this_pred = dist_matrix[pred_idx_sorted, :]
        possible_gt_matches_indices = np.where((dists_to_gts_for_this_pred <= matching_dist_threshold) & (~gt_matched_flags))[0]

        if len(possible_gt_matches_indices) > 0:
            # Match to the closest available GT
            best_gt_match_local_idx = np.argmin(dists_to_gts_for_this_pred[possible_gt_matches_indices])
            gt_idx_global = possible_gt_matches_indices[best_gt_match_local_idx]
            
            pred_is_tp[pred_idx_sorted] = True
            gt_matched_flags[gt_idx_global] = True
            matched_pairs_info.append((
                pred_idx_sorted, 
                gt_idx_global, 
                dists_to_gts_for_this_pred[gt_idx_global], 
                pred_scores_np[pred_idx_sorted]
            ))

    TP = np.sum(pred_is_tp)
    FP = num_preds - TP
    FN = num_gts - np.sum(gt_matched_flags)
            
    return TP, FP, FN, matched_pairs_info


def evaluate_pipeline_on_tomogram(tomo_id, pred_detections, gt_csv_path, matching_dist_thresh_3d, beta=2.0):
    """
    Evaluates detections for a single 3D tomogram.
    Args:
        tomo_id (str): The ID of the tomogram.
        pred_detections (list of dicts): [{'center':(x,y,z), 'score':score, 'bbox':(...)}].
        gt_csv_path (str): Path to the ground truth CSV.
        matching_dist_thresh_3d (float): Distance threshold for matching.
        beta (float): Beta for F-beta score.
    Returns:
        dict: {'fbeta': score, 'precision': score, 'recall': score, 'tp': count, 'fp': count, 'fn': count}
    """
    full_gt_df = pd.read_csv(gt_csv_path)
    tomo_gt_df = full_gt_df[full_gt_df["tomo_id"] == tomo_id]

    if tomo_gt_df.empty:
        gt_centers_xyz = np.empty((0,3))
    else:
        # Ensure correct X,Y,Z order for cdist
        gt_centers_xyz = tomo_gt_df[["center_x", "center_y", "center_z"]].values.astype(float)

    pred_centers_scores = []
    for det in pred_detections:
        # Ensure center is (x,y,z) tuple/list
        pred_centers_scores.append((tuple(det['center']), det['score']))
    
    TP, FP, FN, _ = match_detections_3d(pred_centers_scores, gt_centers_xyz, matching_dist_thresh_3d)
    
    fbeta, precision, recall = calculate_fbeta_precision_recall(
        np.concatenate([np.ones(TP), np.ones(FP), np.zeros(FN)]), # Equivalent binary predictions for TP,FP
        np.concatenate([np.ones(TP), np.zeros(FP), np.ones(FN)]), # Equivalent binary GT for TP,FN
        beta=beta
    )
    
    return {
        f"{CONFIG['optuna_objective_metric']}": fbeta, # Use the metric name from config for consistency
        "precision": precision,
        "recall": recall,
        "tp": TP,
        "fp": FP,
        "fn": FN
    }


# %% [markdown]
# ## Part 7: Inference Pipeline
# Includes Mode 1 (Full 3D Tomograms) and Mode 2 (Individual 2D Slices adaptation).
#
# ### Helper: Post-process GNN outputs to detections

# %%
def gnn_output_to_detections(graph_data_with_preds, gnn_score_threshold, output_bbox_size_3d_or_2d):
    """
    Converts GNN output probabilities on nodes to final detection list.
    Args:
        graph_data_with_preds (PyGData or dict): Graph data object containing:
            - `pos`: (N,3) or (N,2) tensor of node coordinates (z,y,x) or (y,x).
            - `pred_probs`: (N,) tensor of GNN predicted probabilities for nodes.
            - (optional) `tomo_id` or `slice_id`
        gnn_score_threshold (float): Confidence threshold to make a detection.
        output_bbox_size_3d_or_2d (tuple): Size of bbox (Dx,Dy,Dz) or (Dy,Dx). Assumed odd for centering.
    Returns:
        list of dicts: [{'center':(x,y,z) or (x,y), 'bbox':(...), 'score':score, 'id': id}]
    """
    detections = []
    node_coords = graph_data_with_preds.pos.cpu().numpy() # (z,y,x) or (y,x)
    node_probs = graph_data_with_preds.pred_probs.cpu().numpy()

    is_3d = (len(output_bbox_size_3d_or_2d) == 3)

    for i in range(node_coords.shape[0]):
        score = node_probs[i]
        if score >= gnn_score_threshold:
            if is_3d:
                # node_coords are (z,y,x)
                center_z, center_y, center_x = node_coords[i, 0], node_coords[i, 1], node_coords[i, 2]
                # Output center as (x,y,z)
                center_out = (float(center_x), float(center_y), float(center_z))
                
                # Bbox (X_min,Y_min,Z_min,X_max,Y_max,Z_max)
                # Size is (size_x, size_y, size_z) if output_bbox_size is given that way (Dx,Dy,Dz)
                # Let's assume output_bbox_size is (SZ, SY, SX) to match ZYX coords.
                # Or, consistently use (X,Y,Z) for bbox size. Let's use X,Y,Z for bbox size.
                bs_x, bs_y, bs_z = output_bbox_size_3d_or_2d[0], output_bbox_size_3d_or_2d[1], output_bbox_size_3d_or_2d[2]
                
                bbox = (
                    center_x - (bs_x -1)//2, center_y - (bs_y -1)//2, center_z - (bs_z -1)//2,
                    center_x + (bs_x -1)//2, center_y + (bs_y -1)//2, center_z + (bs_z -1)//2
                )
            else: # 2D
                # node_coords are (y,x) for 2D
                center_y, center_x = node_coords[i, 0], node_coords[i, 1]
                # Output center as (x,y)
                center_out = (float(center_x), float(center_y))

                # Bbox (X_min,Y_min,X_max,Y_max)
                bs_x, bs_y = output_bbox_size_3d_or_2d[0], output_bbox_size_3d_or_2d[1] # Assumed (size_x, size_y)
                bbox = (
                    center_x - (bs_x -1)//2, center_y - (bs_y -1)//2,
                    center_x + (bs_x -1)//2, center_y + (bs_y -1)//2
                )

            detections.append({
                "center": center_out,
                "bbox": tuple(map(float, bbox)),
                "score": float(score),
                "id": f"{getattr(graph_data_with_preds, 'id_prefix', 'det')}_{i}"
            })
    return detections

# %% [markdown]
# ### 7.1. Mode 1: Inference on Full 3D Tomograms

# %%
def inference_on_3d_tomogram(tomo_id, base_model, gnn_model, config, data_root, device):
    """
    Performs full pipeline inference on a single 3D tomogram.
    Returns: list of detection dictionaries.
    """
    base_model.eval().to(device)
    if gnn_model: gnn_model.eval().to(device)

    # --- Load and preprocess tomogram ---
    tomo_path = Path(data_root) / tomo_id
    if not tomo_path.is_dir():
        print(f"Tomogram directory {tomo_path} not found for inference.")
        return []

    img_arr, meta = JPGSequenceReader().read(str(tomo_path)) # H,W,D
    img_arr = img_arr.transpose(2,0,1) # D,H,W
    
    data_item = {"image": img_arr, "id": tomo_id}
    
    # Use validation transforms but without label key
    inference_transforms_3d = Compose([
        EnsureChannelFirstd(keys=["image"], channel_dim="no_channel"),
        ScaleIntensityRanged(keys=["image"], a_min=0.0, a_max=255.0, b_min=0.0, b_max=1.0, clip=True),
        ToTensord(keys=["image"]),
    ])
    processed_tomo = inference_transforms_3d(data_item)
    input_tensor = processed_tomo["image"].unsqueeze(0).to(device) # Add batch dim: (1,C,D,H,W)

    # --- Base Model Inference ---
    with torch.no_grad():
        base_model_output = sliding_window_inference(
            input_tensor,
            roi_size=config["inference_roi_size_base"],
            sw_batch_size=config["inference_sw_batch_size_base"],
            predictor=base_model,
            overlap=config["inference_overlap_base"],
            mode="gaussian",
            progress=True
        )
    base_prob_map = torch.sigmoid(base_model_output.squeeze(0).squeeze(0)) # (D,H,W) on device

    if not gnn_model or not PYG_AVAILABLE: # If no GNN, use base model output directly
        print("Performing inference using only base model (GNN not available or specified).")
        candidate_coords_zyx, candidate_probs = get_candidate_nodes_from_prob_map(
            base_prob_map,
            config["gnn_node_candidate_threshold"], # Use GNN's candidate threshold as a proxy
            config["gnn_node_NMS_footprint"]
        )
        # Convert these directly to detections
        detections = []
        for i, (z,y,x) in enumerate(candidate_coords_zyx):
            center_out = (float(x), float(y), float(z))
            bs_x, bs_y, bs_z = config["output_bbox_size_3d"]
            bbox = (
                x - (bs_x-1)//2, y - (bs_y-1)//2, z - (bs_z-1)//2,
                x + (bs_x-1)//2, y + (bs_y-1)//2, z + (bs_z-1)//2
            )
            detections.append({
                "center": center_out, "bbox": tuple(map(float, bbox)), 
                "score": float(candidate_probs[i]), "id": f"{tomo_id}_base_det_{i}"})
        return detections


    # --- Graph Construction for GNN ---
    candidate_coords_zyx, _ = get_candidate_nodes_from_prob_map( # Don't need probs here, GNN re-evals
        base_prob_map,
        config["gnn_node_candidate_threshold"],
        config["gnn_node_NMS_footprint"]
    )
    if candidate_coords_zyx.shape[0] == 0:
        print(f"No candidate nodes from base model for GNN input on tomo {tomo_id}.")
        return []

    node_features = extract_node_features(candidate_coords_zyx, [base_prob_map]) # Pass list of maps
    edge_index = create_graph_edges(candidate_coords_zyx, config["gnn_edge_max_distance"])
    
    graph_for_gnn = PyGData(
        x=node_features.to(device),
        edge_index=edge_index.to(device),
        pos=torch.from_numpy(candidate_coords_zyx).float().to(device)
    )
    graph_for_gnn.id_prefix = f"{tomo_id}_gnn_det"


    # --- GNN Meta-Model Inference ---
    with torch.no_grad():
        gnn_logits = gnn_model(graph_for_gnn.x, graph_for_gnn.edge_index)
        gnn_probs = torch.sigmoid(gnn_logits).squeeze(-1) # (Num_nodes,)
    
    graph_for_gnn.pred_probs = gnn_probs

    # --- Post-process GNN outputs ---
    # Threshold for GNN output can be different from candidate generation threshold
    gnn_final_score_threshold = config.get("gnn_final_score_threshold", 0.5) # Default 0.5 if not in config
    
    final_detections = gnn_output_to_detections(
        graph_for_gnn,
        gnn_final_score_threshold,
        config["output_bbox_size_3d"]
    )
    
    return final_detections

# %% [markdown]
# ### 7.2. Mode 2: Inference on Individual 2D Slices (Adaptation)
# This is the critical adaptation part.
#
# **Conceptual Explanation for 2D Slice Adaptation:**
#
# 1.  **Base Model Adaptation:**
#     *   A single 2D input slice `(H, W)` is received.
#     *   To feed it to the 3D CNN base model, we create a "pseudo-3D" volume. A simple method is to replicate the slice `D_pseudo` times along the Z-axis, resulting in a `(D_pseudo, H, W)` volume. `D_pseudo` (e.g., `CONFIG["pseudo_3d_depth_for_2d_slice"]`) should be chosen carefully:
#         *   It should be large enough to contain meaningful features for the 3D CNN, ideally related to its receptive field or patch depth during training.
#         *   If `D_pseudo` is smaller than the Z-dimension of `CONFIG["inference_roi_size_base"]`, the sliding window Z-dimension needs to be adjusted or padding applied. For simplicity, we might make `D_pseudo` equal to `inference_roi_size_base[2]`.
#     *   The base 3D CNN (using `sliding_window_inference` if `(H, W)` is larger than the XY dimensions of `inference_roi_size_base`) processes this pseudo-3D volume.
#     *   The output is a pseudo-3D probability map, e.g., `(1, 1, D_pseudo, H, W)`.
#     *   We extract the 2D probability map corresponding to the original input slice, typically the central slice of the `D_pseudo` dimension: `prob_map_2d = pseudo_3d_prob_map[:, :, D_pseudo // 2, :, :]`.
#
# 2.  **GNN Meta-Model Adaptation:**
#     *   **Candidate Generation (2D):** From the `prob_map_2d`, extract 2D candidate motor locations (nodes) using 2D NMS. Node coordinates will be `(y, x)`.
#     *   **Node Feature Extraction (2D):** For each 2D candidate node, features are extracted from `prob_map_2d`. These features should be analogous to those used in 3D GNN training (e.g., probability at the candidate center from the base model's (adapted) 2D output).
#     *   **Graph Construction (2D):** A 2D graph is built. Edges connect nodes based on 2D spatial proximity on the slice.
#     *   **GNN Inference:** The GNN model (which was trained on 3D graph data) processes this 2D graph.
#         *   *Challenge & Assumption:* The GNN's architecture (especially if using GAT) should be somewhat robust to changes in graph dimensionality (implicitly, as node features are similar and graph structure is just adjacency). The number of node features must match what the GNN expects.
#     *   **Output:** The GNN yields 2D detections (center `(x,y)` on slice, 2D bounding box, confidence score).
#
# **Challenges with 2D Adaptation:**
# *   **Feature Consistency:** Ensuring node features derived from adapted 2D base model outputs are semantically similar enough to those from 3D outputs for the GNN to generalize.
# *   **Graph Structure:** The GNN sees graphs with potentially different typical node degrees or densities. GATs are generally better at this than GCNs.
# *   **Pseudo-3D Volume:** The choice of `D_pseudo` and replication strategy can impact base model performance. Replicating might create artificial Z-continuity. Padding with zeros or mean values are alternatives.
# *   **Information Loss:** The GNN was trained on 3D contextual information. Applying it to 2D graphs means it loses Z-axis relational information between nodes. Performance might be lower than full 3D inference.

# %%
def inference_on_2d_slice(slice_img_hw, slice_id, base_model, gnn_model, config, device):
    """
    Performs adapted full pipeline inference on a single 2D slice.
    Args:
        slice_img_hw (np.ndarray): Single 2D slice (H, W), grayscale.
        slice_id (str): Identifier for the slice.
    Returns: list of 2D detection dictionaries.
    """
    base_model.eval().to(device)
    if gnn_model: gnn_model.eval().to(device)

    # --- 1. Base Model Adaptation: Pseudo-3D Volume Creation ---
    h, w = slice_img_hw.shape
    d_pseudo = config["pseudo_3d_depth_for_2d_slice"] # e.g., 32 or 64, or even config["inference_roi_size_base"][2]

    # Replicate slice to form D_pseudo x H x W volume
    pseudo_3d_volume_dhw = np.stack([slice_img_hw] * d_pseudo, axis=0) # (D_pseudo, H, W)
    
    data_item = {"image": pseudo_3d_volume_dhw, "id": slice_id}
    
    inference_transforms_pseudo3d = Compose([ # Similar to 3D inference transforms
        EnsureChannelFirstd(keys=["image"], channel_dim="no_channel"), # -> (1, D_pseudo, H, W)
        ScaleIntensityRanged(keys=["image"], a_min=0.0, a_max=255.0, b_min=0.0, b_max=1.0, clip=True),
        ToTensord(keys=["image"]),
    ])
    processed_pseudo_3d = inference_transforms_pseudo3d(data_item)
    input_tensor_pseudo_3d = processed_pseudo_3d["image"].unsqueeze(0).to(device) # (1, 1, D_pseudo, H, W)

    # Adapt ROI size for sliding window if D_pseudo is smaller than Z-dim of inference_roi_size_base
    roi_size_for_2d_adapted = list(config["inference_roi_size_base"])
    roi_size_for_2d_adapted[2] = min(roi_size_for_2d_adapted[2], d_pseudo) # Z-dim of ROI cannot exceed D_pseudo
    # This ensures sliding window operates correctly within the limited depth.
    # Alternatively, pad D_pseudo to match inference_roi_size_base[2] if it's larger.
    # If d_pseudo < roi_size_for_2d_adapted[2], then pad input_tensor_pseudo_3d in Z dim.
    # Let's assume d_pseudo >= minimal_z_depth for CNN or roi_size_for_2d_adapted[2] is handled.
    # A simple approach: if d_pseudo < roi_size_for_2d_adapted[2], then the roi_z becomes d_pseudo (full depth).
    
    with torch.no_grad():
        base_model_output_pseudo_3d = sliding_window_inference(
            input_tensor_pseudo_3d,
            roi_size=tuple(roi_size_for_2d_adapted),
            sw_batch_size=config["inference_sw_batch_size_base"],
            predictor=base_model,
            overlap=config["inference_overlap_base"],
            mode="gaussian",
            progress=False # Less verbose for multiple slice calls
        )
    # base_model_output_pseudo_3d is (1, 1, D_pseudo, H, W)
    base_prob_map_pseudo_3d = torch.sigmoid(base_model_output_pseudo_3d)
    
    # Extract central 2D probability map
    central_slice_idx = d_pseudo // 2
    # Squeeze out batch and channel, then take central Z slice: (H,W)
    prob_map_2d = base_prob_map_pseudo_3d[0, 0, central_slice_idx, :, :].to(device) # Keep on device


    if not gnn_model or not PYG_AVAILABLE: # If no GNN, use base model's 2D adapted output
        print(f"Performing 2D adapted inference using only base model for slice {slice_id}.")
        # Use 2D NMS for candidate nodes from prob_map_2d
        # Need a 2D version of get_candidate_nodes_from_prob_map (or adapt existing)
        # For simplicity, use existing by unsqueezing prob_map_2d to (1,H,W) and using NMS footprint (1, FH, FW)
        candidate_coords_yx, candidate_probs_2d = get_candidate_nodes_from_prob_map(
            prob_map_2d.unsqueeze(0), # Make it (1,H,W)
            config["gnn_node_candidate_threshold"],
            (1, config["gnn_node_NMS_footprint"][1], config["gnn_node_NMS_footprint"][2]), # 2D NMS
            max_candidates=200 # Limit per slice
        ) # Returns (N,3) with Z-coord as 0. We need (N,2) for YX.
        
        candidate_coords_yx = candidate_coords_yx[:, 1:] # Keep only Y, X

        detections_2d = []
        for i, (y,x) in enumerate(candidate_coords_yx):
            center_out = (float(x), float(y)) # X, Y order for output
            bs_x, bs_y = config["output_bbox_size_2d"]
            bbox = (
                x - (bs_x-1)//2, y - (bs_y-1)//2,
                x + (bs_x-1)//2, y + (bs_y-1)//2
            )
            detections_2d.append({
                "center": center_out, "bbox": tuple(map(float, bbox)), 
                "score": float(candidate_probs_2d[i]), "id": f"{slice_id}_base_det_{i}"})
        return detections_2d


    # --- 2. GNN Meta-Model Adaptation ---
    # Candidate Generation (2D)
    # Adapt NMS footprint for 2D: (1, FootprintY, FootprintX)
    nms_footprint_2d = (1, config["gnn_node_NMS_footprint"][1], config["gnn_node_NMS_footprint"][2])
    # get_candidate_nodes returns ZYX, so Z will be 0. We take YX.
    candidate_coords_zyx_on_slice, _ = get_candidate_nodes_from_prob_map(
        prob_map_2d.unsqueeze(0), # Add dummy Z dim for compatibility: (1,H,W)
        config["gnn_node_candidate_threshold"],
        nms_footprint_2d
    )
    if candidate_coords_zyx_on_slice.shape[0] == 0:
        # print(f"No candidate nodes for GNN (2D adapted) on slice {slice_id}.") # Can be too verbose
        return []
    
    candidate_coords_yx = candidate_coords_zyx_on_slice[:, 1:] # Get (Y,X) coordinates

    # Node Feature Extraction (2D) - use the 2D prob map
    # extract_node_features expects (D,H,W) or (C,D,H,W) map and (Z,Y,X) coords
    # We pass map (1,H,W) and coords (0,Y,X)
    node_features_2d = extract_node_features(
        np.insert(candidate_coords_yx, 0, 0, axis=1), # Add Z=0: (N,3) with (0,Y,X)
        [prob_map_2d.unsqueeze(0)] # Pass list of maps [(1,H,W)]
    )

    # Graph Construction (2D) - use 2D proximity
    # create_graph_edges expects (Z,Y,X). Pass (0,Y,X)
    edge_index_2d = create_graph_edges(
        np.insert(candidate_coords_yx, 0, 0, axis=1), # (N,3) with (0,Y,X)
        config["gnn_edge_max_distance"] # Use same distance, effectively becomes 2D dist
    )

    graph_for_gnn_2d = PyGData(
        x=node_features_2d.to(device),
        edge_index=edge_index_2d.to(device),
        pos=torch.from_numpy(candidate_coords_yx).float().to(device) # Store (Y,X) positions
    )
    graph_for_gnn_2d.id_prefix = f"{slice_id}_gnn_det"

    # GNN Inference
    with torch.no_grad():
        gnn_logits_2d = gnn_model(graph_for_gnn_2d.x, graph_for_gnn_2d.edge_index)
        gnn_probs_2d = torch.sigmoid(gnn_logits_2d).squeeze(-1)
    
    graph_for_gnn_2d.pred_probs = gnn_probs_2d

    # Post-process GNN outputs for 2D
    gnn_final_score_threshold = config.get("gnn_final_score_threshold", 0.5)
    final_detections_2d = gnn_output_to_detections(
        graph_for_gnn_2d,
        gnn_final_score_threshold,
        config["output_bbox_size_2d"]
    )
    
    return final_detections_2d


# %% [markdown]
# ## Part 8: Hyperparameter Optimization with Optuna
# This is a complex HPO setup because the GNN training depends on the base model.
# The `objective` function for Optuna will:
# 1.  Suggest hyperparameters for base model and GNN.
# 2.  Run the full training pipeline (base model training, GNN data gen, GNN training).
# 3.  Evaluate the *entire trained pipeline* on a final validation set (e.g., VAL_TOMO_IDS or a dedicated HPO validation set).
# 4.  Return the F-beta score (beta=2.0).

# %%
# Global variable to store best trial results if needed, or use Optuna's study storage
BEST_OPTUNA_RESULTS = {
    "best_f2_score": -1.0,
    "best_trial_number": -1,
    "best_params": None
}

def objective_optuna(trial: optuna.trial.Trial):
    # --- Suggest Hyperparameters ---
    # Create a trial-specific config by overriding parts of global CONFIG
    trial_config = CONFIG.copy() # Start with global defaults

    # Base Model Hyperparameters
    trial_config["base_model_lr"] = trial.suggest_float("base_model_lr", 1e-5, 1e-3, log=True)
    # trial_config["patch_size_base"] = ... # Could tune this if data loading adapts, but tricky. Keep fixed for now.
    # trial_config["base_model_target_radius"] = trial.suggest_int("base_model_target_radius", 2, 5) # If this affects GT generation
    # For DynUNet, can tune filters, depth, etc. This adds significant complexity.
    # Example: num_unet_levels = trial.suggest_int("num_unet_levels", 3, 5)
    # dynunet_filters_start = trial.suggest_categorical("dynunet_filters_start", [16, 32])
    # trial_config["dynunet_filters"] = [dynunet_filters_start * (2**i) for i in range(num_unet_levels)]
    # trial_config["dynunet_strides"] = ... # adapt based on num_unet_levels
    # trial_config["unet_kernel_sizes"] = ... # adapt based on num_unet_levels

    # GNN Hyperparameters
    trial_config["gnn_node_candidate_threshold"] = trial.suggest_float("gnn_node_candidate_threshold", 0.1, 0.7)
    trial_config["gnn_edge_max_distance"] = trial.suggest_int("gnn_edge_max_distance", 10, 40)
    trial_config["gnn_lr"] = trial.suggest_float("gnn_gnn_lr", 1e-4, 1e-2, log=True)
    trial_config["gnn_hidden_channels"] = trial.suggest_categorical("gnn_hidden_channels", [32, 64, 128])
    trial_config["gnn_num_layers"] = trial.suggest_int("gnn_num_layers", 2, 5)
    trial_config["gnn_gat_heads"] = trial.suggest_categorical("gnn_gat_heads", [2, 4, 8])
    # trial_config["gnn_final_score_threshold"] = trial.suggest_float("gnn_final_score_threshold", 0.3, 0.9) # Output threshold

    # For demo, keep epochs low. In real HPO, these should also be reasonably high.
    # trial_config["base_model_train_epochs"] = trial.suggest_int("base_model_train_epochs", 5, 20)
    # trial_config["gnn_train_epochs"] = trial.suggest_int("gnn_train_epochs", 5, 20)
    
    print(f"\nOptuna Trial {trial.number}: Starting with params {trial.params}")

    # --- Run Full Training Pipeline ---
    # This will use the trial_config which has the suggested HPs.
    # The `train_base_model` and `train_gnn_model` functions internally handle Optuna trial reporting for pruning.
    try:
        trained_base_model, trained_gnn_model, pipeline_results_history = run_full_training_pipeline(trial_config, trial=trial)
    except optuna.exceptions.TrialPruned:
        print(f"Trial {trial.number} pruned.")
        torch.cuda.empty_cache()
        raise # Re-raise to Optuna
    except Exception as e:
        print(f"Trial {trial.number} failed with error: {e}")
        # traceback.print_exc() # For detailed error
        torch.cuda.empty_cache()
        # Report a very bad score or let Optuna handle failure
        return -1.0 # Or some other indicator of failure that Optuna minimizes/maximizes away from

    if trained_base_model is None or (CONFIG["gnn_model_name"] and trained_gnn_model is None and PYG_AVAILABLE):
        print(f"Trial {trial.number}: Model training failed to produce models. Reporting low score.")
        torch.cuda.empty_cache()
        return -1.0 # Or appropriate low score for maximization problem

    # --- Evaluate Entire Pipeline on a Hold-out Validation Set ---
    # This validation set should NOT have been used for GNN training data generation if possible.
    # Or, use the same VAL_TOMO_IDS, acknowledging this is a validation of GNN's performance on data derived from these.
    # For a robust HPO, a dedicated HPO_VAL_TOMO_IDS set is best.
    # Here, we'll use VAL_TOMO_IDS for simplicity for final evaluation.
    
    if not VAL_TOMO_IDS:
        print(f"Trial {trial.number}: No validation tomograms for final Optuna evaluation. Returning GNN's internal best val metric if available.")
        torch.cuda.empty_cache()
        return pipeline_results_history.get("final_gnn_val_metric", 0.0) if pipeline_results_history else 0.0

    print(f"Trial {trial.number}: Evaluating pipeline on validation set: {VAL_TOMO_IDS}")
    all_tomogram_f2_scores = []
    total_tp, total_fp, total_fn = 0, 0, 0

    for tomo_id_val in VAL_TOMO_IDS:
        # Perform 3D inference
        detections = inference_on_3d_tomogram(
            tomo_id_val, trained_base_model, trained_gnn_model,
            trial_config, trial_config["dataset_root"], trial_config["device"]
        )
        
        # Match detections to GT for this tomogram
        # Define matching threshold - could also be part of HPO if desired
        matching_dist_3d = trial_config["base_model_target_radius"] * 3 # Heuristic
        
        eval_results_tomo = evaluate_pipeline_on_tomogram(
            tomo_id_val, detections, trial_config["gt_csv_path"], matching_dist_3d, beta=2.0
        )
        all_tomogram_f2_scores.append(eval_results_tomo[CONFIG['optuna_objective_metric']])
        total_tp += eval_results_tomo['tp']
        total_fp += eval_results_tomo['fp']
        total_fn += eval_results_tomo['fn']
        print(f"Tomo {tomo_id_val} - F2: {eval_results_tomo[CONFIG['optuna_objective_metric']]:.4f}, P: {eval_results_tomo['precision']:.4f}, R: {eval_results_tomo['recall']:.4f}")

    # Aggregate metric: either average F2 over tomograms, or calculate F2 from total TP/FP/FN
    if all_tomogram_f2_scores:
        # Micro-averaged F2 from total TP, FP, FN
        final_f2_score, _, _ = calculate_fbeta_precision_recall(
             np.concatenate([np.ones(total_tp), np.ones(total_fp), np.zeros(total_fn)]),
             np.concatenate([np.ones(total_tp), np.zeros(total_fp), np.ones(total_fn)]),
             beta=2.0
        )
        # final_f2_score = np.mean(all_tomogram_f2_scores) # Macro-average
    else:
        final_f2_score = 0.0 # Default if no validation occurred

    print(f"Trial {trial.number}: Overall F2-score on validation set: {final_f2_score:.4f}")

    # Store best trial info (optional, Optuna study object also stores this)
    if final_f2_score > BEST_OPTUNA_RESULTS["best_f2_score"]:
        BEST_OPTUNA_RESULTS["best_f2_score"] = final_f2_score
        BEST_OPTUNA_RESULTS["best_trial_number"] = trial.number
        BEST_OPTUNA_RESULTS["best_params"] = trial.params
        # Save best models from this trial (can get very disk intensive)
        # Consider saving just the trial_config or a reference.
    
    torch.cuda.empty_cache() # Clean up GPU memory after trial
    return final_f2_score


def run_optuna_study():
    if not TRAIN_TOMO_IDS or not VAL_TOMO_IDS:
        print("Optuna study cannot run: Missing training or validation tomogram IDs.")
        return None, None

    study_name = f"{CONFIG['project_name']}_study"
    # Ensure output_dir is a string for f-string formatting if it was converted to Path earlier
    output_dir_str = str(CONFIG["output_dir"]) 
    storage_path = f"sqlite:///{Path(output_dir_str) / study_name}.db" # Save study to a DB

    # Check if PyG is available; if not, GNN part is dummied, so HPO for GNN params is less meaningful
    if not PYG_AVAILABLE and CONFIG["gnn_model_name"]:
        print("WARNING: PyTorch Geometric not available. GNN-related HPO will be for a dummy GNN.")


    study = optuna.create_study(
        study_name=study_name,
        direction="maximize", # We want to maximize F2-score
        pruner=MedianPruner(n_startup_trials=2, n_warmup_steps=2, interval_steps=1), # Prune early if not promising
        storage=storage_path,
        load_if_exists=True # Resume study if it exists
    )

    # <<< BEGIN MODIFICATION >>>
    import json 
    import copy
    # Ensure torch and Path are accessible here, they should be from global imports
    # from pathlib import Path # Already imported globally
    # import torch # Already imported globally

    def make_json_serializable(obj):
        if isinstance(obj, dict):
            return {k: make_json_serializable(v) for k, v in obj.items()}
        elif isinstance(obj, list):
            return [make_json_serializable(i) for i in obj]
        elif isinstance(obj, tuple): # Tuples also need to be converted if they contain non-serializable items
            return tuple(make_json_serializable(i) for i in obj)
        elif isinstance(obj, torch.device):
            return str(obj)
        elif isinstance(obj, Path):
            return str(obj)
        # Add other non-serializable types here if necessary
        # Common types like int, float, str, bool, None are already serializable
        return obj

    serializable_config = make_json_serializable(copy.deepcopy(CONFIG))
    # <<< END MODIFICATION >>>

    # Set user attributes for the study to store the main config (without trial-specific changes)
    study.set_user_attr("global_config", serializable_config) # Use the modified, serializable copy

    try:
        study.optimize(objective_optuna, n_trials=CONFIG["optuna_n_trials"], timeout=None) # Add timeout if desired
    except KeyboardInterrupt:
        print("Optuna study interrupted by user.")
    except Exception as e: # Catch other potential errors during optimize
        print(f"Exception during study.optimize: {e}")
        import traceback
        traceback.print_exc()

    print("\nOptuna Study Summary:")
    print(f"  Number of finished trials: {len(study.trials)}")
    
    best_trial = None
    try:
        best_trial = study.best_trial
        print(f"  Best trial number: {best_trial.number}")
        print(f"  Best F2-score: {best_trial.value:.4f}")
        print("  Best parameters:")
        for key, value in best_trial.params.items():
            print(f"    {key}: {value}")
        
        # Save best params to a JSON file
        best_params_path = Path(str(CONFIG["output_dir"])) / "best_optuna_params.json" # Ensure CONFIG["output_dir"] is string
        with open(best_params_path, "w") as f:
            json.dump(best_trial.params, f, indent=4)
        print(f"Best parameters saved to {best_params_path}")

    except ValueError: # If no trials completed or all failed
        print("  No successful trials completed in the study.")
    except Exception as e:
        print(f"  Error retrieving best trial: {e}")

    return study, best_trial

# %% [markdown]
# ## Part 9: Usage Instructions
#
# 1.  **Setup:**
#     *   Install all required packages as listed in "Part 0: Setup and Imports". Ensure your PyTorch installation matches your CUDA version. PyTorch Geometric installation might require specific commands based on your PyTorch/CUDA versions.
#     *   Make sure a CUDA-enabled GPU is available and selected by PyTorch (`CONFIG["device"]`).
#
# 2.  **Data Structure:**
#     *   Set `CONFIG["dataset_root"]` to the main directory containing your tomogram data.
#     *   Inside `CONFIG["dataset_root"]`, each tomogram should be a subfolder (e.g., `tomo_001acc`, `tomo_002xyz`).
#     *   Each tomogram subfolder must contain its 2D slices as sequentially named JPG files (e.g., `slice_0000.jpg`, `slice_0001.jpg`, ...).
#     *   Set `CONFIG["gt_csv_path"]` to the path of your ground truth CSV file.
#     *   The CSV file must contain at least the columns: `tomo_id` (matching folder names), `center_x`, `center_y`, `center_z` (voxel coordinates of motor centers).
#     *   The simulation code (`simulate_dataset`) can be used to create a dummy dataset if you don't have one initially. Comment it out for real data.
#     *   Define `TRAIN_TOMO_IDS`, `VAL_TOMO_IDS`, `TEST_TOMO_IDS` lists with the respective tomogram IDs.
#
# 3.  **Configuration:**
#     *   Review and adjust parameters in `CONFIG` (Part 1), especially:
#         *   `patch_size_base`: Critical for base model training. Ensure it fits your GPU VRAM.
#         *   `*_train_epochs`: Increase for real training (e.g., 100+ for base, 50+ for GNN).
#         *   `optuna_n_trials`: Increase significantly for meaningful HPO (e.g., 50-100+).
#         *   `output_dir`: Where models and results will be saved.
#
# 4.  **Running Hyperparameter Optimization (HPO - Recommended First Step):**
#     ```python
#     # Make sure all necessary data loaders (train_loader_base, val_loader_base) are created
#     # by running the cells in Part 2 after configuring paths and IDs.
#     # Then run:
#     # optuna_study, optuna_best_trial = run_optuna_study()
#     # if optuna_best_trial:
#     #     print("\nTo train with best HPO params, update CONFIG with optuna_best_trial.params and re-run training.")
#     #     CONFIG.update(optuna_best_trial.params) # Update global CONFIG with best params
#     #     # Or create a new config from best_trial.params
#     #     # best_config = CONFIG.copy()
#     #     # best_config.update(optuna_best_trial.params)
#     ```
#     *   This will run the Optuna study. Results (including best parameters) will be saved.
#     *   After HPO, update `CONFIG` with the `best_trial.params` found by Optuna for final model training.
#
# 5.  **Running Training (with fixed or HPO-tuned config):**
#     ```python
#     # Ensure CONFIG has desired parameters (either default or from HPO)
#     # Ensure data loaders (train_loader_base, val_loader_base) are created.
#     # Then run:
#     # trained_base_model, trained_gnn_model, pipeline_history = run_full_training_pipeline(CONFIG)
#     # if trained_base_model and trained_gnn_model:
#     #     print("Training complete. Models are ready for inference.")
#     #     # Save final models explicitly if not done by training loops or if you want a specific name
#     #     final_base_model_path = Path(CONFIG["output_dir"]) / "final_base_model.pth"
#     #     final_gnn_model_path = Path(CONFIG["output_dir"]) / "final_gnn_model.pth"
#     #     torch.save(trained_base_model.state_dict(), final_base_model_path)
#     #     if PYG_AVAILABLE and trained_gnn_model:
#     #         torch.save(trained_gnn_model.state_dict(), final_gnn_model_path)
#     #     print(f"Final models saved to {final_base_model_path} and {final_gnn_model_path}")
#     # elif trained_base_model:
#     #     print("Base model training complete. GNN training might have been skipped or failed.")
#     #     final_base_model_path = Path(CONFIG["output_dir"]) / "final_base_model.pth"
#     #     torch.save(trained_base_model.state_dict(), final_base_model_path)
#     #     print(f"Final base model saved to {final_base_model_path}")
#     # else:
#     #     print("Training failed.")
#     ```
#
# 6.  **Running Inference:**
#     *   Load your trained models:
#         ```python
#         # # Example loading (adjust paths and model instantiation as needed)
#         # inference_config = CONFIG.copy() # Use the config the models were trained with
#         #
#         # # Load Base Model
#         # loaded_base_model = get_base_model(inference_config) # Re-create model structure
#         # # Find the best saved base model checkpoint, e.g. from HPO or final training
#         # # base_model_checkpoint_path = Path(inference_config["output_dir"]) / "DynUNet_best_epochX_diceY.pth" # Or final_base_model.pth
#         # # For example, if HPO was run and best_params.json exists:
#         # best_params_json_path = Path(inference_config["output_dir"]) / "best_optuna_params.json"
#         # if best_params_json_path.exists():
#         #    with open(best_params_json_path, 'r') as f:
#         #        best_hpo_params = json.load(f)
#         #        inference_config.update(best_hpo_params) # Ensure model structure matches
#         #    # Need to find the actual .pth file associated with the best HPO trial,
#         #    # This example assumes you saved a "final_base_model.pth" after HPO-informed training.
#         #    base_model_checkpoint_path = Path(inference_config["output_dir"]) / "final_base_model.pth"
#         # else: # Fallback if no HPO params found - try to find a recently saved model
#         #    # This part needs robust logic to find the correct model file.
#         #    # For demo, assuming a fixed name if run_full_training_pipeline was used:
#         #    base_model_checkpoint_path = Path(inference_config["output_dir"]) / "final_base_model.pth"
#         #
#         # if base_model_checkpoint_path.exists():
#         #    loaded_base_model.load_state_dict(torch.load(base_model_checkpoint_path, map_location=inference_config["device"]))
#         #    print(f"Loaded base model from {base_model_checkpoint_path}")
#         # else:
#         #    print(f"ERROR: Base model checkpoint {base_model_checkpoint_path} not found for inference.")
#         #    # loaded_base_model = None # Or handle error
#         #
#         # # Load GNN Model (if used)
#         # loaded_gnn_model = None
#         # if PYG_AVAILABLE and inference_config["gnn_model_name"]:
#         #    # Determine gnn_in_channels. This is tricky without knowing the base model outputs during that training.
#         #    # Assuming it's based on num_base_models for simplicity.
#         #    gnn_in_channels_inf = inference_config["num_base_models"] # Matching the training setup
#         #    loaded_gnn_model = GATNet(
#         #        in_channels=gnn_in_channels_inf,
#         #        hidden_channels=inference_config["gnn_hidden_channels"], # From loaded best_params or default
#         #        out_channels=1,
#         #        num_layers=inference_config["gnn_num_layers"],
#         #        heads=inference_config["gnn_gat_heads"]
#         #    )
#         #    # gnn_model_checkpoint_path = Path(inference_config["output_dir"]) / "GAT_best_f2_X.pth" # Or final_gnn_model.pth
#         #    gnn_model_checkpoint_path = Path(inference_config["output_dir"]) / "final_gnn_model.pth"
#         #    if gnn_model_checkpoint_path.exists():
#         #        loaded_gnn_model.load_state_dict(torch.load(gnn_model_checkpoint_path, map_location=inference_config["device"]))
#         #        print(f"Loaded GNN model from {gnn_model_checkpoint_path}")
#         #    else:
#         #        print(f"WARNING: GNN model checkpoint {gnn_model_checkpoint_path} not found. GNN will not be used or may error.")
#         #        # loaded_gnn_model = None
#         ```
#     *   **Mode 1: Inference on Full 3D Tomogram:**
#         ```python
#         # if loaded_base_model: # Check if base model loaded successfully
#         #    tomo_id_to_infer = TEST_TOMO_IDS[0] # Example
#         #    detections_3d = inference_on_3d_tomogram(
#         #        tomo_id_to_infer, loaded_base_model, loaded_gnn_model,
#         #        inference_config, inference_config["dataset_root"], inference_config["device"]
#         #    )
#         #    print(f"3D Detections for {tomo_id_to_infer}: {detections_3d}")
#         #
#         #    # Optional: Evaluate if ground truth is available for test tomogram
#         #    # eval_3d = evaluate_pipeline_on_tomogram(tomo_id_to_infer, detections_3d, inference_config["gt_csv_path"], matching_dist_thresh_3d=10)
#         #    # print(f"Evaluation for {tomo_id_to_infer}: {eval_3d}")
#         ```
#     *   **Mode 2: Inference on Individual 2D Slice:**
#         ```python
#         # if loaded_base_model: # Check if base model loaded successfully
#         #    # Load a sample 2D slice (e.g., first slice of first test tomogram)
#         #    tomo_id_for_slice = TEST_TOMO_IDS[0]
#         #    slice_path_pattern = str(Path(inference_config["dataset_root"]) / tomo_id_for_slice / "slice_*.jpg")
#         #    slice_files = sorted(glob.glob(slice_path_pattern))
#         #    if slice_files:
#         #        sample_slice_path = slice_files[len(slice_files)//2] # Middle slice
#         #        sample_2d_image_hw = cv2.imread(sample_slice_path, cv2.IMREAD_GRAYSCALE)
#         #        sample_slice_id = Path(sample_slice_path).stem
#         #
#         #        detections_2d = inference_on_2d_slice(
#         #            sample_2d_image_hw, sample_slice_id,
#         #            loaded_base_model, loaded_gnn_model,
#         #            inference_config, inference_config["device"]
#         #        )
#         #        print(f"2D Detections for slice {sample_slice_id}: {detections_2d}")
#         #    else:
#         #        print(f"No slices found for tomogram {tomo_id_for_slice} to test 2D inference.")
#         ```
#
# 7.  **Interpreting Output:**
#     *   Detections are lists of dictionaries, each containing `center`, `bbox`, and `score`.
#     *   Coordinates are (X,Y,Z) for 3D and (X,Y) for 2D. Bounding boxes are (X_min,Y_min,Z_min,X_max,Y_max,Z_max) or (X_min,Y_min,X_max,Y_max).
#
# **Important Notes for Execution:**
# *   **GPU Memory:** This pipeline, especially with DynUNet and GNNs on large graphs, is VRAM intensive. Monitor GPU usage. Reduce `patch_size_base`, `base_model_batch_size`, or GNN complexity if OOM errors occur.
# *   **Time:** Full training and HPO will take a very long time. The demo settings (`*_train_epochs`, `optuna_n_trials`) are minimal.
# *   **Model Saving/Loading:** The provided training loops save best models. Ensure the loading logic in inference correctly identifies and loads these models along with the configuration they were trained with (especially for model architecture HPs like GNN layers/hidden_dims). The example inference loading is basic and may need refinement for robustly finding the correct HPO-derived model.


# %% [markdown]
# ## Part 10: Discussion, Challenges, and Future Work
#
# ### Discussion
# This notebook outlines a comprehensive pipeline for 3D bacterial motor detection using a GNN-stacked ensemble. The core idea is to leverage 3D CNNs for powerful feature extraction from cryo-ET data and then use a GNN to intelligently combine and refine these predictions, considering spatial relationships between candidate detections. The inclusion of Optuna for HPO is critical for navigating the large parameter space of such a multi-stage model. The adaptation for 2D slice-based inference, while challenging, significantly increases the pipeline's versatility.
#
# ### Key Challenges Addressed (and inherent complexities)
# *   **Designing Effective GNN Node Features:** Node features are derived from base CNN probability maps. This is a common starting point. More sophisticated features could include embeddings from intermediate CNN layers, or geometric/intensity statistics of candidate regions, but this increases complexity.
# *   **Managing Training and HPO Complexity:** The two-stage training (base models then GNN) is standard for stacking. HPO over the entire pipeline is computationally expensive. Each Optuna trial involves re-training base models and then the GNN. Pruning strategies in Optuna are essential. Efficient data caching (e.g., `PersistentDataset` for base model training, pre-generating GNN graph data if base models are fixed) can save time.
# *   **GNN Adaptation for 2D Inference:** This is a novel aspect. The primary challenge is ensuring the GNN, trained on 3D graph structures and features, can generalize to graphs derived from pseudo-3D/2D information. Success hinges on:
#     *   The robustness of the GNN architecture (GATs are generally good with varying graph structures).
#     *   The semantic consistency of node features between 3D and adapted 2D inputs. If features are primarily "probability of motor," this should hold.
#     *   The quality of the pseudo-3D volume generation for the base model.
# *   **Robust Detection of Very Small Motors:** This is an inherent challenge in cryo-ET.
#     *   Base CNNs need appropriate receptive fields and capacity. FocalLoss or DiceCELoss can help with class imbalance.
#     *   High-resolution patches and careful augmentation are important.
#     *   The GNN's ability to filter false positives based on relational context is a key benefit here.
# *   **Computational Resources:** High-end GPUs with substantial VRAM (>=16-24GB) are necessary, especially for training with large 3D patches and GNNs on potentially large graphs.
#
# ### Limitations
# *   **Single Notebook Constraint:** A real-world project of this scale would typically be modularized into multiple Python scripts for better organization, testing, and maintainability.
# *   **Data Simulation:** The provided data simulation is basic. Real cryo-ET data has complex noise and structural characteristics.
# *   **HPO Scope:** The HPO in this notebook tunes a subset of possible parameters. A more exhaustive search could tune DynUNet architectural details, GNN edge construction rules, GNN node feature selection, etc.
# *   **GNN Graph Size:** For very large tomograms generating many candidate nodes, GNN training/inference could become a bottleneck. Graph sampling or hierarchical approaches might be needed.
# *   **Ground Truth Assumption:** Assumes point center GT. If GT provides bounding boxes or segmentation masks, the target generation for base models and evaluation could be more precise.
# *   **Matching Criteria:** The greedy matching for evaluation is standard but not always optimal.
#
# ### Future Improvements
# *   **Advanced Base Models:** Explore other 3D architectures or pre-trained models if applicable. Use multiple, diverse base models for the ensemble.
# *   **Richer GNN Node Features:** Incorporate learnable embeddings from base CNNs (e.g., features from bottleneck layers corresponding to candidate locations) or more handcrafted geometric/intensity features.
# *   **Hierarchical GNNs:** For multi-scale analysis if motor appearance varies significantly with scale or context.
# *   **End-to-End Trainability (Advanced):** While complex, exploring methods to jointly optimize (or fine-tune) base models and the GNN in an end-to-end fashion could potentially improve performance but significantly increases difficulty.
# *   **More Sophisticated 2D Adaptation:**
#     *   Experiment with different pseudo-3D volume creation strategies (e.g., padding with learned values, using adjacent real slices if available in a sequence).
#     *   Train a dedicated "adapter" module or fine-tune the GNN slightly on 2D-derived graphs if sufficient 2D labeled data or a good simulation thereof exists.
# *   **Uncertainty Quantification:** Estimate confidence scores that reflect true uncertainty, potentially using Bayesian deep learning techniques or ensemble disagreement.
# *   **Active Learning:** If labeling is a bottleneck, incorporate active learning to select the most informative tomograms/regions for annotation.
# *   **Interpretability:** Use GNN explainability methods (e.g., GNNExplainer) to understand which nodes/edges are most important for the GNN's decisions, providing insights into the detection process.
#

# %% [markdown]
# ---
# # Main Execution Block (Example Workflow)
# ---
# This block demonstrates a possible sequence of operations.
# You would typically run HPO first, then use the best parameters for final training and inference.
# For this demonstration, we'll run a very short version.
# **Comment out or select parts to run as needed.**

# %%
if __name__ == '__main__':
    # Ensure data loaders are initialized if needed for any step
    # This happens when their respective cells are run.

    # --- OPTION 1: Run Hyperparameter Optimization ---
    print("WORKFLOW: Starting Hyperparameter Optimization (Optuna)...")
    # Make sure train_loader_base and val_loader_base are initialized by running Part 2 cells
    if train_loader_base is None or val_loader_base is None:
         print("ERROR: DataLoaders not initialized. Cannot run Optuna. Please run dataset/dataloader cells first.")
    else:
        # Reduce epochs for HPO demo to make it faster
        original_base_epochs = CONFIG["base_model_train_epochs"]
        original_gnn_epochs = CONFIG["gnn_train_epochs"]
        CONFIG["base_model_train_epochs"] = 1 # Minimal for HPO demo
        CONFIG["gnn_train_epochs"] = 1       # Minimal for HPO demo
        CONFIG["optuna_n_trials"] = 2       # Very few trials for demo

        optuna_study, optuna_best_trial = run_optuna_study()

        # Restore original epochs if changed for HPO demo
        CONFIG["base_model_train_epochs"] = original_base_epochs
        CONFIG["gnn_train_epochs"] = original_gnn_epochs
        
        if optuna_best_trial:
            print("\nOptuna finished. Best parameters found:")
            print(json.dumps(optuna_best_trial.params, indent=2))
            print(f"Best F2-score from HPO: {optuna_best_trial.value}")
            # Update global CONFIG with these best parameters for subsequent training/inference
            CONFIG.update(optuna_best_trial.params)
            print("Global CONFIG updated with best HPO parameters.")
        else:
            print("\nOptuna study did not yield a best trial. Using default CONFIG for subsequent steps.")
    
    # --- OPTION 2: Run Full Training with Current CONFIG ---
    # (This could be the CONFIG updated by HPO, or default if HPO was skipped)
    print("\nWORKFLOW: Starting Full Training Pipeline with current CONFIG...")
    # Ensure data loaders are initialized
    if train_loader_base is None or val_loader_base is None:
         print("ERROR: DataLoaders not initialized. Cannot run full training. Please run dataset/dataloader cells first.")
    else:
        # Use slightly more epochs for this "final" training run than in HPO demo
        CONFIG["base_model_train_epochs"] = 2 # Still low for notebook demo
        CONFIG["gnn_train_epochs"] = 2       # Still low for notebook demo

        trained_base_model, trained_gnn_model, pipeline_history = run_full_training_pipeline(CONFIG)
        
        if trained_base_model: # At least base model should be trained
            print("Full training pipeline complete (or partially complete if GNN failed).")
            # Save final models
            final_base_model_path = Path(CONFIG["output_dir"]) / "final_trained_base_model.pth"
            torch.save(trained_base_model.state_dict(), final_base_model_path)
            print(f"Saved final base model to {final_base_model_path}")

            if trained_gnn_model and PYG_AVAILABLE and CONFIG["gnn_model_name"]:
                final_gnn_model_path = Path(CONFIG["output_dir"]) / "final_trained_gnn_model.pth"
                torch.save(trained_gnn_model.state_dict(), final_gnn_model_path)
                print(f"Saved final GNN model to {final_gnn_model_path}")
            elif CONFIG["gnn_model_name"] and not PYG_AVAILABLE:
                print("GNN model specified in config, but PyG not available. GNN part was dummied.")
            elif CONFIG["gnn_model_name"] and not trained_gnn_model:
                 print("GNN model specified and PyG available, but GNN model not trained/returned.")

            # --- OPTION 3: Run Inference (after training or loading models) ---
            print("\nWORKFLOW: Starting Inference...")
            if not TEST_TOMO_IDS:
                print("No test tomogram IDs defined. Skipping inference.")
            else:
                # Use the models just trained, or implement loading logic here if running separately
                # For simplicity, use models from the training step above.
                # If running inference in a new session, you'd load them:
                # loaded_base_model = get_base_model(CONFIG)
                # loaded_base_model.load_state_dict(torch.load(final_base_model_path, map_location=CONFIG["device"]))
                # if PYG_AVAILABLE and CONFIG["gnn_model_name"]:
                #    gnn_in_channels_inf = ... # determine correctly
                #    loaded_gnn_model = GATNet(...)
                #    loaded_gnn_model.load_state_dict(torch.load(final_gnn_model_path, map_location=CONFIG["device"]))
                # else:
                #    loaded_gnn_model = None

                # Use `trained_base_model` and `trained_gnn_model` directly for this demo flow.
                
                # Mode 1: 3D Tomogram Inference
                tomo_id_to_infer_3d = TEST_TOMO_IDS[0]
                print(f"\nRunning 3D inference on tomogram: {tomo_id_to_infer_3d}")
                detections_3d = inference_on_3d_tomogram(
                    tomo_id_to_infer_3d, trained_base_model, trained_gnn_model,
                    CONFIG, CONFIG["dataset_root"], CONFIG["device"]
                )
                print(f"Found {len(detections_3d)} detections in 3D for {tomo_id_to_infer_3d}.")
                # for det_3d in detections_3d[:min(3, len(detections_3d))]: print(det_3d) # Print a few

                # Evaluate 3D detections if GT available for test set
                eval_3d = evaluate_pipeline_on_tomogram(
                    tomo_id_to_infer_3d, detections_3d, CONFIG["gt_csv_path"], 
                    matching_dist_thresh_3d=CONFIG["base_model_target_radius"] * 3, beta=2.0
                )
                print(f"Evaluation for 3D detections on {tomo_id_to_infer_3d}: {eval_3d}")


                # Mode 2: 2D Slice Inference
                tomo_id_for_slice_2d = TEST_TOMO_IDS[0]
                slice_path_pattern_2d = str(Path(CONFIG["dataset_root"]) / tomo_id_for_slice_2d / "slice_*.jpg")
                slice_files_2d = sorted(glob.glob(slice_path_pattern_2d))
                
                if slice_files_2d:
                    # Pick a slice for demo, e.g., middle slice
                    sample_slice_path_2d = slice_files_2d[len(slice_files_2d) // 2]
                    sample_2d_image_hw = cv2.imread(sample_slice_path_2d, cv2.IMREAD_GRAYSCALE)
                    sample_slice_id_2d = Path(sample_slice_path_2d).stem
                    print(f"\nRunning 2D (adapted) inference on slice: {sample_slice_id_2d} from {tomo_id_for_slice_2d}")

                    detections_2d = inference_on_2d_slice(
                        sample_2d_image_hw, sample_slice_id_2d,
                        trained_base_model, trained_gnn_model,
                        CONFIG, CONFIG["device"]
                    )
                    print(f"Found {len(detections_2d)} detections in 2D for slice {sample_slice_id_2d}.")
                    # for det_2d in detections_2d[:min(3, len(detections_2d))]: print(det_2d) # Print a few
                else:
                    print(f"No slices found for tomogram {tomo_id_for_slice_2d} to test 2D inference.")
        else:
            print("Base model training failed or was skipped. Cannot proceed to inference with trained models.")

    print("\nCryoEM-MotorMetaNet Workflow Demo Finished.")

# %%
# To run the main execution block when executing the notebook:
# You would typically uncomment the call to if __name__ == '__main__':
# and run the cell. For safety in a notebook that defines functions and then calls them,
# it's often better to call the main execution logic explicitly in a new cell.
# Example:
#
# if train_loader_base and val_loader_base: # Check if data setup was done
#    # Call optuna part
#    optuna_study, optuna_best_trial = run_optuna_study()
#    if optuna_best_trial:
#        CONFIG.update(optuna_best_trial.params)
#
#    # Call training part
#    trained_base_model, trained_gnn_model, pipeline_history = run_full_training_pipeline(CONFIG)
#
#    # Call inference part with trained_base_model, trained_gnn_model
# else:
#    print("Please run data setup cells (Part 2) before running main workflow.")

MONAI version: 1.4.0
Numpy version: 1.26.4
Pytorch version: 2.8.0.dev20250409+cu128
MONAI flags: HAS_EXT = False, USE_COMPILED = False, USE_META_DICT = False
MONAI rev id: 46a5272196a6c2590ca2589029eed8e4d56ff008
MONAI __file__: D:\Anaconda3\envs\my_env\Lib\site-packages\monai\__init__.py

Optional dependencies:
Pytorch Ignite version: 0.4.11
ITK version: 5.4.3
Nibabel version: 5.3.2
scikit-image version: 0.25.2
scipy version: 1.15.2
Pillow version: 11.0.0
Tensorboard version: 2.19.0
gdown version: 5.2.0
TorchVision version: 0.22.0.dev20250410+cu128
tqdm version: 4.67.1
lmdb version: 1.6.2
psutil version: 7.0.0
pandas version: 2.2.3
einops version: 0.8.1
transformers version: NOT INSTALLED or UNKNOWN VERSION.
mlflow version: 2.22.0
pynrrd version: 1.1.3
clearml version: 2.0.0rc0

For details about installing the optional dependencies, please visit:
    https://docs.monai.io/en/latest/installation.html#installing-the-recommended-dependencies

Using device: cuda
Simulated dataset appears

[I 2025-05-13 21:33:52,239] Using an existing study with name 'CryoEM_MotorMetaNet_study' instead of creating a new one.



Optuna Trial 4: Starting with params {'base_model_lr': 4.865150593642991e-05, 'gnn_node_candidate_threshold': 0.6454186578345281, 'gnn_edge_max_distance': 35, 'gnn_gnn_lr': 0.0011339926912395115, 'gnn_hidden_channels': 32, 'gnn_num_layers': 2, 'gnn_gat_heads': 4}
--- Stage 1: Training Base Model ---
Starting base model training for 1 epochs.


Epoch 1/1 Training:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 1 average training loss: 0.2642


Epoch 1 Validation:   0%|          | 0/1 [00:00<?, ?it/s]

Epoch 1 validation Dice: 0.0011
Saved new best model to output_cryoEM_MotorMetaNet\DynUNet_best_epoch1_dice0.0011.pth
Finished base model training. Best validation metric: 0.0011 at epoch 1
Loaded best model from output_cryoEM_MotorMetaNet\DynUNet_best_epoch1_dice0.0011.pth for GNN data generation.
--- Stage 2: Generating GNN Training Data ---
Generating GNN training data for 1 tomograms...


Processing Tomos for GNN Data:   0%|          | 0/1 [00:00<?, ?it/s]

Generated 1 graphs for GNN training.
Not enough GNN graphs for training split. GNN training skipped.
Trial 4: Model training failed to produce models. Reporting low score.


[I 2025-05-13 21:33:55,848] Trial 4 finished with value: -1.0 and parameters: {'base_model_lr': 4.865150593642991e-05, 'gnn_node_candidate_threshold': 0.6454186578345281, 'gnn_edge_max_distance': 35, 'gnn_gnn_lr': 0.0011339926912395115, 'gnn_hidden_channels': 32, 'gnn_num_layers': 2, 'gnn_gat_heads': 4}. Best is trial 0 with value: -1.0.



Optuna Trial 5: Starting with params {'base_model_lr': 2.9824524172137357e-05, 'gnn_node_candidate_threshold': 0.4623229900911201, 'gnn_edge_max_distance': 11, 'gnn_gnn_lr': 0.0013344400450179684, 'gnn_hidden_channels': 64, 'gnn_num_layers': 2, 'gnn_gat_heads': 2}
--- Stage 1: Training Base Model ---
Starting base model training for 1 epochs.


Epoch 1/1 Training:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 1 average training loss: 0.2650


Epoch 1 Validation:   0%|          | 0/1 [00:00<?, ?it/s]

Epoch 1 validation Dice: 0.0011
Saved new best model to output_cryoEM_MotorMetaNet\DynUNet_best_epoch1_dice0.0011.pth
Finished base model training. Best validation metric: 0.0011 at epoch 1
Loaded best model from output_cryoEM_MotorMetaNet\DynUNet_best_epoch1_dice0.0011.pth for GNN data generation.
--- Stage 2: Generating GNN Training Data ---
Generating GNN training data for 1 tomograms...


Processing Tomos for GNN Data:   0%|          | 0/1 [00:00<?, ?it/s]

Generated 1 graphs for GNN training.
Not enough GNN graphs for training split. GNN training skipped.
Trial 5: Model training failed to produce models. Reporting low score.


[I 2025-05-13 21:33:59,362] Trial 5 finished with value: -1.0 and parameters: {'base_model_lr': 2.9824524172137357e-05, 'gnn_node_candidate_threshold': 0.4623229900911201, 'gnn_edge_max_distance': 11, 'gnn_gnn_lr': 0.0013344400450179684, 'gnn_hidden_channels': 64, 'gnn_num_layers': 2, 'gnn_gat_heads': 2}. Best is trial 0 with value: -1.0.



Optuna Study Summary:
  Number of finished trials: 6
  Best trial number: 0
  Best F2-score: -1.0000
  Best parameters:
    base_model_lr: 1.0740586090703404e-05
    gnn_node_candidate_threshold: 0.5136299758870757
    gnn_edge_max_distance: 12
    gnn_gnn_lr: 0.003515395317060771
    gnn_hidden_channels: 32
    gnn_num_layers: 2
    gnn_gat_heads: 2
Best parameters saved to output_cryoEM_MotorMetaNet\best_optuna_params.json

Optuna finished. Best parameters found:
{
  "base_model_lr": 1.0740586090703404e-05,
  "gnn_node_candidate_threshold": 0.5136299758870757,
  "gnn_edge_max_distance": 12,
  "gnn_gnn_lr": 0.003515395317060771,
  "gnn_hidden_channels": 32,
  "gnn_num_layers": 2,
  "gnn_gat_heads": 2
}
Best F2-score from HPO: -1.0
Global CONFIG updated with best HPO parameters.

WORKFLOW: Starting Full Training Pipeline with current CONFIG...
--- Stage 1: Training Base Model ---
Starting base model training for 2 epochs.


Epoch 1/2 Training:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 1 average training loss: 0.2659


Epoch 1 Validation:   0%|          | 0/1 [00:00<?, ?it/s]

Epoch 1 validation Dice: 0.0011
Saved new best model to output_cryoEM_MotorMetaNet\DynUNet_best_epoch1_dice0.0011.pth


Epoch 2/2 Training:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 2 average training loss: 0.2645


Epoch 2 Validation:   0%|          | 0/1 [00:00<?, ?it/s]

Epoch 2 validation Dice: 0.0011
Saved new best model to output_cryoEM_MotorMetaNet\DynUNet_best_epoch2_dice0.0011.pth
Finished base model training. Best validation metric: 0.0011 at epoch 2
Loaded best model from output_cryoEM_MotorMetaNet\DynUNet_best_epoch2_dice0.0011.pth for GNN data generation.
--- Stage 2: Generating GNN Training Data ---
Generating GNN training data for 1 tomograms...


Processing Tomos for GNN Data:   0%|          | 0/1 [00:00<?, ?it/s]

Generated 1 graphs for GNN training.
Not enough GNN graphs for training split. GNN training skipped.
Full training pipeline complete (or partially complete if GNN failed).
Saved final base model to output_cryoEM_MotorMetaNet\final_trained_base_model.pth
GNN model specified and PyG available, but GNN model not trained/returned.

WORKFLOW: Starting Inference...

Running 3D inference on tomogram: tomo_004


100%|███████████████████████████████████████████████████████████████████████████████████| 4/4 [00:00<00:00, 181.37it/s]


Performing inference using only base model (GNN not available or specified).
Found 500 detections in 3D for tomo_004.
Evaluation for 3D detections on tomo_004: {'f2_score': 0.00992063491079302, 'precision': 0.0019999999995999997, 'recall': 0.9999999000000099, 'tp': 1, 'fp': 499, 'fn': 0}

Running 2D (adapted) inference on slice: slice_0060 from tomo_004
Performing 2D adapted inference using only base model for slice slice_0060.
Found 200 detections in 2D for slice slice_0060.

CryoEM-MotorMetaNet Workflow Demo Finished.
