-
Notifications
You must be signed in to change notification settings - Fork 11
/
models.py
124 lines (108 loc) · 4.05 KB
/
models.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
from layers import *
class SGC(nn.Module):
# for SGC we use data without normalization
def __init__(self, nfeat, nhid, nclass, dropout, nlayer=2, norm_mode='None', norm_scale=10, **kwargs):
super(SGC, self).__init__()
self.linear = torch.nn.Linear(nfeat, nclass)
self.norm = PairNorm(norm_mode, norm_scale)
self.dropout = nn.Dropout(p=dropout)
self.nlayer = nlayer
def forward(self, x, adj):
x = self.norm(x)
for _ in range(self.nlayer):
x = adj.mm(x)
x = self.norm(x)
x = self.dropout(x)
x = self.linear(x)
return x
class GCN(nn.Module):
def __init__(self, nfeat, nhid, nclass, dropout,
norm_mode='None', norm_scale=1, **kwargs):
super(GCN, self).__init__()
self.gc1 = GraphConv(nfeat, nhid)
self.gc2 = GraphConv(nhid, nclass)
self.dropout = nn.Dropout(p=dropout)
self.relu = nn.ReLU(True)
self.norm = PairNorm(norm_mode, norm_scale)
def forward(self, x, adj):
x = self.dropout(x)
x = self.gc1(x, adj)
x = self.norm(x)
x = self.relu(x)
x = self.dropout(x)
x = self.gc2(x, adj)
return x
class GAT(nn.Module):
def __init__(self, nfeat, nhid, nclass, dropout, nhead,
norm_mode='None', norm_scale=1,**kwargs):
super(GAT, self).__init__()
alpha_droprate = dropout
self.gac1 = GraphAttConv(nfeat, nhid, nhead, alpha_droprate)
self.gac2 = GraphAttConv(nhid, nclass, 1, alpha_droprate)
self.dropout = nn.Dropout(p=dropout)
self.relu = nn.ELU(True)
self.norm = PairNorm(norm_mode, norm_scale)
def forward(self, x, adj):
x = self.dropout(x) # ?
x = self.gac1(x, adj)
x = self.norm(x)
x = self.relu(x)
x = self.dropout(x)
x = self.gac2(x, adj)
return x
class DeepGCN(nn.Module):
def __init__(self, nfeat, nhid, nclass, dropout, nlayer=2, residual=0,
norm_mode='None', norm_scale=1, **kwargs):
super(DeepGCN, self).__init__()
assert nlayer >= 1
self.hidden_layers = nn.ModuleList([
GraphConv(nfeat if i==0 else nhid, nhid)
for i in range(nlayer-1)
])
self.out_layer = GraphConv(nfeat if nlayer==1 else nhid , nclass)
self.dropout = nn.Dropout(p=dropout)
self.dropout_rate = dropout
self.relu = nn.ReLU(True)
self.norm = PairNorm(norm_mode, norm_scale)
self.skip = residual
def forward(self, x, adj):
x_old = 0
for i, layer in enumerate(self.hidden_layers):
x = self.dropout(x)
x = layer(x, adj)
x = self.norm(x)
x = self.relu(x)
if self.skip>0 and i%self.skip==0:
x = x + x_old
x_old = x
x = self.dropout(x)
x = self.out_layer(x, adj)
return x
class DeepGAT(nn.Module):
def __init__(self, nfeat, nhid, nclass, dropout, nlayer=2, residual=0, nhead=1,
norm_mode='None', norm_scale=1, **kwargs):
super(DeepGAT, self).__init__()
assert nlayer >= 1
alpha_droprate = dropout
self.hidden_layers = nn.ModuleList([
GraphAttConv(nfeat if i==0 else nhid, nhid, nhead, alpha_droprate)
for i in range(nlayer-1)
])
self.out_layer = GraphAttConv(nfeat if nlayer==1 else nhid, nclass, 1, alpha_droprate)
self.dropout = nn.Dropout(p=dropout)
self.relu = nn.ELU(True)
self.norm = PairNorm(norm_mode, norm_scale)
self.skip = residual
def forward(self, x, adj):
x_old = 0
for i, layer in enumerate(self.hidden_layers):
x = self.dropout(x)
x = layer(x, adj)
x = self.norm(x)
x = self.relu(x)
if self.skip>0 and i%self.skip==0:
x = x + x_old
x_old = x
x = self.dropout(x)
x = self.out_layer(x, adj)
return x