In this file, we are going to show how to train your own MSPC-LSTM model.

# Prepare Data

The base of MSPC-LSTM is multiple layers AE/CAEs. To train the multiple layers CAEs, 4 datasets are required: high-fidelity training data, high-fidelity testing data, low-fidelity training data and low-fidelity testing data. 

In [None]:
HF_train_data_path = 'HF_train_data_path'
HF_test_data_path = 'HF_test_data_path'
LF_train_data_path = 'LF_train_data_path'
LF_test_data_path = 'LF_test_data_path'

Then we can load data by customized dataloader.py to get specific dataloader.
You can use the dataloader.py by "from models import dataloader" 

In [None]:
from models import dataloader
# for high fidelity cae
HF_CAE_train_loader = dataloader.load_data(HF_train_data_path,model='HighFidelityCAE', batch_size=20, shuffle=True)
# for the test data, we can set test=True to get tensor rather than a dataloader
HF_CAE_test_data = dataloader.load_data(HF_test_data_path, model='HighFidelityCAE', test=True)

The low-fidelity AE/CAEs should not only accept high-fidelity data but also low-fidelity data. We also set a specific way to generate dataloader for low-fidelity cae training.

In function $\text{dataloader.load\_data}$, when the $\text{model='LowFidelityCAE'}$, it can accept an extra parameter $\text{low\_fidelity\_path}$ to get the low-fidelity data.

In [None]:
LF_CAE_train_loader = dataloader.load_data(HF_train_data_path, model='LowFidelityCAE', low_fidelity_path=LF_train_data_path, batch_size=20, shuffle=True)
LF_test_data, HF_test_data = dataloader.load_data(HF_test_data_path, model='LowFidelityCAE', low_fidelity_path=LF_test_data_path, batch_size=20, shuffle=True, test=True)

Besides the multiple layers CAE, LSTM is another part of the MSPC-LSTM. In this model, LSTM is used to do Seq2Seq predcition. However, in order to reduce the memory requirements for model training, we first batch the data and then sequence it during the training process, so please set $\text{model='LSTM'}$ in the dataloader generated during training.

In [None]:
LSTM_train_loader = dataloader.load_data(HF_train_data_path, model='LSTM', batch_size=20, shuffle=True)

About the sequence length, we can set it in training function.

# Initialize MSPC-LSTM

To initialize a MSPC-LSTM, you need to prepare two AEs and one LSTM model first, which is customized for your data. There we put the example of shallow water model, whose shape of high-fidelity data is (3,64,64) and low-fidelity data is (3,32,32).

In [None]:
# Structure of two CAEs
import torch
import torch.nn as nn
import torch.nn.functional as F

class Encoder(nn.Module):
    def __init__(self):
        super(Encoder, self).__init__()
        self.conv1 = nn.Conv2d(3, 32, 3, stride=2, padding=1)  # output size: 32x32
        self.conv2 = nn.Conv2d(32, 16, 3, stride=2, padding=1)  # output size: 16x16
        self.dense = nn.Linear(16 * 16 * 16, 512)  # Added dense layer

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = x.view(x.size(0), -1)  # Flatten the input
        x = self.dense(x) # F.relu(self.dense(x))  # Apply dense layer
        return x

class Decoder(nn.Module):
    def __init__(self):
        super(Decoder, self).__init__()
        self.dense = nn.Linear(512, 16 * 16 * 16)  # Added dense layer
        self.t_conv1 = nn.ConvTranspose2d(16, 32, 3, stride=2, padding=1, output_padding=1)  # output size: 32x32
        self.t_conv2 = nn.ConvTranspose2d(32, 3, 3, stride=2, padding=1, output_padding=1)  # output size: 64x64

    def forward(self, x):
        x = self.dense(x) # F.relu(self.dense(x))  # Apply dense layer
        x = x.view(x.size(0), 16, 16, 16)  # Reshape the input
        x = F.relu(self.t_conv1(x))
        x = self.t_conv2(x)  # No activation here
        return x

class CAE(nn.Module):
    def __init__(self):
        super(CAE, self).__init__()
        self.encoder = Encoder()
        self.decoder = Decoder()

    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x

class low_Encoder(nn.Module):
    def __init__(self):
        super(low_Encoder, self).__init__()
        self.conv1 = nn.Conv2d(3, 16, 3, stride=2, padding=1)  # output size: 16x16
        self.conv2 = nn.Conv2d(16, 16, 3, stride=2, padding=1)  # output size: 8x8
        self.dense = nn.Linear(16 * 8 * 8, 512)  # Added dense layer

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = x.view(x.size(0), -1)  # Flatten the input
        x = self.dense(x) # F.relu(self.dense(x))  # Apply dense layer
        return x

class low_Decoder(nn.Module):
    def __init__(self):
        super(low_Decoder, self).__init__()
        self.dense = nn.Linear(512, 16 * 8 * 8)  # Added dense layer
        self.t_conv1 = nn.ConvTranspose2d(16, 16, 3, stride=2, padding=1, output_padding=1)  # output size: 16x16
        self.t_conv2 = nn.ConvTranspose2d(16, 3, 3, stride=2, padding=1, output_padding=1)  # output size: 32x32

    def forward(self, x):
        x = self.dense(x) # F.relu(self.dense(x))  # Apply dense layer
        x = x.view(x.size(0), 16, 8, 8)  # Reshape the input
        x = F.relu(self.t_conv1(x))
        x = self.t_conv2(x)  # No activation here
        return x
    
class low_CAE(nn.Module):
    def __init__(self):
        super(CAE, self).__init__()
        self.encoder = low_Encoder()
        self.decoder = low_Decoder()

    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x

In [None]:
# Structure of LSTM
class Seq2Seq(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
        super(Seq2Seq, self).__init__()

        self.hidden_dim = hidden_dim
        self.num_layers = num_layers

        # LSTM layer
        self.lstm = nn.LSTM(input_dim, hidden_dim, num_layers, batch_first=True)

        # output layer
        self.fc = nn.Linear(hidden_dim, output_dim)

    def forward(self, x):
        # initialize hidden state and cell state
        h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_dim).to(x.device)
        c0 = torch.zeros(self.num_layers, x.size(0), self.hidden_dim).to(x.device)

        out, _ = self.lstm(x, (h0, c0))
        out = self.fc(out)  # we only want the last 3 time steps

        return out

After constructing the structure of the models, you can initialize them.

In [None]:
HF_CAE = CAE()
LF_CAE = low_CAE()
LSTM = Seq2Seq()

Now, you can initialize the MSPC-LSTM directly or set a subclass of MSPC-LSTM class as shown below.

In [None]:
from models import model
# initialize MSPC-LSTM directly
MSPC_LSTM = model.MSPC_LSTM(HF_CAE, LF_CAE, LSTM)

If you initialize the MPSC-LSTM model directly, we have equipped the class with basic training fucntion of high-fidelity cae, low-fidelity cae and lstm. You can call them by $\text{train\_HFCAE, train\_LFCAE, train\_LSTM}$. After training, you can use function $\text{predict}$ to get prediction.

However, if you want to embed your own physical constraints into the model, you need to initialize a subclass of MSPC-LSTM, and add the specific physical-constraint functions and LSTM training precess. There is the example of shallow water problem embedded with energy conservation and flow operator physical constraints.

In [None]:
class SW_MSPC_LSTM(model.MSPC_LSTM):
    
    def __init__(self, HF_CAE, LF_CAE, LSTM, device=None):
        if device is None:
            self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        else:
            self.device = device
            
        self.HF_CAE = HF_CAE.to(self.device)
        self.LF_CAE = LF_CAE.to(self.device)
        self.LSTM = LSTM.to(self.device)
        self.HF_encoder = self.HF_CAE.encoder
        self.HF_decoder = self.HF_CAE.decoder
        self.LF_encoder = self.LF_CAE[0]
        self.LF_decoder = self.LF_CAE[1]

    def train_LSTM_with_PC(self, train_loader, test_data, energy_coef=None, fo_coef=None, decoder_type=None, batch_size=20, num_epochs=30, lr=0.0001, criterion_type='mse'):
        # Coerce coefficients to lists if they are not
        if isinstance(energy_coef, (int, float)):
            energy_coef = [energy_coef] * num_epochs
        if isinstance(pde_coef, (int, float)):
            pde_coef = [pde_coef] * num_epochs
        # Load loss function
        criterion = load_loss_function(criterion_type)

        # Optimizers for the LSTM
        optimizer = torch.optim.Adam(self.LSTM.parameters(), lr=lr)
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.8)

        # Set test dataset
        b, s, c, w, h = test_data.shape
        test_compr = self.HF_encoder(test_data.view(-1, c, w, h))
        test_dataset = dataloader.load_data(test_compr.reshape(b,s,-1), model='Seq2Seq', lookback=3, lookahead=3)

        train_losses = []
        test_losses = []

        for epoch in range(num_epochs):
            train_epoch_loss, test_epoch_loss = 0., 0.

            # Training
            self.LSTM.train()
            for i, train_batch in enumerate(train_loader):
                b, s, c, h, w = train_batch.shape
                train_compr = self.HF_encoder(train_batch.view(-1, c, w, h))
                train_compr_dataset = dataloader.load_data(train_compr.reshape(b, s, -1).cpu().detach().numpy(), model='seq2seq', lookback=3, lookahead=3)
                for j in range(train_compr_dataset.__len__()):
                    inputs, targets = train_compr_dataset[j]
                    inputs = torch.tensor(inputs).to(self.device)
                    targets = torch.tensor(targets).to(self.device)

                    outputs = self.LSTM(inputs)

                    mse_loss = criterion(outputs, targets)
                    total_loss = mse_loss
                    if energy_coef is not None:
                        energy_loss = torch.abs(torch.mean((self.calculate_total_energy(inputs, decoder_type) -  #, dx=0.02
                                                    torch.mean(self.calculate_total_energy(outputs, decoder_type), dim=1, keepdim=True)))) # , dx=0.02
                        total_loss += energy_coef[epoch]*energy_loss
                    if fo_coef is not None:
                        fo_loss = self.compute_evolve_loss(inputs, outputs, decoder_type)
                        total_loss += fo_coef[epoch]*fo_loss

                    train_epoch_loss += total_loss.item()/len(train_loader)/len(train_compr_dataset)

                    optimizer.zero_grad()
                    total_loss.backward()
                    optimizer.step()

            # Testing
            self.LSTM.eval()
            with torch.no_grad():
                for j in range(test_dataset.__len__()):
                    test_inputs, test_targets = test_dataset[j]
                    test_inputs = test_inputs.to(self.device)
                    test_targets = test_targets.to(self.device)

                    test_outputs = self.LSTM(test_inputs)

                    test_mse_loss = criterion(test_outputs, test_targets)
                    test_total_loss = test_mse_loss

                    if energy_coef is not None:
                        test_energy_loss = torch.abs(torch.mean((self.calculate_total_energy(test_inputs, decoder_type) -  #, dx=0.02
                                                    torch.mean(self.calculate_total_energy(test_outputs, decoder_type), dim=1, keepdim=True)))) # , dx=0.02
                        test_total_loss += energy_coef[epoch]*test_energy_loss
                    if fo_coef is not None:
                        test_fo_loss = self.compute_evolve_loss(test_inputs, test_outputs, decoder_type)
                        test_total_loss += fo_coef[epoch]*test_fo_loss

                    test_epoch_loss += test_total_loss.item()/len(test_data)

            # Print the averaged loss per epoch
            print ('Epoch [{}/{}], MSE_Loss: {:.6f}, Total_Loss: {:.6f}, Test_MSE_Loss: {:.6f}, Test_Total_Loss: {:.6f}'
                .format(epoch+1, num_epochs, mse_loss.item(), train_epoch_loss, test_mse_loss.item(), test_epoch_loss))

            train_losses.append(train_epoch_loss)
            test_losses.append(test_epoch_loss)

    def calculate_total_energy(self, data_compr, decoder_type):
        # Set decoder for physcial constraints
        if decoder_type == 'high':
            decoder = self.HF_decoder
            dx = 0.01
        else:
            decoder = self.LF_decoder
            dx = 0.02

        g = 9.8

        N, Nt, latent = data_compr.shape
        energies = torch.zeros(N, Nt)
        # reconstruct the data
        data = decoder(data_compr.reshape(N*Nt, latent))
        data = data.reshape(N, Nt, *data.shape[-3:])
        for i in range(N):
            for j in range(Nt):
                # Calculate kinetic energy
                kinetic_energy = 0.5 * (torch.sum(data[i][j][0]**2) + torch.sum(data[i][j][1]**2)) * dx**2

                # Calculate potential energy
                potential_energy = torch.sum(0.5 * g * data[i][j][2]**2) * dx**2

                # Calculate total energy
                energies[i, j] = kinetic_energy + potential_energy

        return energies
    
    def dxy(self, A, dx, axis=0):
        return (roll(A, -1, axis) - roll(A, 1, axis)) / (dx*2.)

    def d_dx(self, A, dx):
        return self.dxy(A, dx, 2)

    def d_dy(self, A, dx):
        return self.dxy(A, dx, 1)

    def d_dt(self, h, u, v, dx):
        for x in [h, u, v]:
            assert isinstance(x, ndarray) and not isinstance(x, matrix)
        g, b = 1., 0.2
        du_dt = -g*self.d_dx(h, dx) - b*u
        dv_dt = -g*self.d_dy(h, dx) - b*v
        H = 0 
        dh_dt = -self.d_dx(u * (H+h), dx) - self.d_dy(v * (H+h), dx)
        return dh_dt, du_dt, dv_dt
    
    def evolve(self, h, u, v, dt=0.0001, dx=0.01):
        dh_dt, du_dt, dv_dt = self.d_dt(h, u, v, dx)
        h += dh_dt * dt
        u += du_dt * dt
        v += dv_dt * dt
        return h, u, v

    def compute_evolve_loss(self, input_data_compr, outputdata_compr, decoder_type, dt=0.0001):
        if decoder_type == 'high':
            decoder = self.HF_decoder
            dx = 0.01
        else:
            decoder = self.LF_decoder
            dx = 0.02

        N, Nt_in, latent = input_data_compr.shape
        N, Nt_out, latent = input_data_compr.shape
        input_images = decoder(input_data_compr.reshape(N*Nt_in, latent))
        input_images = input_images.reshape(N, Nt_in, *input_images.shape[-3:])
        output_images = decoder(outputdata_compr.reshape(N*Nt_out, latent))
        output_images = output_images.reshape(N, Nt_out, *output_images.shape[-3:])
        u = input_images[:,-1,0,:].cpu().detach().numpy()
        v = input_images[:,-1,1,:].cpu().detach().numpy()
        h = input_images[:,-1,2,:].cpu().detach().numpy()
        output_images_evolved = torch.empty((N, Nt_out, *output_images.shape[-3:]))
        for i in range(Nt_out):
            h, u, v = self.evolve(h, u, v, dt, dx)
            output_image_evolved = torch.stack((torch.tensor(u), torch.tensor(v), torch.tensor(h)), dim=1)
            output_images_evolved[:, i, :, :, :] = output_image_evolved
        output_images_evolved = output_images_evolved.to(self.device)
        # calculate the asymmetry
        FO_loss = nn.MSELoss()(output_images_evolved, output_images)
        return FO_loss

In [None]:
sw_MSPC_LSTM = SW_MSPC_LSTM(HF_CAE, LF_CAE,LSTM)

To clarify, the CAEs training processes is general, so there is no need to overwrite it again in subclass.

# Training Processes

As we have prepared the dataset and the model, now we are going to train the model.

The first step is training the high-fidelity CAE by function $\text{train\_HFCAE}$.

In [None]:
sw_MSPC_LSTM.train_HFCAE(HF_CAE_train_loader, HF_test_data, num_epochs=30, lr=0.0001, criterion_type='mse')

The low-fidelity CAE can be trained after the high-fidelity CAE.

In [None]:
sw_MSPC_LSTM.train_LFCAE(LF_CAE_train_loader, LF_test_data, HF_test_data, num_epochs=30, lr=0.0001, criterion_type='mse')

After training the multiple layers CAEs, we can train the LSTM now.

In [None]:
sw_MSPC_LSTM.train_LSTM_with_PC(HF_CAE_train_loader, HF_test_data, energy_coef=0.0001, fo_coef=0.0001, decoder_type='low', batch_size=20, num_epochs=30,lr=0.0001, criterion_type='mse')

Here, the $\text{decoder\_type}$ means that what decoder you want to apply on the physcial constraints. The defalut decoder is low-fidelity deocoder, you can set $\text{decoder\_type = 'high'}$ to apply high-fidelity decoder.

At this point, the MSPC-LSTM model is all trained, and you can use the predict function to make predictions. At the same time, if you have trained CAEs and LSTMs, you can also directly pass them in without going through the training process.

For specific call predictions and accompanying plot usage, see the specific prediction sample file: 'Burgers'_equation.ipynb' and 'Shallow_Water.ipynb'.