-
Notifications
You must be signed in to change notification settings - Fork 8
/
model.py
63 lines (48 loc) · 2.19 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
import torch
from torch import nn
from torch_geometric.nn import GATConv
class GNNImageClassificator(nn.Module):
"""
See Figure 2 from https://iopscience.iop.org/article/10.1088/1742-6596/1871/1/012071/pdf
"""
def __init__(
self,
in_channels: int = 3,
hidden_dim: int = 152,
num_classes: int = 10,
):
super().__init__()
self.in_channels = in_channels
self.hidden_dim = hidden_dim
self.num_classes = num_classes
self.conv1 = GATConv(in_channels=in_channels, out_channels=hidden_dim)
self.conv2 = GATConv(in_channels=hidden_dim, out_channels=hidden_dim)
self.conv3 = GATConv(in_channels=in_channels + hidden_dim, out_channels=hidden_dim)
self.fc = nn.Sequential(
nn.Linear(in_channels + 3 * hidden_dim, 256),
nn.ReLU(True),
nn.Linear(256, 32),
nn.ReLU(True),
nn.Linear(32, 32),
nn.ReLU(True),
nn.Linear(32, num_classes),
)
def forward_one_base(self, node_features: torch.Tensor, edge_indices: torch.Tensor) -> torch.Tensor:
assert node_features.ndim == 2 and node_features.shape[1] == self.in_channels
assert edge_indices.ndim == 2 and edge_indices.shape[0] == 2
x0 = node_features
x1 = self.conv1(x0, edge_indices)
x2 = self.conv2(x1, edge_indices)
x0_x2 = torch.cat((x0, x2), dim=-1)
x3 = self.conv3(x0_x2, edge_indices)
x0_x1_x2_x3 = torch.cat((x0, x1, x2, x3), dim=-1)
return x0_x1_x2_x3
def forward(self, batch_node_features: list[torch.Tensor], batch_edge_indices: list[torch.Tensor]) -> torch.Tensor:
assert len(batch_node_features) == len(batch_edge_indices)
features_list = []
for node_features, edge_indices in zip(batch_node_features, batch_edge_indices):
features_list.append(self.forward_one_base(node_features=node_features, edge_indices=edge_indices))
features = torch.stack(features_list, dim=0) # BATCH_SIZE x NUM_NODES x NUM_FEATURES
features = features.mean(dim=1) # readout operation [BATCH_SIZE x NUM_FEATURES]
logits = self.fc(features)
return logits