-
Notifications
You must be signed in to change notification settings - Fork 1
/
train_helper_2d.py
201 lines (178 loc) · 8.18 KB
/
train_helper_2d.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
import torch
import random
from torch import nn, optim
from torch.utils.data import DataLoader
from data_creator_2d import GraphCreator_FS_2D
# from PDEs import *
def training_itp(itp_model: torch.nn.Module,
mesh_model: torch.nn.Module,
unrolling: list,
batch_size: int,
optimizer: torch.optim,
optimizer2: torch.optim,
loader: DataLoader,
graph_creator: GraphCreator_FS_2D,
criterion: torch.nn.modules.loss,
device: torch.cuda.device="cpu") -> torch.Tensor:
"""
One training epoch with random starting points for every trajectory
Args:
mesh_model (torch.nn.Module): moving mesh operator
unrolling (list): list of different unrolling steps for each batch entry
batch_size (int): batch size
optimizer (torch.optim): optimizer used for training
loader (DataLoader): training dataloader
graph_creator (GraphCreator_FS_2D): helper object to handle graph data
criterion (torch.nn.modules.loss): criterion for training
device (torch.cuda.device): device (cpu/gpu)
Returns:
torch.Tensor: training losses
"""
losses = []
for (u_base, u_super) in loader:
optimizer.zero_grad()
if optimizer2 != None:
optimizer2.zero_grad()
# Randomly choose number of unrollings
unrolled_graphs = random.choice(unrolling)
steps = [t for t in range(graph_creator.tw,
graph_creator.t_res - graph_creator.tw - (graph_creator.tw * unrolled_graphs) + 1)]
# Randomly choose starting (time) point at the PDE solution manifold
random_steps = random.choices(steps, k=batch_size)
data, labels = graph_creator.create_data(u_super, random_steps)
graph = graph_creator.create_graph(itp_model, data, labels, random_steps, device, mesh_model)
itp_u = graph.x
u_uni = graph_creator.interpolate_pred(itp_model, itp_u, graph, data, device)
# data_uni = graph_creator.interpolate_label(data, device)
data = data.to(device)
loss = criterion(u_uni, data.reshape(-1, 1))
loss.backward()
losses.append(loss.detach() / 2)
optimizer.step()
if optimizer2 != None:
optimizer2.step()
losses = torch.stack(losses)
return losses
def training_loop_branch(model: torch.nn.Module,
model_b: torch.nn.Module,
itp_model: torch.nn.Module,
mesh_model: torch.nn.Module,
unrolling: list,
batch_size: int,
optimizer: torch.optim,
optimizer2: torch.optim,
loader: DataLoader,
graph_creator: GraphCreator_FS_2D,
criterion: torch.nn.modules.loss,
device: torch.cuda.device="cpu") -> torch.Tensor:
"""
One training epoch with random starting points for every trajectory
Args:
model (torch.nn.Module): neural network PDE solver
model_b (torch.nn.Module): branch neural network PDE solver
mesh_model (torch.nn.Module): moving mesh operator
unrolling (list): list of different unrolling steps for each batch entry
batc-h_size (int): batch size
optimizer (torch.optim): optimizer used for training
loader (DataLoader): training dataloader
graph_creator (GraphCreator_FS_2D): helper object to handle graph data
criterion (torch.nn.modules.loss): criterion for training
device (torch.cuda.device): device (cpu/gpu)
Returns:
torch.Tensor: training losses
"""
losses = []
for idx, (u_base, u_super) in enumerate(loader):
optimizer.zero_grad()
if optimizer2 != None:
optimizer2.zero_grad()
# Randomly choose number of unrollings
unrolled_graphs = random.choice(unrolling)
steps = [t for t in range(graph_creator.tw,
graph_creator.t_res - graph_creator.tw - (graph_creator.tw * unrolled_graphs) + 1)]
# Randomly choose starting (time) point at the PDE solution manifold
random_steps = random.choices(steps, k=batch_size)
data, labels = graph_creator.create_data(u_super, random_steps)
if f'{model}' == 'GNN':
graph = graph_creator.create_graph(itp_model, data, labels, random_steps, device, mesh_model)
graph_uni = graph_creator.create_graph(itp_model, data, labels, random_steps, device, None)
else:
data, labels = data.to(device), labels.to(device)
# Unrolling of the equation which serves as input at the current step
if f'{model}' == 'GNN':
if mesh_model != None:
pred = graph_creator.interpolate_pred(itp_model, model_b(graph), graph, data, device) + model(graph_uni)
else:
pred = model(graph_uni)
# labels_uni = graph_creator.interpolate_label(labels, device)
labels = labels.to(device)
loss = criterion(pred, labels.reshape(-1, 1))
else:
pred = model(data)
loss = criterion(pred, labels.squeeze())
loss.backward()
losses.append(loss.detach())
optimizer.step()
if optimizer2 != None:
if idx % 1 == 0:
optimizer2.step()
losses = torch.stack(losses)
return losses
def test_timestep_losses(model: torch.nn.Module,
model_b: torch.nn.Module,
itp_model: torch.nn.Module,
mesh_model: torch.nn.Module,
steps: list,
batch_size: int,
loader: DataLoader,
graph_creator: GraphCreator_FS_2D,
criterion: torch.nn.modules.loss,
device: torch.cuda.device = "cpu") -> None:
"""
Loss for one neural network forward pass at certain timepoints on the validation/test datasets
Args:
model (torch.nn.Module): neural network PDE solver
model_b (torch.nn.Module): branch neural network PDE solver
mesh_model (torch.nn.Module): moving mesh operator
steps (list): input list of possible starting (time) points
batch_size (int): batch size
loader (DataLoader): dataloader [valid, test]
graph_creator (GraphCreator_FS_2D): helper object to handle graph data
criterion (torch.nn.modules.loss): criterion for training
device (torch.cuda.device): device (cpu/gpu)
Returns:
None
"""
losses_t = []
losses_uni_t = []
for step in steps:
if (step != graph_creator.tw and step % graph_creator.tw != 0):
continue
losses = []
for (u_base, u_super) in loader:
same_steps = [step]*batch_size
data, labels = graph_creator.create_data(u_super, same_steps)
if f'{model}' == 'GNN':
if mesh_model != None:
graph = graph_creator.create_graph(itp_model, data, labels, same_steps, device, mesh_model)
graph_uni = graph_creator.create_graph(itp_model, data, labels, same_steps, device, None)
with torch.no_grad():
if f'{model}' == 'GNN':
if mesh_model != None:
pred = graph_creator.interpolate_pred(itp_model, model_b(graph), graph, data, device) + model(graph_uni)
else:
pred = model(graph_uni)
labels = labels.to(device)
loss = criterion(pred, labels.reshape(-1, 1))
else:
data, labels = data.to(device), labels.to(device)
pred = model(data)
loss = criterion(pred, labels.squeeze())
losses.append(loss)
losses = torch.stack(losses)
losses_t.append(torch.mean(losses))
if step % 2 == 1:
print(f'Step {step}, time step loss {torch.mean(losses)}')
losses_t = torch.stack(losses_t)
print(f'Mean Timestep Test Error: {torch.mean(losses_t)}')
return torch.mean(losses_t)