# **ENVIRONMENT SETUP**

In [24]:
%pip install -q pyg-lib -f https://data.pyg.org/whl/torch-2.6.0+cu124.html
%pip install -q torch torchvision torchaudio
%pip install -q torch-scatter -f https://data.pyg.org/whl/torch-2.6.0+cu124.html  # prevents the wheel from taking forever to build
%pip install -q torch-sparse -f https://data.pyg.org/whl/torch-2.6.0+cu124.html   # prevents the wheel from taking forever to build
%pip install -q torch-cluster -f https://data.pyg.org/whl/torch-2.6.0+cu124.html  # prevents the wheel from taking forever to build
%pip install -q torch_geometric

Note: you may need to restart the kernel to use updated packages.
Note: you may need to restart the kernel to use updated packages.
Note: you may need to restart the kernel to use updated packages.
Note: you may need to restart the kernel to use updated packages.
Note: you may need to restart the kernel to use updated packages.
Note: you may need to restart the kernel to use updated packages.


In [25]:
import torch
print(torch.__version__)

2.7.1+cu126


In [26]:
%pwd

'/home/nicholas_tyy/Documents/TU_Delft/CS4350/proj/GML-Project'

In [27]:
#Change based on project structure
%cd ../Documents/Y3S2/CS4350/GML-Project

[Errno 2] No such file or directory: '../Documents/Y3S2/CS4350/GML-Project'
/home/nicholas_tyy/Documents/TU_Delft/CS4350/proj/GML-Project


In [32]:
import os
import datetime
import time
import pickle

import numpy as np
import torch
import torch_geometric
import torch.nn as nn
import torch_geometric.nn as pyg
from torch_geometric.nn import aggr
from torch_geometric.data import Batch, Data
from torch_geometric.loader import DataLoader
from torch_geometric.utils import subgraph
from tqdm import tqdm

from apps.data import graph_to_matrix, FolderDataset
from neuralif.utils import (
    count_parameters, save_dict_to_file,
    TwoHop
)
from neuralif.logger import TrainResults
from neuralif.loss import loss

from krylov.cg import preconditioned_conjugate_gradient
from krylov.gmres import gmres

from numml.sparse import SparseCSRTensor

import metis  # pip install metis

In [29]:
device = torch.device(f"cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


# **DATASET CREATION**

In [30]:
from apps.synthetic import create_dataset
n = 10_000
alpha=10e-4

#create_dataset(size, number of samples, density, mode, seed, is_graph, compute_solution)
create_dataset(n, 100, alpha=alpha, mode='train', rs=0, graph=True, solution=True)
create_dataset(n, 10, alpha=alpha, mode='val', rs=10000, graph=True)
create_dataset(n, 10, alpha=alpha, mode='test', rs=103600, graph=True)

Generating 100 samples for the train dataset.


  2%|▏         | 2/100 [00:07<05:49,  3.57s/it]


KeyboardInterrupt: 

# **CONFIG**

In [None]:
config = {
    "name": "experiment_6",
    "save": True,
    "seed": 42,
    "n": 10000,
    "batch_size": 1,
    "num_epochs": 100,
    "dataset": "random",
    "loss": None,
    "gradient_clipping": 1.0,
    "regularizer": 0.0,
    "scheduler": False,
    "model": "neuralif",
    "normalize": False,
    "latent_size": 8,
    "message_passing_steps": 3,
    "decode_nodes": False,
    "normalize_diag": False,
    "aggregate": ["mean", "sum"],
    "activation": "relu",
    "skip_connections": True,
    "augment_nodes": False,
    "global_features": 0,
    "edge_features": 1,
    "graph_norm": False,
    "two_hop": False,
    "samples": 10,
    "num_neighbors": [15, 10],  # number of neighbours to sample in each hop (GraphSAGE sampling)
    "num_clusters": 10,  # Number of clusters to partition each graph into for GNN
    "clusters_per_batch": 5,  # Number of clusters to sample per batch for GNN
    "cluster_method": "metis",  # Options: 'metis', 'random', 'kmeans'
    "solver": "cg"
}

# Prepare output folder
if config["name"]:
    folder = f"results/{config['name']}"
else:
    folder = datetime.datetime.now().strftime("results/%Y-%m-%d_%H-%M-%S")
if config["save"]:
    os.makedirs(folder, exist_ok=True)
    save_dict_to_file(config, os.path.join(folder, "config.json"))


# **DATALOADERS**

In [None]:
def get_dataloader(dataset, n=0, batch_size=1, spd=True, mode="train", size=None, graph=True):
    # Setup datasets

    if dataset == "random":
        data = FolderDataset(f"./dataset/{mode}/", n, size=size, graph=graph)
    else:
        raise NotImplementedError("Dataset not implemented, Available: random")

    # Data Loaders
    if mode == "train":
        dataloader = DataLoader(data, batch_size=batch_size, shuffle=True)
    else:
        dataloader = DataLoader(data, batch_size=1, shuffle=False)

    return dataloader

In [None]:
class ClusterGCNSampler:
    def __init__(self, num_clusters, cluster_method='metis'):
        self.num_clusters = num_clusters
        self.cluster_method = cluster_method
        self.partitions = {}  # Cache for partitioned clusters
        self.cluster_nodes = {}  # Cache for nodes in each cluster

    def _get_graph_hash(self, data):
        """
        Create a stable hash for the graph data that doesn't depend on object identity
        """
        # Create hash based on graph structure, not object identity
        edge_hash = hash(tuple(data.edge_index.flatten().tolist()))
        node_hash = hash(data.num_nodes)
        
        # Include matrix hash if present
        if hasattr(data, 'matrix'):
            if hasattr(data.matrix, '_indices'):  # Sparse tensor
                matrix_hash = hash((tuple(data.matrix._indices().flatten().tolist()), 
                                  tuple(data.matrix._values().tolist())))
            else:  # Dense tensor
                matrix_hash = hash(tuple(data.matrix.flatten().tolist()))
        else:
            matrix_hash = 0
        
        return hash((edge_hash, node_hash, matrix_hash))

    def partition_graph_once(self, data):
        """
        Partition the graph once and cache the results using stable graph hash
        """
        graph_hash = self._get_graph_hash(data)
        
        if graph_hash in self.partitions:
            return self.partitions[graph_hash], self.cluster_nodes[graph_hash]
        
        edge_index = data.edge_index
        num_nodes = data.num_nodes

        # Create adjacency list for clustering
        if self.cluster_method == 'metis':
            # Convert to format expected by METIS
            adjacency_list = [[] for _ in range(num_nodes)]
            for i in range(edge_index.size(1)):
                src, dst = edge_index[0, i].item(), edge_index[1, i].item()
                if src != dst:  # Avoid self-loops for clustering
                    adjacency_list[src].append(dst)
                    adjacency_list[dst].append(src)

            try:
                # Use METIS for graph partitioning
                _, node_clusters = metis.part_graph(adjacency_list, self.num_clusters)
                node_clusters = torch.tensor(node_clusters, dtype=torch.long)
            except Exception as e:
                # Fallback to random partitioning if METIS fails
                print(f"METIS failed: {e}, falling back to random partitioning")
                node_clusters = torch.randint(0, self.num_clusters, (num_nodes,))
        else:
            # Random partitioning fallback
            node_clusters = torch.randint(0, self.num_clusters, (num_nodes,))

        # Pre-compute which nodes belong to each cluster
        cluster_to_nodes = {}
        for cluster_id in range(self.num_clusters):
            mask = (node_clusters == cluster_id)
            cluster_nodes = torch.where(mask)[0]
            if len(cluster_nodes) > 0:  # Only store non-empty clusters
                cluster_to_nodes[cluster_id] = cluster_nodes

        # Cache the results using stable hash
        self.partitions[graph_hash] = node_clusters
        self.cluster_nodes[graph_hash] = cluster_to_nodes
        
        return node_clusters, cluster_to_nodes

    def sample_clusters(self, data, clusters_per_batch):
        """
        Sample a subset of clusters and return the induced subgraph
        """
        # Get or compute partitions (only done once per unique graph)
        node_clusters, cluster_to_nodes = self.partition_graph_once(data)

        # Get available clusters (only non-empty ones)
        available_clusters = list(cluster_to_nodes.keys())
        
        if len(available_clusters) == 0:
            raise ValueError("No valid clusters found")
        
        # Randomly sample clusters for this batch
        if len(available_clusters) <= clusters_per_batch:
            selected_clusters = available_clusters
        else:
            # Use numpy for faster random sampling
            selected_clusters = np.random.choice(
                available_clusters, 
                size=clusters_per_batch, 
                replace=False
            ).tolist()

        # Collect all nodes from selected clusters
        selected_nodes_list = []
        for cluster_id in selected_clusters:
            selected_nodes_list.append(cluster_to_nodes[cluster_id])
        
        selected_nodes = torch.cat(selected_nodes_list)

        # Extract subgraph
        edge_index, edge_attr = subgraph(
            selected_nodes,
            data.edge_index,
            edge_attr=data.edge_attr if hasattr(data, 'edge_attr') else None,
            relabel_nodes=True,
            num_nodes=data.num_nodes
        )

        # Create new data object for the subgraph
        subgraph_data = Data(
            edge_index=edge_index,
            num_nodes=len(selected_nodes)
        )

        # Copy relevant attributes
        if hasattr(data, 'x') and data.x is not None:
            subgraph_data.x = data.x[selected_nodes]

        if hasattr(data, 'edge_attr') and data.edge_attr is not None:
            subgraph_data.edge_attr = edge_attr

        # For your preconditioner application, you'll need to handle the matrix
        if hasattr(data, 'matrix'):
            # Extract the submatrix corresponding to selected nodes
            subgraph_data.matrix = data.matrix[selected_nodes][:, selected_nodes]

        # Store mapping for reconstruction if needed
        subgraph_data.original_nodes = selected_nodes
        subgraph_data.node_mapping = {new_idx: old_idx.item()
                                    for new_idx, old_idx in enumerate(selected_nodes)}
        subgraph_data.selected_clusters = selected_clusters

        return subgraph_data

class ClusterGCNDataLoader:
    """
    Custom DataLoader that uses ClusterGCN sampling with preprocessing partitioning
    """
    def __init__(self, dataset, batch_size, num_clusters, clusters_per_batch,
                 cluster_method='metis', shuffle=True, load_from_cache=False):
        self.dataset = dataset
        self.batch_size = batch_size
        self.shuffle = shuffle
        self.sampler = ClusterGCNSampler(num_clusters, cluster_method)
        self.clusters_per_batch = clusters_per_batch
        self.load_from_cache = load_from_cache
        
        if not self.load_from_cache:
            self._precompute_all_partitions()
        else: 
            print("load_from_cache=True: skipping Graph partitioning step...")

    def _precompute_all_partitions(self):
        """
        Precompute all graph partitions during initialization
        """
        print(f"Preprocessing: Computing graph partitions for {len(self.dataset)} graphs...")
        start_time = time.time()
        
        for i, data in tqdm(enumerate(self.dataset)):
            self.sampler.partition_graph_once(data)
            
            # Progress reporting
            if len(self.dataset) > 100:
                if (i + 1) % max(1, len(self.dataset) // 20) == 0:  # Report every 5%
                    elapsed = time.time() - start_time
                    progress = (i + 1) / len(self.dataset)
                    eta = elapsed / progress - elapsed if progress > 0 else 0
                    print(f"  Progress: {i + 1}/{len(self.dataset)} ({progress*100:.1f}%) - "
                          f"Elapsed: {elapsed:.1f}s - ETA: {eta:.1f}s")
            elif (i + 1) % 10 == 0 or i == len(self.dataset) - 1:
                print(f"  Partitioned {i + 1}/{len(self.dataset)} graphs")
        
        total_time = time.time() - start_time
        print(f"Preprocessing completed in {total_time:.2f} seconds")
        print(f"Average partitioning time: {total_time/len(self.dataset)*1000:.2f}ms per graph")

    def __iter__(self):
        indices = list(range(len(self.dataset)))
        if self.shuffle:
            indices = torch.randperm(len(indices)).tolist()

        for i in range(0, len(indices), self.batch_size):
            batch_indices = indices[i:i + self.batch_size]
            batch_data = []

            for idx in batch_indices:
                data = self.dataset[idx]
                # Sample clusters for this graph (partitioning already done)
                subgraph_data = self.sampler.sample_clusters(data, self.clusters_per_batch)
                batch_data.append(subgraph_data)

            if len(batch_data) == 1:
                yield batch_data[0]
            else:
                yield Batch.from_data_list(batch_data)

    def __len__(self):
        return (len(self.dataset) + self.batch_size - 1) // self.batch_size

def get_cluster_dataloader(dataset, n, batch_size, spd=True, mode="train", num_clusters=10,
                           clusters_per_batch=2, size=None, cache_dir="./partition_cache", load_from_cache=False):
    """
    If load_from_cache=False: partitions & caches to disk, returns loader.
    If load_from_cache=True: loads partitions & returns loader (no re-partition).
    """
    print(f"Loading dataset [{mode}]…")
    ds = FolderDataset(f"./dataset/{mode}/", n, size=size, graph=True)
    print(f"  → {len(ds)} graphs")

    os.makedirs(cache_dir, exist_ok=True)
    cache_file = os.path.join(cache_dir, f"{mode}_n{n}_k{num_clusters}_partitions.pkl")

    if load_from_cache:
        # 1) load pickle
        if not os.path.exists(cache_file):
            raise FileNotFoundError(f"No cache found at {cache_file!r}. Check if the parameters are correct.")
        with open(cache_file, "rb") as f:
            data = pickle.load(f)
        parts, nodes = data["partitions"], data["cluster_nodes"]
        # 2) build loader without precompute
        loader = ClusterGCNDataLoader(
            ds, batch_size, num_clusters, clusters_per_batch,
            cluster_method='metis',
            shuffle=(mode=="train"),
            load_from_cache=True
        )
        # 3) inject caches
        loader.sampler.partitions    = parts
        loader.sampler.cluster_nodes = nodes
        print(f"✔ Loaded partitions from {cache_file}")
    else:
        # build & precompute
        loader = ClusterGCNDataLoader(
            ds, batch_size, num_clusters, clusters_per_batch,
            cluster_method='metis',
            shuffle=(mode=="train"),
            load_from_cache=False
        )
        # save to disk
        with open(cache_file, "wb") as f:
            pickle.dump({
                "partitions": loader.sampler.partitions,
                "cluster_nodes": loader.sampler.cluster_nodes
            }, f)
        print(f"✔ Cached partitions to {cache_file}")

    print("DataLoader ready!")
    return loader

# **MODELS**

In [15]:
class GraphNet(nn.Module):
    # Follows roughly the outline of torch_geometric.nn.MessagePassing()
    # As shown in https://github.com/deepmind/graph_nets
    # Here is a helpful python implementation:
    # https://github.com/NVIDIA/GraphQSat/blob/main/gqsat/models.py
    # Also allows multirgaph GNN via edge_2_features
    def __init__(self, node_features, edge_features, global_features=0, hidden_size=0,
                 aggregate="mean", activation="relu", skip_connection=False, edge_features_out=None):

        super().__init__()

        # different aggregation functions
        if aggregate == "sum":
            self.aggregate = aggr.SumAggregation()
        elif aggregate == "mean":
            self.aggregate = aggr.MeanAggregation()
        elif aggregate == "max":
            self.aggregate = aggr.MaxAggregation()
        elif aggregate == "softmax":
            self.aggregate = aggr.SoftmaxAggregation(learn=True)
        else:
            raise NotImplementedError(f"Aggregation '{aggregate}' not implemented")

        self.global_aggregate = aggr.MeanAggregation()

        add_edge_fs = 1 if skip_connection else 0
        edge_features_out = edge_features if edge_features_out is None else edge_features_out

        # Graph Net Blocks (see https://arxiv.org/pdf/1806.01261.pdf)
        self.edge_block = MLP([global_features + (edge_features + add_edge_fs) + (2 * node_features),
                               hidden_size,
                               edge_features_out],
                              activation=activation)

        self.node_block = MLP([global_features + edge_features_out + node_features,
                               hidden_size,
                               node_features],
                              activation=activation)

        # optional set of blocks for global GNN
        self.global_block = None
        if global_features > 0:
            self.global_block = MLP([edge_features_out + node_features + global_features,
                                     hidden_size,
                                     global_features],
                                    activation=activation)

    def forward(self, x, edge_index, edge_attr, g=None):
        row, col = edge_index

        if self.global_block is not None:
            assert g is not None, "Need global features for global block"

            # run the edge update and aggregate features
            edge_embedding = self.edge_block(torch.cat([torch.ones(x[row].shape[0], 1, device=x.device) * g,
                                                        x[row], x[col], edge_attr], dim=1))
            aggregation = self.aggregate(edge_embedding, row)


            agg_features = torch.cat([torch.ones(x.shape[0], 1, device=x.device) * g, x, aggregation], dim=1)
            node_embeddings = self.node_block(agg_features)

            # aggregate over all edges and nodes (always mean)
            mp_global_aggr = g
            edge_aggregation_global = self.global_aggregate(edge_embedding)
            node_aggregation_global = self.global_aggregate(node_embeddings)

            # compute the new global embedding
            # the old global feature is part of mp_global_aggr
            global_embeddings = self.global_block(torch.cat([node_aggregation_global,
                                                             edge_aggregation_global,
                                                             mp_global_aggr], dim=1))

            return edge_embedding, node_embeddings, global_embeddings

        else:
            # update edge features and aggregate
            edge_embedding = self.edge_block(torch.cat([x[row], x[col], edge_attr], dim=1))
            aggregation = self.aggregate(edge_embedding, row)
            agg_features = torch.cat([x, aggregation], dim=1)
            # update node features
            node_embeddings = self.node_block(agg_features)
            return edge_embedding, node_embeddings, None


In [16]:
class NeuralIF(nn.Module):
    # Neural Incomplete factorization
    def __init__(self, drop_tol=0, **kwargs) -> None:
        super().__init__()

        self.global_features = kwargs["global_features"]
        self.latent_size = kwargs["latent_size"]
        # node features are augmented with local degree profile
        self.augment_node_features = kwargs["augment_nodes"]

        num_node_features = 8 if self.augment_node_features else 1
        message_passing_steps = kwargs["message_passing_steps"]

        # edge feature representation in the latent layers
        edge_features = kwargs.get("edge_features", 1)

        self.skip_connections = kwargs["skip_connections"]

        self.mps = torch.nn.ModuleList()
        for l in range(message_passing_steps):
            # skip connections are added to all layers except the first one
            self.mps.append(MP_Block(skip_connections=self.skip_connections,
                                     first=l==0,
                                     last=l==(message_passing_steps-1),
                                     edge_features=edge_features,
                                     node_features=num_node_features,
                                     global_features=self.global_features,
                                     hidden_size=self.latent_size,
                                     activation=kwargs["activation"],
                                     aggregate=kwargs["aggregate"]))

        # node decodings
        self.node_decoder = MLP([num_node_features, self.latent_size, 1]) if kwargs["decode_nodes"] else None

        # diag-aggregation for normalization of rows
        self.normalize_diag = kwargs["normalize_diag"] if "normalize_diag" in kwargs else False
        self.diag_aggregate = aggr.SumAggregation()

        # normalization
        self.graph_norm = pyg.norm.GraphNorm(num_node_features) if ("graph_norm" in kwargs and kwargs["graph_norm"]) else None

        # drop tolerance and additional fill-ins and more sparsity
        self.tau = drop_tol
        self.two = kwargs.get("two_hop", False)

    def forward(self, data):
        # ! data could be batched here...(not implemented)

        if self.augment_node_features:
            data = augment_features(data, skip_rhs=True)

        # add additional edges to the data
        if self.two:
            data = TwoHop()(data)

        # * in principle it is possible to integrate reordering here.

        data = ToLowerTriangular()(data)

        # get the input data
        edge_embedding = data.edge_attr
        l_index = data.edge_index

        if self.graph_norm is not None:
            node_embedding = self.graph_norm(data.x, batch=data.batch)
        else:
            node_embedding = data.x

        # copy the input data (only edges of original matrix A)
        a_edges = edge_embedding.clone()

        if self.global_features > 0:
            global_features = torch.zeros((1, self.global_features), device=data.x.device, requires_grad=False)
            # feature ideas: nnz, 1-norm, inf-norm col/row var, min/max variability, avg distances to nnz
        else:
            global_features = None

        # compute the output of the network
        for i, layer in enumerate(self.mps):
            if i != 0 and self.skip_connections:
                edge_embedding = torch.cat([edge_embedding, a_edges], dim=1)

            edge_embedding, node_embedding, global_features = layer(node_embedding, l_index, edge_embedding, global_features)

        # transform the output into a matrix
        return self.transform_output_matrix(node_embedding, l_index, edge_embedding, a_edges)

    def transform_output_matrix(self, node_x, edge_index, edge_values, a_edges):
        # force diagonal to be positive
        diag = edge_index[0] == edge_index[1]

        # normalize diag such that it has zero residual
        if self.normalize_diag:
            # copy the diag of matrix A
            a_diag = a_edges[diag]

            # compute the row norm
            square_values = torch.pow(edge_values, 2)
            aggregated = self.diag_aggregate(square_values, edge_index[0])

            # now, we renormalize the edge values such that they are the square root of the original value...
            edge_values = torch.sqrt(a_diag[edge_index[0]]) * edge_values / torch.sqrt(aggregated[edge_index[0]])

        else:
            # otherwise, just take the edge values as they are...
            # but take the square root as it is numerically better
            # edge_values[diag] = torch.exp(edge_values[diag])
            edge_values[diag] = torch.sqrt(torch.exp(edge_values[diag]))

        # node decoder
        node_output = self.node_decoder(node_x).squeeze() if self.node_decoder is not None else None

        # ! this if should only be activated when the model is in production!!
        if torch.is_inference_mode_enabled():

            # we can decide to remove small elements during inference from the preconditioner matrix
            if self.tau != 0:
                small_value = (torch.abs(edge_values) <= self.tau).squeeze()

                # small value and not diagonal
                elems = torch.logical_and(small_value, torch.logical_not(diag))

                # might be able to do this easily!
                edge_values[elems] = 0

                # remove zeros from the sparse representation
                filt = (edge_values != 0).squeeze()
                edge_values = edge_values[filt]
                edge_index = edge_index[:, filt]

            # ! this is the way to go!!
            # Doing pytorch -> scipy -> numml is a lot faster than pytorch -> numml on CPU
            # On GPU it is faster to go to pytorch -> numml -> CPU

            # convert to scipy sparse matrix
            # m = to_scipy_sparse_matrix(edge_index, matrix_values)
            m = torch.sparse_coo_tensor(edge_index, edge_values.squeeze(),
                                        size=(node_x.size()[0], node_x.size()[0]))
                                        # type=torch.double)

            # produce L and U seperatly
            l = SparseCSRTensor(m)
            u = SparseCSRTensor(m.T)

            return l, u, node_output

        else:
            # For training and testing (computing regular losses for examples.)
            # does not need to be performance optimized!
            # use torch sparse directly
            t = torch.sparse_coo_tensor(edge_index, edge_values.squeeze(),
                                        size=(node_x.size()[0], node_x.size()[0]))

            # normalized l1 norm is best computed here!
            # l2_nn = torch.linalg.norm(edge_values, ord=2)
            l1_penalty = torch.sum(torch.abs(edge_values)) / len(edge_values)

            return t, l1_penalty, node_output


In [17]:
class MLP(nn.Module):
    def __init__(self, width, layer_norm=False, activation="relu", activate_final=False):
        super().__init__()
        width = list(filter(lambda x: x > 0, width))
        assert len(width) >= 2, "Need at least one layer in the network!"

        lls = nn.ModuleList()
        for k in range(len(width)-1):
            lls.append(nn.Linear(width[k], width[k+1], bias=True))
            if k != (len(width)-2) or activate_final:
                if activation == "relu":
                    lls.append(nn.ReLU())
                elif activation == "tanh":
                    lls.append(nn.Tanh())
                elif activation == "leakyrelu":
                    lls.append(nn.LeakyReLU())
                elif activation == "sigmoid":
                    lls.append(nn.Sigmoid())
                else:
                    raise NotImplementedError(f"Activation '{activation}' not implemented")

        if layer_norm:
            lls.append(nn.LayerNorm(width[-1]))

        self.m = nn.Sequential(*lls)

    def forward(self, x):
        return self.m(x)

class ToLowerTriangular(torch_geometric.transforms.BaseTransform):
    def __init__(self, inplace=False):
        self.inplace = inplace

    def __call__(self, data, order=None):
        if not self.inplace:
            data = data.clone()

        # TODO: if order is given use that one instead
        if order is not None:
            raise NotImplementedError("Custom ordering not yet implemented...")

        # transform the data into lower triag graph
        # this should be a data transformation (maybe?)
        rows, cols = data.edge_index[0], data.edge_index[1]
        fil = cols <= rows
        l_index = data.edge_index[:, fil]
        edge_embedding = data.edge_attr[fil]

        data.edge_index, data.edge_attr = l_index, edge_embedding
        return data

In [18]:
class MP_Block(nn.Module):
    # L@L.T matrix multiplication graph layer
    # Aligns the computation of L@L.T - A with the learned updates
    def __init__(self, skip_connections, first, last, edge_features, node_features, global_features, hidden_size, **kwargs) -> None:
        super().__init__()

        # first and second aggregation
        if "aggregate" in kwargs and kwargs["aggregate"] is not None:
            aggr = kwargs["aggregate"] if len(kwargs["aggregate"]) == 2 else kwargs["aggregate"] * 2
        else:
            aggr = ["mean", "sum"]

        act = kwargs["activation"] if "activation" in kwargs else "relu"

        edge_features_in = 1 if first else edge_features
        edge_features_out = 1 if last else edge_features

        # We use 2 graph nets in order to operate on the upper and lower triangular parts of the matrix
        self.l1 = GraphNet(node_features=node_features, edge_features=edge_features_in, global_features=global_features,
                           hidden_size=hidden_size, skip_connection=(not first and skip_connections),
                           aggregate=aggr[0], activation=act, edge_features_out=edge_features)

        self.l2 = GraphNet(node_features=node_features, edge_features=edge_features, global_features=global_features,
                           hidden_size=hidden_size, aggregate=aggr[1], activation=act, edge_features_out=edge_features_out)

    def forward(self, x, edge_index, edge_attr, global_features):
        edge_embedding, node_embeddings, global_features = self.l1(x, edge_index, edge_attr, g=global_features)

        # flip row and column indices
        edge_index = torch.stack([edge_index[1], edge_index[0]], dim=0)
        edge_embedding, node_embeddings, global_features = self.l2(node_embeddings, edge_index, edge_embedding, g=global_features)

        return edge_embedding, node_embeddings, global_features

# **VALIDATION**

In [None]:
class Preconditioner:
    def __init__(self, A, **kwargs):
        self.breakdown = False
        self.nnz = 0
        self.time = 0
        self.n = kwargs.get("n", 0)

    def timed_setup(self, A, **kwargs):
        start = time_function()
        self.setup(A, **kwargs)
        stop = time_function()
        self.time = stop - start

    def get_inverse(self):
        ones = torch.ones(self.n)
        offset = torch.zeros(1).to(torch.int64)

        I = torch.sparse.spdiags(ones, offset, (self.n, self.n))
        I = I.to(torch.float64)

        return I

    def get_p_matrix(self):
        return self.get_inverse()

    def check_breakdown(self, P):
        if np.isnan(np.min(P)):
            self.breakdown = True

    def __call__(self, x):
        return x

class LearnedPreconditioner(Preconditioner):
    def __init__(self, data, model, **kwargs):
        super().__init__(data, **kwargs)

        self.model = model
        self.spd = isinstance(model, NeuralIF)

        self.timed_setup(data, **kwargs)

        if self.spd:
            self.nnz = self.L.nnz
        else:
            self.nnz = self.L.nnz + self.U.nnz - data.x.shape[0]

    def setup(self, data, **kwargs):
        L, U, _ = self.model(data)

        self.L = L.to("cpu").to(torch.float64)
        self.U = U.to("cpu").to(torch.float64)

    def get_inverse(self):
        L_inv = torch.inverse(self.L.to_dense())
        U_inv = torch.inverse(self.U.to_dense())

        return U_inv@L_inv

    def get_p_matrix(self):
        return self.L@self.U

    def __call__(self, x):
        return fb_solve(self.L, self.U, x, unit_upper=not self.spd)


def fb_solve(L, U, r, unit_lower=False, unit_upper=False):
    y = L.solve_triangular(upper=False, unit=unit_lower, b=r)
    z = U.solve_triangular(upper=True, unit=unit_upper, b=y)
    return z


def fb_solve_joint(LU, r):
    # Note: solve triangular ignores the values in lower/upper triangle
    y = LU.solve_triangular(upper=False, unit=False, b=r)
    z = LU.solve_triangular(upper=True, unit=False, b=y)
    return z

time_function = lambda: time.perf_counter()

In [21]:
@torch.no_grad()
def validate(model, validation_loader, solve=False, solver="cg", **kwargs):
    model.eval()

    acc_loss = 0.0
    num_loss = 0
    acc_solver_iters = 0.0

    for i, data in enumerate(validation_loader):
        data = data.to(device)

        # construct problem data
        A, b = graph_to_matrix(data)

        # run conjugate gradient method
        # this requires the learned preconditioner to be reasonably good!
        if solve:
            # run CG on CPU
            with torch.inference_mode():
                preconditioner = LearnedPreconditioner(data, model)

            A = A.to("cpu").to(torch.float64)
            b = b.to("cpu").to(torch.float64)
            x_init = None

            solver_start = time.time()

            if solver == "cg":
                l, x_hat = preconditioned_conjugate_gradient(A.to("cpu"), b.to("cpu"), M=preconditioner,
                                                             x0=x_init, rtol=1e-6, max_iter=1_000)
                
                # l, x_hat = preconditioned_conjugate_gradient(A.to("cpu"), b.to("cpu"), M=None,
                #                                              x0=x_init, rtol=1e-6, max_iter=1_000)
            elif solver == "gmres":
                l, x_hat = gmres(A, b, M=preconditioner, x0=x_init, atol=1e-6, max_iter=1_000, left=False)
            else:
                raise NotImplementedError("Solver not implemented choose between CG and GMRES!")

            solver_stop = time.time()

            # Measure preconditioning performance
            solver_time = (solver_stop - solver_start)
            acc_solver_iters += len(l) - 1

        else:
            output, _, _ = model(data)

            # Here, we compute the loss using the full forbenius norm (no estimator)
            # l = frobenius_loss(output, A)

            l = loss(data, output, config=None) ##USe the default loss instead of the frobenius loss

            acc_loss += l.item()
            num_loss += 1

    if solve:
        # print(f"Smallest eigenvalue: {dist[0]}")
        print(f"Validation\t iterations:\t{acc_solver_iters / len(validation_loader):.2f}")
        return acc_solver_iters / len(validation_loader)

    else:
        print(f"Validation loss:\t{acc_loss / num_loss:.2f}")
        return acc_loss / len(validation_loader)


# **MODEL TRAINING**

## **BASELINE NEURALIF**

In [None]:
from torch_geometric.loader import DataLoader
from tqdm import tqdm
if config["save"]:
        os.makedirs(folder, exist_ok=True)
        save_dict_to_file(config, os.path.join(folder, "config.json"))

# global seed-ish
torch_geometric.seed_everything(config["seed"])

# args for the model
model_args = {k: config[k] for k in ["latent_size", "message_passing_steps", "skip_connections",
                                      "augment_nodes", "global_features", "decode_nodes",
                                      "normalize_diag", "activation", "aggregate", "graph_norm",
                                      "two_hop", "edge_features", "normalize"]
              if k in config}

# run the GMRES algorithm instead of CG (?)
gmres = False

# Create model

if config["model"] == "nif" or config["model"] == "neuralif" or config["model"] == "inf":
    model = NeuralIF(**model_args)


else:
    raise NotImplementedError

model.to(device)

print(f"Number params in model: {count_parameters(model)}")
print()

optimizer = torch.optim.AdamW(model.parameters())
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="min", factor=0.5, patience=20)

# Setup datasets
train_loader = get_dataloader(config["dataset"], config["n"], config["batch_size"],
                              spd=not gmres, mode="train")

validation_loader = get_dataloader(config["dataset"], config["n"], 1, spd=(not gmres), mode="val")

best_val = float("inf")
logger = TrainResults(folder)


total_it = 0

# Train loop
for epoch in range(config["num_epochs"]):
    running_loss = 0.0
    grad_norm = 0.0

    start_epoch = time.perf_counter()

    for it, data in tqdm(enumerate(train_loader)):
        # increase iteration count
        total_it += 1

        # enable training mode
        model.train()

        start = time.perf_counter()
        data = data.to(device)

        output, reg, _ = model(data)
        l = loss(output, data, c=reg, config=config["loss"])

        #  if reg:
        #    l = l + config["regularizer"] * reg

        l.backward()
        running_loss += l.item()

        # track the gradient norm
        if "gradient_clipping" in config and config["gradient_clipping"]:
            grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), config["gradient_clipping"])

        else:
            total_norm = 0.0

            for p in model.parameters():
                if p.grad is not None:
                    param_norm = p.grad.detach().data.norm(2)
                    total_norm += param_norm.item() ** 2

            grad_norm = total_norm ** 0.5 / config["batch_size"]

        # update network parameters
        optimizer.step()
        optimizer.zero_grad()

        logger.log(l.item(), grad_norm, time.perf_counter() - start)

        # Do validation after 100 updates (to support big datasets)
        # convergence is expected to be pretty fast...
        if (total_it + 1) % 1000 == 0:

            # start with cg-checks after 5 iterations
            val_its = validate(model, validation_loader, solve=True,
                                solver="gmres" if gmres else "cg")

            # use scheduler
            # if config["scheduler"]:
            #    scheduler.step(val_loss)

            logger.log_val(None, val_its)

            # val_perf = val_cgits if val_cgits > 0 else val_loss
            val_perf = val_its
            print(val_perf)

            if val_perf < best_val:
                if config["save"]:
                    torch.save(model.state_dict(), f"{folder}/best_model.pt")
                best_val = val_perf

    epoch_time = time.perf_counter() - start_epoch

    # save model every epoch for analysis...
    if config["save"]:
        torch.save(model.state_dict(), f"{folder}/model_epoch{epoch+1}.pt")

    print(f"Epoch {epoch+1} \t loss: {1/len(train_loader) * running_loss} \t time: {epoch_time}")

# save fully trained model
if config["save"]:
    logger.save_results()
    torch.save(model.to(torch.float).state_dict(), f"{folder}/final_model.pt")

# Test the model
# wandb.run.summary["validation_chol"] = best_val
print("Best validation loss:", best_val)

Number params in model: 460



136it [00:10, 12.68it/s]


Epoch 1 	 loss: 1097.0139586504768 	 time: 10.731323720000091


136it [00:10, 12.39it/s]


Epoch 2 	 loss: 471.38088181439565 	 time: 10.981040607999603


136it [00:10, 12.85it/s]


Epoch 3 	 loss: 431.88145850686465 	 time: 10.583723503000328


136it [00:10, 12.62it/s]


Epoch 4 	 loss: 419.03057457419004 	 time: 10.782060632000139


136it [00:10, 12.51it/s]


Epoch 5 	 loss: 407.4191353741814 	 time: 10.870200481999746


136it [00:10, 13.30it/s]


Epoch 6 	 loss: 393.6336337818819 	 time: 10.225598386000001


136it [00:10, 12.96it/s]


Epoch 7 	 loss: 384.4011674768784 	 time: 10.494448142999772


48it [02:06, 17.32s/it]

Validation	 iterations:	276.50
276.5


136it [02:13,  1.02it/s]


Epoch 8 	 loss: 376.21987129660215 	 time: 133.48633064900014


136it [00:11, 12.03it/s]


Epoch 9 	 loss: 370.46222080903897 	 time: 11.311065273999702


136it [00:11, 12.34it/s]


Epoch 10 	 loss: 367.51837360157685 	 time: 11.021956188999866


136it [00:11, 12.23it/s]


Epoch 11 	 loss: 365.1175483254825 	 time: 11.124394987999949


136it [00:10, 12.50it/s]


Epoch 12 	 loss: 362.9482177285587 	 time: 10.883757287000208


136it [00:10, 12.72it/s]


Epoch 13 	 loss: 361.38887382956113 	 time: 10.693607877999966


136it [00:10, 12.65it/s]


Epoch 14 	 loss: 360.9041445115033 	 time: 10.750646944000437


96it [02:05, 16.72s/it]

Validation	 iterations:	268.40
268.4


136it [02:08,  1.06it/s]


Epoch 15 	 loss: 359.74800760605757 	 time: 128.8257134720002


136it [00:11, 12.09it/s]


Epoch 16 	 loss: 357.1126677569221 	 time: 11.249121930000001


136it [00:11, 12.21it/s]


Epoch 17 	 loss: 356.74992280847886 	 time: 11.139036469000075


136it [00:10, 12.97it/s]


Epoch 18 	 loss: 355.80646829044116 	 time: 10.483333674999813


136it [00:10, 12.45it/s]


Epoch 19 	 loss: 356.0829209720387 	 time: 10.928011820999927


136it [00:10, 12.74it/s]


Epoch 20 	 loss: 354.3644451814539 	 time: 10.673291319000327


136it [00:10, 12.69it/s]


Epoch 21 	 loss: 354.01501554601333 	 time: 10.717812395000237


136it [00:10, 12.78it/s]


Epoch 22 	 loss: 353.87880661908315 	 time: 10.642187573999763


6it [00:32,  5.38s/it]


KeyboardInterrupt: 

## **CLUSTER-GCN**

In [22]:
# model setup

from tqdm import tqdm
if config["save"]:
        os.makedirs(folder, exist_ok=True)
        save_dict_to_file(config, os.path.join(folder, "config.json"))

# global seed-ish
torch_geometric.seed_everything(config["seed"])

# args for the model
model_args = {k: config[k] for k in ["latent_size", "message_passing_steps", "skip_connections",
                                      "augment_nodes", "global_features", "decode_nodes",
                                      "normalize_diag", "activation", "aggregate", "graph_norm",
                                      "two_hop", "edge_features", "normalize"]
              if k in config}

# run the GMRES algorithm instead of CG (?)
gmres = False

# Create model

if config["model"] == "nif" or config["model"] == "neuralif" or config["model"] == "inf":
    model = NeuralIF(**model_args)

else:
    raise NotImplementedError

model.to(device)

print(f"Number params in model: {count_parameters(model)}")

optimizer = torch.optim.AdamW(model.parameters())
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="min", factor=0.5, patience=20)

best_val = float("inf")
logger = TrainResults(folder)
total_it = 0

Number params in model: 460


In [23]:
#PARTITION
# running time for partitioning: ~ 5-10 mins for 100 graphs

load_from_cache = False  # set to false if no existing graph paritions cache

# hyperparams manually written to match cache file (if used)
n = 10000
num_clusters = 20

if load_from_cache: # load using params
    train_loader = get_cluster_dataloader(
        config["dataset"], n, config["batch_size"],
        mode="train", num_clusters=num_clusters,
        clusters_per_batch=config["clusters_per_batch"],
        load_from_cache=load_from_cache
    )

else:  # run full graph parititioning
    train_loader = get_cluster_dataloader(
        config["dataset"], config["n"], config["batch_size"],
        mode="train", num_clusters=config["num_clusters"],
        clusters_per_batch=config["clusters_per_batch"],
        load_from_cache=load_from_cache
    )

validation_loader = get_dataloader(config["dataset"], config["n"], 1, spd=(not gmres), mode="val")

Loading dataset [train]…
  → 100 graphs
Preprocessing: Computing graph partitions for 100 graphs...


10it [00:50,  5.00s/it]

  Partitioned 10/100 graphs


20it [01:39,  5.04s/it]

  Partitioned 20/100 graphs


30it [02:30,  5.05s/it]

  Partitioned 30/100 graphs


40it [03:21,  5.06s/it]

  Partitioned 40/100 graphs


50it [04:11,  5.00s/it]

  Partitioned 50/100 graphs


60it [05:01,  5.02s/it]

  Partitioned 60/100 graphs


70it [05:52,  5.11s/it]

  Partitioned 70/100 graphs


80it [06:41,  5.08s/it]

  Partitioned 80/100 graphs


90it [07:33,  5.24s/it]

  Partitioned 90/100 graphs


100it [08:23,  5.04s/it]

  Partitioned 100/100 graphs
Preprocessing completed in 503.67 seconds
Average partitioning time: 5036.70ms per graph
✔ Cached partitions to ./partition_cache/train_n10000_k10_partitions.pkl
DataLoader ready!





In [26]:
#TRAIN

# Train loop (similar to your existing code)
for epoch in range(config["num_epochs"]):
    running_loss = 0.0
    grad_norm = 0.0
    start_epoch = time.perf_counter()

    for it, data in tqdm(enumerate(train_loader)):
        total_it += 1
        model.train()

        start = time.perf_counter()
        data = data.to(device)

        # Your existing training logic
        output, reg, _ = model(data)
        l = loss(output, data, c=reg, config=config["loss"])

        l.backward()
        running_loss += l.item()

        # Gradient handling (your existing code)
        if "gradient_clipping" in config and config["gradient_clipping"]:
            grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), config["gradient_clipping"])
        else:
            total_norm = 0.0
            for p in model.parameters():
                if p.grad is not None:
                    param_norm = p.grad.detach().data.norm(2)
                    total_norm += param_norm.item() ** 2
            grad_norm = total_norm ** 0.5 / config["batch_size"]

        optimizer.step()
        optimizer.zero_grad()

        logger.log(l.item(), grad_norm, time.perf_counter() - start)

        # Validation (your existing logic)
        if (total_it) % 1000 == 0:
            val_its = validate(model, validation_loader, solve=True, solver="cg")
            logger.log_val(None, val_its)

            if val_its < best_val:
                if config["save"]:
                    torch.save(model.state_dict(), f"{folder}/best_model.pt")
                best_val = val_its

    epoch_time = time.perf_counter() - start_epoch

    if config["save"]:
        torch.save(model.state_dict(), f"{folder}/model_epoch{epoch+1}.pt")

    print(f"Epoch {epoch+1} \t loss: {1/len(train_loader) * running_loss} \t time: {epoch_time}")

# Final save
if config["save"]:
    logger.save_results()
    torch.save(model.to(torch.float).state_dict(), f"{folder}/final_model.pt")

print("Best validation performance:", best_val)

100it [00:10,  9.45it/s]


Epoch 1 	 loss: 511.6380316162109 	 time: 10.578984948000652


100it [00:09, 10.64it/s]


Epoch 2 	 loss: 246.5203454589844 	 time: 9.402286567999909


100it [00:09, 10.92it/s]


Epoch 3 	 loss: 137.14080520629884 	 time: 9.1661058589998


100it [00:09, 10.68it/s]


Epoch 4 	 loss: 126.32836334228516 	 time: 9.363024966998637


100it [00:09, 10.01it/s]


Epoch 5 	 loss: 121.8111898803711 	 time: 9.989059386000008


100it [00:09, 11.08it/s]


Epoch 6 	 loss: 116.57205749511719 	 time: 9.02541889999884


100it [00:09, 10.24it/s]


Epoch 7 	 loss: 113.5141293334961 	 time: 9.769512048000252


100it [00:09, 10.59it/s]


Epoch 8 	 loss: 110.85219772338867 	 time: 9.441301711000051


100it [00:09, 10.72it/s]


Epoch 9 	 loss: 107.37228118896485 	 time: 9.33072001599976


100it [04:58,  2.98s/it]


Validation	 iterations:	702.40
Epoch 10 	 loss: 104.30412284851074 	 time: 298.16499000099975


100it [00:11,  8.75it/s]


Epoch 11 	 loss: 101.62905364990235 	 time: 11.434606515000269


100it [00:11,  9.07it/s]


Epoch 12 	 loss: 100.26534980773926 	 time: 11.020935622000252


100it [00:11,  9.08it/s]


Epoch 13 	 loss: 98.43288116455078 	 time: 11.011818689999927


100it [00:11,  8.68it/s]


Epoch 14 	 loss: 97.81258888244629 	 time: 11.525848602001133


100it [00:10,  9.21it/s]


Epoch 15 	 loss: 95.95898880004883 	 time: 10.86220885900002


100it [00:11,  9.09it/s]


Epoch 16 	 loss: 95.70913780212402 	 time: 11.002918280999438


100it [00:11,  8.84it/s]


Epoch 17 	 loss: 96.02757682800294 	 time: 11.316912579999553


100it [00:11,  8.41it/s]


Epoch 18 	 loss: 92.39644554138184 	 time: 11.892047715000444


100it [00:11,  8.57it/s]


Epoch 19 	 loss: 93.45857192993164 	 time: 11.674255378000453


100it [03:52,  2.33s/it]


Validation	 iterations:	491.90
Epoch 20 	 loss: 91.650072555542 	 time: 232.7134814789988


100it [00:11,  8.82it/s]


Epoch 21 	 loss: 91.9603653717041 	 time: 11.343098290000853


100it [00:10,  9.68it/s]


Epoch 22 	 loss: 94.18167778015138 	 time: 10.330789814999662


49it [00:05,  9.12it/s]


KeyboardInterrupt: 

# **DEBUGGING**

In [None]:
##DEBUGGING
cluster_train_loader = get_cluster_dataloader(
    config["dataset"],
    config["n"],
    config["batch_size"],
    spd=True,
    mode="train",
    num_clusters=config["num_clusters"],
    clusters_per_batch=config["clusters_per_batch"]
)

train_loader = get_dataloader(config["dataset"], config["n"], config["batch_size"],
                              spd=not gmres, mode="train")

print(cluster_train_loader)
print(train_loader)


In [None]:
for i, data in enumerate(cluster_train_loader):
  if i == 1:
    print (data)
for i, data in enumerate(train_loader):
  if i == 1:
    print (data)

# **TESTING**

In [5]:
import argparse
import os
import datetime

import numpy as np
import scipy
import scipy.sparse
import torch
import json

from krylov.cg import conjugate_gradient, preconditioned_conjugate_gradient
from krylov.gmres import gmres
from krylov.preconditioner import get_preconditioner

from neuralif.models import NeuralIF, NeuralPCG, PreCondNet, LearnedLU
from neuralif.utils import torch_sparse_to_scipy, time_function
from neuralif.logger import TestResults

# from apps.data import matrix_to_graph_sparse, get_dataloader
from apps.data import matrix_to_graph_sparse

experiment_number = 6


@torch.inference_mode()
def test(model, test_loader, device, folder, save_results=False, dataset="random", solver="cg"):
    
    if save_results:
        os.makedirs(folder, exist_ok=False)

    print()
    print(f"Test:\t{len(test_loader.dataset)} samples")
    print(f"Solver:\t{solver} solver")
    print()
    
    # Two modes: either test baselines or the learned preconditioner
    if model is None:
        methods = ["baseline", "jacobi", "ilu"]
    else:
        assert solver in ["cg", "gmres"], "Data-driven method only works with CG or GMRES"
        methods = ["learned"]
    
    # using direct solver
    if solver == "direct":
        methods = ["direct"]
    
    for method in methods:
        print(f"Testing {method} preconditioner")
        
        test_results = TestResults(method, dataset, folder,
                                   model_name= f"\n{model.__class__.__name__}" if method == "learned" else "",
                                   target=1e-6,
                                   solver=solver)
        
        for sample, data in enumerate(test_loader):
            plot = save_results and sample == (len(test_loader.dataset) - 1)
            
            # Getting the preconditioners
            start = time_function()
            
            data = data.to(device)
            prec = get_preconditioner(data, method, model=model)
            
            # Get properties...
            p_time = prec.time
            breakdown = prec.breakdown
            nnzL = prec.nnz
            
            stop = time_function()
            
            A = torch.sparse_coo_tensor(data.edge_index, data.edge_attr.squeeze(),
                                        dtype=torch.float64,
                                        requires_grad=False).to("cpu").to_sparse_csr()
            
            b = data.x[:, 0].squeeze().to("cpu").to(torch.float64)
            b_norm = torch.linalg.norm(b)
            
            # we assume that b is unit norm wlog
            b = b / b_norm
            solution = data.s.to("cpu").to(torch.float64).squeeze() / b_norm if hasattr(data, "s") else None
            
            overhead = (stop - start) - (p_time)
            
            # RUN CONJUGATE GRADIENT
            start_solver = time_function()
            
            solver_settings = {
                "max_iter": 10_000,
                "x0": None
            }
            
            if breakdown:
                res = []
            
            elif solver == "direct":
                
                # convert to sparse matrix (scipy)
                A_ = torch.sparse_coo_tensor(data.edge_index, data.edge_attr.squeeze(),
                                             dtype=torch.float64, requires_grad=False)
                
                # scipy sparse...
                A_s = torch_sparse_to_scipy(A_).tocsr()
                
                # override start time
                start_solver = time_function()
                
                dense = False
                
                if dense:
                    _ = scipy.linalg.solve(A_.to_dense().numpy(), b.numpy(), assume_a='pos')
                else:
                    _ = scipy.sparse.linalg.spsolve(A_s, b.numpy())
                
                # dummy values...
                res = [(torch.Tensor([0]), torch.Tensor([0]))] * 2
            
            elif solver == "cg" and method == "baseline":
                # no preconditioner required when using baseline method
                res, _ = conjugate_gradient(A, b, x_true=solution,
                                            rtol=test_results.target, **solver_settings)
            
            elif solver == "cg":
                res, _ = preconditioned_conjugate_gradient(A, b, M=None, x_true=solution,
                                                           rtol=test_results.target, **solver_settings)
                
            elif solver == "gmres":
                
                res, _ = gmres(A, b, M=prec, x_true=solution,
                               **solver_settings, plot=plot,
                               atol=test_results.target,
                               left=False)
            
            stop_solver = time_function()
            solver_time = (stop_solver - start_solver)
            
            # LOGGING
            test_results.log_solve(A.shape[0], solver_time, len(res) - 1,
                                   np.array([r[0].item() for r in res]),
                                   np.array([r[1].item() for r in res]),
                                   p_time, overhead)
            
            # ANALYSIS of the preconditioner and its effects!
            nnzA = A._nnz()
            
            test_results.log(nnzA, nnzL, plot=plot)
            
            svd = False
            if svd:
                # compute largest and smallest singular value
                Pinv = prec.get_inverse()
                APinv = A.to_dense() @ Pinv
                
                # compute the singular values of the preconditioned matrix
                S = torch.linalg.svdvals(APinv)
                
                # print the smallest and largest singular value
                test_results.log_eigenval_dist(S, plot=plot)
                
                # compute the loss of the preconditioner
                p = prec.get_p_matrix()
                loss1 = torch.linalg.norm(p.to_dense() - A.to_dense(), ord="fro")
                
                a_inv = torch.linalg.inv(A.to_dense())
                loss2 = torch.linalg.norm(p.to_dense()@a_inv - torch.eye(a_inv.shape[0]), ord="fro")
                
                test_results.log_loss(loss1, loss2, plot=False)
                
                print(f"Smallest singular value: {S[-1]} | Largest singular value: {S[0]} | Condition number: {S[0] / S[-1]}")
                print(f"Loss Lmax: {loss1}\tLoss Lmin: {loss2}")
                print()
                
        if save_results:
            test_results.save_results()
        
        test_results.print_summary()


def load_checkpoint(model, args, device):
    # load the saved weights of the model and the hyper-parameters
    checkpoint = "latest"
    
    if checkpoint == "latest":
        # list all the directories in the results folder
        d = os.listdir("./results/")
        d.sort()
        
        config = None
        
        # find the latest checkpoint
        for i in range(len(d)):
            if os.path.isdir("./results/" + d[-i-1]):
                dir_contents = os.listdir("./results/" + d[-i-1])
                
                # looking for a directory with both config and model weights
                if "config.json" in dir_contents and "final_model.pt" in dir_contents:
                    # load the config.json file
                    with open("./results/" + d[-i-1] + "/config.json") as f:
                        config = json.load(f)
                        
                        # if config["model"] != args.model:
                        #     config = None
                        #     continue
                        #There is currently no best model yet
                        if "best_model.pt" in dir_contents:
                            checkpoint = "./results/" + d[-i-1] + "/best_model.pt"
                            break
                        else:
                            checkpoint = "./results/" + d[-i-1] + "/final_model.pt"
                            break
        if config is None:
            print("Checkpoint not found...")
        
        # neuralif has optional drop tolerance...
        # if args.model == "neuralif":
        #     config["drop_tol"] = args.drop_tol
        
        # intialize model and hyper-parameters
        model = model(**config)
        checkpoint= f"./results/experiment_{experiment_number}/best_model.pt"
        print(f"load checkpoint: {checkpoint}")
        
        model.load_state_dict(torch.load(checkpoint, weights_only=False, map_location=torch.device(device)))
    
    elif checkpoint is not None:
        with open(checkpoint + "/config.json") as f:
            config = json.load(f)
        
        # if args.model == "neuralif":
        #     config["drop_tol"] = args.drop_tol
        
        model = model(**config)
        print(f"load checkpoint: {checkpoint}")
        model.load_state_dict(torch.load(checkpoint + f"/{args.weights}.pt",
                                            map_location=torch.device(model.device)))
    
    else:
        model = model(**{"global_features": 0, "latent_size": 8, "augment_nodes": False,
                            "message_passing_steps": 3, "skip_connections": True, "activation": "relu",
                            "aggregate": None, "decode_nodes": False})
        
        print("No checkpoint provided, using random weights")
    
    return model


def warmup(model, device):
    # set testing parameters
    model.to(device)
    model.eval()
    
    # run model warmup
    test_size = 1_000
    matrix = scipy.sparse.coo_matrix((np.ones(test_size), (np.arange(test_size), np.arange(test_size))))
    data = matrix_to_graph_sparse(matrix, torch.ones(test_size))
    data.to(device)
    _ = model(data)
    
    print("Model warmup done...")


# # argument is the model to load and the dataset to evaluate on
# def argparser():
#     parser = argparse.ArgumentParser()

#     parser.add_argument("--name", type=str, default=None)
#     parser.add_argument("--device", type=int, required=False)
    
#     # select data driven model to run
#     parser.add_argument("--model", type=str, required=False, default="none")
#     parser.add_argument("--checkpoint", type=str, required=False)
#     parser.add_argument("--weights", type=str, required=False, default="model")
#     parser.add_argument("--drop_tol", type=float, default=0)
    
#     parser.add_argument("--solver", type=str, default="cg")
    
#     # select dataset and subset
#     parser.add_argument("--dataset", type=str, required=False, default="random")
#     parser.add_argument("--subset", type=str, required=False, default="test")
#     parser.add_argument("--n", type=int, required=False, default=0)
#     parser.add_argument("--samples", type=int, required=False, default=None)
    
#     # select if to save
#     parser.add_argument("--save", action='store_true', default=False)
    
#     return parser.parse_args()


def call():
    
    test_device = "cpu"
        
    folder = folder = "results/" + datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S')
    
    print()
    print(f"Using device: {test_device}")
    # torch.set_num_threads(1)
    
    # Load the model
    if config["model"] == "nif" or config["model"] == "neuralif":
        print("Use model: NeuralIF")
        model = NeuralIF
    
    else:
        raise NotImplementedError(f"Model {config["model"]} not available.")
    
    if model is not None:
        model = load_checkpoint(model, None, test_device)
        warmup(model, test_device)
    
    # spd = config["solver"] == "cg" or config["solver"] == "direct"
    spd = True
    testdata_loader = get_dataloader(config["dataset"], n=config["n"], batch_size=1, mode="test",
                                     size=config["samples"], spd=spd, graph=True)
    
    # Evaluate the model
    test(model, testdata_loader, test_device, folder,
         save_results=config["save"], dataset=config["dataset"], solver=config["solver"])

call()


Using device: cpu
Use model: NeuralIF
Checkpoint not found...


TypeError: neuralif.models.NeuralIF() argument after ** must be a mapping, not NoneType