In [None]:
#Install dependencies
%pip install torch_geometric

In [2]:
%pip install torch torchvision torchaudio

Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch)
  Downloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cufft-cu12==11.2.1.3 (from torch)
  Downloading nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-curand-cu12==10.3.5.147 (from torch)
  Downloading nvidia_curand_cu12-10.3.5

In [None]:
!pip3 uninstall numml
!git clone https://github.com/nicknytko/numml.git
%cd numml/
!pip3 install .
%cd ../

In [3]:
!pip3 install ilupp

Collecting ilupp
  Downloading ilupp-1.0.2.tar.gz (155 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/155.4 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m155.4/155.4 kB[0m [31m5.2 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Building wheels for collected packages: ilupp
  Building wheel for ilupp (setup.py) ... [?25l[?25hdone
  Created wheel for ilupp: filename=ilupp-1.0.2-cp311-cp311-linux_x86_64.whl size=3237628 sha256=26328801784c1609940f6b29e2a4460202fc1be7b22caaa9d15a62ed7d5c80f8
  Stored in directory: /root/.cache/pip/wheels/f8/29/86/ee4ba827a16c7450e9750424fd645ce86fc0d958fadf3e9f9c
Successfully built ilupp
Installing collected packages: ilupp
Successfully installed ilupp-1.0.2


In [4]:
#Imports

import os
import datetime
import pprint
import time

import ilupp
import numpy as np
import scipy
import torch
import torch_geometric
import torch.nn as nn
import torch_geometric.nn as pyg
from torch_geometric.nn import aggr
from torch_geometric.utils import to_scipy_sparse_matrix
from torch_geometric.loader import NeighborLoader
from torch_geometric.data import Batch
from scipy.sparse import tril, coo_matrix


from apps.data import get_dataloader, FolderDataset
from apps.data import graph_to_matrix, matrix_to_graph

from neuralif.utils import (
    count_parameters, save_dict_to_file,
    condition_number, eigenval_distribution, gershgorin_norm,
    TwoHop
)
from neuralif.logger import TrainResults, TestResults
from neuralif.loss import loss

from krylov.cg import preconditioned_conjugate_gradient
from krylov.gmres import gmres


In [5]:
#Select GPU as device to run code
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
print(torch.__version__)

Using device: cuda:0
2.6.0+cu124


# **DATA GENERATION**
Data has been pregenerated and put into github to prevent long time taken to re-generate matrices. If another dataset is to be generated, use the following code ot generate


In [None]:
#Generate matrices (This function takes a long time. The dataset has been pregenerated and put into our github repo)
#from apps.synthetic import create_dataset

# mat_size = 10_000
# density = 10e-4

# create_dataset(mat_size, 1000, alpha=density, mode='train', rs=0, graph=True, solution=True)
# create_dataset(mat_size, 10, alpha=density, mode='val', rs=10000, graph=True, solution=True)
# create_dataset(mat_size, 100, alpha=density, mode='test', rs=103600, graph=True, solution=True)


# **GRAPH NETWORK**

In [6]:
class GraphNet(nn.Module):
    # Follows roughly the outline of torch_geometric.nn.MessagePassing()
    # As shown in https://github.com/deepmind/graph_nets
    # Here is a helpful python implementation:
    # https://github.com/NVIDIA/GraphQSat/blob/main/gqsat/models.py
    # Also allows multirgaph GNN via edge_2_features
    def __init__(self, node_features, edge_features, global_features=0, hidden_size=0,
                 aggregate="mean", activation="relu", skip_connection=False, edge_features_out=None):

        super().__init__()

        # different aggregation functions
        if aggregate == "sum":
            self.aggregate = aggr.SumAggregation()
        elif aggregate == "mean":
            self.aggregate = aggr.MeanAggregation()
        elif aggregate == "max":
            self.aggregate = aggr.MaxAggregation()
        elif aggregate == "softmax":
            self.aggregate = aggr.SoftmaxAggregation(learn=True)
        else:
            raise NotImplementedError(f"Aggregation '{aggregate}' not implemented")

        self.global_aggregate = aggr.MeanAggregation()

        add_edge_fs = 1 if skip_connection else 0
        edge_features_out = edge_features if edge_features_out is None else edge_features_out

        # Graph Net Blocks (see https://arxiv.org/pdf/1806.01261.pdf)
        self.edge_block = MLP([global_features + (edge_features + add_edge_fs) + (2 * node_features),
                               hidden_size,
                               edge_features_out],
                              activation=activation)

        self.node_block = MLP([global_features + edge_features_out + node_features,
                               hidden_size,
                               node_features],
                              activation=activation)

        # optional set of blocks for global GNN
        self.global_block = None
        if global_features > 0:
            self.global_block = MLP([edge_features_out + node_features + global_features,
                                     hidden_size,
                                     global_features],
                                    activation=activation)

    def forward(self, x, edge_index, edge_attr, g=None):
        row, col = edge_index

        if self.global_block is not None:
            assert g is not None, "Need global features for global block"

            # run the edge update and aggregate features
            edge_embedding = self.edge_block(torch.cat([torch.ones(x[row].shape[0], 1, device=x.device) * g,
                                                        x[row], x[col], edge_attr], dim=1))
            aggregation = self.aggregate(edge_embedding, row)


            agg_features = torch.cat([torch.ones(x.shape[0], 1, device=x.device) * g, x, aggregation], dim=1)
            node_embeddings = self.node_block(agg_features)

            # aggregate over all edges and nodes (always mean)
            mp_global_aggr = g
            edge_aggregation_global = self.global_aggregate(edge_embedding)
            node_aggregation_global = self.global_aggregate(node_embeddings)

            # compute the new global embedding
            # the old global feature is part of mp_global_aggr
            global_embeddings = self.global_block(torch.cat([node_aggregation_global,
                                                             edge_aggregation_global,
                                                             mp_global_aggr], dim=1))

            return edge_embedding, node_embeddings, global_embeddings

        else:
            # update edge features and aggregate
            edge_embedding = self.edge_block(torch.cat([x[row], x[col], edge_attr], dim=1))
            aggregation = self.aggregate(edge_embedding, row)
            agg_features = torch.cat([x, aggregation], dim=1)
            # update node features
            node_embeddings = self.node_block(agg_features)
            return edge_embedding, node_embeddings, None




In [7]:
class MLP(nn.Module):
    def __init__(self, width, layer_norm=False, activation="relu", activate_final=False):
        super().__init__()
        width = list(filter(lambda x: x > 0, width))
        assert len(width) >= 2, "Need at least one layer in the network!"

        lls = nn.ModuleList()
        for k in range(len(width)-1):
            lls.append(nn.Linear(width[k], width[k+1], bias=True))
            if k != (len(width)-2) or activate_final:
                if activation == "relu":
                    lls.append(nn.ReLU())
                elif activation == "tanh":
                    lls.append(nn.Tanh())
                elif activation == "leakyrelu":
                    lls.append(nn.LeakyReLU())
                elif activation == "sigmoid":
                    lls.append(nn.Sigmoid())
                else:
                    raise NotImplementedError(f"Activation '{activation}' not implemented")

        if layer_norm:
            lls.append(nn.LayerNorm(width[-1]))

        self.m = nn.Sequential(*lls)

    def forward(self, x):
        return self.m(x)



In [8]:
class MP_Block(nn.Module):
    # L@L.T matrix multiplication graph layer
    # Aligns the computation of L@L.T - A with the learned updates
    def __init__(self, skip_connections, first, last, edge_features, node_features, global_features, hidden_size, **kwargs) -> None:
        super().__init__()

        # first and second aggregation
        if "aggregate" in kwargs and kwargs["aggregate"] is not None:
            aggr = kwargs["aggregate"] if len(kwargs["aggregate"]) == 2 else kwargs["aggregate"] * 2
        else:
            aggr = ["mean", "sum"]

        act = kwargs["activation"] if "activation" in kwargs else "relu"

        edge_features_in = 1 if first else edge_features
        edge_features_out = 1 if last else edge_features

        # We use 2 graph nets in order to operate on the upper and lower triangular parts of the matrix
        self.l1 = GraphNet(node_features=node_features, edge_features=edge_features_in, global_features=global_features,
                           hidden_size=hidden_size, skip_connection=(not first and skip_connections),
                           aggregate=aggr[0], activation=act, edge_features_out=edge_features)

        self.l2 = GraphNet(node_features=node_features, edge_features=edge_features, global_features=global_features,
                           hidden_size=hidden_size, aggregate=aggr[1], activation=act, edge_features_out=edge_features_out)

    def forward(self, x, edge_index, edge_attr, global_features):
        edge_embedding, node_embeddings, global_features = self.l1(x, edge_index, edge_attr, g=global_features)

        # flip row and column indices
        edge_index = torch.stack([edge_index[1], edge_index[0]], dim=0)
        edge_embedding, node_embeddings, global_features = self.l2(node_embeddings, edge_index, edge_embedding, g=global_features)

        return edge_embedding, node_embeddings, global_features



# **HELPER FUNCTIONS**

In [9]:
def augment_features(data, skip_rhs=False):
    # transform nodes to include more features

    if skip_rhs:
        # use instead notde position as an input feature!
        data.x = torch.arange(data.x.size()[0], device=data.x.device).unsqueeze(1)

    data = torch_geometric.transforms.LocalDegreeProfile()(data)

    # diagonal dominance and diagonal decay from the paper
    row, col = data.edge_index
    diag = (row == col)
    diag_elem = torch.abs(data.edge_attr[diag])
    # remove diagonal elements by setting them to zero
    non_diag_elem = data.edge_attr.clone()
    non_diag_elem[diag] = 0

    row_sums = aggr.SumAggregation()(torch.abs(non_diag_elem), row)
    alpha = diag_elem / row_sums
    row_dominance_feature = alpha / (alpha + 1)
    row_dominance_feature = torch.nan_to_num(row_dominance_feature, nan=1.0)

    # compute diagonal decay features
    row_max = aggr.MaxAggregation()(torch.abs(non_diag_elem), row)
    alpha = diag_elem / row_max
    row_decay_feature = alpha / (alpha + 1)
    row_decay_feature = torch.nan_to_num(row_decay_feature, nan=1.0)

    data.x = torch.cat([data.x, row_dominance_feature, row_decay_feature], dim=1)

    return data


class ToLowerTriangular(torch_geometric.transforms.BaseTransform):
    def __init__(self, inplace=False):
        self.inplace = inplace

    def __call__(self, data, order=None):
        if not self.inplace:
            data = data.clone()

        # TODO: if order is given use that one instead
        if order is not None:
            raise NotImplementedError("Custom ordering not yet implemented...")

        # transform the data into lower triag graph
        # this should be a data transformation (maybe?)
        rows, cols = data.edge_index[0], data.edge_index[1]
        fil = cols <= rows
        l_index = data.edge_index[:, fil]
        edge_embedding = data.edge_attr[fil]

        data.edge_index, data.edge_attr = l_index, edge_embedding
        return data




In [10]:
@torch.no_grad()
def validate(model, validation_loader, solve=False, solver="cg"):
    model.eval()
    acc_loss = 0.0
    num_loss = 0
    acc_solver_iters = 0.0

    for data in validation_loader:
        data = data.to(device)
        A, b = graph_to_matrix(data)

        if solve:
            preconditioner = LearnedPreconditioner(data, model)
            print(preconditioner)
            A_cpu = A.cpu().double()
            b_cpu = b.cpu().double()
            x0 = None

            start = time.time()
            if solver == "cg":
                print("Code reaches here")
                iters, x_hat = preconditioned_conjugate_gradient(
                    A_cpu, b_cpu, M=None, x0=x0,
                    rtol=1e-6, max_iter=1000
                )
            else:
                iters, x_hat = gmres(
                    A_cpu, b_cpu, M=preconditioner, x0=x0,
                    atol=1e-6, max_iter=1000, left=False
                )
            acc_solver_iters += len(iters) - 1
        else:
            output, _, _ = model(data)
            # l = frobenius_loss(output, A)
            l = loss(data, output, config="frobenius")
            acc_loss += l.item()
            num_loss += 1

    if solve:
        print("BLABLABLA")
        avg_iters = acc_solver_iters / len(validation_loader)
        print(f"Validation iterations: {avg_iters:.2f}")
        return avg_iters
    else:
        avg_loss = acc_loss / num_loss
        print(f"Validation loss: {avg_loss:.4f}")
        return avg_loss



# **MODEL CONFIGURATION**

In [11]:
config = {
    "name": "experiment_1",
    "save": True,
    "seed": 42,
    "n": 10000,
    "batch_size": 1,
    "num_epochs": 100,
    "dataset": "random",
    "loss": None,
    "gradient_clipping": 1.0,
    "regularizer": 0.0,
    "scheduler": False,
    "model": "neuralif",
    "normalize": False,
    "latent_size": 8,
    "message_passing_steps": 3,
    "decode_nodes": False,
    "normalize_diag": False,
    "aggregate": ["mean", "sum"],
    "activation": "relu",
    "skip_connections": True,
    "augment_nodes": False,
    "global_features": 0,
    "edge_features": 1,
    "graph_norm": False,
    "two_hop": False,
    "num_neighbors": [15, 10]  # number of neighbours to sample in each hop (GraphSAGE sampling)
}

# Prepare output folder
if config["name"]:
    folder = f"results/{config['name']}"
else:
    folder = datetime.datetime.now().strftime("results/%Y-%m-%d_%H-%M-%S")
if config["save"]:
    os.makedirs(folder, exist_ok=True)
    save_dict_to_file(config, os.path.join(folder, "config.json"))

In [12]:
#NeuralIF imported manually here to avoid using numml

class NeuralIF(nn.Module):
    # Neural Incomplete factorization
    def __init__(self, drop_tol=0, **kwargs) -> None:
        super().__init__()

        self.global_features = kwargs["global_features"]
        self.latent_size = kwargs["latent_size"]
        # node features are augmented with local degree profile
        self.augment_node_features = kwargs["augment_nodes"]

        num_node_features = 8 if self.augment_node_features else 1
        message_passing_steps = kwargs["message_passing_steps"]

        # edge feature representation in the latent layers
        edge_features = kwargs.get("edge_features", 1)

        self.skip_connections = kwargs["skip_connections"]

        self.mps = torch.nn.ModuleList()
        for l in range(message_passing_steps):
            # skip connections are added to all layers except the first one
            self.mps.append(MP_Block(skip_connections=self.skip_connections,
                                     first=l==0,
                                     last=l==(message_passing_steps-1),
                                     edge_features=edge_features,
                                     node_features=num_node_features,
                                     global_features=self.global_features,
                                     hidden_size=self.latent_size,
                                     activation=kwargs["activation"],
                                     aggregate=kwargs["aggregate"]))

        # node decodings
        self.node_decoder = MLP([num_node_features, self.latent_size, 1]) if kwargs["decode_nodes"] else None

        # diag-aggregation for normalization of rows
        self.normalize_diag = kwargs["normalize_diag"] if "normalize_diag" in kwargs else False
        self.diag_aggregate = aggr.SumAggregation()

        # normalization
        self.graph_norm = pyg.norm.GraphNorm(num_node_features) if ("graph_norm" in kwargs and kwargs["graph_norm"]) else None

        # drop tolerance and additional fill-ins and more sparsity
        self.tau = drop_tol
        self.two = kwargs.get("two_hop", False)

    def forward(self, data):
        # ! data could be batched here...(not implemented)

        if self.augment_node_features:
            data = augment_features(data, skip_rhs=True)

        # add additional edges to the data
        if self.two:
            data = TwoHop()(data)

        # * in principle it is possible to integrate reordering here.

        data = ToLowerTriangular()(data)

        # get the input data
        edge_embedding = data.edge_attr
        l_index = data.edge_index

        if self.graph_norm is not None:
            node_embedding = self.graph_norm(data.x, batch=data.batch)
        else:
            node_embedding = data.x

        # copy the input data (only edges of original matrix A)
        a_edges = edge_embedding.clone()

        if self.global_features > 0:
            global_features = torch.zeros((1, self.global_features), device=data.x.device, requires_grad=False)
            # feature ideas: nnz, 1-norm, inf-norm col/row var, min/max variability, avg distances to nnz
        else:
            global_features = None

        # compute the output of the network
        for i, layer in enumerate(self.mps):
            if i != 0 and self.skip_connections:
                edge_embedding = torch.cat([edge_embedding, a_edges], dim=1)

            edge_embedding, node_embedding, global_features = layer(node_embedding, l_index, edge_embedding, global_features)

        # transform the output into a matrix
        return self.transform_output_matrix(node_embedding, l_index, edge_embedding, a_edges)

    def transform_output_matrix(self, node_x, edge_index, edge_values, a_edges):
        # force diagonal to be positive
        diag = edge_index[0] == edge_index[1]

        # normalize diag such that it has zero residual
        if self.normalize_diag:
            # copy the diag of matrix A
            a_diag = a_edges[diag]

            # compute the row norm
            square_values = torch.pow(edge_values, 2)
            aggregated = self.diag_aggregate(square_values, edge_index[0])

            # now, we renormalize the edge values such that they are the square root of the original value...
            edge_values = torch.sqrt(a_diag[edge_index[0]]) * edge_values / torch.sqrt(aggregated[edge_index[0]])

        else:
            # otherwise, just take the edge values as they are...
            # but take the square root as it is numerically better
            # edge_values[diag] = torch.exp(edge_values[diag])
            edge_values[diag] = torch.sqrt(torch.exp(edge_values[diag]))

        # node decoder
        node_output = self.node_decoder(node_x).squeeze() if self.node_decoder is not None else None

        # ! this if should only be activated when the model is in production!!
        if torch.is_inference_mode_enabled():

            # we can decide to remove small elements during inference from the preconditioner matrix
            if self.tau != 0:
                small_value = (torch.abs(edge_values) <= self.tau).squeeze()

                # small value and not diagonal
                elems = torch.logical_and(small_value, torch.logical_not(diag))

                # might be able to do this easily!
                edge_values[elems] = 0

                # remove zeros from the sparse representation
                filt = (edge_values != 0).squeeze()
                edge_values = edge_values[filt]
                edge_index = edge_index[:, filt]

            # ! this is the way to go!!
            # Doing pytorch -> scipy -> numml is a lot faster than pytorch -> numml on CPU
            # On GPU it is faster to go to pytorch -> numml -> CPU

            # convert to scipy sparse matrix
            # m = to_scipy_sparse_matrix(edge_index, matrix_values)
            m = torch.sparse_coo_tensor(edge_index, edge_values.squeeze(),
                                        size=(node_x.size()[0], node_x.size()[0]))
                                        # type=torch.double)

            # produce L and U seperatly
            l = sp.SparseCSRTensor(m)
            u = sp.SparseCSRTensor(m.T)

            return l, u, node_output

        else:
            # For training and testing (computing regular losses for examples.)
            # does not need to be performance optimized!
            # use torch sparse directly
            t = torch.sparse_coo_tensor(edge_index, edge_values.squeeze(),
                                        size=(node_x.size()[0], node_x.size()[0]))

            # normalized l1 norm is best computed here!
            # l2_nn = torch.linalg.norm(edge_values, ord=2)
            l1_penalty = torch.sum(torch.abs(edge_values)) / len(edge_values)

            return t, l1_penalty, node_output



# **SETUP**
Model instantiation

In [13]:
# Seed for reproducibility
torch_geometric.seed_everything(config["seed"])

# Select model
model_args = {k: config[k] for k in [
    "latent_size", "message_passing_steps", "skip_connections",
    "augment_nodes", "global_features", "decode_nodes",
    "normalize_diag", "activation", "aggregate", "graph_norm",
    "two_hop", "edge_features", "normalize"
] if k in config}

use_gmres = False
if config["model"] in ("nif", "neuralif", "inf"):
    model = NeuralIF(**model_args) ##Assign model to be NeuralIF
else:
    raise ValueError("Unknown model type")

model.to(device)
print("Number of parameters:", count_parameters(model))

optimizer = torch.optim.AdamW(model.parameters())
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, mode="min", factor=0.5, patience=20
)

Number of parameters: 460


# DATALOADER **CLASS**

In [14]:
from glob import glob
import numpy as np
import torch
from scipy.sparse import coo_matrix
from torch_geometric.data import Data

from torch_geometric.loader import DataLoader


def matrix_to_graph_sparse(A, b):
    edge_index = torch.tensor(list(map(lambda x: [x[0], x[1]], zip(A.row, A.col))), dtype=torch.long)
    edge_features = torch.tensor(list(map(lambda x: [x], A.data)), dtype=torch.float)
    node_features = torch.tensor(list(map(lambda x: [x], b)), dtype=torch.float)

    # diag_elements = edge_index[:, 0] == edge_index[:, 1]
    # node_features = edge_features[diag_elements]
    # node_features = torch.cat((node_features, torch.tensor(list(map(lambda x: [x], b)), dtype=torch.float)), dim=1)

    # Embed the information into data object
    data = Data(x=node_features, edge_index=edge_index.t().contiguous(), edge_attr=edge_features)
    return data


def matrix_to_graph(A, b):
    return matrix_to_graph_sparse(coo_matrix(A), b)


def graph_to_matrix(data, normalize=False):
    A = torch.sparse_coo_tensor(data.edge_index, data.edge_attr[:, 0].squeeze(), requires_grad=False)
    b = data.x[:, 0].squeeze()

    if normalize:
        b = b / torch.linalg.norm(b)

    return A, b


def get_dataloader(dataset, n=0, batch_size=1, spd=True, mode="train", size=None, graph=True):
    # Setup datasets

    if dataset == "random":
        print("TEST")
        data = FolderDataset(f"./dataset/train/", n, size=size, graph=graph) #Note that for our dataset, n must be 10000??
        validation = FolderDataset(f"./dataset/validate/", n, size=size, graph=graph)
    else:
        raise NotImplementedError("Dataset not implemented, Available: random")

    # Data Loaders
    if mode == "train":
        dataloader = DataLoader(data, batch_size=batch_size, shuffle=True)
    else:
        dataloader = DataLoader(validation, batch_size=1, shuffle=False)

    return dataloader


class FolderDataset(torch.utils.data.Dataset):
    def __init__(self, folder, n, graph=True, size=None) -> None:
        super().__init__()

        self.graph = True
        assert self.graph, "Graph keyword is depracated, only graph=True is supported."

        if n != 0:
            if self.graph:
                self.files = list(filter(lambda x: x.split("/")[-1].split('_')[0] == str(n), glob(folder+'*.pt')))
            else:
                self.files = list(filter(lambda x: x.split("/")[-1].split('_')[0] == str(n), glob(folder+'*.npz')))
        else:
            file_ending = "pt" if self.graph else "npz"
            self.files = list(glob(folder+f'*.{file_ending}'))

        if size is not None:
            assert len(self.files) >= size, f"Only {len(self.files)} files found in {folder} with n={n}"
            self.files = self.files[:size]

        if len(self.files) == 0:
            raise FileNotFoundError(f"No files found in {folder} with n={n}")

    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):
        if self.graph:
            g = torch.load(self.files[idx], weights_only=False)

        else:
            # deprecated...
            d = np.load(self.files[idx], allow_pickle=True)
            g = matrix_to_graph(d["A"], d["b"])

        return g

# **MAIN CODE TO RUN (TRAINING THE MODEL)**

In [25]:
# 1. retrieve the FolderDataset object for each of train and validation loader based on .pt files
# 2. pass the FolderDataset object into torch.util's DataLoader as the dataset field
# 3. return this DataLoader object

# the loader below passes train_dataset directly to torch.util's DataLoader class



train_loader = get_dataloader(config["dataset"], config["n"], config["batch_size"],
                                  spd=not gmres, mode="train")

print(train_loader.dataset[0])
print(train_loader.dataset[0].x)
print(train_loader.dataset[0].edge_index)

print(len(train_loader))

validation_loader = get_dataloader(config["dataset"], config["n"], 1, spd=(not gmres), mode="val")

TEST
Data(x=[10000, 1], edge_index=[2, 1005398], edge_attr=[1005398, 1], s=[10000], n=10000)
tensor([[0.8808],
        [0.6831],
        [0.3183],
        ...,
        [0.7282],
        [0.6627],
        [0.9138]])
tensor([[   0,    0,    0,  ..., 9999, 9999, 9999],
        [   0, 1255, 3754,  ..., 4025, 4467, 6623]])
33
TEST


In [None]:
from torch_geometric.utils import add_self_loops

# training loop
best_val = float("inf")
logger = TrainResults(folder)
total_it = 0

for epoch in range(config["num_epochs"]):
    running_loss = 0.0
    start_epoch = time.perf_counter()

    for data in train_loader:
        total_it += 1
        model.train()

        start = time.perf_counter()
        data = data.to(device)

        ### resolving the unmatching dimension bug for NeighborSampler class
        # 1) override n properly
        data.n = int(data.x.size(0))

        # 2) add self-loops so each node has at least one incoming edge
        data.edge_index, data.edge_attr = add_self_loops(
            data.edge_index,
            data.edge_attr,
            fill_value=0.0,
            num_nodes=data.n
        )
        ###

        # print(f"Input training data to the model is {data}")
        output, reg, _ = model(data)
        # print(f"Output from the model is {output} and reg term is {reg}")

        l = loss(output, data, c=reg, config=config["loss"])
        l.backward()

        # gradient clipping or manual norm
        if config["gradient_clipping"]:
            grad_norm = torch.nn.utils.clip_grad_norm_(
                model.parameters(), config["gradient_clipping"]
            )
        else:
            total_norm = sum(
                p.grad.detach().data.norm(2).item() ** 2
                for p in model.parameters() if p.grad is not None
            )
            grad_norm = (total_norm ** 0.5) / config["batch_size"]

        optimizer.step()
        optimizer.zero_grad()

        running_loss += l.item()
        logger.log(l.item(), grad_norm, time.perf_counter() - start)

        # periodic validation
        if total_it % 200 == 0:
            use_gmres = False
            val_metric = validate(
                model, validation_loader, solve=True,
                solver="gmres" if use_gmres else "cg"
            )
            logger.log_val(None, val_metric)
            print(val_metric)
            if val_metric < best_val:
                best_val = val_metric
                if config["save"]:
                    torch.save(model.state_dict(), f"{folder}/best_model.pt")

    epoch_time = time.perf_counter() - start_epoch
    print(f"Epoch {epoch+1} ‚Äî loss: {running_loss/len(train_loader):.4f}, time: {epoch_time:.1f}s")
    if config["save"]:
        torch.save(model.state_dict(), f"{folder}/model_epoch{epoch+1}.pt")



Epoch 1 ‚Äî loss: 394.4080, time: 48.8s
Epoch 2 ‚Äî loss: 393.4740, time: 47.9s
Epoch 3 ‚Äî loss: 392.4386, time: 48.4s
Epoch 4 ‚Äî loss: 391.9327, time: 48.1s
Epoch 5 ‚Äî loss: 391.6303, time: 48.3s
Epoch 6 ‚Äî loss: 391.1181, time: 48.9s
<__main__.LearnedPreconditioner object at 0x7994a4a8ce50>
Code reaches here
<__main__.LearnedPreconditioner object at 0x7994a4bf9910>
Code reaches here
<__main__.LearnedPreconditioner object at 0x7994a4b79450>
Code reaches here
<__main__.LearnedPreconditioner object at 0x7994a4dc9a10>
Code reaches here
<__main__.LearnedPreconditioner object at 0x7994a51d8550>
Code reaches here
<__main__.LearnedPreconditioner object at 0x7994a51d9190>
Code reaches here
<__main__.LearnedPreconditioner object at 0x7994a4d07010>
Code reaches here


# **VALIDATION**

In [18]:
class Preconditioner:
    def __init__(self, A, **kwargs):
        self.breakdown = False
        self.nnz = 0
        self.time = 0
        self.n = kwargs.get("n", 0)

    def timed_setup(self, A, **kwargs):
        start = time_function()
        self.setup(A, **kwargs)
        stop = time_function()
        self.time = stop - start

    def get_inverse(self):
        ones = torch.ones(self.n)
        offset = torch.zeros(1).to(torch.int64)

        I = torch.sparse.spdiags(ones, offset, (self.n, self.n))
        I = I.to(torch.float64)

        return I

    def get_p_matrix(self):
        return self.get_inverse()

    def check_breakdown(self, P):
        if np.isnan(np.min(P)):
            self.breakdown = True

    def __call__(self, x):
        return x

time_function = lambda: time.perf_counter()


In [19]:
class LearnedPreconditioner(Preconditioner):
    def __init__(self, data, model, **kwargs):
        super().__init__(data, **kwargs)

        self.model = model
        self.spd = isinstance(model, NeuralIF)

        self.timed_setup(data, **kwargs)

        if self.spd:
            self.nnz = self.L._nnz()
        else:
            self.nnz = self.L._nnz() + self.U._nnz() - data.x.shape[0]

    def setup(self, data, **kwargs):
        L, U, _ = self.model(data)

        self.L = L.to("cuda")
        self.U = U.to("cuda")

    def get_inverse(self):
        L_inv = torch.inverse(self.L.to_dense())
        U_inv = torch.inverse(self.U.to_dense())

        return U_inv@L_inv

    def get_p_matrix(self):
        return self.L@self.U

    def __call__(self, x):
        # return fb_solve(self.L, self.U, x, unit_upper=not self.spd)
        P = self.get_p_matrix()  # P = L @ U
        return torch.linalg.solve(P, x)


def fb_solve(L, U, r, unit_lower=False, unit_upper=False):
    y = L.solve_triangular(upper=False, unit=unit_lower, b=r)
    z = U.solve_triangular(upper=True, unit=unit_upper, b=y)
    return z
