Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 93 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -75,3 +75,96 @@ The list can be found in the `configs/data/chebi50_graph_properties.yml` file.
```bash
python -m chebai fit --trainer=configs/training/default_trainer.yml --trainer.logger=configs/training/csv_logger.yml --model=../python-chebai-graph/configs/model/gnn_res_gated.yml --model.train_metrics=configs/metrics/micro-macro-f1.yml --model.test_metrics=configs/metrics/micro-macro-f1.yml --model.val_metrics=configs/metrics/micro-macro-f1.yml --data=../python-chebai-graph/configs/data/chebi50_graph_properties.yml --data.init_args.batch_size=128 --trainer.accumulate_grad_batches=4 --data.init_args.num_workers=10 --model.pass_loss_kwargs=false --data.init_args.chebi_version=241 --trainer.min_epochs=200 --trainer.max_epochs=200 --model.criterion=configs/loss/bce.yml
```

## Augmented Graphs

Below is the command for the model and data configuration that achieved the best classification performance using augmented graphs.


```bash
python -m chebai fit --trainer=configs/training/default_trainer.yml --trainer.logger=configs/training/wandb_logger.yml --model=../python-chebai-graph/configs/model/gat_aug_amgpool.yml --model.train_metrics=configs/metrics/micro-macro-f1.yml --model.test_metrics=configs/metrics/micro-macro-f1.yml --model.val_metrics=configs/metrics/micro-macro-f1.yml --model.config.v2=True --data=../python-chebai-graph/configs/data/chebi50_aug_prop_as_per_node.yml --data.init_args.batch_size=128 --trainer.accumulate_grad_batches=4 --data.init_args.num_workers=10 --model.pass_loss_kwargs=false --data.init_args.chebi_version=241 --trainer.min_epochs=200 --trainer.max_epochs=200 --model.criterion=configs/loss/bce.yml --trainer.logger.init_args.name=gatv2_amg_s0
```

### Model Hyperparameters

#### **GAT Architecture**

To use a GAT-based model, choose **one** of the following configs:

- **Atom–Motif–Graph Node Pooling**: `--model=../python-chebai-graph/configs/model/gat_aug_amgpool.yml`
- **Atom-Augmented Node Pooling**: `--model=../python-chebai-graph/configs/model/gat_aug_aagpool.yml`
- **Standard Pooling**: `--model=../python-chebai-graph/configs/model/gat.yml`

#### GAT-specific hyperparameters

- **Number of message-passing layers**: `--model.config.num_layers=5` (default: 4)
- **Attention heads**: `--model.config.heads=4` (default: 8)
> Note: The number of heads should be divisible by the output channels (or hidden channels if output channels are not specified).
- **Use GATv2**: `--model.config.v2=True` (default: False)

#### **ResGated Architecture**

To use a ResGated GNN model, choose **one** of the following configs:

- **Atom–Motif–Graph Node Pooling**: `--model=../python-chebai-graph/configs/model/res_aug_amgpool.yml`
- **Atom-Augmented Node Pooling**: `--model=../python-chebai-graph/configs/model/res_aug_aagpool.yml`
- **Standard Pooling**: `--model=../python-chebai-graph/configs/model/resgated.yml`

#### **Common Hyperparameters**

These can be used for both GAT and ResGated architectures:

- **Dropout**: `--model.config.dropout=0.1` (default: 0)
- **Number of final linear layers**: `--model.n_linear_layers=2` (default: 1)

# Random Node Initialization

## Static Node Initialization

In this type of node initialization, the node features (and/or edge features) of the given molecular graph are initialized only once during dataset creation with the given initialization scheme.

```bash
python -m chebai fit --trainer=configs/training/default_trainer.yml --trainer.logger=configs/training/wandb_logger.yml --model=../python-chebai-graph/configs/model/resgated.yml --model.config.in_channels=203 --model.config.edge_dim=11 --model.train_metrics=configs/metrics/micro-macro-f1.yml --model.test_metrics=configs/metrics/micro-macro-f1.yml --model.val_metrics=configs/metrics/micro-macro-f1.yml --data=../python-chebai-graph/configs/data/chebi50_graph_properties.yml --data.pad_node_features=45 --data.pad_edge_features=4 --data.init_args.batch_size=128 --trainer.accumulate_grad_batches=4 --data.init_args.num_workers=10 --data.init_args.persistent_workers=False --model.pass_loss_kwargs=false --data.init_args.chebi_version=241 --trainer.min_epochs=200 --trainer.max_epochs=200 --model.criterion=configs/loss/bce.yml --trainer.logger.init_args.name=gni_res_props+zeros_s0
```

In the above command, for each node we use the 158 node features (corresponding the node properties defined in `chebi50_graph_properties.yml`) which are retrieved from RDKit and add 45 additional features (specified by `--data.pad_node_features=45`) drawn from a normal distribution (default).

You can change the distribution using the following config in above command: `--data.distribution=zeros`

Available distributions: `"normal", "uniform", "xavier_normal", "xavier_uniform", "zeros"`


Similarly, each edge is initialized with 7 RDKit features and 4 additional features drawn from the given distribution.


If you want all node (and edge) features to be drawn from a given distribution (i.e., ignore RDKit features), use: `--data=../python-chebai-graph/configs/data/chebi50_static_gni.yml`


Refer to the data class code for details.


## Dynamic Node Initialization

In this type of node initialization, the node features (and/or edge features) of the molecular graph are initialized at **each forward pass** of the model using the given initialization scheme.



Currently, dynamic node initialization is implemented only for the **resgated** architecture by specifying: `--model=../python-chebai-graph/configs/model/resgated_dynamic_gni.yml`

To keep RDKit features and *add* dynamically initialized features use the following config in the command:

```
--model.config.complete_randomness=False
--model.config.pad_node_features=45
```

The additional features are drawn from normal distribution (default). You can change it using:`--model.config.distribution=uniform`

If all features should be initialized from the given distribution, remove the complete_randomness flag (default is True).


Please find below the command for a typical dynamic node initialization:

```bash
python -m chebai fit --trainer=configs/training/default_trainer.yml --trainer.logger=configs/training/wandb_logger.yml --model=../python-chebai-graph/configs/model/resgated_dynamic_gni.yml --model.config.in_channels=203 --model.config.edge_dim=11 --model.config.complete_randomness=False --model.config.pad_node_features=45 --model.config.pad_edge_features=4 --model.train_metrics=configs/metrics/micro-macro-f1.yml --model.test_metrics=configs/metrics/micro-macro-f1.yml --model.val_metrics=configs/metrics/micro-macro-f1.yml --data=../python-chebai-graph/configs/data/chebi50_graph_properties.yml --data.init_args.batch_size=128 --trainer.accumulate_grad_batches=4 --data.init_args.num_workers=10 --data.init_args.persistent_workers=False --model.pass_loss_kwargs=false --data.init_args.chebi_version=241 --trainer.min_epochs=200 --trainer.max_epochs=200 --model.criterion=configs/loss/bce.yml --trainer.logger.init_args.name=gni_dres_props+rand_s0
```
61 changes: 55 additions & 6 deletions chebai_graph/models/dynamic_gni.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,23 @@
"""
ResGatedDynamicGNIGraphPred
------------------------------------------------

Module providing a ResGated GNN model that applies Random Node Initialization
(RNI) dynamically at each forward pass. This follows the approach from:

Abboud, R., et al. (2020). "The surprising power of graph neural networks with
random node initialization." arXiv preprint arXiv:2010.01179.

The module exposes:
- ResGatedDynamicGNI: a model that can either completely replace node/edge
features with random tensors each forward pass or pad existing features with
additional random features.
- ResGatedDynamicGNIGraphPred: a thin wrapper that instantiates the above for
graph-level prediction pipelines.
"""

__all__ = ["ResGatedDynamicGNIGraphPred"]

from typing import Any

import torch
Expand All @@ -14,12 +34,37 @@

class ResGatedDynamicGNI(GraphModelBase):
"""
Base model class for applying ResGatedGraphConv layers to graph-structured data
with dynamic initialization of features for nodes and edges.

Args:
config (dict): Configuration dictionary containing model hyperparameters.
**kwargs: Additional keyword arguments for parent class.
ResGated GNN with dynamic Random Node Initialization (RNI).

This model supports two modes controlled by the `config`:

- complete_randomness (bool-like): If True, **replace** node and edge
features entirely with randomly initialized tensors each forward pass.
If False, the model **pads** existing features with extra randomly
initialized features on-the-fly.

- pad_node_features (int, optional): Number of random columns to append
to each node feature vector when `complete_randomness` is False.

- pad_edge_features (int, optional): Number of random columns to append
to each edge feature vector when `complete_randomness` is False.

- distribution (str): Distribution for random initialization. Must be one
of RandomFeatureInitializationReader.DISTRIBUTIONS.

Parameters
----------
config : Dict[str, Any]
Configuration dictionary containing model hyperparameters. Expected keys
used by this class:
- distribution (optional, default "normal")
- complete_randomness (optional, default "True")
- pad_node_features (optional, int)
- pad_edge_features (optional, int)
Keys required by GraphModelBase (e.g., in_channels, hidden_channels,
out_channels, num_layers, edge_dim) should also be present.
**kwargs : Any
Additional keyword arguments forwarded to GraphModelBase.
"""

def __init__(self, config: dict[str, Any], **kwargs: Any):
Expand Down Expand Up @@ -96,6 +141,8 @@ def forward(self, batch: dict[str, Any]) -> Tensor:

new_x = None
new_edge_attr = None

# If replacing features entirely with random values
if self.complete_randomness:
new_x = torch.empty(
graph_data.x.shape[0], graph_data.x.shape[1], device=self.device
Expand All @@ -110,6 +157,8 @@ def forward(self, batch: dict[str, Any]) -> Tensor:
RandomFeatureInitializationReader.random_gni(
new_edge_attr, self.distribution
)

# If padding existing features with additional random columns
else:
if self.pad_node_features is not None:
pad_node = torch.empty(
Expand Down
159 changes: 135 additions & 24 deletions chebai_graph/preprocessing/reader/static_gni.py
Original file line number Diff line number Diff line change
@@ -1,47 +1,127 @@
"""
Abboud, Ralph, et al.
"The surprising power of graph neural networks with random node initialization."
arXiv preprint arXiv:2010.01179 (2020).
RandomFeatureInitializationReader
--------------------------------

Code Reference: https://github.com/ralphabb/GNN-RNI/blob/main/GNNHyb.py
Implements random node / edge / molecule feature initialization for graph neural
networks following:

Abboud, R., et al. (2020). "The surprising power of graph neural networks with
random node initialization." arXiv preprint arXiv:2010.01179.

Code reference: https://github.com/ralphabb/GNN-RNI/blob/main/GNNHyb.py

This module provides a reader that replaces node/edge/molecule features with
randomly initialized tensors drawn from a selected distribution.

Notes
-----
- This reader subclasses GraphPropertyReader and is intended to be used where a
graph object with attributes `x`, `edge_attr`, and optionally `molecule_attr`
is expected (e.g., `torch_geometric.data.Data`).
- The reader only performs random initialization and does not support reading
specific properties from the input data.
"""

from typing import Any, Optional

import torch
from torch import Tensor
from torch_geometric.data import Data as GeomData

from .reader import GraphPropertyReader


class RandomFeatureInitializationReader(GraphPropertyReader):
DISTRIBUTIONS = ["normal", "uniform", "xavier_normal", "xavier_uniform", "zeros"]
"""
Reader that initializes node, bond (edge), and molecule features with
random values according to a chosen distribution.

Supported distributions:
- "normal" : standard normal (mean=0, std=1)
- "uniform" : uniform in [-1, 1]
- "xavier_normal" : Xavier normal initialization
- "xavier_uniform" : Xavier uniform initialization
- "zeros" : all zeros

Parameters
----------
num_node_properties : int
Number of features to generate per node.
num_bond_properties : int
Number of features to generate per edge/bond.
num_molecule_properties : int
Number of global molecule-level features to generate.
distribution : str, optional
One of the supported distributions (default: "normal").
*args, **kwargs : Any
Additional positional and keyword arguments passed to the parent
GraphPropertyReader.
"""

DISTRIBUTIONS = [
"normal",
"uniform",
"xavier_normal",
"xavier_uniform",
"zeros",
]

def __init__(
self,
num_node_properties: int,
num_bond_properties: int,
num_molecule_properties: int,
distribution: str = "normal",
*args,
**kwargs,
):
*args: Any,
**kwargs: Any,
) -> None:
super().__init__(*args, **kwargs)
self.num_node_properties = num_node_properties
self.num_bond_properties = num_bond_properties
self.num_molecule_properties = num_molecule_properties
assert distribution in self.DISTRIBUTIONS
self.distribution = distribution
if distribution not in self.DISTRIBUTIONS:
raise ValueError(
f"distribution must be one of {self.DISTRIBUTIONS}, got '{distribution}'"
)

self.num_node_properties: int = int(num_node_properties)
self.num_bond_properties: int = int(num_bond_properties)
self.num_molecule_properties: int = int(num_molecule_properties)
self.distribution: str = distribution

def name(self) -> str:
"""
Get the name identifier of the reader.
Return a human-readable identifier for this reader configuration.

Returns
-------
str
A name encoding the chosen distribution and generated feature sizes.
"""
return (
f"gni-{self.distribution}"
f"-node{self.num_node_properties}"
f"-bond{self.num_bond_properties}"
f"-mol{self.num_molecule_properties}"
)

Returns:
str: The name of the reader.
def _read_data(self, raw_data: Any) -> Optional[GeomData]:
"""
return f"gni-{self.distribution}-node{self.num_node_properties}-bond{self.num_bond_properties}-mol{self.num_molecule_properties}"
Read and return a `torch_geometric.data.Data` object with randomized
node/edge/molecule features.

def _read_data(self, raw_data):
data: GeomData = super()._read_data(raw_data)
This method calls the parent's `_read_data` to obtain a graph object,
then replaces `x`, `edge_attr` and sets `molecule_attr` with new tensors.

Parameters
----------
raw_data : Any
Raw input that the parent reader understands.

Returns
-------
Optional[GeomData]
A `Data` object with randomized attributes or `None` if the parent
`_read_data` returned `None`.
"""
data: Optional[GeomData] = super()._read_data(raw_data)
if data is None:
return None

Expand All @@ -51,24 +131,55 @@ def _read_data(self, raw_data):
)
random_molecule_properties = torch.empty(1, self.num_molecule_properties)

# Initialize them according to the chosen distribution.
self.random_gni(random_x, self.distribution)
self.random_gni(random_edge_attr, self.distribution)
self.random_gni(random_molecule_properties, self.distribution)

# Assign randomized attributes back to the data object.
data.x = random_x
data.edge_attr = random_edge_attr
# Use `molecule_attr` as the name in this codebase; if your Data object
# expects a different name (e.g., `u` or `global_attr`) adapt accordingly.
data.molecule_attr = random_molecule_properties

return data

def read_property(self, *args, **kwargs) -> Exception:
"""This reader does not support reading specific properties."""
raise NotImplementedError("This reader only performs random initialization.")
def read_property(self, *args: Any, **kwargs: Any) -> None:
"""
This reader does not support reading specific properties from the input.
It only performs random initialization of features.

Raises
------
NotImplementedError
Always raised to indicate unsupported operation.
"""
raise NotImplementedError(
"RandomFeatureInitializationReader only performs random initialization."
)

@staticmethod
def random_gni(tensor: torch.Tensor, distribution: str) -> None:
def random_gni(tensor: Tensor, distribution: str) -> None:
"""
Fill `tensor` in-place according to the requested initialization.

Parameters
----------
tensor : torch.Tensor
The tensor to initialize in-place.
distribution : str
One of the supported distribution identifiers.

Raises
------
ValueError
If an unknown distribution string is provided.
"""
if distribution == "normal":
torch.nn.init.normal_(tensor)
elif distribution == "uniform":
# Uniform in [-1, 1]
torch.nn.init.uniform_(tensor, a=-1.0, b=1.0)
elif distribution == "xavier_normal":
torch.nn.init.xavier_normal_(tensor)
Expand All @@ -77,4 +188,4 @@ def random_gni(tensor: torch.Tensor, distribution: str) -> None:
elif distribution == "zeros":
torch.nn.init.zeros_(tensor)
else:
raise ValueError("Unknown distribution type")
raise ValueError(f"Unknown distribution type: '{distribution}'")