-
Notifications
You must be signed in to change notification settings - Fork 7
/
Copy pathgcn.py
31 lines (25 loc) · 900 Bytes
/
gcn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
import torch.nn as nn
import torch.nn.functional as F
class GCN(nn.Module):
def __init__(self,
in_size,
out_size,
num_layers,
hidden_size,
dropout):
super().__init__()
self.lins = nn.ModuleList()
if num_layers >= 2:
self.lins.append(nn.Linear(in_size, hidden_size))
for _ in range(num_layers - 2):
self.lins.append(nn.Linear(hidden_size, hidden_size))
self.lins.append(nn.Linear(hidden_size, out_size))
else:
self.lins.append(nn.Linear(in_size, out_size))
self.dropout = dropout
def forward(self, A, H):
for lin in self.lins[:-1]:
H = A @ lin(H)
H = F.relu(H)
H = F.dropout(H, p=self.dropout, training=self.training)
return A @ self.lins[-1](H)