# Scalable GNN–Based Preconditioners for Conjugate Gradient Methods
**Authors: Nicholas Tan Yun Yu, Low Jun Yu, Yuhan Wu**

This project was inspired by [NeuralIF](https://arxiv.org/abs/2305.16368).

**Summary**: The authors come up with a novel message-passing GNN block that is used by the network to predict efficient preconditioners to solve sparse linear systems. These preconditioners are tested using the preconditioned conjugate gradient (CG) method, which make the algorithm converge faster than other state-of-the-art preconditioners.

**Motivation**: Modern data-driven and physics-based applications frequently force us to deal with dense matrices. Therefore, we hope to show that the message-passing GNN block can learn effective preconditioners for these scaled up fields. An example of a machine learning problem that could benefit from this is Gaussian Processes, which makes use of a dense kernel function as such: (some image)

**The problem**: Scaling the problem to dense matrices is nontrivial. The Coates graph representation has 1 node per row/column, and one edge only for each nonzero entry in A. For a dense n*n matrix, that graph becomes complete – with n^2 edges – so both memory and compute blow up to O(n^2).

**Research direction**: Implement an edge-regression GNN that can work on dense matrices. We can achieve this using sampling techniques such as GraphSAGE and Cluster-GCN.

# 1. Installation & Setup

## Load files from google drive

In [None]:
# mount google drive
from google.colab import drive
drive.mount('/content/drive', force_remount=True)

Mounted at /content/drive


In [None]:
# Add directory to python's path
# This directory path should lead straight to the root of the project
# Ensure that the project's root directory has the folders "krylov", "apps" and "neuralif"
#   that has the files contained in https://github.com/paulhausner/neural-incomplete-factorization/tree/main
import sys
sys.path.append('/content/drive/MyDrive/CS4350/project/')

## Package installation

In [None]:
%pip install -q torch_geometric

[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/63.1 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m63.1/63.1 kB[0m [31m3.3 MB/s[0m eta [36m0:00:00[0m
[?25h[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/1.1 MB[0m [31m?[0m eta [36m-:--:--[0m[2K   [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m [32m1.1/1.1 MB[0m [31m58.8 MB/s[0m eta [36m0:00:01[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.1/1.1 MB[0m [31m31.3 MB/s[0m eta [36m0:00:00[0m
[?25h

## Imports

In [None]:
import os
import datetime
import pprint
import time

import numpy as np
import scipy
import torch
import torch_geometric
import torch.nn as nn
import torch_geometric.nn as pyg
from torch_geometric.nn import aggr
from torch_geometric.utils import to_scipy_sparse_matrix
from scipy.sparse import tril, coo_matrix

from apps.data import get_dataloader, graph_to_matrix, matrix_to_graph
from neuralif.utils import (
    count_parameters, save_dict_to_file,
    condition_number, eigenval_distribution, gershgorin_norm
)
from neuralif.logger import TrainResults, TestResults
from neuralif.loss import loss
# from neuralif.models import NeuralPCG, NeuralIF, PreCondNet, LearnedLU

from krylov.cg import preconditioned_conjugate_gradient
from krylov.gmres import gmres

# import from self-curated numml file
from numml import SparseCSRTensor

## Set GPU

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cpu


# 2. Dataset generation

## Helper functions

In [None]:
def generate_sparse_random(n, alpha=1e-4, random_state=0, sol=False, ood=False):
    rng = np.random.RandomState(random_state)
    if alpha is None:
        alpha = rng.uniform(1e-4, 1e-2)
    sparsity = 10e-4  # this is 1% sparsity for n = 10 000

    if ood:
        factor = rng.uniform(0.22, 2.2)
        sparsity *= factor

    nnz = int(sparsity * n ** 2)
    rows = [rng.randint(0, n) for _ in range(nnz)]
    cols = [rng.randint(0, n) for _ in range(nnz)]
    uniques = set(zip(rows, cols))
    rows, cols = zip(*uniques)
    vals = rng.normal(0, 1, size=len(cols))

    M = coo_matrix((vals, (rows, cols)), shape=(n, n))
    I = scipy.sparse.identity(n)
    A = (M @ M.T) + alpha * I # create spd matrix
    print(f"Generated matrix with {100 * (A.nnz / n**2):.2f}% non-zeros ({A.nnz} entries)")

    b = rng.uniform(0, 1, size=n)
    x = None
    if sol:
        x, _ = scipy.sparse.linalg.cg(A, b)
    return A, x, b

def create_dataset(n, samples, alpha=1e-2, graph=True, rs=0, mode='train', solution=False):
    if mode != 'train' and rs == 0:
        raise ValueError('`rs` must be non-zero for val/test to avoid overlap')

    print(f"Creating {samples} samples for '{mode}' set (n={n})")
    for sam in range(samples):
        A, x, b = generate_sparse_random(
            n, alpha=alpha, random_state=rs + sam,
            sol=solution, ood=(mode == "test_ood")
        )
        if graph:
            g = matrix_to_graph(A, b)
            if x is not None:
                g.s = torch.tensor(x, dtype=torch.float)
            g.n = n
            torch.save(g, f'./data/Random/{mode}/{n}_{sam}.pt')
        else:
            scipy.sparse.save_npz(f'./data/Random/{mode}/{n}_{sam}.npz', A)
            np.savez(f'./data/Random/{mode}/{n}_{sam}.npz', A=A, b=b, x=x)


## Create Train, Validation and Test datasets

In [None]:
# ensure target folders exist
for split in ['train', 'val', 'test']:
    os.makedirs(f'./data/Random/{split}', exist_ok=True)

# parameters
n = 10_000
alpha = 10e-4

# generate
create_dataset(n, samples=1000, alpha=alpha, mode='train', rs=0, graph=True, solution=True)
create_dataset(n, samples=10, alpha=alpha, mode='val', rs=10000, graph=True, solution=False)
create_dataset(n, samples=100, alpha=alpha, mode='test', rs=103600, graph=True, solution=False)


# 3. Models

In [None]:
class NeuralIF(nn.Module):
    # Neural Incomplete factorization
    def __init__(self, drop_tol=0, **kwargs) -> None:
        super().__init__()

        self.global_features = kwargs["global_features"]
        self.latent_size = kwargs["latent_size"]
        # node features are augmented with local degree profile
        self.augment_node_features = kwargs["augment_nodes"]

        num_node_features = 8 if self.augment_node_features else 1
        message_passing_steps = kwargs["message_passing_steps"]

        # edge feature representation in the latent layers
        edge_features = kwargs.get("edge_features", 1)

        self.skip_connections = kwargs["skip_connections"]

        self.mps = torch.nn.ModuleList()
        for l in range(message_passing_steps):
            # skip connections are added to all layers except the first one
            self.mps.append(MP_Block(skip_connections=self.skip_connections,
                                     first=l==0,
                                     last=l==(message_passing_steps-1),
                                     edge_features=edge_features,
                                     node_features=num_node_features,
                                     global_features=self.global_features,
                                     hidden_size=self.latent_size,
                                     activation=kwargs["activation"],
                                     aggregate=kwargs["aggregate"]))

        # node decodings
        self.node_decoder = MLP([num_node_features, self.latent_size, 1]) if kwargs["decode_nodes"] else None

        # diag-aggregation for normalization of rows
        self.normalize_diag = kwargs["normalize_diag"] if "normalize_diag" in kwargs else False
        self.diag_aggregate = aggr.SumAggregation()

        # normalization
        self.graph_norm = pyg.norm.GraphNorm(num_node_features) if ("graph_norm" in kwargs and kwargs["graph_norm"]) else None

        # drop tolerance and additional fill-ins and more sparsity
        self.tau = drop_tol
        self.two = kwargs.get("two_hop", False)

    def forward(self, data):
        # ! data could be batched here...(not implemented)

        if self.augment_node_features:
            data = augment_features(data, skip_rhs=True)

        # add additional edges to the data
        if self.two:
            data = TwoHop()(data)

        # * in principle it is possible to integrate reordering here.

        data = ToLowerTriangular()(data)

        # get the input data
        edge_embedding = data.edge_attr
        l_index = data.edge_index

        if self.graph_norm is not None:
            node_embedding = self.graph_norm(data.x, batch=data.batch)
        else:
            node_embedding = data.x

        # copy the input data (only edges of original matrix A)
        a_edges = edge_embedding.clone()

        if self.global_features > 0:
            global_features = torch.zeros((1, self.global_features), device=data.x.device, requires_grad=False)
            # feature ideas: nnz, 1-norm, inf-norm col/row var, min/max variability, avg distances to nnz
        else:
            global_features = None

        # compute the output of the network
        for i, layer in enumerate(self.mps):
            if i != 0 and self.skip_connections:
                edge_embedding = torch.cat([edge_embedding, a_edges], dim=1)

            edge_embedding, node_embedding, global_features = layer(node_embedding, l_index, edge_embedding, global_features)

        # transform the output into a matrix
        return self.transform_output_matrix(node_embedding, l_index, edge_embedding, a_edges)

    def transform_output_matrix(self, node_x, edge_index, edge_values, a_edges):
        # force diagonal to be positive
        diag = edge_index[0] == edge_index[1]

        # normalize diag such that it has zero residual
        if self.normalize_diag:
            # copy the diag of matrix A
            a_diag = a_edges[diag]

            # compute the row norm
            square_values = torch.pow(edge_values, 2)
            aggregated = self.diag_aggregate(square_values, edge_index[0])

            # now, we renormalize the edge values such that they are the square root of the original value...
            edge_values = torch.sqrt(a_diag[edge_index[0]]) * edge_values / torch.sqrt(aggregated[edge_index[0]])

        else:
            # otherwise, just take the edge values as they are...
            # but take the square root as it is numerically better
            # edge_values[diag] = torch.exp(edge_values[diag])
            edge_values[diag] = torch.sqrt(torch.exp(edge_values[diag]))

        # node decoder
        node_output = self.node_decoder(node_x).squeeze() if self.node_decoder is not None else None

        # ! this if should only be activated when the model is in production!!
        if torch.is_inference_mode_enabled():

            # we can decide to remove small elements during inference from the preconditioner matrix
            if self.tau != 0:
                small_value = (torch.abs(edge_values) <= self.tau).squeeze()

                # small value and not diagonal
                elems = torch.logical_and(small_value, torch.logical_not(diag))

                # might be able to do this easily!
                edge_values[elems] = 0

                # remove zeros from the sparse representation
                filt = (edge_values != 0).squeeze()
                edge_values = edge_values[filt]
                edge_index = edge_index[:, filt]

            # ! this is the way to go!!
            # Doing pytorch -> scipy -> numml is a lot faster than pytorch -> numml on CPU
            # On GPU it is faster to go to pytorch -> numml -> CPU

            # convert to scipy sparse matrix
            # m = to_scipy_sparse_matrix(edge_index, matrix_values)
            m = torch.sparse_coo_tensor(edge_index, edge_values.squeeze(),
                                        size=(node_x.size()[0], node_x.size()[0]))
                                        # type=torch.double)

            # produce L and U seperatly
            l = sp.SparseCSRTensor(m)
            u = sp.SparseCSRTensor(m.T)

            return l, u, node_output

        else:
            # For training and testing (computing regular losses for examples.)
            # does not need to be performance optimized!
            # use torch sparse directly
            t = torch.sparse_coo_tensor(edge_index, edge_values.squeeze(),
                                        size=(node_x.size()[0], node_x.size()[0]))

            # normalized l1 norm is best computed here!
            # l2_nn = torch.linalg.norm(edge_values, ord=2)
            l1_penalty = torch.sum(torch.abs(edge_values)) / len(edge_values)

            return t, l1_penalty, node_output


# 4. Training

## Set Training Configuration

In [None]:
config = {
    "name": "experiment_1",
    "save": True,
    "seed": 42,
    "n": 0,
    "batch_size": 1,
    "num_epochs": 100,
    "dataset": "random",
    "loss": "frobenius",
    "gradient_clipping": 1.0,
    "regularizer": 0.0,
    "scheduler": False,
    "model": "neuralif",
    "normalize": False,
    "latent_size": 8,
    "message_passing_steps": 3,
    "decode_nodes": False,
    "normalize_diag": False,
    "aggregate": [],
    "activation": "relu",
    "skip_connections": True,
    "augment_nodes": False,
    "global_features": 0,
    "edge_features": 1,
    "graph_norm": False,
    "two_hop": False,
}

# Prepare output folder
if config["name"]:
    folder = f"results/{config['name']}"
else:
    folder = datetime.datetime.now().strftime("results/%Y-%m-%d_%H-%M-%S")
if config["save"]:
    os.makedirs(folder, exist_ok=True)
    save_dict_to_file(config, os.path.join(folder, "config.json"))


In [None]:
# Seed for reproducibility
torch_geometric.seed_everything(config["seed"])

# Select model
model_args = {k: config[k] for k in [
    "latent_size", "message_passing_steps", "skip_connections",
    "augment_nodes", "global_features", "decode_nodes",
    "normalize_diag", "activation", "aggregate", "graph_norm",
    "two_hop", "edge_features", "normalize"
] if k in config}

use_gmres = False
if config["model"] == "__":
    model = NeuralPCG(**model_args)
elif config["model"] in ("nif", "neuralif", "inf"):
    model = NeuralIF(**model_args)
else:
    raise ValueError("Unknown model type")

model.to(device)
print("Number of parameters:", count_parameters(model))

optimizer = torch.optim.AdamW(model.parameters())
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, mode="min", factor=0.5, patience=20
)

In [None]:
# helper functions
@torch.no_grad()
def validate(model, validation_loader, solve=False, solver="cg"):
    model.eval()
    acc_loss = 0.0
    num_loss = 0
    acc_solver_iters = 0.0

    for data in validation_loader:
        data = data.to(device)
        A, b = graph_to_matrix(data)

        if solve:
            preconditioner = LearnedPreconditioner(data, model)
            A_cpu = A.cpu().double()
            b_cpu = b.cpu().double()
            x0 = None

            start = time.time()
            if solver == "cg":
                iters, x_hat = preconditioned_conjugate_gradient(
                    A_cpu, b_cpu, M=preconditioner, x0=x0,
                    rtol=1e-6, max_iter=1000
                )
            else:
                iters, x_hat = gmres(
                    A_cpu, b_cpu, M=preconditioner, x0=x0,
                    atol=1e-6, max_iter=1000, left=False
                )
            acc_solver_iters += len(iters) - 1
        else:
            output, _, _ = model(data)
            # l = frobenius_loss(output, A)
            l = loss(data, output, config="frobenius")
            acc_loss += l.item()
            num_loss += 1

    if solve:
        avg_iters = acc_solver_iters / len(validation_loader)
        print(f"Validation iterations: {avg_iters:.2f}")
        return avg_iters
    else:
        avg_loss = acc_loss / num_loss
        print(f"Validation loss: {avg_loss:.4f}")
        return avg_loss

In [None]:
# training loop
best_val = float("inf")
logger = TrainResults(folder)
total_it = 0

for epoch in range(config["num_epochs"]):
    running_loss = 0.0
    start_epoch = time.perf_counter()

    for data in train_loader:
        total_it += 1
        model.train()
        data = data.to(device)

        output, reg, _ = model(data)
        l = loss(output, data, c=reg, config=config["loss"])
        l.backward()

        # gradient clipping or manual norm
        if config["gradient_clipping"]:
            grad_norm = torch.nn.utils.clip_grad_norm_(
                model.parameters(), config["gradient_clipping"]
            )
        else:
            total_norm = sum(
                p.grad.detach().data.norm(2).item() ** 2
                for p in model.parameters() if p.grad is not None
            )
            grad_norm = (total_norm ** 0.5) / config["batch_size"]

        optimizer.step()
        optimizer.zero_grad()

        running_loss += l.item()
        logger.log(l.item(), grad_norm, time.perf_counter() - start=0)

        # periodic validation
        if total_it % 1000 == 0:
            val_metric = validate(
                model, val_loader, solve=True,
                solver="gmres" if use_gmres else "cg"
            )
            logger.log_val(None, val_metric)
            if val_metric < best_val:
                best_val = val_metric
                if config["save"]:
                    torch.save(model.state_dict(), f"{folder}/best_model.pt")

    epoch_time = time.perf_counter() - start_epoch
    print(f"Epoch {epoch+1} — loss: {running_loss/len(train_loader):.4f}, time: {epoch_time:.1f}s")
    if config["save"]:
        torch.save(model.state_dict(), f"{folder}/model_epoch{epoch+1}.pt")


In [None]:
# save results
if config["save"]:
    logger.save_results()
    torch.save(model.to(torch.float).state_dict(), f"{folder}/final_model.pt")



In [None]:
# test printout
print("Best validation performance:", best_val)