In [2]:
from dynabench.dataset import DynabenchIterator, download_equation
from torch.utils.data import DataLoader
from dynabench.model import NeuralPDE
from models.MamPDE import MambaPDE

import torch
import torch.optim as optim
import torch.nn as nn
from torchdiffeq import odeint_adjoint, odeint

import os
import shutil
import glob
import hashlib
from typing import List, Optional, Tuple

In [3]:
class NeuralPDE_ssm(nn.Module):
    """
        Neural PDE model for grid data. The model combines a CNN with a differentiable ODE solver to learn the dynamics of the data using the method of lines. The CNN is used to approximate the spatial derivatives of the data, while the ODE solver is used to approximate the temporal evolution of the data.
        The model has been taken from `NeuralPDE: Modelling Dynamical Systems from Data <https://arxiv.org/abs/2111.07671>`_ by Dulny et al.

        Parameters
        ----------
        input_dim : int
            Number of input channels.
        hidden_channels : int
            Number of channels in each hidden layer of the CNN. Default is 64.
        hidden_layers : int 
            Number of hidden layers in the CNN. Default is 1.
        solver : dict
            Dictionary of solver parameters. Default is {"method": "dopri5"}.
        use_adjoint : bool
            Whether to use the adjoint method for backpropagation. Default is True.

    """
    def __init__(self,
                 input_dim: int,
                 hidden_channels: int = 64,
                 hidden_layers: int = 1,
                 solver: dict = {"method": "dopri5"},
                 use_adjoint: bool = True,
                 *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)
        self.input_dim = input_dim

        self.ssm = MambaPDE(input_dim=15*15, d_model=200, output_dim=15*15, n_layers=3)
        
        self.solver = solver
        self.use_adjoint = use_adjoint

    def _ssm(self, t, x):
        return self.ssm(x)
    
    def forward(self, 
                x: torch.Tensor, 
                t_eval: List[float]=[0.0, 1.0]):
        
        """
            Forward pass of the model. Should not be called directly, instead call the model instance.

            Parameters
            ----------
            x : torch.Tensor
                Input tensor of shape (batch_size, input_size, height, width).
            t_eval : List[float], default [0.0, 1.0]
                List of times to evaluate the ODE solver at. Default is [0.0, 1.0].

            Returns
            -------
            torch.Tensor
                Output tensor of shape (batch_size, rollout, output_size, height, width).
        """

        t_eval = torch.tensor(t_eval, dtype=x.dtype, device=x.device)
        if self.use_adjoint:
            pred = odeint_adjoint(self._ssm, x, t_eval, **self.solver, adjoint_params=self.cnn.parameters())[1:]
        else:
            pred =  odeint(self._ssm, x, t_eval, **self.solver)[1:]
        
        pred = torch.swapaxes(pred, 0, 1)
        return pred

In [4]:
burgers_train_iterator = DynabenchIterator(split="train",
                                           equation='burgers',
                                           structure='grid',
                                           resolution='low',
                                           lookback=2,
                                           rollout=1)

  burgers_train_iterator = DynabenchIterator(split="train",


RuntimeError: Dataset not found. You can use download=True to download it

In [12]:
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
# device = 'cpu'
print(f"Using device: {device}")

train_loader = DataLoader(burgers_train_iterator, batch_size=32, shuffle=True)

model = NeuralPDE_ssm(input_dim=2, hidden_channels=64, hidden_layers=3,
                solver={'method': 'euler', 'options': {'step_size': 0.1}},
                use_adjoint=False).to(device)

optimizer = optim.Adam(model.parameters(), lr=1e-3)
criterion = nn.MSELoss()

Using device: mps


In [13]:
for epoch in range(5):
    model.train()
    for i, (x, y, p) in enumerate(train_loader):
        print(x.shape)
        x, y = x[:,0].float().to(device), y.float().to(device) # only use the first channel and convert to float32
        print(x.shape)
        # print(y.shape)
        # optimizer.zero_grad()
        # y_pred = model(x)
        # loss = criterion(y_pred, y)
        # loss.backward()
        # optimizer.step()
        # print(f"Epoch: {epoch}, Batch: {i}, Loss: {loss.item()}")

torch.Size([32, 2, 2, 15, 15])
torch.Size([32, 2, 15, 15])
torch.Size([32, 2, 2, 15, 15])
torch.Size([32, 2, 15, 15])
torch.Size([32, 2, 2, 15, 15])
torch.Size([32, 2, 15, 15])
torch.Size([32, 2, 2, 15, 15])
torch.Size([32, 2, 15, 15])
torch.Size([32, 2, 2, 15, 15])
torch.Size([32, 2, 15, 15])
torch.Size([32, 2, 2, 15, 15])
torch.Size([32, 2, 15, 15])
torch.Size([32, 2, 2, 15, 15])
torch.Size([32, 2, 15, 15])
torch.Size([32, 2, 2, 15, 15])
torch.Size([32, 2, 15, 15])
torch.Size([32, 2, 2, 15, 15])
torch.Size([32, 2, 15, 15])
torch.Size([32, 2, 2, 15, 15])
torch.Size([32, 2, 15, 15])
torch.Size([32, 2, 2, 15, 15])
torch.Size([32, 2, 15, 15])
torch.Size([32, 2, 2, 15, 15])
torch.Size([32, 2, 15, 15])
torch.Size([32, 2, 2, 15, 15])
torch.Size([32, 2, 15, 15])
torch.Size([32, 2, 2, 15, 15])
torch.Size([32, 2, 15, 15])
torch.Size([32, 2, 2, 15, 15])
torch.Size([32, 2, 15, 15])
torch.Size([32, 2, 2, 15, 15])
torch.Size([32, 2, 15, 15])
torch.Size([32, 2, 2, 15, 15])
torch.Size([32, 2, 15, 15

KeyboardInterrupt: 

In [6]:
for epoch in range(5):
    model.train()
    for i, (x, y, p) in enumerate(train_loader):
        x, y = x[:,0].float().to(device), y.float().to(device) # only use the first channel and convert to float32
        optimizer.zero_grad()
        y_pred = model(x)
        loss = criterion(y_pred, y)
        loss.backward()
        optimizer.step()
        print(f"Epoch: {epoch}, Batch: {i}, Loss: {loss.item()}")

Epoch: 0, Batch: 0, Loss: 0.01726127229630947
Epoch: 0, Batch: 1, Loss: 0.01126931793987751
Epoch: 0, Batch: 2, Loss: 0.013541501946747303
Epoch: 0, Batch: 3, Loss: 0.011269127018749714
Epoch: 0, Batch: 4, Loss: 0.01682719588279724
Epoch: 0, Batch: 5, Loss: 0.007590473163872957
Epoch: 0, Batch: 6, Loss: 0.013149096630513668
Epoch: 0, Batch: 7, Loss: 0.014652708545327187
Epoch: 0, Batch: 8, Loss: 0.008443024009466171
Epoch: 0, Batch: 9, Loss: 0.019929584115743637
Epoch: 0, Batch: 10, Loss: 0.006708018481731415
Epoch: 0, Batch: 11, Loss: 0.015017409808933735
Epoch: 0, Batch: 12, Loss: 0.008847776800394058
Epoch: 0, Batch: 13, Loss: 0.009199068881571293
Epoch: 0, Batch: 14, Loss: 0.015973132103681564
Epoch: 0, Batch: 15, Loss: 0.010173874907195568
Epoch: 0, Batch: 16, Loss: 0.01763436757028103
Epoch: 0, Batch: 17, Loss: 0.00655487272888422
Epoch: 0, Batch: 18, Loss: 0.01763397455215454
Epoch: 0, Batch: 19, Loss: 0.015493819490075111
Epoch: 0, Batch: 20, Loss: 0.024265464395284653
Epoch: 0

In [7]:
burgers_test_iterator = DynabenchIterator(split="test",
                                          equation='burgers',
                                          structure='grid',
                                          resolution='low',
                                          lookback=1,
                                          rollout=16)

test_loader = DataLoader(burgers_test_iterator, batch_size=32, shuffle=False)

model.eval()

  burgers_test_iterator = DynabenchIterator(split="test",


NeuralPDE_ssm(
  (ssm): MambaPDE(
    (input_layer): Linear(in_features=225, out_features=200, bias=True)
    (mamba_tower): MambaTower(
      (blocks): ModuleList(
        (0-2): 3 x MambaBlock(
          (ssm): Mamba(
            (embedding): Embedding(40000, 200)
            (layers): ModuleList(
              (0): ResidualBlock(
                (mixer): MambaBlock_1(
                  (in_proj): Linear(in_features=200, out_features=800, bias=False)
                  (conv1d): Conv1d(400, 400, kernel_size=(4,), stride=(1,), padding=(3,), groups=400)
                  (x_proj): Linear(in_features=400, out_features=45, bias=False)
                  (dt_proj): Linear(in_features=13, out_features=400, bias=True)
                  (out_proj): Linear(in_features=400, out_features=200, bias=False)
                )
                (norm): RMSNorm()
              )
            )
            (norm_f): RMSNorm()
            (lm_head): Linear(in_features=200, out_features=40000, bias=False)
  

In [13]:
loss_values = []
for i, (x, y, p) in enumerate(test_loader):
    x, y = x[:,0].float().to(device), y.float().to(device) # only use the first channel and convert to float32
    y_pred = model(x, t_eval=range(17))
    loss = criterion(y_pred, y)
    loss_values.append(loss.item())
    print(i)

0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
27

In [14]:
print(f"Mean Loss: {sum(loss_values) / len(loss_values)}")

Mean Loss: 0.15976723989688274
