In [None]:
from torch_geometric.data import Data
import torch
import torch.nn.functional as F
import torch_geometric

from scipy import sparse
import math
from numba import cuda
from tqdm import tqdm

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("using device: ", device)

import numpy as np
import networkx as nx
import os
import json

import random as random
from math import ceil

from utils.seeds import val_seeds
import wandb

In [1]:
from torch_geometric.datasets import KarateClub,Planetoid

from torch_geometric.datasets import WebKB #Cornell,Texas & Wisconsin dataset

from torch_geometric.datasets import Actor

from torch_geometric.datasets import WikipediaNetwork #Squirrel & Chameleon

## Seeds

In [None]:
development_seed = 1684992425

## Dictionaries parameters

In [None]:
hyperparameters_Neurips_1 = {
    "Cora": {
        "method": "random",
        "metric": {"goal": "maximize", "name": "mean accuracy"},
        "parameters": {
            "learning_rate": {"max": 0.5555, "min": 0.0001},
            "layers": {"values": [[16],[16,16],[16,16,16],[32],[32,32],[32,32,32],[64],[64,64],[64,64,64],[128],[128,128],[128,128,128]]},
            "dropout":{"max": 0.9999,"min":0.0001},
            "weight_decay":{"max": 0.5555, "min": 0.0001},
			"loops": {"min":80,"max":120},
			"tau":{"min": 25, "max": 216},
			"C+": {"min": 0.2, "max": 17.5}

        }
    },
    "Citeseer": {
        "method": "random",
        "metric": {"goal": "maximize", "name": "mean accuracy"},
        "parameters": {
            "learning_rate": {"max": 0.5555, "min": 0.0001},
            "layers": {"values": [[16],[16,16],[16,16,16],[32],[32,32],[32,32,32],[64],[64,64],[64,64,64],[128],[128,128],[128,128,128]]},
            "dropout":{"max": 0.9999,"min":0.0001},
            "weight_decay":{"max": 0.5555, "min": 0.0001},
			"loops": {"min":67,"max":101},
			"tau":{"min": 25, "max": 216},
			"C+": {"min": 0.2, "max": 17.5}

        }
	},
    "Pubmed": {
        "method": "random",
        "metric": {"goal": "maximize", "name": "mean accuracy"},
        "parameters": {
            "learning_rate": {"max": 0.5555, "min": 0.0001},
            "layers": {"values": [[16],[16,16],[16,16,16],[32],[32,32],[32,32,32],[64],[64,64],[64,64,64],[128],[128,128],[128,128,128]]},
            "dropout":{"max": 0.9999,"min":0.0001},
            "weight_decay":{"max": 0.5555, "min": 0.0001},
			"loops": {"min":92,"max":138},
			"tau":{"min": 25, "max": 216},
			"C+": {"min": 0.2, "max": 17.5}
        }
	},
    "Cornell": {
        "method": "random",
        "metric": {"goal": "maximize", "name": "mean accuracy"},
        "parameters": {
            "learning_rate": {"min": 0.0001, "max": 0.5555},
            "layers": {"values": [[16],[16,16],[16,16,16],[32],[32,32],[32,32,32],[64],[64,64],[64,64,64],[128],[128,128],[128,128,128]]},
            "dropout":{"min":0.0001,"max": 0.9999},
            "weight_decay":{"min": 0.0001, "max": 0.5555},
			"loops": {"min":100,"max":151},
			"tau":{"min": 9, "max": 302},
			"C+": {"min": 0.7, "max": 21.2}
        }
	},
    "Texas": {
        "method": "random",
        "metric": {"goal": "maximize", "name": "mean accuracy"},
        "parameters": {
            "learning_rate": {"min": 0.0001, "max": 0.5555},
            "layers": {"values": [[16],[16,16],[16,16,16],[32],[32,32],[32,32,32],[64],[64,64],[64,64,64],[128],[128,128],[128,128,128]]},
            "dropout":{"min":0.0001,"max": 0.9999},
            "weight_decay":{"min": 0.0001, "max": 0.5555},
			"loops": {"min":71,"max":107},
			"tau":{"min": 9, "max": 302},
			"C+": {"min": 0.7, "max": 21.2}
        }
	},
    "Wisconsin": {
        "method": "random",
        "metric": {"goal": "maximize", "name": "mean accuracy"},
        "parameters": {
            "learning_rate": {"min": 0.0001, "max": 0.5555},
            "layers": {"values": [[16],[16,16],[16,16,16],[32],[32,32],[32,32,32],[64],[64,64],[64,64,64],[128],[128,128],[128,128,128]]},
            "dropout":{"min":0.0001,"max": 0.9999},
            "weight_decay":{"min": 0.0001, "max": 0.5555},
			"loops": {"min":108,"max":163},
			"tau":{"min": 9, "max": 302},
			"C+": {"min": 0.7, "max": 21.2}
        }
	},
    "Chameleon": {
        "method": "random",
        "metric": {"goal": "maximize", "name": "mean accuracy"},
        "parameters": {
            "learning_rate": {"min": 0.0001, "max": 0.5555},
            "layers": {"values": [[16],[16,16],[16,16,16],[32],[32,32],[32,32,32],[64],[64,64],[64,64,64],[128],[128,128],[128,128,128]]},
            "dropout":{"min":0.0001,"max": 0.9999},
            "weight_decay":{"min": 0.0001, "max": 0.5555},
			"loops": {"min":665,"max":999},
			"tau":{"min": 9, "max": 302},
			"C+": {"min": 0.7, "max": 21.2}
        }
	},
    "Squirrel": {
        "method": "random",
        "metric": {"goal": "maximize", "name": "mean accuracy"},
        "parameters": {
            "learning_rate": {"min": 0.0001, "max": 0.5555},
            "layers": {"values": [[16],[16,16],[16,16,16],[32],[32,32],[32,32,32],[64],[64,64],[64,64,64],[128],[128,128],[128,128,128]]},
            "dropout":{"min":0.0001,"max": 0.9999},
            "weight_decay":{"min": 0.0001, "max": 0.5555},
			"loops": {"min":4925,"max":7325},
			"tau":{"min": 9, "max": 302},
			"C+": {"min": 0.7, "max": 21.2}
        }
	},
    "Actor": {
        "method": "random",
        "metric": {"goal": "maximize", "name": "mean accuracy"},
        "parameters": {
            "learning_rate": {"min": 0.0001, "max": 0.5555},
            "layers": {"values": [[16],[16,16],[16,16,16],[32],[32,32],[32,32,32],[64],[64,64],[64,64,64],[128],[128,128],[128,128,128]]},
            "dropout":{"min":0.0001,"max": 0.9999},
            "weight_decay":{"min": 0.0001, "max": 0.5555},
			"loops": {"min":808,"max":1212},
			"tau":{"min": 9, "max": 302},
			"C+": {"min": 0.7, "max": 21.2}
        }
	}
}

# Functions

## Get Dataset

In [None]:
DEFAULT_DATA_PATH = "data"

def get_dataset(
    name: str, data_dir=DEFAULT_DATA_PATH
):
    #path = os.path.join(data_dir, name)
    path = DEFAULT_DATA_PATH
    if name in ["Cora", "Citeseer", "Pubmed"]:
        dataset = Planetoid(path, name)
    elif name in ["Computers", "Photo"]:
        dataset = Amazon(path, name)
    elif name == "CoauthorCS":
        dataset = Coauthor(path, "CS")
    elif name in ["Cornell", "Texas", "Wisconsin"]:
        dataset = WebKB(path, name)
    elif name in ["Chameleon", "Squirrel"]:
        dataset = WikipediaNetwork(path, name, geom_gcn_preprocess=True)
    elif name == "Actor":
        dataset = Actor(path, "Actor")
    else:
        raise Exception(f"Unknown dataset: {name}")

    return dataset


def load_data(name: str, make_undirected: bool = False):
    dataset = get_dataset(name)
    data = dataset[0]
    G = torch_geometric.utils.to_networkx(data)
    if data.is_undirected() or make_undirected:#undirected:
        G = G.to_undirected() #This is for Networkx to represent it as a undirected Graph (Otherwise it would 'plot' i->j and j->i as two different edges)

    return dataset,data,G

def data_information(dataset,data):
    print()
    print(f'Dataset: {dataset}:')
    print('======================')

    print(f'Number of features: {dataset.num_features}')
    print(f'Number of classes: {dataset.num_classes}')
    print()

    # Gather some statistics about the graph.
    print(f'Number of nodes: {data.num_nodes}')
    print(f'Number of edges: {data.num_edges}')
    print(f'Average node degree: {data.num_edges / data.num_nodes:.2f}')
    print(f'Has isolated nodes: {data.has_isolated_nodes()}')
    print(f'Has self-loops: {data.has_self_loops()}')
    print(f'Is undirected: {data.is_undirected()}')

## LargestConnectedCommonent

In [None]:
def get_largest_connected_component_pytorch(connectiontype: str, directed ,data,num_nodes:int):
    """
    Pytorch (backend Scipy) implementation of lcc calculation

    Args:
    edge_index (torch tensor): edge index of the graph
    num nodes (int): number of nodes of the graph
    directed: (boolean): Wether input graph is directed or not.  If directed == False, connectiontype keyword is not referenced.
    connectiontype (str): Only relevant for directed graphs. Weak or Strong constrain for largest connected component

    Returns:
    Subgraph (Networkx Graph): Subgraph which corresponds to the largest connected component of the graph
    """

    adj = torch_geometric.utils.to_scipy_sparse_matrix(data.edge_index, num_nodes=num_nodes)
    num_components, component = sparse.csgraph.connected_components(adj, directed=directed, connection=connectiontype)
    #print("Total Number of Components: ",num_components)

    _, count = np.unique(component, return_counts=True)
    subset_np = np.in1d(component, count.argsort()[-1:])
    subset = torch.from_numpy(subset_np)
    subset = subset.to(data.edge_index.device, torch.bool)
    Subgraph = torch_geometric.utils.to_networkx(data.subgraph(subset))
    #print("Largest Connected Component size: ", len(Subgraph.nodes))
    return data.subgraph(subset)#Subgraph

def get_largest_connected_component_networkx(connectiontype: str,directed, G):
    """
    Network implementation of lcc calculation

    Args:
    G (networkx graph): Input Graph
    connectiontype (str or None): Only relevant for directed graphs. Weak or Strong constrain for largest connected component

    Returns:
    Subgraph (Networkx Graph): Subgraph which corresponds to the largest connected component of the graph
    """
    if not directed:
        return G.subgraph(max(nx.connected_components(G), key=len)).copy()
    elif directed and connectiontype == 'strong':
        return G.subgraph(max(nx.strongly_connected_components(G), key=len)).copy()
    elif directed and connectiontype == 'weak':
        return G.subgraph(max(nx.weakly_connected_components(G), key=len)).copy()

def get_component_toppingetal(data, start: int = 0) -> set:
    visited_nodes = set()
    queued_nodes = set([start])
    row, col = data.edge_index.numpy()
    while queued_nodes:
        current_node = queued_nodes.pop()
        visited_nodes.update([current_node])
        neighbors = col[np.where(row == current_node)[0]]
        neighbors = [
            n for n in neighbors if n not in visited_nodes and n not in queued_nodes
        ]
        queued_nodes.update(neighbors)
    return visited_nodes

def get_largest_connected_component_toppingetal(data):

    remaining_nodes = set(range(data.x.shape[0]))
    comps = []
    while remaining_nodes:
        start = min(remaining_nodes)
        comp = get_component_toppingetal(data, start)
        comps.append(comp)
        remaining_nodes = remaining_nodes.difference(comp)
    return np.array(list(comps[np.argmax(list(map(len, comps)))]))

def remap_edges(edges: list, mapper: dict) -> list:
    row = [e[0] for e in edges]
    col = [e[1] for e in edges]
    row = list(map(lambda x: mapper[x], row))
    col = list(map(lambda x: mapper[x], col))
    return [row, col]

def get_node_mapper(lcc: np.ndarray) -> dict:
    mapper = {}
    counter = 0
    for node in lcc:
        mapper[node] = counter
        counter += 1
    return mapper

def lcc_dataset(dataset,to_undirected = True):
    lcc = get_largest_connected_component_toppingetal(dataset[0])

    x_new = dataset.data.x[lcc]
    y_new = dataset.data.y[lcc]

    row, col = dataset.data.edge_index.numpy()
    edges = [[i, j] for i, j in zip(row, col) if i in lcc and j in lcc]
    edges = remap_edges(edges, get_node_mapper(lcc))

    if to_undirected:
        data = Data(
            x=x_new,
            edge_index=torch_geometric.utils.to_undirected(torch.LongTensor(edges)),
            y=y_new,
            train_mask=torch.zeros(y_new.size()[0], dtype=torch.bool),
            test_mask=torch.zeros(y_new.size()[0], dtype=torch.bool),
            val_mask=torch.zeros(y_new.size()[0], dtype=torch.bool),
             )
    else:
        data = Data(
            x=x_new,
            edge_index=torch.LongTensor(edges),
            y=y_new,
            train_mask=torch.zeros(y_new.size()[0], dtype=torch.bool),
            test_mask=torch.zeros(y_new.size()[0], dtype=torch.bool),
            val_mask=torch.zeros(y_new.size()[0], dtype=torch.bool),
             )
    dataset.data = data

    mapping = dict(
        zip(np.unique(dataset.data.y), range(len(np.unique(dataset.data.y))))
    )
    dataset.data.y = torch.LongTensor([mapping[u] for u in np.array(dataset.data.y)])

    return dataset

## GNN Models

In [None]:
from typing import List
from torch.nn import ModuleList, Dropout, ReLU
from torch_geometric.nn import GCNConv
from torch_geometric.data import Data


class GCN_Toppingetal(torch.nn.Module):
    def __init__(
        self, dataset, hidden: List[int] = [64], dropout: float = 0.5
    ):
        super().__init__()

        num_features = [dataset.data.x.shape[1]] + hidden + [dataset.num_classes]
        layers = []
        for in_features, out_features in zip(num_features[:-1], num_features[1:]):
            layers.append(GCNConv(in_features, out_features))
        self.layers = ModuleList(layers)

        self.reg_params = list(layers[0].parameters())
        self.non_reg_params = list([p for l in layers[1:] for p in l.parameters()])

        self.dropout = Dropout(p=dropout)
        self.act_fn = ReLU()

    def reset_parameters(self):
        for layer in self.layers:
            layer.reset_parameters()

    def forward(self, data: Data,device):
        x, edge_index, edge_attr = data.x, data.edge_index, data.edge_attr
        #x = x.to(device = device)
        #edge_index = edge_index.to(device = device)
        for i, layer in enumerate(self.layers):
            x = layer(x, edge_index, edge_weight=edge_attr)

            if i == len(self.layers) - 1:
                break

            x = self.act_fn(x)
            x = self.dropout(x)

        return torch.nn.functional.log_softmax(x, dim=1)


class GCN(torch.nn.Module):
    def __init__(self, num_features,num_classes,hidden_channels):
        super().__init__()
        self.conv1 = GCNConv(num_features,hidden_channels)
        self.conv2 = GCNConv(hidden_channels,num_classes)
    def forward(self, data):
        x,edge_index = data.x,data.edge_index
        x = self.conv1(x, edge_index)
        x = x.relu()
        x = F.dropout(x, p=0.4144)
        x = self.conv2(x, edge_index)
        return F.log_softmax(x, dim=1)

## Experiment Class

In [None]:
def load_hyperparameters_gnn(dataset_name):
    try:
        with open(os.path.join('experiment_utils','hyperparameters_gnn.json'), 'r') as file:
            hyperparameters_data = json.load(file)
            return hyperparameters_data.get(dataset_name, {})
    except FileNotFoundError:
        print("Hyperparameters file not found.")
        return {}

class Experiment():
    def __init__(self,device,datasetname,dataset,data,hyperparameters):

        #self.model = GCN(dataset.num_features,dataset.num_classes,hidden_channels=64 )

        self.data = data

        self.lr = hyperparameters["learning_rate"]
        self.layers = hyperparameters["layers"]
        self.weight_decay = hyperparameters["weight_decay"]
        self.dropout = hyperparameters["dropout"]

        self.model = GCN_Toppingetal(dataset,self.layers,self.dropout).to(device = device)
        self.device = device
        self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.lr, weight_decay= self.weight_decay)

        self.epoch = 10000#hyperparameters["epochs"]
    def train(self):
        self.model.train()
        self.optimizer.zero_grad()  # Clear gradients.

        out = self.model(self.data,self.device)  # Perform a single forward pass.
        loss = F.nll_loss(out[self.data.train_mask], self.data.y[self.data.train_mask].to(self.device))  # Compute the loss solely based on the training nodes.
        loss.backward()  # Derive gradients.
        self.optimizer.step()  # Update parameters based on gradients.
        return loss

    def validate(self):
        self.model.eval()
        out = self.model(self.data,self.device)  # Perform a single forward pass.
        pred = out[self.data.val_mask].max(1)[1]
        val_acc = pred.eq(self.data.y[self.data.val_mask]).sum().item() / self.data.val_mask.sum().item()

        return val_acc
    def test(self):
        self.model.eval()
        out = self.model(self.data,self.device)  # Perform a single forward pass.
        pred = out[self.data.test_mask].max(1)[1]
        test_acc = pred.eq(self.data.y[self.data.test_mask]).sum().item() / self.data.test_mask.sum().item()

        return test_acc

    def training(self):
        losses = []
        validations = []
        counter = 0
        for epoch in range(1, self.epoch):
            loss = self.train()
            losses.append(loss.detach().cpu().numpy())
            val = self.validate()
            validations.append(val)
            if epoch ==1:
                best_val = val
            elif epoch > 1 and val >= best_val:
                best_val = val
                counter = 0
            else:
                counter += 1

            if counter > 100:
                #print("Early stopping at Epoch: ", epoch)
                break
        return losses,validations


## Data Splits

In [None]:
def set_train_val_test_split_frac(seed: int, data: Data, val_frac: float, test_frac: float):
    num_nodes = data.y.shape[0]

    val_size = ceil(val_frac * num_nodes)
    test_size = ceil(test_frac * num_nodes)
    train_size = num_nodes - val_size - test_size

    nodes = list(range(num_nodes))

    # Take same test set every time using development seed for robustness
    random.seed(development_seed)
    random.shuffle(nodes)
    test_idx = sorted(nodes[:test_size])
    nodes = [x for x in nodes if x not in test_idx]

    # Take train / val split according to seed
    random.seed(seed)
    random.shuffle(nodes)
    train_idx = sorted(nodes[:train_size])
    val_idx = sorted(nodes[train_size:])

    assert len(train_idx) + len(val_idx) + len(test_idx) == num_nodes

    def get_mask(idx):
        mask = torch.zeros(num_nodes, dtype=torch.bool)
        mask[idx] = 1
        return mask

    data.train_mask = get_mask(train_idx)
    data.val_mask = get_mask(val_idx)
    data.test_mask = get_mask(test_idx)

    return data

def set_train_val_test_split(
        seed: int,
        data: Data,
        development_frac: float = 0.5,
        num_per_class: int = 20) -> Data:
    rnd_state = np.random.RandomState(development_seed)
    num_nodes = data.y.shape[0]

    #num_development = ceil(development_frac * num_nodes)
    num_development = 1500 #ceil(development_frac * num_nodes)
    development_idx = rnd_state.choice(num_nodes, num_development, replace=False)
    test_idx = [i for i in np.arange(num_nodes) if i not in development_idx]

    train_idx = []
    rnd_state = np.random.RandomState(seed)
    for c in range(data.y.max() + 1):
        class_idx = development_idx[np.where(data.y[development_idx].cpu() == c)[0]]
        train_idx.extend(rnd_state.choice(class_idx, min(num_per_class, ceil(len(class_idx) * 0.5)), replace=False))

    val_idx = [i for i in development_idx if i not in train_idx]

    def get_mask(idx):
        mask = torch.zeros(num_nodes, dtype=torch.bool)
        mask[idx] = 1
        return mask

    data.train_mask = get_mask(train_idx)
    data.val_mask = get_mask(val_idx)
    data.test_mask = get_mask(test_idx)

    return data

## Curvature

In [None]:
@cuda.jit(
    "void(float32[:,:], float32[:,:], int32[:,:], int32[:,:], float32[:], float32[:], int32, float32[:,:],boolean)"
)

def _balanced_forman_curvature_undirected_personal(A, A2,edge_index,indices_neigh,d_in, d_out, N, C,fcc = True):
    i, j = cuda.grid(2)

    if (i < N) and (j < N):
        if A[i, j] == 0:
            C[i, j] = 0
            return

        if d_in[i] > d_out[j]:
            d_max = d_in[i]
            d_min = d_out[j]
        else:
            d_max = d_out[j]
            d_min = d_in[i]

        if d_min == 1:
           C[i, j] = 0
           return

        C[i, j] = ((2 / d_max) + (2 / d_min) - 2
                    + (2 / d_max + 1 / d_min) * A2[i, j] * A[i, j]
                  )
        if fcc:
            ind1_i,ind2_i = indices_neigh[i,0], indices_neigh[i,1]
            neighs_i = edge_index[1,ind1_i:ind2_i]

            ind1_j,ind2_j = indices_neigh[j,0],indices_neigh[j,1]
            neighs_j = edge_index[1,ind1_j:ind2_j]


            sharp_ij = 0
            lambda_ij = 0
            for k_count in range(len(neighs_i)):
                k = neighs_i[k_count]

                ind1_k = indices_neigh[k,0]
                ind2_k = indices_neigh[k,1]
                neighs_k = edge_index[1,ind1_k:ind2_k]
                if A[k,i]*(1-A[k,j]) !=0 and k != j: #Only have k in S(i)\S(j)

                    had = 0
                    for l_count in range(len(neighs_k)):
                        l = neighs_k[l_count]
                        had += A[k,l]*A[i,l]*A[j,l]

                    TMP =A[k,i]*(1-A[k,j])*(A2[k,j] -had- 1)

                    if TMP > 0:
                        sharp_ij += 1
                        if TMP > lambda_ij:
                            lambda_ij = TMP

            for k_count in range(len(neighs_j)):
                k = neighs_j[k_count]

                ind1_k,ind2_k = indices_neigh[k,0],indices_neigh[k,1]
                neighs_k = edge_index[1,ind1_k:ind2_k]

                if A[j,k]*(1-A[k,i]) !=0 and k != i: #Only have k in S(j)\S(i)
                    had = 0

                    for l_count in range(len(neighs_k)):
                        l = neighs_k[l_count]
                        had += A[k,l]*A[i,l]*A[j,l]

                    TMP = A[j,k]*(1-A[k,i])*(A2[k,i] -had- 1)

                    if TMP > 0:
                        sharp_ij += 1
                        if TMP > lambda_ij:
                            lambda_ij = TMP

            if lambda_ij > 0:
                C[i, j] += sharp_ij / (d_max * lambda_ij)


def balanced_forman_curvature_undirected_personal(A,edge_index, C=None,fcc = True):
    N = A.shape[0]
    threadsperblock = (32,16)#,10)
    blockspergrid_x = math.ceil(N / threadsperblock[0])
    blockspergrid_y = math.ceil(N / threadsperblock[1])

    blockspergrid_2d = (blockspergrid_x, blockspergrid_y)

    A2 = torch.matmul(A, A)

    d_in = A.sum(axis=0)
    d_out = A.sum(axis=1)

    ind1 = 0
    ind2 = 0
    index_tuples = []
    for k in range(N):#test:
        ind2 += int(d_in[k].item())
        index_tuples.append((ind1,ind2))
        ind1 = ind2
    index_tuples = torch.tensor(index_tuples).cuda()

    if C is None:
        C = torch.zeros(N, N).cuda()

    _balanced_forman_curvature_undirected_personal[blockspergrid_2d, threadsperblock](A, A2,edge_index,index_tuples,d_in, d_out, N, C,fcc)
    return C

@cuda.jit(
    "void(float32[:,:], float32[:,:], float32[:], float32[:], int32, float32[:,:])"
)
def _balanced_forman_curvature_jctopping(A, A2, d_in, d_out, N, C):
    i, j = cuda.grid(2)

    if (i < N) and (j < N):
        if A[i, j] == 0:
            C[i, j] = 0
            return

        if d_in[i] > d_out[j]:
            d_max = d_in[i]
            d_min = d_out[j]
        else:
            d_max = d_out[j]
            d_min = d_in[i]

        if d_max * d_min == 0:
            C[i, j] = 0
            return

        sharp_ij = 0
        lambda_ij = 0
        for k in range(N):
            TMP = A[k, j] * (A2[i, k] - A[i, k]) * A[i, j]
            if TMP > 0:
                sharp_ij += 1
                if TMP > lambda_ij:
                    lambda_ij = TMP

            TMP = A[i, k] * (A2[k, j] - A[k, j]) * A[i, j]
            if TMP > 0:
                sharp_ij += 1
                if TMP > lambda_ij:
                    lambda_ij = TMP

        C[i, j] = (
            (2 / d_max) + (2 / d_min) - 2 + (2 / d_max + 1 / d_min) * A2[i, j] * A[i, j]
        )
        if lambda_ij > 0:
            C[i, j] += sharp_ij / (d_max * lambda_ij)


def balanced_forman_curvature_jctopping(A, C=None):
    N = A.shape[0]
    A2 = torch.matmul(A, A)
    d_in = A.sum(axis=0)
    d_out = A.sum(axis=1)
    if C is None:
        C = torch.zeros(N, N).cuda()

    threadsperblock = (16, 16)
    blockspergrid_x = math.ceil(N / threadsperblock[0])
    blockspergrid_y = math.ceil(N / threadsperblock[1])
    blockspergrid = (blockspergrid_x, blockspergrid_y)

    _balanced_forman_curvature_jctopping[blockspergrid, threadsperblock](A, A2, d_in, d_out, N, C)
    return C



## SDRF Cuda

In [None]:
NUMBA_CUDA_LOW_OCCUPANCY_WARNINGS=0

def softmax(a, tau=1):
    exp_a = np.exp(a * tau)
    return exp_a / exp_a.sum()

@cuda.jit(
    "void(float32[:,:], float32[:,:], int32[:,:], int32[:,:], float32, float32, int32, float32[:,:], int32, int32, int32[:], int32[:], int32, int32,boolean)"
)

def _curvature_post_rewiring_undirected_personal( A, A2,edge_index,indices_neigh, d_in_x, d_out_y, N, D, x, y, i_neighbors, j_neighbors, dim_i, dim_j,fcc = True
):
    I, J = cuda.grid(2)

    if (I < dim_i) and (J < dim_j):
        i = i_neighbors[I]
        j = j_neighbors[J]

        if (i == j) or (A[i, j] != 0):
            D[I, J] = -1000
            return

        A_i_j = A[i, j]
        A_i_j += 1

        if i == x:
            d_in_x += 1
        elif j == y:
            d_out_y += 1


        if d_in_x > d_out_y:
            d_max = d_in_x
            d_min = d_out_y
        else:
            d_max = d_out_y
            d_min = d_in_x

        if d_min ==1:#d_in_x * d_out_y == 0:
            D[I, J] = 0
            return

        A2_x_y = A2[x, y]
        # Difference in triangles term
        if (x == i) and (A[j, y] != 0):
            A2_x_y += 1.
        elif (y == j) and (A[x, i] != 0):
            A2_x_y += 1.

        # Difference in four-cycles term
        ind1_x,ind2_x = indices_neigh[x,0], indices_neigh[x,1]
        neighs_x = edge_index[1,ind1_x:ind2_x]

        ind1_y,ind2_y = indices_neigh[y,0],indices_neigh[y,1]
        neighs_y = edge_index[1,ind1_y:ind2_y]

        D[I, J] = (
                (2 / d_max)
                + (2 / d_min)
                - 2
                + (2 / d_max + 1 / d_min) * A2_x_y * A[x, y]
            )

        if fcc:

            sharp_xy = 0
            lambda_xy = 0


            A_x_j = A[x,j] + 0
            if i == x and y !=j:
                A_x_j += 1

            for k_count in range(len(neighs_x)):
                k = neighs_x[k_count]

                ind1_k = indices_neigh[k,0]
                ind2_k = indices_neigh[k,1]
                neighs_k = edge_index[1,ind1_k:ind2_k]

                if k != i and k != j and y !=i and y!=j:
                    A2_k_y = A2[k, y]
                elif k ==i and y !=j:
                    A2_k_y = A2[k, y] + A[j,y]
                elif k ==j and y !=i:
                    A2_k_y = A2[k, y] + A[i,y]
                elif k!=j and y==i:
                    A2_k_y = A2[k, y] + A[k,j]
                elif k!=i and y==j:
                    A2_k_y = A2[k, y] + A[k,i]
                elif (k ==i and y ==j) or (k ==j and y == i):
                    A2_k_y = A2[k, y] + +1*A[k,k] + 1*A[y,y]

                A_k_y = A[k,y] + 0
                A_x_k = A[x,k] + 0

                if  (k == i and j ==y) or (k == y and j == i):
                    A_k_y +=1
                if (i == x and k == j) or (x==j and k == i):
                    A_x_k +=1

                if A_x_k*(1-A_k_y) !=0 and k!=y:

                    had = 0
                    for l_count in range(len(neighs_k)): #This doesn't sum over j since we haven't adapted the edge index yet
                        l = neighs_k[l_count]
                        A_k_l = A[k,l] + 0
                        A_x_l = A[x,l] + 0
                        A_y_l = A[y,l] + 0
                        if (k == i and l == j) or (k == j and l == i):
                            A_k_l +=1
                        if (x == i and l == j) or (x == j and l == i):
                            A_x_l +=1
                        if (y == i and l == j) or (y == j and l == i):
                            A_y_l +=1
                        had += A_k_l*A_x_l*A_y_l

                    TMP =A_x_k*(1-A_k_y)*(A2_k_y -had- 1)

                    if TMP > 0:
                        sharp_xy += 1
                        if TMP > lambda_xy:
                            lambda_xy = TMP

            for w_count in range(len(neighs_y)):
                w = neighs_y[w_count]
                ind1_w = indices_neigh[w,0]
                ind2_w = indices_neigh[w,1]
                neighs_w = edge_index[1,ind1_w:ind2_w]

                if w != i and w != j and x !=i and x!=j:
                    A2_w_x = A2[w, x]
                elif w ==i and x !=j:
                    A2_w_x = A2[w, x] + A[j,x]
                elif w ==j and x !=i:
                    A2_w_x = A2[w, x] + A[i,x]
                elif w!=j and x==i:
                    A2_w_x = A2[w, x] + A[w,j]
                elif w!=i and x==j:
                    A2_w_x = A2[w, x] + A[w,i]
                elif (w ==i and x ==j) or (w ==j and x == i):
                    A2_w_x = A2[w, x] +1*A[w,w] + 1*A[x,x]

                A_x_w = A[x,w] + 0
                if  w ==j and x ==i:
                    A_x_w +=1

                A_y_w = A[y,w] + 0
                A_w_x = A[w,x] + 0

                if  (w == i and j ==y) or (w == y and j == i):
                        A_y_w +=1
                if (i == x and w == j) or (x==j and w == i):
                        A_w_x +=1

                if A_y_w*(1-A_w_x) !=0 and w != x:
                    had = 0
                    for l_count in range(len(neighs_w)): # If w ==j (SHOULD NEVER HAPPEN), this doesn't sum over i since we haven't adapted the edge index yet
                        l = neighs_w[l_count]
                        A_w_l = A[w,l] + 0
                        A_x_l = A[x,l] + 0
                        A_y_l = A[y,l] + 0
                        if (w == i and l == j) or (w == j and l == i):
                            A_w_l +=1
                        if (x == i and l == j) or (x == j and l == i):
                            A_x_l +=1
                        if (y == i and l == j) or (y == j and l == i):
                            A_y_l +=1

                        had += A_w_l*A_x_l*A_y_l


                    TMP = A_y_w*(1-A_x_w)*(A2_w_x -had- 1)

                    if TMP > 0:
                        sharp_xy +=  1
                        if TMP > lambda_xy:
                            lambda_xy = TMP


            if lambda_xy > 0:
                D[I, J] += sharp_xy / (d_max * lambda_xy)


def curvature_post_rewiring_personal(A, x, y,edge_index, i_neighbors, j_neighbors, D=None,is_undirected = False,fcc = True):

    N = A.shape[0]
    A2 = torch.matmul(A, A)
    d_in = A.sum(axis = 0)#A[:, x].sum()
    d_out = A.sum(axis = 1)#A[y].sum()
    if D is None:
        D = torch.zeros(len(i_neighbors), len(j_neighbors)).cuda()

    ind1 = 0
    ind2 = 0
    index_tuples = []
    for k in range(N):
        ind2 += int(d_in[k].item())
        index_tuples.append((ind1,ind2))
        ind1 = ind2
    index_tuples = torch.tensor(index_tuples).cuda()

    d_in = d_in[x]
    d_out = d_out[y]
    threadsperblock = (16, 16)
    blockspergrid_x = math.ceil(D.shape[0] / threadsperblock[0])
    blockspergrid_y = math.ceil(D.shape[1] / threadsperblock[1])
    blockspergrid = (blockspergrid_x, blockspergrid_y)
    if is_undirected:
        _curvature_post_rewiring_undirected_personal[blockspergrid, threadsperblock](
            A,
            A2,
            edge_index,
            index_tuples,
            d_in,
            d_out,
            N,
            D,
            x,
            y,
            np.array(i_neighbors),
            np.array(j_neighbors),
            D.shape[0],
            D.shape[1],
            fcc
        )
    else:
        print("Not implemented for directed graphs")
        return
    return D


def sdrf_cuda_personal(
    data,
    loops=10,
    remove_edges=False,
    removal_bound=0.5,
    tau=1,
    int_node = False,
    is_undirected=False,
    fcc = True
):
    N = data.num_nodes
    G_in = torch_geometric.utils.to_networkx(data)

    if is_undirected:
        G_in = G_in.to_undirected()

    #print("Start", G_in)
    count_edge_removal = 0

    A = torch.tensor(nx.adjacency_matrix(G_in).todense(), dtype = torch.float)
    A = A.cuda()


    edge_index = data.edge_index.clone()
    edge_index = edge_index.cuda()
    N = A.shape[0]

    C = torch.zeros(N, N).cuda()

    for idx in range(loops):

        count_new_node = len(G_in.nodes)
        can_add = True
        if is_undirected:
            balanced_forman_curvature_undirected_personal(A,edge_index ,C=C,fcc = fcc)
        else:
            print("Not implemented for directed graphs")
            return


        ix_min = C.argmin()


        x =  torch.div(ix_min,N,rounding_mode='trunc')
        y = ix_min % N

        x = x.item()
        y = y.item()


        if is_undirected:
            x_neighbors = list(G_in.neighbors(x)) + [x] # !! We're adding x to the set of neighbours
            y_neighbors = list(G_in.neighbors(y)) + [y]
        else:
            x_neighbors = list(G_in.successors(x)) + [x]
            y_neighbors = list(G_in.predecessors(y)) + [y]

        candidates = []


        for i in x_neighbors:
            for j in y_neighbors:
                if (i != j) and (not G_in.has_edge(i, j)):
                    candidates.append((i, j))

        if len(candidates):
            D = curvature_post_rewiring_personal(A,x,y,edge_index,x_neighbors,y_neighbors,D=None,is_undirected=is_undirected,fcc = fcc)
            improvements = []
            for (i, j) in candidates:
                improvements.append(
                    (D-C[x,y])[x_neighbors.index(i), y_neighbors.index(j)].item()
                )
            k, l = candidates[np.random.choice(range(len(candidates)), p=softmax(np.array(improvements), tau=tau))] ##For directed graph: Makes sense: k is selected uit of "i" and "l" out of j

            if int_node:
                A = F.pad(input=A, pad=(0,1,0,1), mode='constant', value=0)
                G_in.add_node(count_new_node)
                G_in.add_edge(k,count_new_node)
                G_in.add_edge(count_new_node, l)
                if is_undirected:
                    A[k, count_new_node] = A[count_new_node, l] = 1.
                    A[count_new_node, k] = A[l, count_new_node] = 1.
                    edge_index=A.to_sparse().indices()
                else:
                    A[k, count_new_node] = A[count_new_node, l] = 1.
                    edge_index=A.to_sparse().indices()
            else:
                G_in.add_edge(k, l)
                if is_undirected:
                    A[k, l] = A[l, k] = 1.
                    edge_index=A.to_sparse().indices()
                else:
                    A[k, l] = 1.
                    edge_index=A.to_sparse().indices()

        else:
            can_add = False
            if not remove_edges:
                break

        if remove_edges:
            ix_max = C.argmax()
            xmax = torch.div(ix_max,N,rounding_mode='trunc').item()
            ymax = (ix_max % N).item()
            if C[xmax, ymax] > removal_bound:
                G_in.remove_edge(xmax, ymax)

                if is_undirected:
                    A[xmax, ymax] = A[ymax, xmax] = 0.
                    edge_index=A.to_sparse().indices()
                else:
                    A[xmax, ymax] = 0.
                    edge_index=A.to_sparse().indices()
                count_edge_removal += 1

            else:
                if can_add is False:
                    break

        #Dcomputed = D[x_neighbors.index(k), y_neighbors.index(l)].item()
        #Cnew = balanced_forman_curvature_undirected_personal(A,edge_index ,fcc = fcc)
        #print(Cnew[x,y])
        #print(Dcomputed)
    return G_in,count_edge_removal

@cuda.jit(
    "void(float32[:,:], float32[:,:], float32, float32, int32, float32[:,:], int32, int32, int32[:], int32[:], int32, int32)"
)
def _balanced_forman_post_delta_jctopping(
    A, A2, d_in_x, d_out_y, N, D, x, y, i_neighbors, j_neighbors, dim_i, dim_j
):
    I, J = cuda.grid(2)

    if (I < dim_i) and (J < dim_j):
        i = i_neighbors[I]
        j = j_neighbors[J]

        if (i == j) or (A[i, j] != 0):
            D[I, J] = -1000
            return

        # Difference in degree terms
        if j == x:
            d_in_x += 1
        elif i == y:
            d_out_y += 1

        if d_in_x * d_out_y == 0:
            D[I, J] = 0
            return

        if d_in_x > d_out_y:
            d_max = d_in_x
            d_min = d_out_y
        else:
            d_max = d_out_y
            d_min = d_in_x

        # Difference in triangles term
        A2_x_y = A2[x, y]
        if (x == i) and (A[j, y] != 0):
            A2_x_y += A[j, y]
        elif (y == j) and (A[x, i] != 0):
            A2_x_y += A[x, i]

        # Difference in four-cycles term
        sharp_ij = 0
        lambda_ij = 0
        for z in range(N):
            A_z_y = A[z, y] + 0
            A_x_z = A[x, z] + 0
            A2_z_y = A2[z, y] + 0
            A2_x_z = A2[x, z] + 0

            if (z == i) and (y == j):
                A_z_y += 1
            if (x == i) and (z == j):
                A_x_z += 1
            if (z == i) and (A[j, y] != 0):
                A2_z_y += A[j, y]
            if (x == i) and (A[j, z] != 0):
                A2_x_z += A[j, z]
            if (y == j) and (A[z, i] != 0):
                A2_z_y += A[z, i]
            if (z == j) and (A[x, i] != 0):
                A2_x_z += A[x, i]

            TMP = A_z_y * (A2_x_z - A_x_z) * A[x, y]
            if TMP > 0:
                sharp_ij += 1
                if TMP > lambda_ij:
                    lambda_ij = TMP

            TMP = A_x_z * (A2_z_y - A_z_y) * A[x, y]
            if TMP > 0:
                sharp_ij += 1
                if TMP > lambda_ij:
                    lambda_ij = TMP

        D[I, J] = (
            (2 / d_max) + (2 / d_min) - 2 + (2 / d_max + 1 / d_min) * A2_x_y * A[x, y]
        )
        if lambda_ij > 0:
            D[I, J] += sharp_ij / (d_max * lambda_ij)


def balanced_forman_post_delta_jctopping(A, x, y, i_neighbors, j_neighbors, D=None):
    N = A.shape[0]
    A2 = torch.matmul(A, A)
    d_in = A[:, x].sum()
    d_out = A[y].sum()
    if D is None:
        D = torch.zeros(len(i_neighbors), len(j_neighbors)).cuda()

    threadsperblock = (16, 16)
    blockspergrid_x = math.ceil(D.shape[0] / threadsperblock[0])
    blockspergrid_y = math.ceil(D.shape[1] / threadsperblock[1])
    blockspergrid = (blockspergrid_x, blockspergrid_y)

    _balanced_forman_post_delta_jctopping[blockspergrid, threadsperblock](
        A,
        A2,
        d_in,
        d_out,
        N,
        D,
        x,
        y,
        np.array(i_neighbors),
        np.array(j_neighbors),
        D.shape[0],
        D.shape[1],
    )
    return D


def sdrf_jctopping(
    data,
    loops=10,
    remove_edges=True,
    removal_bound=0.5,
    tau=1,
    is_undirected=False,
):
    edge_index = data.edge_index
    if is_undirected:
        edge_index = torch_geometric.utils.to_undirected(edge_index)
    A = torch_geometric.utils.to_dense_adj(torch_geometric.utils.remove_self_loops(edge_index)[0])[0]
    N = A.shape[0]
    G = torch_geometric.utils.to_networkx(data)
    if is_undirected:
        G = G.to_undirected()
    A = A.cuda()
    C = torch.zeros(N, N).cuda()

    for x in tqdm(range(loops)):
        can_add = True
        balanced_forman_curvature_jctopping(A, C=C)
        ix_min = C.argmin().item()
        x = ix_min // N
        y = ix_min % N

        if is_undirected:
            x_neighbors = list(G.neighbors(x)) + [x]
            y_neighbors = list(G.neighbors(y)) + [y]
        else:
            x_neighbors = list(G.successors(x)) + [x]
            y_neighbors = list(G.predecessors(y)) + [y]
        candidates = []
        for i in x_neighbors:
            for j in y_neighbors:
                if (i != j) and (not G.has_edge(i, j)):
                    candidates.append((i, j))

        if len(candidates):
            D = balanced_forman_post_delta_jctopping(A, x, y, x_neighbors, y_neighbors)
            improvements = []
            for (i, j) in candidates:
                improvements.append(
                    (D - C[x, y])[x_neighbors.index(i), y_neighbors.index(j)].item()
                )

            k, l = candidates[
                np.random.choice(
                    range(len(candidates)), p=softmax(np.array(improvements), tau=tau)
                )
            ]
            G.add_edge(k, l)
            if is_undirected:
                A[k, l] = A[l, k] = 1
            else:
                A[k, l] = 1
        else:
            can_add = False
            if not remove_edges:
                break

        if remove_edges:
            ix_max = C.argmax().item()
            x = ix_max // N
            y = ix_max % N
            if C[x, y] > removal_bound:
                G.remove_edge(x, y)
                if is_undirected:
                    A[x, y] = A[y, x] = 0
                else:
                    A[x, y] = 0
            else:
                if can_add is False:
                    break

    return G

# Experiments Rewiring Sweeps: Undirected

In [None]:
"""
Parameters for the experiment
"""

datasetname = "Texas"
results_dir = "results"
rewiring_run = True
make_undirected = True
int_node = False
Curvature_type = "BFC_no4cycle"

path = ""

os.environ["WANDB_SILENT"] = "true"
os.environ["NUMBA_CUDA_LOW_OCCUPANCY_WARNINGS"] = "false"


dataset,data,G = load_data(datasetname)
dataset_lcc = lcc_dataset(dataset,to_undirected = make_undirected)
data_lcc = dataset_lcc[0]

data_information(dataset_lcc,data_lcc)


sweep_configuration = hyperparameters_Neurips_1[datasetname]

NUMBA_CUDA_LOW_OCCUPANCY_WARNINGS=0


sweep_configuration["name"] = datasetname + '_' + Curvature_type


def objective(config,rewire = False):
    accuracies = []
    test_acc = []
    if rewire:
        print("===Starting Rewiring===")
        G_rewired,edge_index_rewired = create_rewired_edge_index(data_lcc,config,intermediate_node=int_node,remove_edges=True,curvaturetype=Curvature_type)
        print(" ")
    print(" == Starting Runs == ")
    for idx_k,k in tqdm(enumerate(val_seeds)):
        if datasetname == "Cora" or datasetname == "Citeseer" or datasetname == "Pubmed":
            data_undirected_split = set_train_val_test_split(k,data_lcc)
        else:
            data_undirected_split = set_train_val_test_split_frac(k,data_lcc,0.2,0.2)

        if rewire:
            delta = len(G_rewired.nodes) - data_undirected_split.num_nodes
            if delta != 0:
                print("Additional Nodes added: ", delta)

                # pad(left, right, top, bottom)
                new_x = F.pad(input=data.x, pad=(0, 0, 0,delta), mode='constant', value=0)
                new_y = F.pad(input=data.y, pad=(0,delta), mode='constant', value=0)
                new_train_mask = F.pad(input=data.train_mask, pad=(0,delta), mode='constant', value=False)
                new_val_mask = F.pad(input=data.val_mask, pad=(0,delta), mode='constant', value=False)
                new_test_mask = F.pad(input=data.test_mask, pad=(0,delta), mode='constant', value=False)

                data_undirected_split = Data(x = new_x, edge_index = edge_index_rewired, y = new_y,train_mask = new_train_mask,val_mask = new_val_mask,test_mask = new_test_mask,is_undirected = True)
            else:
                data_undirected_split.edge_index = edge_index_rewired

        data_undirected_split.to(device)

        Exp = Experiment(device,datasetname,dataset_lcc,data_undirected_split,config)


        counter = 0
        for epoch in range(1, Exp.epoch):
            loss = Exp.train()
            val = Exp.validate()
            wandb.log({"loss " + str(idx_k): loss, "val " + str(idx_k): val,"epoch": epoch})
            if epoch ==1:
                best_val = val
            elif epoch > 1 and val > best_val:
                best_val = val
                counter = 0
            else:
                counter += 1
            if counter > 100:
                break
        final_accuracy = Exp.validate()
        final_test_acc = Exp.test()
        accuracies.append(final_accuracy)
        test_acc.append(final_test_acc)
    print("")
    return np.mean(np.array(accuracies)),np.mean(np.array(test_acc))

def create_rewired_edge_index(data,hyperparameters,intermediate_node,remove_edges,curvaturetype: str ):
    if curvaturetype == "BFC_w4cycle":
        G_rewired,_ = sdrf_cuda_personal(
        data,
        loops=hyperparameters["loops"],
        remove_edges=remove_edges,
        removal_bound=hyperparameters["C+"],
        tau=hyperparameters["tau"],
        int_node = intermediate_node,
        is_undirected=data.is_undirected(),
        fcc = True
                        )
    elif curvaturetype == "BFC_no4cycle":
        G_rewired,_ = sdrf_cuda_personal(
        data,
        loops=hyperparameters["loops"],
        remove_edges=remove_edges,
        removal_bound=hyperparameters["C+"],
        tau=hyperparameters["tau"],
        int_node = intermediate_node,
        is_undirected=data.is_undirected(),
        fcc = False
                        )
    elif curvaturetype == "JcT":
        G_rewired = sdrf_jctopping(
        data,
        loops=hyperparameters["loops"],
        remove_edges=remove_edges,
        removal_bound=hyperparameters["C+"],
        tau=hyperparameters["tau"],
        is_undirected=data.is_undirected(),
                        )

    edge_index_rewired = torch_geometric.utils.to_undirected(torch.tensor(list(G_rewired.edges)).t())
    return G_rewired,edge_index_rewired

def main():
    wandb.init(dir = "")
    acc,test_acc = objective(wandb.config,rewiring_run)
    #wandb.run.summary["mean accuracy"] = acc
    wandb.log({"mean accuracy": acc, "mean test accuracy": test_acc})

sweep_id = "9f46xixw" # wandb.sweep(sweep=sweep_configuration, project="curvature")
wandb.agent(sweep_id, project="Curvature_Neurips", function=main,count = 150)

#sweep_id = wandb.sweep(sweep=sweep_configuration, project="Curvature_Neurips")
#wandb.agent(sweep_id, function=main,count = 10)