In [1]:
import os
import time
import torch
import numpy as np
from network import Net1, Net2
import torch.nn.functional as F
from utils import splits_regression
from torch_geometric.loader import DataLoader
from torch_geometric.datasets import WikipediaNetwork
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda')

In [2]:
import torch.nn.functional as F
from torch_geometric.nn import GCNConv

class Net2(torch.nn.Module):
    def __init__(self, num_features, hidden, num_layers):
        super(Net2, self).__init__()
        self.num_layers = num_layers
        self.conv = torch.nn.ModuleList()
        self.conv.append(GCNConv(num_features, hidden))
        for i in range(self.num_layers - 1):
            self.conv.append(GCNConv(hidden, hidden))
        self.lt1 = torch.nn.Linear(hidden, 1)

    def reset_parameters(self):
        for module in self.conv:
            module.reset_parameters()
        self.lt1.reset_parameters()

    def forward(self, x, edge_index):
        for i in range(self.num_layers):
            x = self.conv[i](x, edge_index)
            x = F.elu(x)
            x = F.dropout(x, training=self.training)
        x = self.lt1(x)
        return x

In [4]:
dataset_names = ['chameleon', "squirrel", "crocodile"]
for dataset_name in dataset_names:
    dataset = WikipediaNetwork(root='./dataset', name=dataset_name, geom_gcn_preprocess=False)
    data = splits_regression(dataset[0], 0.2, 0.3)
    model = Net2(data.x.shape[1], 512, 2).to(device)
    loss_fn = torch.nn.L1Loss()
    model.reset_parameters()
    optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)
    all_loss = []
    avg_time = 0
    for run in range(20):
        best_val_loss = 100000
        model.reset_parameters()
        data = data.to(device)
        for epoch in range(100):
            model.train()
            optimizer.zero_grad()
            out = model(data.x, data.edge_index)
            loss = loss_fn(out[data.train_mask].view(-1, 1), data.y[data.train_mask].view(-1, 1))
            loss.backward()
            optimizer.step()

            model.eval()
            with torch.no_grad():
                out = model(data.x, data.edge_index)
                val_loss = loss_fn(out[data.val_mask].view(-1, 1), data.y[data.val_mask].view(-1, 1))
                if val_loss < best_val_loss:
                    best_val_loss = val_loss
                    #save model
                    torch.save(model.state_dict(), f'save/node_regr_3/baselines/{dataset_name}_baseline_model.pt')
        start = time.time()
        out = model(data.x, data.edge_index)
        avg_time += time.time() - start
        test_loss = loss_fn(out[data.test_mask].view(-1, 1), data.y[data.test_mask].view(-1, 1))
        all_loss.append(test_loss.item()/data.y[data.test_mask].std().item())

    print(f"##########{dataset_name}##########")
    print(all_loss)
    print(sum(all_loss)/len(all_loss))
    print(avg_time/20)
    top_loss = sorted(all_loss)[:10]

    if not os.path.exists(f"results_3/baselines.csv"):
        with open(f"results_3/baelines.csv", 'w') as f:
            f.write('dataset,hidden,runs,num_layers,lr,avg_time,top_10_loss,best_loss\n')

    with open(f"results_3/baselines.csv", 'a') as f:
        f.write(f"{dataset_name},512,20,2,0.01,{avg_time/20},{np.mean(top_loss)} +/- {np.std(top_loss)},{top_loss[0]}\n")

##########chameleon##########
[0.5542109280863412, 0.5476571426535451, 0.5821994771802472, 0.5656288211347624, 0.5576801371026772, 0.5880737951240456, 0.5954102625255281, 0.5708912008041599, 0.5532072312932941, 0.5585349689989131, 0.5513660720310137, 0.5777760074323333, 0.5569076429709826, 0.5558179897827686, 0.5339317323620706, 0.5733611110554399, 0.5622672802973467, 0.5330684384372949, 0.5532056724985631, 0.5465342537348165]
0.5608865082753073
0.002548110485076904
##########squirrel##########
[0.6468659435179561, 0.6665409743062363, 0.6570845787301263, 0.6499941530666998, 0.6515432896918232, 0.6638391565447899, 0.6509730967563887, 0.6597220468142232, 0.6461966534158452, 0.6487787259833039, 0.6539380599273911, 0.6496815593069479, 0.6534733790797395, 0.6578310483517996, 0.6491654254462547, 0.6625796268787867, 0.6419095482759566, 0.6717810961630085, 0.6683477769291372, 0.6521001182082858]
0.655117312869735
0.006338047981262207
##########crocodile##########
[0.40969306639895303, 0.382740