# **Bayesian GraphSAGE: Scalable protein function prediction on the ogbn-proteins dataset** 
Implementing multi-label classification with uncertainty estimation using JAX/NNX and PyArrow.

## ***Dataset description: ogbn-proteins (Biological network)***

On your laptop, you are working with a dataset from the **Open Graph Benchmark (OGB)** series, which represents the cutting edge in testing graph algorithms.

- **Povaha dat:** Reprezentuje vztahy mezi proteiny v různých biologických organismech.
- **Structure:**
    - **Nodes (132,534):** Individual proteins.
    - **Edges (39,473,625):** Represent 8 types of interactions (e.g. homology, co-expression).
- **Input Features:** Nodes initially have no features of their own. Your process uses **8-dimensional edge vectors** (link features) that are aggregated into nodes as initial features.
- **Target:** Multi-label classification. You predict **112 binary labels** at once, corresponding to the presence of a protein in various biological functions.

## ***Teorie: Bayesian GraphSAGE***

Your model in your laptop is not just a standard neural network, it's an inductive and probabilistic system.

### ***A. GraphSAGE (Inductive learning)***

Unlike transductive models (like GCN) that require the entire graph in memory, GraphSAGE learns the **aggregation function**. This means that after training, your model can process proteins it never saw during training.

**Key mechanism (SAmple and aggreGatE):** The model selects a fixed number of neighbors for each node in (Neighbor Sampling) and performs a state update:

$$
h_v^{(k)} = \sigma \left( W^k \cdot \text{CONTACT} \left( h_v^{(k-1)}, \text{AGG} \left( \{ h_u^{(k-1)} , \forall u \in N(v) \}  \right) \right) \right)
$$

- You are using **Mean Aggregator** (neighbor averaging) in your laptop.

### ***B. Bayesian Uncertainty (MC Dropout)***

The reason the model is called "Bayesian" is the implementation of **Monte Carlo Dropout**.

- **Theory:** Gal & Ghahramani proved that the application of Dropout during inference is mathematically equivalent to the approximation of Gaussian processes.
- **Usage in a laptop:** Even when testing (inference), you leave Dropout on. If you run the model N times, you get a distribution of results. The mean is your prediction and the variance is your uncertainty.

## ***Tutorial Description (Implementation process on a laptop)***

The notebook is divided into logical blocks that form the production pipeline:

### ***Phase 1: Data Layer (Hybrid Storage)***

Here you implement the critical solution for large graphs:

- **PyArrow (Parquet):** Use it to store 39 million edges and node properties. Thanks to memory-mapping, you only load what you need.
- **SQLite:** Serves as a fast topological index. The model asks SQLite: "Who are the neighbors of node 42?" and SQLite returns the rows in the Parquet file.

### ***Phase 2: Model Architecture in NNX***

You are using **JAX NNX**, which is an object-oriented API.

- **Neighbor Sampling:** At each training step, your dataloader randomly selects 10-25 neighbors for each layer. This prevents memory overflow (Memory Out of Memory).
- **Multi-label Head:** The last layer has 112 outputs with **Sigmoid** activation.

### ***Phase 3: Training and Optimization***

- **Loss:** You use `Binary Cross Entropy` calculated over all 112 labels.
- **JIT Kompilace:** Funkce `train_step` je obalena `@jax.jit`, což zkompiluje tvůj Python kód do vysoce optimalizovaného strojového kódu pro GPU/TPU (OpenXLA).

## ***Why is this solution "Production-Ready"?***

This approach, which you have in your laptop, solves the three biggest pain points of graphing tasks:

1. **Memory:** Thanks to PyArrow and neighbor sampling, you don't need 128GB of RAM.
2. **Speed:** JAX and XLA compilations allow you to train millions of edges in minutes.
3. **Credibility:** The Bayesian element gives your predictions a "certificate of certainty", which is essential in biochemistry or industry.

***Summary of components for your documentation:***

|Element|Implementation in a laptop|Theoretical benefit|
|--------|-------------------------|-------------------|
|**Data**|PyArrow + SQLite|Efficient I/O for giant graphs|
|**Model**|GraphSAGE (NNX)|Inductive capability (new nodes)|
|**Inference**|MC Dropout|Uncertainty estimation (Bayesian)|
|**Engine**|JAX / OpenXLA|Maximum HW acceleration|

## ***Environment settings***

In [None]:
!rm -r /content/sample_data

In [None]:
# Remove existing JAX installations
!pip uninstall -y -qq jax jaxlib jax-cuda12-plugin scipy pyarrow gymnasium numpy

In [None]:
# Install JAX 
!pip install -qq --upgrade "jax[cpu12]"
!pip install scipy gymnasium==0.29.0
!pip install tensorboard
!pip install tensorboard-plugin-profile

In [None]:
# Install core dependencies
%pip -qq install --upgrade jax jaxlib flax optax orbax-checkpoint grain
%pip -qq install numpy matplotlib scipy
%pip -qq install torch torchvision
%pip -qq install datasets 
%pip -qq install msgpack requests tqdm
%pip -qq install bitsandbytes
%pip -qq install jraph
%pip -qq install networkx
%pip -qq install ogb
%pip -qq install pyarrow
%pip -qq install db-sqlite3
%pip -qq install pandas polars
%pip -qq install bitsandbytes numpyro langdetect
%pip -qq install xprof
%pip -qq install jax2onnx
%pip -qq install onnx onnxruntime

In [None]:
# Install Git LFS for large files
!apt install git-lfs

In [None]:
import IPython
print("Rebooting kernel... Please wait 5-10 seconds.")
IPython.Application.instance().kernel.do_shutdown(restart=True)
!pip install --upgrade pip

In [None]:
!pip install -qq --upgrade "jax[cuda12]"

In [None]:
print(100*"-")
%pip show jax
print(100*"-")
%pip show jaxlib
print(100*"-")
%pip show jax-cuda12-plugin
print(100*"-")
%pip show flax
print(100*"-")
%pip show optax
print(100*"-")
%pip show torch
print(100*"-")
%pip show torchvision
print(100*"-")
%pip show orbax-checkpoint
print(100*"-")
%pip show numpy
print(100*"-")
%pip show tqdm
print(100*"-")
%pip show datasets
print(100*"-")
%pip show msgpack
print(100*"-")
%pip show bitsandbytes
print(100*"-")
%pip show jraph
print(100*"-")
%pip show networkx
print(100*"-")
%pip show ogb
print(100*"-")
%pip show pyarrow
print(100*"-")
%pip show db-sqlite3
print(100*"-")
%pip show polars
print(100*"-")
%pip show pandas
print(100*"-")
%pip show grain
print(100*"-")
%pip show bitsandbytes
print(100*"-")
%pip show numpyro
print(100*"-")
%pip show langdetect
print(100*"-")
%pip show xprof
print(100*"-")
%pip show jax2onnx
print(100*"-")
%pip show onnx
print(100*"-")
%pip show onnxruntime
print(100*"-")
%pip show tensorboard-plugin-profile
print(100*"-")
%pip show numpy
print(100*"-")
%pip show matplotlib
print(100*"-")
%pip show scipy
print(100*"-")

In [None]:
print("Environment setup complete!")

## ***Import and configuration***

In [None]:
import os
import sys
import io
import glob
import gc

import warnings
warnings.filterwarnings("ignore")

import json
import time
import tensorflow as tf
import subprocess
import pickle
import zipfile
import base64
import shutil
import sqlite3
import numpy as np
import pandas as pd
import polars as pl
import matplotlib.pyplot as plt
from matplotlib.patches import Patch
from tqdm import tqdm
import networkx as nx
from pathlib import Path
import functools
from functools import partial
from typing import (
    Any,
    Tuple,
    Callable,
    Optional,
    Sequence,
    List,
    Dict
)
from IPython.display import clear_output

import pyarrow as pa
import pyarrow.parquet as pq

from ogb.nodeproppred import NodePropPredDataset

# JAX and Flax NNX
import jax
import jax.ops
import jax.lax
import jax.profiler
import jax.numpy as jnp
import jax.export as jax_export
from jax import (
    random,
    jit,
    value_and_grad,
    remat
)
import jax.tree_util as tree_util
import flax.nnx as nnx
from flax.nnx import filterlib
from flax.serialization import (
    msgpack_serialize,
    from_bytes
)
import orbax.checkpoint as ocp
from orbax.checkpoint import PyTreeCheckpointer, CheckpointManager

# Optimization
import optax

import jraph

from sklearn.metrics import accuracy_score, confusion_matrix

import numpyro
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS, Predictive
from numpyro.contrib.module import nnx_module

from datasets import load_dataset

from kaggle_secrets import UserSecretsClient

# PyTorch for compatibility (GGUF conversion)
import torch

import jax2onnx
from jax2onnx import onnx_function, to_onnx

import onnx
import onnxruntime as ort


## ***Ensure JAX uses GPU if available***

In [None]:
# Configure JAX for GPU
try:
    jax.config.update('jax_platform_name', 'gpu')
    print("JAX devices:", jax.devices())
except RuntimeError:
    print("GPU not available, using CPU")
    jax.config.update('jax_platform_name', 'cpu')
    print("JAX devices:", jax.devices())

## ***Setting up access to Hugging Face***

In [None]:
def set_git_config(email, name):
    try:
        # Setting global user.email
        subprocess.run(["git", "config", "--global", "user.email", email], check=True)
        print(f"Git user.email set to: {email}")

        # Setting the global user.name
        subprocess.run(["git", "config", "--global", "user.name", name], check=True)
        print(f"Git user.name set to: {name}")

        # Check settings (optional)
        email_output = subprocess.run(["git", "config", "--global", "user.email"], capture_output=True, text=True, check=True)
        name_output = subprocess.run(["git", "config", "--global", "user.name"], capture_output=True, text=True, check=True)
        print(f"Check - Email: {email_output.stdout.strip()}")
        print(f"Check - Name: {name_output.stdout.strip()}")

    except subprocess.CalledProcessError as e:
        print(f"Error while setting up Git configuration: {e}")

## ***SAGEConv layer***

In [None]:
class SAGEConv(nnx.Module):
    def __init__(self, in_features: int, out_features: int, rngs=None):
        self.linear = nnx.Linear(in_features * 2, out_features, rngs=rngs or nnx.Rngs(0))
    
    def __call__(self, x, senders, receivers):
        mean_neighbor = jraph.segment_mean(x[senders], receivers, x.shape[0])
        concatenated = jnp.concatenate([x, mean_neighbor], axis=-1)
        return self.linear(concatenated)

## ***Bayesian GraphSAGE model***

In [None]:
class BayesianGraphSAGE(nnx.Module):
    def __init__(self, in_features=8, hidden_features=256, out_features=112):
        rngs = nnx.Rngs(0)
        self.sage1 = SAGEConv(in_features, hidden_features, rngs)
        self.sage2 = SAGEConv(hidden_features, out_features, rngs)
        self.dropout = nnx.Dropout(0.2, rngs=rngs)

    def __call__(self, x, senders, receivers, training=False):
        def core(x):
            x = jax.nn.relu(self.sage1(x, senders, receivers))
            x = self.sage2(x, senders, receivers)
            return x
        
        x = remat(core)(x)
        x = self.dropout(x, deterministic=not training)
        return jax.nn.sigmoid(x)

## ***Loader for OGBN-Proteins***

In [None]:
class OGBNProteinsLoader:
    def __init__(self, edges_parquet, nodes_parquet):
        self.edges = pq.read_table(edges_parquet).to_pandas()
        self.nodes = pq.read_table(nodes_parquet).to_pandas()
        
        self.features = np.stack(self.nodes["features"].values).astype(np.float32)
        self.labels   = np.stack(self.nodes["labels"].values).astype(np.float32)
        
        self.node_id_to_idx = {nid: i for i, nid in enumerate(self.nodes["node_id"])}
        
        self.train_idx = self.nodes[self.nodes["split"] == "train"]["node_id"].values
        self.val_idx   = self.nodes[self.nodes["split"] == "valid"]["node_id"].values
        self.test_idx  = self.nodes[self.nodes["split"] == "test"]["node_id"].values
        
        print(f"Loaded: {len(self.nodes)} nodes, {len(self.edges)} edges")

    def get_neighbors(self, node_id, max_neighbors=25):
        src = self.edges[self.edges["source"] == node_id]["target"].values[:max_neighbors]
        tgt = self.edges[self.edges["target"] == node_id]["source"].values[:max_neighbors]
        neigh = np.unique(np.concatenate([src, tgt]))
        return neigh if len(neigh) > 0 else np.array([node_id])

    def sample_batch(self, batch_size=64, max_neighbors=25, split="train"):
        idx = {"train": self.train_idx, "valid": self.val_idx, "test": self.test_idx}[split]
        centers = np.random.choice(idx, batch_size, replace=False)
        return [(c, self.get_neighbors(c, max_neighbors)) for c in centers]

    def get_features(self, node_ids):
        return self.features[[self.node_id_to_idx[nid] for nid in node_ids]]

    def get_labels(self, node_ids):
        return self.labels[[self.node_id_to_idx[nid] for nid in node_ids]]

## ***Bayesian model and inference***

In [None]:
def bayesian_graphsage_model(graph_data, X, y=None):
    senders, receivers, num_nodes_batch = graph_data
    
    # Define priors for all model parameters with more reasonable scales
    in_features, hidden_features, out_features = 8, 256, 112
    
    # sage1 parameters - use wider priors for better exploration
    w1_scale = numpyro.sample('w1_scale', dist.HalfNormal(1.0))
    w1 = numpyro.sample('w1', dist.Normal(0, w1_scale).expand([in_features * 2, hidden_features]).to_event(2))
    b1 = numpyro.sample('b1', dist.Normal(0, 0.5).expand([hidden_features]).to_event(1))
    
    # sage2 parameters
    w2_scale = numpyro.sample('w2_scale', dist.HalfNormal(1.0))
    w2 = numpyro.sample('w2', dist.Normal(0, w2_scale).expand([hidden_features * 2, out_features]).to_event(2))
    b2 = numpyro.sample('b2', dist.Normal(0, 0.5).expand([out_features]).to_event(1))
    
    # Forward pass
    # Layer 1
    mean_neighbor = jraph.segment_mean(X[senders], receivers, num_nodes_batch)
    concat1 = jnp.concatenate([X, mean_neighbor], axis=-1)
    h = jax.nn.relu(concat1 @ w1 + b1)
    
    # Layer 2
    mean_neighbor2 = jraph.segment_mean(h[senders], receivers, num_nodes_batch)
    concat2 = jnp.concatenate([h, mean_neighbor2], axis=-1)
    logits = concat2 @ w2 + b2
    preds = jax.nn.sigmoid(logits)
    
    # Add small epsilon to avoid numerical issues
    preds = jnp.clip(preds, 1e-7, 1 - 1e-7)
    
    # Likelihood
    with numpyro.plate("data", X.shape[0]):
        numpyro.sample("obs", dist.Bernoulli(probs=preds).to_event(1), obs=y)

In [None]:
def run_mcmc(loader, batch_size=32, num_samples=500, warmup=500):
    print("Preparing the training batch...")
    batch = loader.sample_batch(batch_size=batch_size, split="train")
    
    # Create node ID to local index mapping
    all_nodes = set()
    for center, neigh in batch:
        all_nodes.add(center)
        all_nodes.update(neigh)
    
    all_nodes = sorted(list(all_nodes))
    node_to_idx = {nid: i for i, nid in enumerate(all_nodes)}
    
    # Build edges with local indices
    senders, receivers = [], []
    for center, neigh in batch:
        center_idx = node_to_idx[center]
        for n in neigh:
            n_idx = node_to_idx[n]
            senders += [center_idx, n_idx]
            receivers += [n_idx, center_idx]
    
    senders = jnp.array(senders)
    receivers = jnp.array(receivers)
    num_nodes_batch = len(all_nodes)
    
    graph_data = (senders, receivers, num_nodes_batch)
    
    # Get features and labels for all nodes in the batch
    X = jnp.array(loader.get_features(all_nodes))
    # Only get labels for center nodes
    center_nodes = [c for c, _ in batch]
    center_indices = jnp.array([node_to_idx[c] for c in center_nodes])
    
    # Create full label array (only center nodes have labels)
    y_full = jnp.zeros((num_nodes_batch, 112), dtype=jnp.float32)
    y_centers = jnp.array(loader.get_labels(center_nodes))
    y_full = y_full.at[center_indices].set(y_centers)
    
    key = random.PRNGKey(42)
    
    # Use more samples and longer warmup for better convergence
    kernel = NUTS(bayesian_graphsage_model, max_tree_depth=6)
    mcmc = MCMC(kernel, num_warmup=warmup, num_samples=num_samples, num_chains=1)
    
    print(f"Running MCMC (warmup={warmup}, samples={num_samples})...")
    mcmc.run(key, graph_data, X, y=y_full)
    
    # Print diagnostics
    mcmc.print_summary()
    
    print("MCMC completed!")
    
    return mcmc.get_samples()

## ***Prediction with uncertainty***

In [None]:
def predict_with_uncertainty(posterior_samples, graph_data, X, n_samples=100):
    senders, receivers, num_nodes_batch = graph_data
    
    preds = []
    for i in tqdm(range(n_samples), desc="MC prediction"):
        # Extract parameters for this sample
        w1 = posterior_samples['w1'][i]
        b1 = posterior_samples['b1'][i]
        w2 = posterior_samples['w2'][i]
        b2 = posterior_samples['b2'][i]
        
        # Forward pass with sampled parameters
        # Layer 1
        mean_neighbor = jraph.segment_mean(X[senders], receivers, num_nodes_batch)
        concat1 = jnp.concatenate([X, mean_neighbor], axis=-1)
        h = jax.nn.relu(concat1 @ w1 + b1)
        
        # Layer 2
        mean_neighbor2 = jraph.segment_mean(h[senders], receivers, num_nodes_batch)
        concat2 = jnp.concatenate([h, mean_neighbor2], axis=-1)
        logits = concat2 @ w2 + b2
        pred = jax.nn.sigmoid(logits)
        
        preds.append(pred)
    
    preds = jnp.stack(preds)
    return preds.mean(0), preds.std(0)


## ***Visualization of inference results***

In [None]:
def plot_predictions(mean_pred, uncertainty, true_labels, num_classes=30):
    mean_node = mean_pred[0]
    unc_node = uncertainty[0]
    true_node = true_labels[0]
    
    x = np.arange(num_classes)
    classes = [f'C{i}' for i in range(num_classes)]
    
    fig, ax = plt.subplots(figsize=(14, 7))
    ax.bar(x, mean_node[:num_classes], yerr=unc_node[:num_classes], capsize=5, color='skyblue', alpha=0.8, label='Pre-values ± uncertainty of values')
    positive = np.where(true_node[:num_classes] > 0.5)[0]
    ax.plot(positive, true_node[positive], 'ro', markersize=8, label='True positive')
    
    ax.set_ylabel('Probability')
    ax.set_title('Bayesian GraphSAGE – Prediction with uncertainty (first node)')
    ax.set_xticks(x)
    ax.set_xticklabels(classes, rotation=45)
    ax.legend()
    ax.grid(True, axis='y')
    plt.tight_layout()
    plt.show()

## ***Export do všech formátů***

In [None]:
def export_model_all_formats(model, posterior_samples, output_dir="models"):
    """
    Fixed export with synchronous checkpointer to avoid async traceback
    """
    import time
    os.makedirs(output_dir, exist_ok=True)

    print(f"\n{'='*70}")
    print("EXPORT MODELU (Synchronous PyTreeCheckpointer)")
    print("="*70 + "\n")

    # Use synchronous checkpointer - no async threads
    checkpointer = PyTreeCheckpointer()

    # 1. Save model structure - separate graphdef and state
    graphdef, param_state, rng_state = nnx.split(model, nnx.Param, nnx.RngState)
    
    # Save graphdef separately using pickle (it's not a pytree)
    model_dir = os.path.join(output_dir, "bayesian_graphsage_nnx")
    os.makedirs(model_dir, exist_ok=True)
    
    graphdef_path = os.path.join(model_dir, "graphdef.pkl")
    with open(graphdef_path, 'wb') as f:
        pickle.dump(graphdef, f)
    print(f"✓ GraphDef uložen jako pickle: {graphdef_path}")
    
    # Save param_state and rng_state with Orbax (these ARE pytrees)
    state_items = {
        "param_state": param_state,
        "rng_state": rng_state
    }
    
    state_path = os.path.join(model_dir, "state")
    checkpointer.save(state_path, state_items)
    print(f"✓ Model state uložen do: {state_path}")

    # 2. Save posterior samples (stacked)
    sample_keys = list(posterior_samples.keys())
    num_samples = len(next(iter(posterior_samples.values())))

    stacked_state = {}
    for key in sample_keys:
        stacked_state[key] = jnp.stack(posterior_samples[key])

    posterior_state = nnx.State(stacked_state)

    posterior_path = os.path.join(output_dir, "posterior_samples")
    checkpointer.save(posterior_path, posterior_state)
    print(f"✓ Posterior samples uloženy ({num_samples} samples) do: {posterior_path}")

    # CRITICAL: Wait for Orbax to finalize (prevents async traceback)
    print("\nWaiting for the finalization of Orbax checkpoints...")
    time.sleep(2)  # Give Orbax time to finish async operations
    
    print(f"\n{'='*70}")
    print("ALL DONE! Checkpoints are fully finalized.")
    print("="*70)
    print("   • bayesian_graphsage_nnx/")
    print("     ├── graphdef.pkl          — model structure")
    print("     └── state/                — model parameters")
    print("   • posterior_samples/        — MCMC samples")


## ***Memory-optimized inference - processes one sample at a time***

In [None]:
def load_and_run_inference(loader, path_model_dir, n_mc=50, batch_size=16):
    """
    Ultra memory-optimized inference - one sample at a time with immediate cleanup
    """
    
    print("=== LOADING THE MODEL AND POSTERIOR SAMPLES ===\n")

    checkpointer = PyTreeCheckpointer()

    # 1. Load model checkpoint (we don't actually need graphdef for inference)
    model_dir = os.path.join(path_model_dir, "bayesian_graphsage_nnx")

    if not os.path.exists(model_dir):
        raise FileNotFoundError(f"Model directory not found: {model_dir}")

    print("✓ Model directory found")

    # 2. Load posterior samples
    posterior_path = os.path.join(path_model_dir, "posterior_samples")
    
    if not os.path.exists(posterior_path):
        candidates = glob.glob(os.path.join(path_model_dir, "posterior_samples*"))
        if not candidates:
            raise FileNotFoundError("No directory with posterior samples found!")
        posterior_path = sorted(candidates)[-1]

    posterior_state = checkpointer.restore(posterior_path)
    posterior_samples = {k: posterior_state[k] for k in posterior_state.keys()}
    num_samples_total = posterior_state[list(posterior_state.keys())[0]].shape[0]

    print(f"✓ Loaded {num_samples_total} posterior samples\n")

    # 3. Prepare test batch with SMALLER size to reduce memory
    print("I'm preparing a test batch...")
    test_batch = loader.sample_batch(batch_size=batch_size, split="test")
    
    # Build node mapping for test batch
    all_nodes = set()
    for center, neigh in test_batch:
        all_nodes.add(center)
        all_nodes.update(neigh)
    
    all_nodes = sorted(list(all_nodes))
    node_to_idx = {nid: i for i, nid in enumerate(all_nodes)}
    
    # Build edges with local indices
    senders_list, receivers_list = [], []
    for center, neigh in test_batch:
        center_idx = node_to_idx[center]
        for n in neigh:
            n_idx = node_to_idx[n]
            senders_list += [center_idx, n_idx]
            receivers_list += [n_idx, center_idx]

    senders = jnp.array(senders_list)
    receivers = jnp.array(receivers_list)
    
    # Get features and labels for all nodes
    X_test = jnp.array(loader.get_features(all_nodes))
    
    # Get labels only for center nodes
    center_nodes = [c for c, _ in test_batch]
    center_indices = jnp.array([node_to_idx[c] for c in center_nodes])
    y_test = loader.get_labels(center_nodes)
    
    num_nodes_batch = len(all_nodes)

    print(f"Test batch ready: {len(center_nodes)} central nodes, {num_nodes_batch} total\n")

    # 4. Ultra memory-efficient Bayesian inference - ONE sample at a time
    print(f"Running Bayesian Monte Carlo inference ({n_mc} samples)...")
    print("Processing one sample at a time to minimize memory usage...\n")
    
    # Initialize accumulators for online statistics (Welford's algorithm)
    n_centers = len(center_nodes)
    n_classes = y_test.shape[1]
    
    # We'll accumulate mean and M2 for online variance calculation
    mean_accumulator = jnp.zeros((n_centers, n_classes), dtype=jnp.float32)
    m2_accumulator = jnp.zeros((n_centers, n_classes), dtype=jnp.float32)
    
    # Pre-compute layer 1 aggregation (doesn't depend on weights)
    mean_neighbor_layer1 = jraph.segment_mean(X_test[senders], receivers, num_nodes_batch)
    
    # Process one sample at a time
    for i in tqdm(range(n_mc), desc="MC Inference", ncols=80):
        # Extract i-th sample from posterior
        w1 = posterior_samples['w1'][i]
        b1 = posterior_samples['b1'][i]
        w2 = posterior_samples['w2'][i]
        b2 = posterior_samples['b2'][i]
        
        # Forward pass - Layer 1
        concat1 = jnp.concatenate([X_test, mean_neighbor_layer1], axis=-1)
        h = jax.nn.relu(concat1 @ w1 + b1)
        
        # Forward pass - Layer 2
        mean_neighbor2 = jraph.segment_mean(h[senders], receivers, num_nodes_batch)
        concat2 = jnp.concatenate([h, mean_neighbor2], axis=-1)
        logits = concat2 @ w2 + b2
        pred_all = jax.nn.sigmoid(logits)
        
        # Extract predictions for center nodes only
        pred = pred_all[center_indices]
        
        # Update online statistics (Welford's algorithm)
        delta = pred - mean_accumulator
        mean_accumulator += delta / (i + 1)
        delta2 = pred - mean_accumulator
        m2_accumulator += delta * delta2
        
        # Explicit cleanup every 10 samples
        if (i + 1) % 10 == 0:
            gc.collect()
    
    # Final statistics
    mean_pred = mean_accumulator
    variance = m2_accumulator / n_mc
    uncertainty = jnp.sqrt(variance)
    
    # Clear memory
    del mean_accumulator, m2_accumulator, variance
    gc.collect()

    # 5. Results
    acc = ((mean_pred > 0.5) == (y_test > 0.5)).mean() * 100

    print(f"\n{'='*60}")
    print(f"BAYESIAN GRAPHSAGE — TEST SET RESULTS")
    print(f"{'='*60}")
    print(f"Accuracy Test (Monte Carlo average): {acc:.2f}%")
    print(f"Average epistemic uncertainty: {uncertainty.mean():.5f}")
    print(f"Maximum uncertainty: {uncertainty.max():.5f}")
    print(f"{'='*60}\n")

    return mean_pred, uncertainty, y_test, acc


## ***Visualization function***

In [None]:
def plot_bayesian_prediction(mean_pred, uncertainty, y_test, acc, node_idx=0, num_classes=30):
    """
    Visualize Bayesian prediction for a single node
    """
    mean_node = mean_pred[node_idx]
    unc_node = uncertainty[node_idx]
    true_node = y_test[node_idx]

    x = np.arange(num_classes)
    fig, ax = plt.subplots(figsize=(15, 8))

    bars = ax.bar(x, mean_node[:num_classes], yerr=unc_node[:num_classes],
                  capsize=5, color='cornflowerblue', edgecolor='navy', alpha=0.8,
                  label='Prediction ± epistemic uncertainty')

    # True positive classes (red dots)
    positive = np.where(true_node[:num_classes] > 0.5)[0]
    ax.scatter(positive, true_node[positive], color='red', s=120, marker='o',
               label='True positive class', zorder=10)

    ax.set_title(f'Bayesian GraphSAGE — Prediction on a test node #{node_idx}\n'
                 f'Accuracy: {acc:.2f}% | Average uncertainty: {uncertainty.mean():.5f}',
                 fontsize=16, fontweight='bold')
    ax.set_xlabel('Class', fontsize=12)
    ax.set_ylabel('Probability', fontsize=12)
    ax.set_xticks(x)
    ax.set_xticklabels([f'C{i}' for i in range(num_classes)], rotation=45, ha='right')
    ax.legend(fontsize=12)
    ax.grid(True, axis='y', alpha=0.3, linestyle='--')
    ax.set_ylim(0, 1.05)

    # Text with values above bars
    for bar, val, err in zip(bars, mean_node[:num_classes], unc_node[:num_classes]):
        if val > 0.05 or err > 0.01:
            ax.text(bar.get_x() + bar.get_width()/2, val + err + 0.02,
                    f'{val:.2f} ± {err:.2f}', ha='center', va='bottom', fontsize=9)

    plt.tight_layout()
    plt.show()

## ***DIAGNOSTIC - Check what's actually saved***

In [None]:
def diagnose_saved_model(path_model_dir):
    """
    Diagnose what's in the saved checkpoints
    """
    import os
    import pickle
    from orbax.checkpoint import PyTreeCheckpointer
    
    print("="*70)
    print("DIAGNOSTIC REPORT")
    print("="*70)
    
    checkpointer = PyTreeCheckpointer()
    
    # Check posterior samples
    posterior_path = os.path.join(path_model_dir, "posterior_samples")
    
    if os.path.exists(posterior_path):
        print(f"\n✓ Found posterior_samples at: {posterior_path}")
        
        try:
            posterior_state = checkpointer.restore(posterior_path)
            
            print("\nPosterior samples keys:")
            for key in posterior_state.keys():
                if hasattr(posterior_state[key], 'shape'):
                    print(f"  - {key}: shape={posterior_state[key].shape}, dtype={posterior_state[key].dtype}")
                else:
                    print(f"  - {key}: {type(posterior_state[key])}")
            
            # Check if shapes are correct
            if 'w1' in posterior_state:
                w1_shape = posterior_state['w1'].shape
                print(f"\n✓ w1 shape: {w1_shape}")
                print(f"  Expected: (num_samples, 16, 256)")
                
            if 'w2' in posterior_state:
                w2_shape = posterior_state['w2'].shape
                print(f"✓ w2 shape: {w2_shape}")
                print(f"  Expected: (num_samples, 512, 112)")
                
            return posterior_state
            
        except Exception as e:
            print(f"\n✗ Error loading posterior samples: {e}")
            return None
    else:
        print(f"\n✗ Posterior samples not found at: {posterior_path}")
        return None

## ***MAIN FUNCTIONS – MAIN WITH TRAINING, VALIDATION, INFERENCE AND VISUALIZATION***

In [None]:
if __name__ == "__main__":
    # === LOADING DATASET ===
    base_dir_input = os.path.join("/kaggle","input","bayesian-graphsage-dataset","datasets","graph_dataset")
    edges_parquet_file = os.path.join(base_dir_input, "proteins_edges.parquet")
    nodes_parquet_file = os.path.join(base_dir_input, "proteins_nodes.parquet")

    path_model_dir = os.path.join('/kaggle','working','models')
    os.makedirs(path_model_dir, exist_ok=True)
    
    loader = OGBNProteinsLoader(
        edges_parquet=edges_parquet_file,
        nodes_parquet=nodes_parquet_file
    )

    # MCMC (posterior training) - with better hyperparameters
    posterior_samples = run_mcmc(loader, batch_size=32, num_samples=500, warmup=500)

    print("\n" + "="*70)
    print("VALIDATION")
    print("="*70)
    
    val_batch = loader.sample_batch(batch_size=64, split="valid")
    
    # Create node mapping for validation
    all_nodes = set()
    for center, neigh in val_batch:
        all_nodes.add(center)
        all_nodes.update(neigh)
    
    all_nodes = sorted(list(all_nodes))
    node_to_idx = {nid: i for i, nid in enumerate(all_nodes)}
    
    senders, receivers = [], []
    for center, neigh in val_batch:
        center_idx = node_to_idx[center]
        for n in neigh:
            n_idx = node_to_idx[n]
            senders += [center_idx, n_idx]
            receivers += [n_idx, center_idx]
    
    graph_data = (jnp.array(senders), jnp.array(receivers), len(all_nodes))
    X_val = jnp.array(loader.get_features(all_nodes))
    
    # Get predictions with more samples
    mean_pred, uncertainty = predict_with_uncertainty(posterior_samples, graph_data, X_val, n_samples=500)
    
    # Get labels only for center nodes
    center_nodes = [c for c, _ in val_batch]
    center_indices = jnp.array([node_to_idx[c] for c in center_nodes])
    y_val = loader.get_labels(center_nodes)
    
    # Extract predictions for center nodes only
    mean_pred_centers = mean_pred[center_indices]
    uncertainty_centers = uncertainty[center_indices]
    
    acc = ((mean_pred_centers > 0.5) == (y_val > 0.5)).mean() * 100
    print(f"\nMicro-Accuracy: {acc:.2f}%")
    print(f"Average uncertainty: {uncertainty_centers.mean():.4f}")
    print(f"Max uncertainty: {uncertainty_centers.max():.4f}")
    print(f"Min uncertainty: {uncertainty_centers.min():.4f}")

    plot_predictions(mean_pred_centers, uncertainty_centers, y_val, num_classes=30)

    print("\n" + "="*70)
    print("ALL DONE! Graph displayed, no errors.")
    print("="*70)

    print("\nStarting model export...")
    
    sample_batch = loader.sample_batch(batch_size=8, split="valid")
    s, r = [], []
    for center, neigh in sample_batch:
        for n in neigh:
            s += [center, n]
            r += [n, center]
    
    sample_x = loader.get_features([c for c, _ in sample_batch])[:4]
    sample_senders = jnp.array(s[:200])
    sample_receivers = jnp.array(r[:200])
    
    export_model = BayesianGraphSAGE()  # čistý model

    export_model_all_formats(
        model=export_model,
        posterior_samples=posterior_samples,
        output_dir=path_model_dir
    )

## ***INFERENCE ON A LEARNED BAYESIAN GRAPHSAGE MODEL***

In [None]:
# First, run diagnostics to see what's in the saved files
posterior_samples = diagnose_saved_model(path_model_dir)

# INFERENCE - Start with very conservative settings
mean_pred, uncertainty, y_test, acc = load_and_run_inference(
    loader, 
    path_model_dir,
    n_mc=20,        # Start with just 20 samples
    batch_size=16   # Very small batch
)
    
# VISUALIZE
plot_bayesian_prediction(mean_pred, uncertainty, y_test, acc, node_idx=0, num_classes=30)