In [1]:
from experiment_utils.largestconnectedcomponent import lcc_dataset
from utils.load_datasets import load_data,data_information
from experiment_utils.sdrf_cuda import sdrf_BFc,sdrf_JTc,sdrf_JLc,sdrf_AFc
from utils.seeds import val_seeds
from utils.splits import set_train_val_test_split,set_train_val_test_split_frac
from experiment_utils.experimentclass import Experiment

import torch
import torch.nn.functional as F
import torch_geometric
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("using device: ", device)
import numpy as np

from tqdm import tqdm 
import os
import json

import wandb

using device:  cuda


In [4]:
import random


# Bounds of numbers
n_min = 0
n_max = 5000000

# Final number of values 
n_numbers = 100


numbers = [int(random.uniform(n_min, n_max)) for i in range(0, n_numbers)]

print(numbers)

[352349, 2963575, 265738, 2242045, 4286979, 2572232, 361988, 4896762, 698240, 1080210, 1076899, 4470144, 647785, 2864677, 3673026, 1001466, 591596, 2955934, 4309601, 659396, 226608, 371554, 420705, 1627659, 820255, 476332, 1319946, 1544823, 1896655, 4024031, 3846613, 3294119, 1753540, 1542684, 1243694, 4846073, 1347739, 2071124, 2517929, 646582, 4475777, 4028314, 2060860, 119103, 3652748, 3094689, 3083191, 3603385, 4172515, 4667641, 2122177, 1352276, 2309021, 4868434, 3284025, 2780096, 1828231, 179863, 2910985, 3058147, 2073243, 1454711, 3348766, 1406724, 968063, 1075577, 652008, 3564090, 1987302, 3094829, 4183375, 3082331, 4634634, 2911863, 4255755, 1707861, 2791017, 1364313, 2926477, 1232824, 4476194, 2245886, 4768541, 1870819, 1958616, 2910357, 3886764, 1915449, 1291027, 395481, 3258661, 2211164, 4123811, 3719360, 1476017, 622206, 4295282, 3027382, 3764862, 3140854]


In [2]:
def create_rewired_edge_index(data,hyperparameters,intermediate_node,remove_edges,curvaturetype: str ):
    if curvaturetype == "BFc_w4cycle":
        G_rewired,_ = sdrf_BFc(
            data,
            loops=hyperparameters["loops"],
            remove_edges=remove_edges,
            removal_bound=hyperparameters["C+"],
            tau=hyperparameters["tau"],
            int_node = intermediate_node,
            is_undirected=data.is_undirected(),
            fcc = True,
            progress_bar= False
                        )
        edge_index_rewired = torch_geometric.utils.to_undirected(torch.tensor(list(G_rewired.edges)).t())
    elif curvaturetype == "BFc_no4cycle":
        G_rewired,_ = sdrf_BFc(
            data,
            loops=hyperparameters["loops"],
            remove_edges=remove_edges,
            removal_bound=hyperparameters["C+"],
            tau=hyperparameters["tau"],
            int_node = intermediate_node,
            is_undirected=data.is_undirected(),
            fcc = False,
            progress_bar= False
                        )
        edge_index_rewired = torch_geometric.utils.to_undirected(torch.tensor(list(G_rewired.edges)).t())
    elif curvaturetype == "JTc":
        G_rewired,_ = sdrf_JTc(
            data,
            loops=hyperparameters["loops"],
            remove_edges=remove_edges,
            removal_bound=hyperparameters["C+"],
            tau=hyperparameters["tau"],
            is_undirected=data.is_undirected(),
            progress_bar= False
                        )
        edge_index_rewired = torch_geometric.utils.to_undirected(torch.tensor(list(G_rewired.edges)).t())
    elif curvaturetype == "JLc":
        G_rewired,_ = sdrf_JLc(
            data,
            loops=hyperparameters["loops"],
            remove_edges=remove_edges,
            removal_bound=hyperparameters["C+"],
            tau=hyperparameters["tau"], 
            is_undirected=data.is_undirected(),
            progress_bar = False
                        )
        edge_index_rewired = torch_geometric.utils.to_undirected(torch.tensor(list(G_rewired.edges)).t())
    elif curvaturetype == "AFc_3":
        G_rewired,_ = sdrf_AFc(
            data,
            loops=hyperparameters["loops"],
            remove_edges=remove_edges,
            removal_bound=hyperparameters["C+"],
            tau=hyperparameters["tau"],
            is_undirected=data.is_undirected(),
            k = 3.,
            progress_bar= False
                        )
        edge_index_rewired = torch_geometric.utils.to_undirected(torch.tensor(list(G_rewired.edges)).t())
    elif curvaturetype == "AFc_4":
        G_rewired,_ = sdrf_AFc(
            data,
            loops=hyperparameters["loops"],
            remove_edges=remove_edges,
            removal_bound=hyperparameters["C+"],
            tau=hyperparameters["tau"],
            is_undirected=data.is_undirected(),
            k = 4,
            progress_bar= False
                        )
        edge_index_rewired = torch_geometric.utils.to_undirected(torch.tensor(list(G_rewired.edges)).t())
    
    return G_rewired,edge_index_rewired 



In [3]:

"""
Parameters for the experiment
"""

os.environ["WANDB_SILENT"] = "true"
os.environ["NUMBA_CUDA_LOW_OCCUPANCY_WARNINGS"] = "false"

datasetname = "Cornell"
results_dir = "results"
rewiring_run = True
make_undirected = True
int_node = False
Curvature_type = "BFc_w4cycle"

path = ""

dataset,data,G = load_data(datasetname)
dataset_lcc = lcc_dataset(dataset,to_undirected = make_undirected)
data_lcc = dataset_lcc[0]

data_information(dataset_lcc,data_lcc)


with open(os.path.join('experiment_utils\hyperparameters','hyperparameters_Neurips_FixedGNNParameters.json'), 'r') as file:
     sweep_configuration = json.load(file)
     sweep_configuration =sweep_configuration.get(datasetname, {})

sweep_configuration["name"] = datasetname + '_' + Curvature_type

def objective(config,rewire = False):
    val_acc = []
    test_acc = []
    if rewire:
        print("===Starting Rewiring===")
        G_rewired,edge_index_rewired = create_rewired_edge_index(data_lcc,config,intermediate_node=int_node,remove_edges=True,curvaturetype=Curvature_type)
        print(" ")

    print(" == Starting Runs == ")
    for idx_k,k in tqdm(enumerate(val_seeds[:2])):

        if datasetname == "Cora" or datasetname == "Citeseer" or datasetname == "Pubmed":
            data_undirected_split = set_train_val_test_split(k,data_lcc)
        else:
            data_undirected_split = set_train_val_test_split_frac(k,data_lcc,0.2,0.2)

        if rewire:            
            data_undirected_split.edge_index = edge_index_rewired

        data_undirected_split.to(device)

        Exp = Experiment(device,datasetname,dataset_lcc,data_undirected_split,config)

        
        counter = 0
        for epoch in range(1, Exp.epoch):
            loss = Exp.train()
            val = Exp.validate()
            
            if epoch ==1:
                best_val = val
            elif epoch > 1 and val > best_val:
                best_val = val
                counter = 0
            else:
                counter += 1
            if counter > 100:
                break  
        final_accuracy = Exp.validate()
        final_test_acc = Exp.test()
        val_acc.append(final_accuracy)
        test_acc.append(final_test_acc)
    print("")
    return np.mean(np.array(val_acc)),np.mean(np.array(test_acc))



def main():
    wandb.init(dir = "../../wandb")
    acc,test_acc = objective(wandb.config,rewiring_run)
    wandb.log({"mean accuracy": acc, "mean test accuracy": test_acc})

#sweep_id = "cwu2mmfw"# wandb.sweep(sweep=sweep_configuration, project="curvature")
#wandb.agent(sweep_id, project="curvature", function=main,count = 100)

sweep_id = wandb.sweep(sweep=sweep_configuration, project="Curvature_Neurips_FixedGNNParameters")
wandb.agent(sweep_id, function=main,count = 3)




Dataset: cornell():
Number of features: 1703
Number of classes: 5

Number of nodes: 140
Number of edges: 401
Average node degree: 2.86
Has isolated nodes: False
Has self-loops: True
Is undirected: True
Create sweep with ID: d72vz8e5
Sweep URL: https://wandb.ai/flotori/Curvature_Neurips_FixedGNNParameters/sweeps/d72vz8e5


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011111111111111112, max=1.0…

===Starting Rewiring===




 
 == Starting Runs == 


2it [00:03,  1.86s/it]



===Starting Rewiring===




 
 == Starting Runs == 


2it [00:04,  2.33s/it]





VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011111111111111112, max=1.0…

===Starting Rewiring===




 
 == Starting Runs == 


2it [00:03,  1.59s/it]



