-
Notifications
You must be signed in to change notification settings - Fork 1
/
utils.py
87 lines (75 loc) · 2.88 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
import torch
import math
import numpy as np
import random
def set_seed(seed):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
def cheby(i,x):
if i==0:
return 1
elif i==1:
return x
else:
T0=1
T1=x
for ii in range(2,i+1):
T2=2*x*T1-T0
T0,T1=T1,T2
return T2
def index_to_mask(index, size):
mask = torch.zeros(size, dtype=torch.bool)
mask[index] = 1
return mask
def random_splits(data, num_classes, percls_trn, val_lb, seed=42):
index=[i for i in range(0,data.y.shape[0])]
train_idx=[]
rnd_state = np.random.RandomState(seed)
for c in range(num_classes):
class_idx = np.where(data.y.cpu() == c)[0]
if len(class_idx)<percls_trn:
train_idx.extend(class_idx)
else:
train_idx.extend(rnd_state.choice(class_idx, percls_trn,replace=False))
rest_index = [i for i in index if i not in train_idx]
val_idx=rnd_state.choice(rest_index,val_lb,replace=False)
test_idx=[i for i in rest_index if i not in val_idx]
data.train_mask = index_to_mask(train_idx,size=data.num_nodes)
data.val_mask = index_to_mask(val_idx,size=data.num_nodes)
data.test_mask = index_to_mask(test_idx,size=data.num_nodes)
return data
def fixed_splits(data, num_classes, percls_trn, val_lb, name, seed=42):
if name in ["Chameleon","Squirrel", "Actor"]:
seed = 1941488137
index=[i for i in range(0,data.y.shape[0])]
train_idx=[]
rnd_state = np.random.RandomState(seed)
for c in range(num_classes):
class_idx = np.where(data.y.cpu() == c)[0]
if len(class_idx)<percls_trn:
train_idx.extend(class_idx)
else:
train_idx.extend(rnd_state.choice(class_idx, percls_trn,replace=False))
rest_index = [i for i in index if i not in train_idx]
val_idx=rnd_state.choice(rest_index,val_lb,replace=False)
test_idx=[i for i in rest_index if i not in val_idx]
data.train_mask = index_to_mask(train_idx,size=data.num_nodes)
data.val_mask = index_to_mask(val_idx,size=data.num_nodes)
data.test_mask = index_to_mask(test_idx,size=data.num_nodes)
return data
def random_splits_citation(data, num_classes):
indices = []
for i in range(num_classes):
index = (data.y == i).nonzero().view(-1)
index = index[torch.randperm(index.size(0))]
indices.append(index)
train_index = torch.cat([i[:20] for i in indices], dim=0)
rest_index = torch.cat([i[20:] for i in indices], dim=0)
rest_index = rest_index[torch.randperm(rest_index.size(0))]
data.train_mask = index_to_mask(train_index, size=data.num_nodes)
data.val_mask = index_to_mask(rest_index[:500], size=data.num_nodes)
data.test_mask = index_to_mask(rest_index[500:1500], size=data.num_nodes)
return data