-
Notifications
You must be signed in to change notification settings - Fork 2
/
main.py
82 lines (76 loc) · 4.35 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
import argparse
from train import *
from model import NE_WNA
from utils import *
from datasets import *
import warnings
parser = argparse.ArgumentParser()
parser.add_argument('--dataset', type=str, required=True)
parser.add_argument('--random_splits', type=bool, default=False)
parser.add_argument('--runs', type=int, default=100)
parser.add_argument('--epochs', type=int, default=1000)
parser.add_argument('--batch_size', type=int, default=2000)
parser.add_argument('--lr', type=float, default=0.01)
parser.add_argument('--weight_decay', type=float, default=0.005)
parser.add_argument('--early_stopping', type=int, default=100)
parser.add_argument('--hidden', type=int, default=64)
parser.add_argument('--hidden_z', type=int, default=64)
parser.add_argument('--dropout', type=float, default=0.5)
parser.add_argument('--normalize_features', type=bool, default=False)
parser.add_argument('--order', type=int, default=4)
parser.add_argument('--alpha', type=float, default=1)
parser.add_argument('--beta', type=float, default=0.5)
parser.add_argument('--tau', type=float, default=0.5)
args = parser.parse_args()
warnings.filterwarnings("ignore", category=UserWarning)
warnings.filterwarnings("ignore", category=RuntimeWarning)
if args.dataset == "cora" or args.dataset == "citeseer" or args.dataset == "pubmed":
adj, features, labels, idx_train, idx_val, idx_test = load_citation(args.dataset, True)
print("Precompute the enhanced adjacency matrix...")
adj = normalize_adj(adj + sp.eye(adj.shape[0]))
sp_A_hat = sparse_mx_to_torch_sparse_tensor(adj).float().cuda()
if args.dataset == "pubmed":
enhanced_adj = get_A_hat_k_power(sp_A_hat, args.order)
print("Precompute Finish!")
model = NE_WNA(args.hidden, args.hidden_z, features.shape[1], labels.max().item() + 1, args.dropout)
print(f'Dataset:{args.dataset}')
run_citation(model, args.runs, args.epochs, args.lr, args.weight_decay, args.early_stopping, args.alpha,
args.beta,
args.tau, args.batch_size, enhanced_adj.to_dense(),
features, labels, idx_train, idx_val,
idx_test)
else:
enhanced_adj = get_ehanced_A_hat_k_power(sp_A_hat, args.order)
print("Precompute Finish!")
model = NE_WNA(args.hidden, args.hidden_z, features.shape[1], labels.max().item() + 1, args.dropout)
print(f'Dataset:{args.dataset}')
run_citation(model, args.runs, args.epochs, args.lr, args.weight_decay, args.early_stopping, args.alpha,
args.beta,
args.tau, args.batch_size, enhanced_adj.to_dense(),
features, labels, idx_train, idx_val,
idx_test)
elif args.dataset == "computers" or args.dataset == "photo":
dataset = get_amazon_dataset(args.dataset, args.normalize_features)
permute_masks = random_coauthor_amazon_splits
print("Precompute the enhanced adjacency matrix...")
A = get_adj_ori(dataset[0])
A_hat = normalize_adj(A + sp.eye(A.shape[0]))
sp_A_hat = sparse_mx_to_torch_sparse_tensor(A_hat).float()
enhanced_adj = get_ehanced_A_hat_k_power(sp_A_hat, args.order)
print("Precompute Finish!")
print(f'Dataset:{args.dataset}')
print("Dataset:", dataset[0])
model = NE_WNA(args.hidden, args.hidden_z, dataset.num_features, dataset.num_classes, args.dropout)
run_coauthor_amazon(dataset, model, args.runs, args.epochs, args.lr, args.weight_decay, args.early_stopping,
args.alpha, args.beta, args.tau, enhanced_adj.to_dense().to(device), permute_masks, lcc=True)
elif args.dataset == "cs":
dataset = get_coauthor_dataset(args.dataset, args.normalize_features)
permute_masks = random_coauthor_amazon_splits
A = get_adj_ori(dataset[0])
A_hat = normalize_adj(A + sp.eye(A.shape[0]))
sp_A_hat = sparse_mx_to_torch_sparse_tensor(A_hat).float()
enhanced_adj = get_ehanced_A_hat_k_power(sp_A_hat, args.order)
print("Dataset:", dataset[0])
model = NE_WNA(args.hidden, args.hidden_z, dataset.num_features, dataset.num_classes, args.dropout)
run_coauthor_amazon(dataset, model, args.runs, args.epochs, args.lr, args.weight_decay, args.early_stopping,
args.alpha, args.beta, args.tau, enhanced_adj.to_dense().to(device), permute_masks, lcc=False)