In [8]:
import os
import time
import torch
import numpy as np
from network import Net1, Net2
import torch.nn.functional as F
from utils import splits_regression
from torch_geometric.loader import DataLoader
from torch_geometric.datasets import WikipediaNetwork
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda')

In [9]:
import torch.nn.functional as F
from torch_geometric.nn import GCNConv

class Net2(torch.nn.Module):
    def __init__(self, num_features, hidden, num_layers):
        super(Net2, self).__init__()
        self.num_layers = num_layers
        self.conv = torch.nn.ModuleList()
        self.conv.append(GCNConv(num_features, hidden))
        for i in range(self.num_layers - 1):
            self.conv.append(GCNConv(hidden, hidden))
        self.lt1 = torch.nn.Linear(hidden, 1)

    def reset_parameters(self):
        for module in self.conv:
            module.reset_parameters()
        self.lt1.reset_parameters()

    def forward(self, x, edge_index):
        for i in range(self.num_layers):
            x = self.conv[i](x, edge_index)
            x = F.elu(x)
            x = F.dropout(x, training=self.training)
        x = self.lt1(x)
        return x

In [12]:
dataset_names = ['chameleon', "squirrel", "crocodile"]
for dataset_name in dataset_names:
    dataset = WikipediaNetwork(root='./dataset', name=dataset_name, geom_gcn_preprocess=False)
    data = splits_regression(dataset[0], 0.2, 0.3)
    model = Net2(data.x.shape[1], 512, 2).to(device)
    loss_fn = torch.nn.L1Loss()
    model.reset_parameters()
    optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)
    all_loss = []
    for run in range(20):
        best_val_loss = 100000
        model.reset_parameters()
        data = data.to(device)
        for epoch in range(100):
            model.train()
            optimizer.zero_grad()
            out = model(data.x, data.edge_index)
            loss = loss_fn(out[data.train_mask].view(-1, 1), data.y[data.train_mask].view(-1, 1))
            loss.backward()
            optimizer.step()

            model.eval()
            with torch.no_grad():
                out = model(data.x, data.edge_index)
                val_loss = loss_fn(out[data.val_mask].view(-1, 1), data.y[data.val_mask].view(-1, 1))
                if val_loss < best_val_loss:
                    best_val_loss = val_loss
                    #save model
                    torch.save(model.state_dict(), f'save/node_regr_3/baselines/{dataset_name}_baseline_model.pt')
            
        out = model(data.x, data.edge_index)
        test_loss = loss_fn(out[data.test_mask].view(-1, 1), data.y[data.test_mask].view(-1, 1))
        all_loss.append(test_loss.item()/data.y[data.test_mask].std().item())

    print(f"##########{dataset_name}##########")
    print(all_loss)
    print(sum(all_loss)/len(all_loss))
    top_loss = sorted(all_loss)[:10]

    if not os.path.exists(f"results_3/baselines.csv"):
        with open(f"results_3/baelines.csv", 'w') as f:
            f.write('dataset,hidden,runs,num_layers,lr,top_10_loss,best_loss\n')

    with open(f"results_3/baselines.csv", 'a') as f:
        f.write(f"{dataset_name},512,20,2,0.01,{np.mean(top_loss)} +/- {np.std(top_loss)},{top_loss[0]}\n")

##########chameleon##########
[0.5518204942304205, 0.5725686377116527, 0.527127460376401, 0.6272746001457263, 0.5580381190713856, 0.5374346504036874, 0.5619520844254143, 0.5579015288605618, 0.5334654496870863, 0.5485142874717378, 0.5686058255272631, 0.5509294918618015, 0.5369901103259938, 0.5465972755228081, 0.5841896039471858, 0.5412594024494877, 0.594582954355803, 0.5564494098353093, 0.5638471605292282, 0.625656435818657]
0.5622602491278805
##########squirrel##########
[0.6680730518748575, 0.642920393343751, 0.6377924423482937, 0.6556296501801275, 0.6506518562267309, 0.6636130898487476, 0.6618066017917837, 0.6468744385189505, 0.6663209940531583, 0.6405619529135137, 0.6615383094754239, 0.6792426235860832, 0.6362052323162326, 0.6529953351122001, 0.6469621768086308, 0.6551953591860703, 0.6611102468167451, 0.6574498811945026, 0.6664678338296369, 0.6619160715559912]
0.6556663770490715
##########crocodile##########
[0.4310035272274854, 0.376981601345508, 0.383104700753609, 0.37368957030948