-
Notifications
You must be signed in to change notification settings - Fork 1
/
gnn_2d.py
141 lines (118 loc) · 5.25 KB
/
gnn_2d.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
import torch
import sys
from torch import nn
from torch.nn import functional as F
from torch_geometric.nn import MessagePassing, global_mean_pool, InstanceNorm, avg_pool_x, BatchNorm
# from einops import rearrange
from IPython import embed
class Swish(nn.Module):
def __init__(self, beta=1):
super(Swish, self).__init__()
self.beta = beta
def forward(self, x):
return x * torch.sigmoid(self.beta*x)
class GNN_Layer_FS_2D(MessagePassing):
"""
Parameters
----------
in_features : int
Dimensionality of input features.
out_features : int
Dimensionality of output features.
hidden_features : int
Dimensionality of hidden features.
"""
def __init__(self,
in_features,
out_features,
hidden_features,
time_window,
n_variables):
super(GNN_Layer_FS_2D, self).__init__(node_dim=-2, aggr='mean')
self.message_net_1 = nn.Sequential(nn.Linear(2 * in_features + time_window + 2 + n_variables, hidden_features),
nn.ReLU()
)
self.message_net_2 = nn.Sequential(nn.Linear(hidden_features, out_features),
nn.ReLU()
)
self.update_net_1 = nn.Sequential(nn.Linear(in_features + hidden_features + n_variables, hidden_features),
nn.ReLU()
)
self.update_net_2 = nn.Sequential(nn.Linear(hidden_features, out_features),
nn.ReLU()
)
self.norm = BatchNorm(hidden_features)
def forward(self, x, u, pos_x, pos_y, variables, edge_index, batch):
""" Propagate messages along edges """
x = self.propagate(edge_index, x=x, u=u, pos_x=pos_x, pos_y=pos_y, variables=variables)
x = self.norm(x)
return x
def message(self, x_i, x_j, u_i, u_j, pos_x_i, pos_x_j, pos_y_i, pos_y_j, variables_i):
""" Message update """
message = self.message_net_1(torch.cat((x_i, x_j, u_i - u_j, pos_x_i - pos_x_j, pos_y_i - pos_y_j, variables_i), dim=-1))
message = self.message_net_2(message)
return message
def update(self, message, x, variables):
""" Node update """
update = self.update_net_1(torch.cat((x, message, variables), dim=-1))
update = self.update_net_2(update)
return x + update
class MP_PDE_Solver_2D(torch.nn.Module):
def __init__(
self,
pde,
time_window=1,
hidden_features=128,
hidden_layer=6,
eq_variables={}
):
super(MP_PDE_Solver_2D, self).__init__()
self.pde = pde
self.out_features = time_window
self.hidden_features = hidden_features
self.hidden_layer = hidden_layer
self.time_window = time_window
self.eq_variables = eq_variables
# in_features have to be of the same size as out_features for the time being
self.gnn_layers = torch.nn.ModuleList(modules=(GNN_Layer_FS_2D(
in_features=self.hidden_features,
hidden_features=self.hidden_features,
out_features=self.hidden_features,
time_window=self.time_window,
n_variables=len(self.eq_variables) + 1 # variables = eq_variables + time
) for _ in range(self.hidden_layer)))
self.embedding_mlp = nn.Sequential(
nn.Linear(self.time_window + 3 + len(self.eq_variables), self.hidden_features),
nn.BatchNorm1d(self.hidden_features),
nn.ReLU(),
nn.Linear(self.hidden_features, self.hidden_features),
nn.BatchNorm1d(self.hidden_features)
#Swish()
)
self.output_mlp = nn.Sequential(nn.Conv1d(1, 4, 16, stride=3),
# nn.BatchNorm1d(8),
nn.ReLU(),
nn.Conv1d(4, 8, 12, stride=3),
nn.ReLU(),
nn.Conv1d(8, 1, 8, stride=2)
)
def __repr__(self):
return f'GNN'
def forward(self, data):
u = data.x
pos = data.pos
pos_x = pos[:, 1][:, None]/self.pde.Lx
pos_y = pos[:, 2][:, None]/self.pde.Ly
pos_t = pos[:, 0][:, None]/self.pde.tmax
edge_index = data.edge_index
batch = data.batch
variables = pos_t # we put the time as equation variable
node_input = torch.cat((u, pos_x, pos_y, variables), -1)
h = self.embedding_mlp(node_input)
for i in range(self.hidden_layer):
h = self.gnn_layers[i](h, u, pos_x, pos_y, variables, edge_index, batch)
diff = self.output_mlp(h[:, None]).squeeze(1)
dt = (torch.ones(1, self.time_window) * self.pde.dt * 0.1).to(h.device)
dt = torch.cumsum(dt, dim=1)
out = dt * diff
return out