# **ENVIRONMENT SETUP**

In [None]:
from google.colab import drive
drive.mount('/content/drive', force_remount=True)

In [None]:
import sys
sys.path.append('/content/drive/MyDrive/CS4350/project/')

In [1]:
import torch
print(torch.__version__)

2.7.1


In [None]:
%pip install -q pyg-lib -f https://data.pyg.org/whl/torch-2.6.0+cu124.html
%pip install -q torch torchvision torchaudio
%pip install -q torch-scatter -f https://data.pyg.org/whl/torch-2.6.0+cu124.html  # prevents the wheel from taking forever to build
%pip install -q torch-sparse -f https://data.pyg.org/whl/torch-2.6.0+cu124.html   # prevents the wheel from taking forever to build
%pip install -q torch-cluster -f https://data.pyg.org/whl/torch-2.6.0+cu124.html  # prevents the wheel from taking forever to build
%pip install -q torch_geometric

In [None]:
# run this only if you don't already have a copy of the built numml wheel on google drive (~7 mins)
# build the wheel just once and then store in google drive /content/drive/MyDrive/CS4350/project/wheels
!git clone https://github.com/nicknytko/numml.git
%cd numml
!pip install --upgrade build
!python -m build --wheel
!mkdir -p /content/drive/MyDrive/CS4350/project/wheels
!mv dist/*.whl /content/drive/MyDrive/CS4350/project/wheels
%cd ..

In [None]:
# once the .whl file is in the drive, just run this cell to install numml within seconds
!pip install \
  /content/drive/MyDrive/CS4350/project/wheels/numml-0.1.0-cp311-cp311-linux_x86_64.whl

In [3]:
%cd ../Documents/Y3S2/CS4350/GML-Project


/Users/jylow/Documents/Y3S2/CS4350/GML-Project


In [4]:
import os
import datetime
import pprint
import time

import numpy as np
import scipy
import torch
import torch_geometric
import torch.nn as nn
import torch_geometric.nn as pyg
from torch_geometric.nn import aggr
from torch_geometric.utils import to_scipy_sparse_matrix
from torch_geometric.loader import NeighborLoader
from torch_geometric.data import Batch
from scipy.sparse import tril, coo_matrix

from apps.data import get_dataloader, graph_to_matrix, matrix_to_graph, FolderDataset
from neuralif.utils import (
    count_parameters, save_dict_to_file,
    condition_number, eigenval_distribution, gershgorin_norm,
    TwoHop
)
from neuralif.logger import TrainResults, TestResults
from neuralif.loss import loss

from krylov.cg import preconditioned_conjugate_gradient
from krylov.gmres import gmres

from numml.sparse import SparseCSRTensor


In [14]:
!nvcc --version
!gcc --version
!python -c "import torch; print(torch.__version__, torch.version.cuda)"


zsh:1: command not found: nvcc
Apple clang version 17.0.0 (clang-1700.0.13.3)
Target: arm64-apple-darwin24.3.0
Thread model: posix
InstalledDir: /Library/Developer/CommandLineTools/usr/bin
2.7.1 None


In [5]:
device = torch.device(f"cuda" if torch.cuda.is_available() else "cpu")
print(device)

cpu


# **DATASET CREATION**

In [52]:
from apps.synthetic import create_dataset
n = 10_000
alpha=10e-4

create_dataset(n, 1000, alpha=alpha, mode='train', rs=0, graph=True, solution=True)
create_dataset(n, 10, alpha=alpha, mode='val', rs=10000, graph=True)
create_dataset(n, 100, alpha=alpha, mode='test', rs=103600, graph=True)

Generating 1000 samples for the train dataset.


100%|██████████| 1000/1000 [50:44<00:00,  3.04s/it]


Generating 10 samples for the val dataset.


100%|██████████| 10/10 [00:18<00:00,  1.86s/it]


Generating 100 samples for the test dataset.


100%|██████████| 100/100 [03:03<00:00,  1.83s/it]


In [6]:
config = {
    "name": "experiment_1",
    "save": True,
    "seed": 42,
    "n": 10000,
    "batch_size": 1,
    "num_epochs": 10000,
    "dataset": "random",
    "loss": None,
    "gradient_clipping": 1.0,
    "regularizer": 0.0,
    "scheduler": False,
    "model": "neuralif",
    "normalize": False,
    "latent_size": 8,
    "message_passing_steps": 3,
    "decode_nodes": False,
    "normalize_diag": False,
    "aggregate": ["mean", "sum"],
    "activation": "relu",
    "skip_connections": True,
    "augment_nodes": False,
    "global_features": 0,
    "edge_features": 1,
    "graph_norm": False,
    "two_hop": False,
    "num_neighbors": [15, 10],  # number of neighbours to sample in each hop (GraphSAGE sampling)
    "num_clusters": 10,  # Number of clusters to partition each graph into
    "clusters_per_batch": 3,  # Number of clusters to sample per batch
    "cluster_method": "metis"  # Options: 'metis', 'random', 'kmeans'
}

# Prepare output folder
if config["name"]:
    folder = f"results/{config['name']}"
else:
    folder = datetime.datetime.now().strftime("results/%Y-%m-%d_%H-%M-%S")
if config["save"]:
    os.makedirs(folder, exist_ok=True)
    save_dict_to_file(config, os.path.join(folder, "config.json"))


In [None]:
!sudo apt-get install libmetis-dev
!pip install metis

In [8]:
import torch
import torch.nn.functional as F
from torch_geometric.data import Data, Batch
from torch_geometric.loader import DataLoader
from torch_geometric.utils import subgraph
import torch_geometric
import numpy as np
from sklearn.cluster import KMeans
import metis  # pip install metis
import time
import os

class ClusterGCNSampler:
    def __init__(self, num_clusters, cluster_method='metis'):
        self.num_clusters = num_clusters
        self.cluster_method = cluster_method
        self.node_clusters = {}  # Cache for node cluster assignments

    def partition_graph(self, data):
        edge_index = data.edge_index
        num_nodes = data.num_nodes

        # Create adjacency list for clustering
        if self.cluster_method == 'metis':
            # Convert to format expected by METIS
            adjacency_list = [[] for _ in range(num_nodes)]
            for i in range(edge_index.size(1)):
                src, dst = edge_index[0, i].item(), edge_index[1, i].item()
                if src != dst:  # Avoid self-loops for clustering
                    adjacency_list[src].append(dst)
                    adjacency_list[dst].append(src)

            try:
                # Use METIS for graph partitioning
                _, node_clusters = metis.part_graph(adjacency_list, self.num_clusters)
                node_clusters = torch.tensor(node_clusters, dtype=torch.long)
            except:
                # Fallback to random partitioning if METIS fails
                print("metis failed")
                node_clusters = torch.randint(0, self.num_clusters, (num_nodes,))

        return node_clusters

    def sample_clusters(self, data, clusters_per_batch):
        """
        Sample a subset of clusters and return the induced subgraph
        """
        # Get or compute node cluster assignments
        graph_id = id(data)  # Use object id as cache key
        if graph_id not in self.node_clusters:
            self.node_clusters[graph_id] = self.partition_graph(data)

        node_clusters = self.node_clusters[graph_id]

        # Randomly sample clusters for this batch
        available_clusters = torch.unique(node_clusters)
        if len(available_clusters) <= clusters_per_batch:
            selected_clusters = available_clusters
        else:
            selected_clusters = available_clusters[torch.randperm(len(available_clusters))[:clusters_per_batch]]

        # Get nodes belonging to selected clusters
        mask = torch.zeros(data.num_nodes, dtype=torch.bool)
        for cluster_id in selected_clusters:
            mask |= (node_clusters == cluster_id)

        selected_nodes = torch.where(mask)[0]

        # Extract subgraph
        edge_index, edge_attr = subgraph(
            selected_nodes,
            data.edge_index,
            edge_attr=data.edge_attr if hasattr(data, 'edge_attr') else None,
            relabel_nodes=True,
            num_nodes=data.num_nodes
        )

        # Create new data object for the subgraph
        subgraph_data = Data(
            edge_index=edge_index,
            num_nodes=len(selected_nodes)
        )

        # Copy relevant attributes
        if hasattr(data, 'x') and data.x is not None:
            subgraph_data.x = data.x[selected_nodes]

        if hasattr(data, 'edge_attr') and data.edge_attr is not None:
            subgraph_data.edge_attr = edge_attr

        # For your preconditioner application, you'll need to handle the matrix
        if hasattr(data, 'matrix'):
            # Extract the submatrix corresponding to selected nodes
            subgraph_data.matrix = data.matrix[selected_nodes][:, selected_nodes]

        # Store mapping for reconstruction if needed
        subgraph_data.original_nodes = selected_nodes
        subgraph_data.node_mapping = {new_idx: old_idx.item()
                                    for new_idx, old_idx in enumerate(selected_nodes)}

        return subgraph_data

class ClusterGCNDataLoader:
    """
    Custom DataLoader that uses ClusterGCN sampling
    """
    def __init__(self, dataset, batch_size, num_clusters, clusters_per_batch,
                 cluster_method='metis', shuffle=True):
        self.dataset = dataset
        self.batch_size = batch_size
        self.shuffle = shuffle
        self.sampler = ClusterGCNSampler(num_clusters, cluster_method)
        self.clusters_per_batch = clusters_per_batch

    def __iter__(self):
        indices = list(range(len(self.dataset)))
        if self.shuffle:
            torch.randperm(len(indices))

        for i in range(0, len(indices), self.batch_size):
            batch_indices = indices[i:i + self.batch_size]
            batch_data = []

            for idx in batch_indices:
                data = self.dataset[idx]
                # Sample clusters for this graph
                subgraph_data = self.sampler.sample_clusters(data, self.clusters_per_batch)
                batch_data.append(subgraph_data)

            if len(batch_data) == 1:
                yield batch_data[0]
            else:
                yield Batch.from_data_list(batch_data)

    def __len__(self):
        return (len(self.dataset) + self.batch_size - 1) // self.batch_size

# Modified training function for your code
def get_cluster_dataloader(dataset, n, size, spd=True, mode="train",
                          num_clusters=10, clusters_per_batch=2):
    """
    Modified version of your get_dataloader function that uses ClusterGCN
    """
    # Get your original dataset (replace with your actual dataset loading logic)
    # dataset = get_dataset(dataset_name, n, spd, mode)  # Your existing function
    data = FolderDataset(f"./dataset/{mode}/", n, size=size, graph=True)

    # if dataset == "random":
    #     data = FolderDataset(f"./dataset/{mode}/", n, size=size, graph=graph)

    # else:
    #     raise NotImplementedError("Dataset not implemented, Available: random")

    # # Data Loaders
    # if mode == "train":
    #     dataloader = DataLoader(data, batch_size=batch_size, shuffle=True)
    # else:
    #     dataloader = DataLoader(data, batch_size=1, shuffle=False)

    # Wrap with ClusterGCN loader
    return ClusterGCNDataLoader(
        data,
        batch_size=size,
        num_clusters=num_clusters,
        clusters_per_batch=clusters_per_batch,
        cluster_method='metis',
        shuffle=(mode == "train")
    )


In [9]:
class GraphNet(nn.Module):
    # Follows roughly the outline of torch_geometric.nn.MessagePassing()
    # As shown in https://github.com/deepmind/graph_nets
    # Here is a helpful python implementation:
    # https://github.com/NVIDIA/GraphQSat/blob/main/gqsat/models.py
    # Also allows multirgaph GNN via edge_2_features
    def __init__(self, node_features, edge_features, global_features=0, hidden_size=0,
                 aggregate="mean", activation="relu", skip_connection=False, edge_features_out=None):

        super().__init__()

        # different aggregation functions
        if aggregate == "sum":
            self.aggregate = aggr.SumAggregation()
        elif aggregate == "mean":
            self.aggregate = aggr.MeanAggregation()
        elif aggregate == "max":
            self.aggregate = aggr.MaxAggregation()
        elif aggregate == "softmax":
            self.aggregate = aggr.SoftmaxAggregation(learn=True)
        else:
            raise NotImplementedError(f"Aggregation '{aggregate}' not implemented")

        self.global_aggregate = aggr.MeanAggregation()

        add_edge_fs = 1 if skip_connection else 0
        edge_features_out = edge_features if edge_features_out is None else edge_features_out

        # Graph Net Blocks (see https://arxiv.org/pdf/1806.01261.pdf)
        self.edge_block = MLP([global_features + (edge_features + add_edge_fs) + (2 * node_features),
                               hidden_size,
                               edge_features_out],
                              activation=activation)

        self.node_block = MLP([global_features + edge_features_out + node_features,
                               hidden_size,
                               node_features],
                              activation=activation)

        # optional set of blocks for global GNN
        self.global_block = None
        if global_features > 0:
            self.global_block = MLP([edge_features_out + node_features + global_features,
                                     hidden_size,
                                     global_features],
                                    activation=activation)

    def forward(self, x, edge_index, edge_attr, g=None):
        row, col = edge_index

        if self.global_block is not None:
            assert g is not None, "Need global features for global block"

            # run the edge update and aggregate features
            edge_embedding = self.edge_block(torch.cat([torch.ones(x[row].shape[0], 1, device=x.device) * g,
                                                        x[row], x[col], edge_attr], dim=1))
            aggregation = self.aggregate(edge_embedding, row)


            agg_features = torch.cat([torch.ones(x.shape[0], 1, device=x.device) * g, x, aggregation], dim=1)
            node_embeddings = self.node_block(agg_features)

            # aggregate over all edges and nodes (always mean)
            mp_global_aggr = g
            edge_aggregation_global = self.global_aggregate(edge_embedding)
            node_aggregation_global = self.global_aggregate(node_embeddings)

            # compute the new global embedding
            # the old global feature is part of mp_global_aggr
            global_embeddings = self.global_block(torch.cat([node_aggregation_global,
                                                             edge_aggregation_global,
                                                             mp_global_aggr], dim=1))

            return edge_embedding, node_embeddings, global_embeddings

        else:
            # update edge features and aggregate
            edge_embedding = self.edge_block(torch.cat([x[row], x[col], edge_attr], dim=1))
            aggregation = self.aggregate(edge_embedding, row)
            agg_features = torch.cat([x, aggregation], dim=1)
            # update node features
            node_embeddings = self.node_block(agg_features)
            return edge_embedding, node_embeddings, None


In [10]:
class NeuralIF(nn.Module):
    # Neural Incomplete factorization
    def __init__(self, drop_tol=0, **kwargs) -> None:
        super().__init__()

        self.global_features = kwargs["global_features"]
        self.latent_size = kwargs["latent_size"]
        # node features are augmented with local degree profile
        self.augment_node_features = kwargs["augment_nodes"]

        num_node_features = 8 if self.augment_node_features else 1
        message_passing_steps = kwargs["message_passing_steps"]

        # edge feature representation in the latent layers
        edge_features = kwargs.get("edge_features", 1)

        self.skip_connections = kwargs["skip_connections"]

        self.mps = torch.nn.ModuleList()
        for l in range(message_passing_steps):
            # skip connections are added to all layers except the first one
            self.mps.append(MP_Block(skip_connections=self.skip_connections,
                                     first=l==0,
                                     last=l==(message_passing_steps-1),
                                     edge_features=edge_features,
                                     node_features=num_node_features,
                                     global_features=self.global_features,
                                     hidden_size=self.latent_size,
                                     activation=kwargs["activation"],
                                     aggregate=kwargs["aggregate"]))

        # node decodings
        self.node_decoder = MLP([num_node_features, self.latent_size, 1]) if kwargs["decode_nodes"] else None

        # diag-aggregation for normalization of rows
        self.normalize_diag = kwargs["normalize_diag"] if "normalize_diag" in kwargs else False
        self.diag_aggregate = aggr.SumAggregation()

        # normalization
        self.graph_norm = pyg.norm.GraphNorm(num_node_features) if ("graph_norm" in kwargs and kwargs["graph_norm"]) else None

        # drop tolerance and additional fill-ins and more sparsity
        self.tau = drop_tol
        self.two = kwargs.get("two_hop", False)

    def forward(self, data):
        # ! data could be batched here...(not implemented)

        if self.augment_node_features:
            data = augment_features(data, skip_rhs=True)

        # add additional edges to the data
        if self.two:
            data = TwoHop()(data)

        # * in principle it is possible to integrate reordering here.

        data = ToLowerTriangular()(data)

        # get the input data
        edge_embedding = data.edge_attr
        l_index = data.edge_index

        if self.graph_norm is not None:
            node_embedding = self.graph_norm(data.x, batch=data.batch)
        else:
            node_embedding = data.x

        # copy the input data (only edges of original matrix A)
        a_edges = edge_embedding.clone()

        if self.global_features > 0:
            global_features = torch.zeros((1, self.global_features), device=data.x.device, requires_grad=False)
            # feature ideas: nnz, 1-norm, inf-norm col/row var, min/max variability, avg distances to nnz
        else:
            global_features = None

        # compute the output of the network
        for i, layer in enumerate(self.mps):
            if i != 0 and self.skip_connections:
                edge_embedding = torch.cat([edge_embedding, a_edges], dim=1)

            edge_embedding, node_embedding, global_features = layer(node_embedding, l_index, edge_embedding, global_features)

        # transform the output into a matrix
        return self.transform_output_matrix(node_embedding, l_index, edge_embedding, a_edges)

    def transform_output_matrix(self, node_x, edge_index, edge_values, a_edges):
        # force diagonal to be positive
        diag = edge_index[0] == edge_index[1]

        # normalize diag such that it has zero residual
        if self.normalize_diag:
            # copy the diag of matrix A
            a_diag = a_edges[diag]

            # compute the row norm
            square_values = torch.pow(edge_values, 2)
            aggregated = self.diag_aggregate(square_values, edge_index[0])

            # now, we renormalize the edge values such that they are the square root of the original value...
            edge_values = torch.sqrt(a_diag[edge_index[0]]) * edge_values / torch.sqrt(aggregated[edge_index[0]])

        else:
            # otherwise, just take the edge values as they are...
            # but take the square root as it is numerically better
            # edge_values[diag] = torch.exp(edge_values[diag])
            edge_values[diag] = torch.sqrt(torch.exp(edge_values[diag]))

        # node decoder
        node_output = self.node_decoder(node_x).squeeze() if self.node_decoder is not None else None

        # ! this if should only be activated when the model is in production!!
        if torch.is_inference_mode_enabled():

            # we can decide to remove small elements during inference from the preconditioner matrix
            if self.tau != 0:
                small_value = (torch.abs(edge_values) <= self.tau).squeeze()

                # small value and not diagonal
                elems = torch.logical_and(small_value, torch.logical_not(diag))

                # might be able to do this easily!
                edge_values[elems] = 0

                # remove zeros from the sparse representation
                filt = (edge_values != 0).squeeze()
                edge_values = edge_values[filt]
                edge_index = edge_index[:, filt]

            # ! this is the way to go!!
            # Doing pytorch -> scipy -> numml is a lot faster than pytorch -> numml on CPU
            # On GPU it is faster to go to pytorch -> numml -> CPU

            # convert to scipy sparse matrix
            # m = to_scipy_sparse_matrix(edge_index, matrix_values)
            m = torch.sparse_coo_tensor(edge_index, edge_values.squeeze(),
                                        size=(node_x.size()[0], node_x.size()[0]))
                                        # type=torch.double)

            # produce L and U seperatly
            l = SparseCSRTensor(m)
            u = SparseCSRTensor(m.T)

            return l, u, node_output

        else:
            # For training and testing (computing regular losses for examples.)
            # does not need to be performance optimized!
            # use torch sparse directly
            t = torch.sparse_coo_tensor(edge_index, edge_values.squeeze(),
                                        size=(node_x.size()[0], node_x.size()[0]))

            # normalized l1 norm is best computed here!
            # l2_nn = torch.linalg.norm(edge_values, ord=2)
            l1_penalty = torch.sum(torch.abs(edge_values)) / len(edge_values)

            return t, l1_penalty, node_output


In [11]:
class MLP(nn.Module):
    def __init__(self, width, layer_norm=False, activation="relu", activate_final=False):
        super().__init__()
        width = list(filter(lambda x: x > 0, width))
        assert len(width) >= 2, "Need at least one layer in the network!"

        lls = nn.ModuleList()
        for k in range(len(width)-1):
            lls.append(nn.Linear(width[k], width[k+1], bias=True))
            if k != (len(width)-2) or activate_final:
                if activation == "relu":
                    lls.append(nn.ReLU())
                elif activation == "tanh":
                    lls.append(nn.Tanh())
                elif activation == "leakyrelu":
                    lls.append(nn.LeakyReLU())
                elif activation == "sigmoid":
                    lls.append(nn.Sigmoid())
                else:
                    raise NotImplementedError(f"Activation '{activation}' not implemented")

        if layer_norm:
            lls.append(nn.LayerNorm(width[-1]))

        self.m = nn.Sequential(*lls)

    def forward(self, x):
        return self.m(x)

class ToLowerTriangular(torch_geometric.transforms.BaseTransform):
    def __init__(self, inplace=False):
        self.inplace = inplace

    def __call__(self, data, order=None):
        if not self.inplace:
            data = data.clone()

        # TODO: if order is given use that one instead
        if order is not None:
            raise NotImplementedError("Custom ordering not yet implemented...")

        # transform the data into lower triag graph
        # this should be a data transformation (maybe?)
        rows, cols = data.edge_index[0], data.edge_index[1]
        fil = cols <= rows
        l_index = data.edge_index[:, fil]
        edge_embedding = data.edge_attr[fil]

        data.edge_index, data.edge_attr = l_index, edge_embedding
        return data

In [12]:
class MP_Block(nn.Module):
    # L@L.T matrix multiplication graph layer
    # Aligns the computation of L@L.T - A with the learned updates
    def __init__(self, skip_connections, first, last, edge_features, node_features, global_features, hidden_size, **kwargs) -> None:
        super().__init__()

        # first and second aggregation
        if "aggregate" in kwargs and kwargs["aggregate"] is not None:
            aggr = kwargs["aggregate"] if len(kwargs["aggregate"]) == 2 else kwargs["aggregate"] * 2
        else:
            aggr = ["mean", "sum"]

        act = kwargs["activation"] if "activation" in kwargs else "relu"

        edge_features_in = 1 if first else edge_features
        edge_features_out = 1 if last else edge_features

        # We use 2 graph nets in order to operate on the upper and lower triangular parts of the matrix
        self.l1 = GraphNet(node_features=node_features, edge_features=edge_features_in, global_features=global_features,
                           hidden_size=hidden_size, skip_connection=(not first and skip_connections),
                           aggregate=aggr[0], activation=act, edge_features_out=edge_features)

        self.l2 = GraphNet(node_features=node_features, edge_features=edge_features, global_features=global_features,
                           hidden_size=hidden_size, aggregate=aggr[1], activation=act, edge_features_out=edge_features_out)

    def forward(self, x, edge_index, edge_attr, global_features):
        edge_embedding, node_embeddings, global_features = self.l1(x, edge_index, edge_attr, g=global_features)

        # flip row and column indices
        edge_index = torch.stack([edge_index[1], edge_index[0]], dim=0)
        edge_embedding, node_embeddings, global_features = self.l2(node_embeddings, edge_index, edge_embedding, g=global_features)

        return edge_embedding, node_embeddings, global_features

In [13]:
def get_dataloader(dataset, n=0, batch_size=1, spd=True, mode="train", size=None, graph=True):
    # Setup datasets

    if dataset == "random":
        data = FolderDataset(f"./dataset/{mode}/", n, size=size, graph=graph)

    else:
        raise NotImplementedError("Dataset not implemented, Available: random")

    # Data Loaders
    if mode == "train":
        dataloader = DataLoader(data, batch_size=batch_size, shuffle=True)
    else:
        dataloader = DataLoader(data, batch_size=1, shuffle=False)

    return dataloader

In [14]:
class Preconditioner:
    def __init__(self, A, **kwargs):
        self.breakdown = False
        self.nnz = 0
        self.time = 0
        self.n = kwargs.get("n", 0)

    def timed_setup(self, A, **kwargs):
        start = time_function()
        self.setup(A, **kwargs)
        stop = time_function()
        self.time = stop - start

    def get_inverse(self):
        ones = torch.ones(self.n)
        offset = torch.zeros(1).to(torch.int64)

        I = torch.sparse.spdiags(ones, offset, (self.n, self.n))
        I = I.to(torch.float64)

        return I

    def get_p_matrix(self):
        return self.get_inverse()

    def check_breakdown(self, P):
        if np.isnan(np.min(P)):
            self.breakdown = True

    def __call__(self, x):
        return x

class LearnedPreconditioner(Preconditioner):
    def __init__(self, data, model, **kwargs):
        super().__init__(data, **kwargs)

        self.model = model
        self.spd = isinstance(model, NeuralIF)

        self.timed_setup(data, **kwargs)

        if self.spd:
            self.nnz = self.L.nnz
        else:
            self.nnz = self.L.nnz + self.U.nnz - data.x.shape[0]

    def setup(self, data, **kwargs):
        L, U, _ = self.model(data)

        self.L = L.to("cpu").to(torch.float64)
        self.U = U.to("cpu").to(torch.float64)

    def get_inverse(self):
        L_inv = torch.inverse(self.L.to_dense())
        U_inv = torch.inverse(self.U.to_dense())

        return U_inv@L_inv

    def get_p_matrix(self):
        return self.L@self.U

    def __call__(self, x):
        return fb_solve(self.L, self.U, x, unit_upper=not self.spd)


def fb_solve(L, U, r, unit_lower=False, unit_upper=False):
    y = L.solve_triangular(upper=False, unit=unit_lower, b=r)
    z = U.solve_triangular(upper=True, unit=unit_upper, b=y)
    return z


def fb_solve_joint(LU, r):
    # Note: solve triangular ignores the values in lower/upper triangle
    y = LU.solve_triangular(upper=False, unit=False, b=r)
    z = LU.solve_triangular(upper=True, unit=False, b=y)
    return z

time_function = lambda: time.perf_counter()

# **VALIDATION FUNCTION**

In [15]:
@torch.no_grad()
def validate(model, validation_loader, solve=False, solver="cg", **kwargs):
    model.eval()

    acc_loss = 0.0
    num_loss = 0
    acc_solver_iters = 0.0

    for i, data in enumerate(validation_loader):
        data = data.to(device)

        # construct problem data
        A, b = graph_to_matrix(data)

        # run conjugate gradient method
        # this requires the learned preconditioner to be reasonably good!
        if solve:
            # run CG on CPU
            with torch.inference_mode():
                preconditioner = LearnedPreconditioner(data, model)

            A = A.to("cpu").to(torch.float64)
            b = b.to("cpu").to(torch.float64)
            x_init = None

            solver_start = time.time()

            if solver == "cg":
                l, x_hat = preconditioned_conjugate_gradient(A.to("cpu"), b.to("cpu"), M=preconditioner,
                                                             x0=x_init, rtol=1e-6, max_iter=1_000)
            elif solver == "gmres":
                l, x_hat = gmres(A, b, M=preconditioner, x0=x_init, atol=1e-6, max_iter=1_000, left=False)
            else:
                raise NotImplementedError("Solver not implemented choose between CG and GMRES!")

            solver_stop = time.time()

            # Measure preconditioning performance
            solver_time = (solver_stop - solver_start)
            acc_solver_iters += len(l) - 1

        else:
            output, _, _ = model(data)

            # Here, we compute the loss using the full forbenius norm (no estimator)
            # l = frobenius_loss(output, A)

            l = loss(data, output, config=None) ##USe the default loss instead of the frobenius loss

            acc_loss += l.item()
            num_loss += 1

    if solve:
        # print(f"Smallest eigenvalue: {dist[0]}")
        print(f"Validation\t iterations:\t{acc_solver_iters / len(validation_loader):.2f}")
        return acc_solver_iters / len(validation_loader)

    else:
        print(f"Validation loss:\t{acc_loss / num_loss:.2f}")
        return acc_loss / len(validation_loader)


# MAIN CODE

In [61]:
from torch_geometric.loader import DataLoader
from tqdm import tqdm
if config["save"]:
        os.makedirs(folder, exist_ok=True)
        save_dict_to_file(config, os.path.join(folder, "config.json"))

# global seed-ish
torch_geometric.seed_everything(config["seed"])

# args for the model
model_args = {k: config[k] for k in ["latent_size", "message_passing_steps", "skip_connections",
                                      "augment_nodes", "global_features", "decode_nodes",
                                      "normalize_diag", "activation", "aggregate", "graph_norm",
                                      "two_hop", "edge_features", "normalize"]
              if k in config}

# run the GMRES algorithm instead of CG (?)
gmres = False

# Create model
if config["model"] == "neuralpcg":
    model = NeuralPCG(**model_args)

elif config["model"] == "nif" or config["model"] == "neuralif" or config["model"] == "inf":
    model = NeuralIF(**model_args)

elif config["model"] == "precondnet":
    model = PreCondNet(**model_args)

elif config["model"] == "lu" or config["model"] == "learnedlu":
    gmres = True
    model = LearnedLU(**model_args)

else:
    raise NotImplementedError

model.to(device)

print(f"Number params in model: {count_parameters(model)}")
print()

optimizer = torch.optim.AdamW(model.parameters())
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="min", factor=0.5, patience=20)

# Setup datasets
train_loader = get_dataloader(config["dataset"], config["n"], config["batch_size"],
                              spd=not gmres, mode="train")

validation_loader = get_dataloader(config["dataset"], config["n"], 1, spd=(not gmres), mode="val")

best_val = float("inf")
logger = TrainResults(folder)

# todo: compile the model
# compiled_model = torch.compile(model, mode="reduce-overhead")
# model = torch_geometric.compile(model, mode="reduce-overhead")

total_it = 0

# Train loop
for epoch in range(config["num_epochs"]):
    running_loss = 0.0
    grad_norm = 0.0

    start_epoch = time.perf_counter()

    for it, data in tqdm(enumerate(train_loader)):
        # increase iteration count
        total_it += 1

        # enable training mode
        model.train()

        # print(data)
        # print(data.x)
        start = time.perf_counter()
        data = data.to(device)

        output, reg, _ = model(data)
        l = loss(output, data, c=reg, config=config["loss"])

        #  if reg:
        #    l = l + config["regularizer"] * reg

        l.backward()
        running_loss += l.item()

        # track the gradient norm
        if "gradient_clipping" in config and config["gradient_clipping"]:
            grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), config["gradient_clipping"])

        else:
            total_norm = 0.0

            for p in model.parameters():
                if p.grad is not None:
                    param_norm = p.grad.detach().data.norm(2)
                    total_norm += param_norm.item() ** 2

            grad_norm = total_norm ** 0.5 / config["batch_size"]

        # update network parameters
        optimizer.step()
        optimizer.zero_grad()

        logger.log(l.item(), grad_norm, time.perf_counter() - start)

        # Do validation after 100 updates (to support big datasets)
        # convergence is expected to be pretty fast...
        if (total_it + 1) % 1000 == 0:

            # start with cg-checks after 5 iterations
            val_its = validate(model, validation_loader, solve=True,
                                solver="gmres" if gmres else "cg")

            # use scheduler
            # if config["scheduler"]:
            #    scheduler.step(val_loss)

            logger.log_val(None, val_its)

            # val_perf = val_cgits if val_cgits > 0 else val_loss
            val_perf = val_its
            print(val_perf)

            if val_perf < best_val:
                if config["save"]:
                    torch.save(model.state_dict(), f"{folder}/best_model.pt")
                best_val = val_perf

    epoch_time = time.perf_counter() - start_epoch

    # save model every epoch for analysis...
    if config["save"]:
        torch.save(model.state_dict(), f"{folder}/model_epoch{epoch+1}.pt")

    print(f"Epoch {epoch+1} \t loss: {1/len(train_loader) * running_loss} \t time: {epoch_time}")

# save fully trained model
if config["save"]:
    logger.save_results()
    torch.save(model.to(torch.float).state_dict(), f"{folder}/final_model.pt")

# Test the model
# wandb.run.summary["validation_chol"] = best_val
print()
print("Best validation loss:", best_val)



Number params in model: 460



999it [08:05, 21.67s/it]

Validation	 iterations:	287.80
287.8


1000it [08:05,  2.06it/s]


Epoch 1 	 loss: 517.5744089355469 	 time: 485.9105565000209


999it [08:14, 20.78s/it]

Validation	 iterations:	270.70
270.7


1000it [08:15,  2.02it/s]


Epoch 2 	 loss: 367.9119475708008 	 time: 495.18716924998444


999it [08:17, 20.65s/it]

Validation	 iterations:	267.60
267.6


1000it [08:17,  2.01it/s]


Epoch 3 	 loss: 357.13817221069337 	 time: 497.9681067909987


999it [09:17, 21.45s/it]

Validation	 iterations:	267.80
267.8


1000it [09:18,  1.79it/s]


Epoch 4 	 loss: 353.96421224975586 	 time: 556.3167774169997


999it [08:02, 20.12s/it]

Validation	 iterations:	268.50
268.5


1000it [08:02,  2.07it/s]


Epoch 5 	 loss: 350.8254994812012 	 time: 482.6419573329913


999it [08:00, 20.18s/it]

Validation	 iterations:	269.20
269.2


1000it [08:00,  2.08it/s]


Epoch 6 	 loss: 348.95432199096683 	 time: 480.44848229098716


999it [07:58, 20.27s/it]

Validation	 iterations:	269.90
269.9


1000it [07:58,  2.09it/s]


Epoch 7 	 loss: 347.16505059814455 	 time: 478.91117370899883


999it [08:00, 20.19s/it]

Validation	 iterations:	269.50
269.5


1000it [08:00,  2.08it/s]


Epoch 8 	 loss: 345.94179406738283 	 time: 480.5256363330118


999it [08:00, 20.29s/it]

Validation	 iterations:	270.60
270.6


1000it [08:00,  2.08it/s]


Epoch 9 	 loss: 344.573157409668 	 time: 480.91597533301683


999it [07:57, 20.29s/it]

Validation	 iterations:	270.80
270.8


1000it [07:57,  2.09it/s]


Epoch 10 	 loss: 343.7699348144531 	 time: 477.79206800000975


999it [07:56, 20.31s/it]

Validation	 iterations:	271.40
271.4


1000it [07:56,  2.10it/s]


Epoch 11 	 loss: 342.9541725158692 	 time: 476.48587433301145


999it [07:56, 20.24s/it]

Validation	 iterations:	270.60
270.6


1000it [07:57,  2.10it/s]


Epoch 12 	 loss: 342.3806929016113 	 time: 477.0009245830006


999it [07:59, 20.39s/it]

Validation	 iterations:	271.70
271.7


1000it [07:59,  2.08it/s]


Epoch 13 	 loss: 341.4950877990723 	 time: 479.82606774999294


698it [05:06,  2.28it/s]


KeyboardInterrupt: 

In [None]:
from torch_geometric.loader import DataLoader
from tqdm import tqdm
if config["save"]:
        os.makedirs(folder, exist_ok=True)
        save_dict_to_file(config, os.path.join(folder, "config.json"))

# global seed-ish
torch_geometric.seed_everything(config["seed"])

# args for the model
model_args = {k: config[k] for k in ["latent_size", "message_passing_steps", "skip_connections",
                                      "augment_nodes", "global_features", "decode_nodes",
                                      "normalize_diag", "activation", "aggregate", "graph_norm",
                                      "two_hop", "edge_features", "normalize"]
              if k in config}

# run the GMRES algorithm instead of CG (?)
gmres = False

# Create model
if config["model"] == "neuralpcg":
    model = NeuralPCG(**model_args)

elif config["model"] == "nif" or config["model"] == "neuralif" or config["model"] == "inf":
    model = NeuralIF(**model_args)

elif config["model"] == "precondnet":
    model = PreCondNet(**model_args)

elif config["model"] == "lu" or config["model"] == "learnedlu":
    gmres = True
    model = LearnedLU(**model_args)

else:
    raise NotImplementedError

model.to(device)

print(f"Number params in model: {count_parameters(model)}")

optimizer = torch.optim.AdamW(model.parameters())
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="min", factor=0.5, patience=20)

# Setup datasets with ClusterGCN
train_loader = get_cluster_dataloader(
    config["dataset"],
    config["n"],
    config["batch_size"],
    spd=True,
    mode="train",
    num_clusters=config["num_clusters"],
    clusters_per_batch=config["clusters_per_batch"]
)

validation_loader = get_cluster_dataloader(
    config["dataset"],
    config["n"],
    1,
    spd=True,
    mode="val",
    num_clusters=config["num_clusters"],
    clusters_per_batch=config["clusters_per_batch"]
)

best_val = float("inf")
logger = TrainResults(folder)
total_it = 0

# Train loop (similar to your existing code)
for epoch in range(config["num_epochs"]):
    running_loss = 0.0
    grad_norm = 0.0
    start_epoch = time.perf_counter()

    for it, data in tqdm(enumerate(train_loader)):
        total_it += 1
        model.train()

        start = time.perf_counter()
        data = data.to(device)

        # Your existing training logic
        output, reg, _ = model(data)
        l = loss(output, data, c=reg, config=config["loss"])

        l.backward()
        running_loss += l.item()

        # Gradient handling (your existing code)
        if "gradient_clipping" in config and config["gradient_clipping"]:
            grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), config["gradient_clipping"])
        else:
            total_norm = 0.0
            for p in model.parameters():
                if p.grad is not None:
                    param_norm = p.grad.detach().data.norm(2)
                    total_norm += param_norm.item() ** 2
            grad_norm = total_norm ** 0.5 / config["batch_size"]

        optimizer.step()
        optimizer.zero_grad()

        logger.log(l.item(), grad_norm, time.perf_counter() - start)

        # Validation (your existing logic)
        if (total_it) % 10 == 0:
            val_its = validate(model, validation_loader, solve=True, solver="cg")
            logger.log_val(None, val_its)

            if val_its < best_val:
                if config["save"]:
                    torch.save(model.state_dict(), f"{folder}/best_model.pt")
                best_val = val_its

    epoch_time = time.perf_counter() - start_epoch

    if config["save"]:
        torch.save(model.state_dict(), f"{folder}/model_epoch{epoch+1}.pt")

    print(f"Epoch {epoch+1} \t loss: {1/len(train_loader) * running_loss} \t time: {epoch_time}")

# Final save
if config["save"]:
    logger.save_results()
    torch.save(model.to(torch.float).state_dict(), f"{folder}/final_model.pt")

print("Best validation performance:", best_val)

Number params in model: 460


1it [00:06,  6.49s/it]


Epoch 1 	 loss: 648.0010375976562 	 time: 6.498716834001243


1it [00:06,  6.17s/it]


Epoch 2 	 loss: 644.8334350585938 	 time: 6.175045791984303


1it [00:06,  6.26s/it]


Epoch 3 	 loss: 644.45849609375 	 time: 6.265549583011307


1it [00:06,  6.28s/it]


Epoch 4 	 loss: 648.1729736328125 	 time: 6.28558462500223


1it [00:06,  6.28s/it]


Epoch 5 	 loss: 661.9956665039062 	 time: 6.277478999982122


1it [00:06,  6.35s/it]


Epoch 6 	 loss: 650.3807983398438 	 time: 6.3484238749952056


1it [00:00, 14.90it/s]


Epoch 7 	 loss: 629.5044555664062 	 time: 0.06798545800847933


1it [00:06,  6.31s/it]


Epoch 8 	 loss: 672.1594848632812 	 time: 6.30877825000789


1it [00:06,  6.30s/it]


Epoch 9 	 loss: 633.2813720703125 	 time: 6.3031272499938495


1it [00:15, 15.31s/it]


Validation	 iterations:	1000.00
Epoch 10 	 loss: 658.324951171875 	 time: 15.311978582991287


1it [00:06,  6.24s/it]


Epoch 11 	 loss: 630.0914916992188 	 time: 6.245356749976054


1it [00:00, 14.97it/s]


Epoch 12 	 loss: 627.0044555664062 	 time: 0.06777891700039618


1it [00:06,  6.31s/it]


Epoch 13 	 loss: 649.7213134765625 	 time: 6.307013624988031


1it [00:00, 14.60it/s]


Epoch 14 	 loss: 648.6812744140625 	 time: 0.06939237500773743


1it [00:06,  6.24s/it]


Epoch 15 	 loss: 682.5653076171875 	 time: 6.235936249984661


1it [00:06,  6.16s/it]


Epoch 16 	 loss: 636.4898681640625 	 time: 6.157793249993119


1it [00:08,  8.56s/it]


Epoch 17 	 loss: 661.7369384765625 	 time: 8.560618791001616


1it [00:00,  5.96it/s]


Epoch 18 	 loss: 644.547119140625 	 time: 0.17553054197924212


1it [00:06,  6.34s/it]


Epoch 19 	 loss: 620.40380859375 	 time: 6.345016290986678


1it [00:15, 15.11s/it]


Validation	 iterations:	1000.00
Epoch 20 	 loss: 660.9364013671875 	 time: 15.111671457998455


1it [00:00, 15.12it/s]


Epoch 21 	 loss: 635.2854614257812 	 time: 0.06703549998928793


1it [00:06,  6.34s/it]


Epoch 22 	 loss: 644.2913208007812 	 time: 6.340084083989495


1it [00:08,  8.17s/it]


Epoch 23 	 loss: 628.8077392578125 	 time: 8.169379583006958


1it [00:00, 15.17it/s]


Epoch 24 	 loss: 579.22509765625 	 time: 0.06687212499673478


1it [00:09,  9.15s/it]


Epoch 25 	 loss: 649.2479858398438 	 time: 9.14973970799474


1it [00:06,  6.49s/it]


Epoch 26 	 loss: 597.6178588867188 	 time: 6.487424667022424


1it [00:06,  6.15s/it]


Epoch 27 	 loss: 632.5523681640625 	 time: 6.148694792005699


1it [00:00, 14.29it/s]


Epoch 28 	 loss: 620.0474243164062 	 time: 0.07086966600036249


1it [00:06,  6.18s/it]


Epoch 29 	 loss: 616.0174560546875 	 time: 6.179065167001681


1it [00:15, 15.32s/it]


Validation	 iterations:	1000.00
Epoch 30 	 loss: 634.068603515625 	 time: 15.319460375001654


1it [00:06,  6.24s/it]


Epoch 31 	 loss: 615.252197265625 	 time: 6.237149166001473


1it [00:06,  6.16s/it]


Epoch 32 	 loss: 586.6155395507812 	 time: 6.165363124979194


1it [00:06,  6.21s/it]


Epoch 33 	 loss: 607.7601318359375 	 time: 6.211012707994087


1it [00:00, 14.85it/s]


Epoch 34 	 loss: 585.810302734375 	 time: 0.06811220801318996


1it [00:06,  6.15s/it]


Epoch 35 	 loss: 590.9260864257812 	 time: 6.154999124992173


1it [00:00, 14.61it/s]


Epoch 36 	 loss: 600.8659057617188 	 time: 0.06935170799260959


1it [00:06,  6.22s/it]


Epoch 37 	 loss: 568.624267578125 	 time: 6.219065666984534


1it [00:00, 14.13it/s]


Epoch 38 	 loss: 616.122802734375 	 time: 0.07188670799951069


1it [00:00, 14.10it/s]


Epoch 39 	 loss: 572.6077880859375 	 time: 0.07194279201212339


1it [00:09,  9.35s/it]


Validation	 iterations:	1000.00
Epoch 40 	 loss: 600.86865234375 	 time: 9.351634042017395


1it [00:06,  6.23s/it]


Epoch 41 	 loss: 612.6912231445312 	 time: 6.233692207984859


1it [00:00, 15.46it/s]


Epoch 42 	 loss: 559.1852416992188 	 time: 0.06566987498081289


1it [00:06,  6.23s/it]


Epoch 43 	 loss: 608.8047485351562 	 time: 6.233021374995587


1it [00:00, 13.75it/s]


Epoch 44 	 loss: 635.92626953125 	 time: 0.07356741701369174


1it [00:00, 14.57it/s]


Epoch 45 	 loss: 586.8803100585938 	 time: 0.07026966701960191


1it [00:06,  6.26s/it]


Epoch 46 	 loss: 593.5850830078125 	 time: 6.261091791006038


1it [00:06,  6.28s/it]


Epoch 47 	 loss: 681.6267700195312 	 time: 6.276212958007818


1it [00:06,  6.24s/it]


Epoch 48 	 loss: 566.6027221679688 	 time: 6.24294545900193


1it [00:06,  6.38s/it]


Epoch 49 	 loss: 598.2230224609375 	 time: 6.381671875016764


1it [00:09,  9.24s/it]


Validation	 iterations:	1000.00
Epoch 50 	 loss: 612.3037719726562 	 time: 9.241110875009326


1it [00:00, 13.21it/s]


Epoch 51 	 loss: 616.7483520507812 	 time: 0.07662987502408214


1it [00:06,  6.26s/it]


Epoch 52 	 loss: 587.8331298828125 	 time: 6.261465791991213


1it [00:00, 13.37it/s]


Epoch 53 	 loss: 611.2992553710938 	 time: 0.07574324999586679


1it [00:06,  6.26s/it]


Epoch 54 	 loss: 558.134765625 	 time: 6.260391041985713


1it [00:00, 14.42it/s]


Epoch 55 	 loss: 585.4784545898438 	 time: 0.07014537500799634


1it [00:06,  6.22s/it]


Epoch 56 	 loss: 597.5950927734375 	 time: 6.216192541993223


1it [00:00, 14.33it/s]


Epoch 57 	 loss: 575.9331665039062 	 time: 0.07063495900365524


1it [00:06,  6.26s/it]


Epoch 58 	 loss: 608.222900390625 	 time: 6.261052207992179


1it [00:06,  6.27s/it]


Epoch 59 	 loss: 560.3479614257812 	 time: 6.2714502080052625


1it [00:09,  9.27s/it]


Validation	 iterations:	1000.00
Epoch 60 	 loss: 629.9467163085938 	 time: 9.273060791019816


1it [00:06,  6.24s/it]


Epoch 61 	 loss: 566.429443359375 	 time: 6.243187458021566


1it [00:06,  6.27s/it]


Epoch 62 	 loss: 579.3529052734375 	 time: 6.271157542010769


1it [00:00, 13.30it/s]


Epoch 63 	 loss: 587.9664916992188 	 time: 0.07619120800518431


1it [00:00, 14.99it/s]


Epoch 64 	 loss: 547.9495849609375 	 time: 0.06788666700595059


1it [00:06,  6.26s/it]


Epoch 65 	 loss: 542.1576538085938 	 time: 6.2580850000085775


1it [00:06,  6.25s/it]


Epoch 66 	 loss: 563.93505859375 	 time: 6.253113667014986


1it [00:00,  9.79it/s]


Epoch 67 	 loss: 545.6659545898438 	 time: 0.10279995799646713


1it [00:06,  6.24s/it]


Epoch 68 	 loss: 570.7387084960938 	 time: 6.237256124994019


1it [00:00, 13.48it/s]


Epoch 69 	 loss: 567.3051147460938 	 time: 0.07519554102327675


1it [00:15, 15.38s/it]


Validation	 iterations:	1000.00
Epoch 70 	 loss: 517.20166015625 	 time: 15.382228417001897


1it [00:06,  6.22s/it]


Epoch 71 	 loss: 557.7051391601562 	 time: 6.219496709003579


1it [00:06,  6.24s/it]


Epoch 72 	 loss: 566.9926147460938 	 time: 6.2441775419865735


1it [00:06,  6.29s/it]


Epoch 73 	 loss: 522.4611206054688 	 time: 6.291806624998571


1it [00:00, 10.17it/s]


Epoch 74 	 loss: 521.5474853515625 	 time: 0.09941475000232458


1it [00:00, 13.50it/s]


Epoch 75 	 loss: 523.2269897460938 	 time: 0.07529162499122322


1it [00:00, 13.55it/s]


Epoch 76 	 loss: 543.7357177734375 	 time: 0.07482095900923014


1it [00:06,  6.37s/it]


Epoch 77 	 loss: 524.1746215820312 	 time: 6.369766249990789


1it [00:00, 14.26it/s]


Epoch 78 	 loss: 560.7022094726562 	 time: 0.0711592499865219


1it [00:06,  6.37s/it]


Epoch 79 	 loss: 534.071044921875 	 time: 6.366313332982827


1it [00:08,  8.39s/it]


Validation	 iterations:	656.00
Epoch 80 	 loss: 542.5164184570312 	 time: 8.390458750014659


1it [00:00, 13.12it/s]


Epoch 81 	 loss: 531.866943359375 	 time: 0.07721358400885947


1it [00:06,  6.32s/it]


Epoch 82 	 loss: 485.89306640625 	 time: 6.322980500000995


1it [00:06,  6.44s/it]


Epoch 83 	 loss: 504.1593017578125 	 time: 6.435696417000145


1it [00:00, 11.68it/s]


Epoch 84 	 loss: 512.3087158203125 	 time: 0.08645895801601


1it [00:00, 14.98it/s]


Epoch 85 	 loss: 503.5458984375 	 time: 0.06776350000291131


1it [00:00, 12.81it/s]


Epoch 86 	 loss: 450.6434631347656 	 time: 0.07888179100700654


1it [00:00, 15.39it/s]


Epoch 87 	 loss: 463.2708435058594 	 time: 0.06582466600229964


1it [00:00, 15.43it/s]


Epoch 88 	 loss: 478.26776123046875 	 time: 0.06573383300565183


1it [00:06,  6.31s/it]


Epoch 89 	 loss: 429.6103210449219 	 time: 6.31205070798751


1it [00:06,  6.57s/it]


Validation	 iterations:	263.00
Epoch 90 	 loss: 455.484375 	 time: 6.569106167007703


1it [00:00, 15.08it/s]


Epoch 91 	 loss: 442.7342529296875 	 time: 0.06727133301319554


1it [00:05,  5.64s/it]


Epoch 92 	 loss: 426.3216857910156 	 time: 5.637004000018351


1it [00:05,  5.69s/it]


Epoch 93 	 loss: 436.168701171875 	 time: 5.689788499992574


1it [00:05,  5.72s/it]


Epoch 94 	 loss: 416.42041015625 	 time: 5.718342833017232


1it [00:00, 16.29it/s]


Epoch 95 	 loss: 393.9168395996094 	 time: 0.06227674998808652


1it [00:05,  5.65s/it]


Epoch 96 	 loss: 422.5009765625 	 time: 5.648008999996819


1it [00:06,  6.30s/it]


Epoch 97 	 loss: 396.6545104980469 	 time: 6.300569084007293


1it [00:05,  5.72s/it]


Epoch 98 	 loss: 381.119873046875 	 time: 5.718054124998162


1it [00:05,  5.87s/it]


Epoch 99 	 loss: 376.19427490234375 	 time: 5.8698034579865634


0it [00:00, ?it/s]