In [1]:
import os
import time
import torch
import numpy as np
from network import Net1, Net2
import torch.nn.functional as F
from utils import splits_regression
from torch_geometric.loader import DataLoader
from torch_geometric.datasets import WikipediaNetwork
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda')

In [2]:
import torch.nn.functional as F
from torch_geometric.nn import GCNConv

class Net2(torch.nn.Module):
    def __init__(self, num_features, hidden, num_layers):
        super(Net2, self).__init__()
        self.num_layers = num_layers
        self.conv = torch.nn.ModuleList()
        self.conv.append(GCNConv(num_features, hidden))
        for i in range(self.num_layers - 1):
            self.conv.append(GCNConv(hidden, hidden))
        self.lt1 = torch.nn.Linear(hidden, 1)

    def reset_parameters(self):
        for module in self.conv:
            module.reset_parameters()
        self.lt1.reset_parameters()

    def forward(self, x, edge_index):
        for i in range(self.num_layers):
            x = self.conv[i](x, edge_index)
            x = F.elu(x)
            x = F.dropout(x, training=self.training)
        x = self.lt1(x)
        return x

In [3]:
dataset_names = ['chameleon', "squirrel", "crocodile"]
for dataset_name in dataset_names:
    dataset = WikipediaNetwork(root='./dataset', name=dataset_name, geom_gcn_preprocess=False)
    data = splits_regression(dataset[0], 0.2, 0.3)
    model = Net2(data.x.shape[1], 512, 2).to(device)
    loss_fn = torch.nn.L1Loss()
    model.reset_parameters()
    optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)
    all_loss = []
    avg_time = 0
    for run in range(20):
        best_val_loss = 100000
        model.reset_parameters()
        data = data.to(device)
        for epoch in range(100):
            model.train()
            optimizer.zero_grad()
            out = model(data.x, data.edge_index)
            loss = loss_fn(out[data.train_mask].view(-1, 1), data.y[data.train_mask].view(-1, 1))
            loss.backward()
            optimizer.step()

            model.eval()
            with torch.no_grad():
                out = model(data.x, data.edge_index)
                val_loss = loss_fn(out[data.val_mask].view(-1, 1), data.y[data.val_mask].view(-1, 1))
                if val_loss < best_val_loss:
                    best_val_loss = val_loss
                    #save model
                    torch.save(model.state_dict(), f'save/node_regr_3/baselines/{dataset_name}_baseline_model.pt')
        start = time.time()
        out = model(data.x, data.edge_index)
        avg_time += time.time() - start
        test_loss = loss_fn(out[data.test_mask].view(-1, 1), data.y[data.test_mask].view(-1, 1))
        all_loss.append(test_loss.item()/data.y[data.test_mask].std().item())

    print(f"##########{dataset_name}##########")
    print(all_loss)
    print(sum(all_loss)/len(all_loss))
    print(avg_time/20)
    top_loss = sorted(all_loss)[:10]

    if not os.path.exists(f"results_3/baselines.csv"):
        with open(f"results_3/baelines.csv", 'w') as f:
            f.write('dataset,hidden,runs,num_layers,lr,top_10_loss,best_loss\n')

    with open(f"results_3/baselines.csv", 'a') as f:
        f.write(f"{dataset_name},512,20,2,0.01,{np.mean(top_loss)} +/- {np.std(top_loss)},{top_loss[0]}\n")

##########chameleon##########
[0.5805188905711305, 0.563514104340165, 0.5896040951614686, 0.5796865384435592, 0.6488201057997373, 0.5614945711502958, 0.5613331052866322, 0.5654249120489501, 0.5669812757893047, 0.550476230805993, 0.5703349371762502, 0.5651063799364365, 0.5900131493487745, 0.5731494807629592, 0.5723054146609464, 0.5907345211926696, 0.5550328018781956, 0.5596122509474295, 0.5760935930059323, 0.5677663320576284]
0.5744001345182228
1.5120256066322326
##########squirrel##########
[0.6583129494902763, 0.6489266371959453, 0.6639313396271047, 0.6548170256520042, 0.6666118181380352, 0.6689605260292223, 0.66460992699741, 0.6599452829246271, 0.6674846311331104, 0.6504289143488354, 0.6605970758153921, 0.6741779608260587, 0.6517506379319004, 0.6696449532219915, 0.6514133710086617, 0.6564704511510187, 0.6537905909742314, 0.6492191779494899, 0.6493877770592125, 0.6651732981056921]
0.6592827172790112
4.034725916385651
##########crocodile##########
[0.4100639721304551, 0.383353340718271