In [ ]:
# (C) Copyright IBM Corp. 2019, 2020, 2021, 2022.

#    Licensed under the Apache License, Version 2.0 (the "License");
#    you may not use this file except in compliance with the License.
#    You may obtain a copy of the License at

#           http://www.apache.org/licenses/LICENSE-2.0

#     Unless required by applicable law or agreed to in writing, software
#     distributed under the License is distributed on an "AS IS" BASIS,
#     WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#     See the License for the specific language governing permissions and
#     limitations under the License.


In [None]:
import os
import numpy as np
import random
import matplotlib.pyplot as plt
from scipy.integrate import odeint
import matplotlib.pyplot as plt

os.environ["engine"] = "pytorch"

from simulai.regression import DenseNetwork
from simulai.models import DeepONet
from simulai.optimization import Optimizer
from simulai.metrics import L2Norm
from simulai.io import IntersectingBatches

In [None]:
def project_to_interval(interval, data):

    return interval[1]*(data - data.min())/(data.max() - data.min()) + interval[0]

In [None]:
class LotkaVolterra:
    
    def __init__(self, alpha=None, beta=None, gamma=None, delta=None):
        
        self.alpha = alpha
        self.beta = beta
        self.gamma = gamma
        self.delta = delta
      
    def eval(self, state:np.ndarray=None, t:float=None) -> np.ndarray:
        
        x = state[0]
        y = state[1]
        
        x_residual = self.alpha*x - self.beta*x*y
        y_residual = self.delta*x*y - self.gamma*y 
        
        return np.array([x_residual, y_residual])
    
    def run(self, initial_state, t):
            
        solution = odeint(self.eval, initial_state, t)

        return np.vstack(solution)

In [None]:
alpha = 1.1
beta = 0.4
gamma = 0.4
delta = 0.1
dt = 0.01
T_max = 150
n_samples = int(T_max/dt)
train_fraction = 0.8
n_samples_train = int(train_fraction*n_samples)
delta_t = 10
n_chunk_samples = 10

In [None]:
t = np.arange(0, T_max, dt)

initial_state = np.array([20, 5])

solver = LotkaVolterra(alpha=alpha, beta=beta, gamma=gamma, delta=delta)

data = solver.run(initial_state, t)

In [None]:
batcher = IntersectingBatches(skip_size=1, batch_size=int(delta_t/dt))

time_chunks_ = batcher(input_data=t[:n_samples_train])
data_chunks = batcher(input_data=data[:n_samples_train])

T_max_train = n_samples_train*dt

time_aux = [t[(t >= i) & (t <= i + delta_t)] for i in np.arange(T_max_train, T_max, delta_t)]
data_aux = [data[(t >= i) & (t <= i + delta_t)] for i in np.arange(T_max_train, T_max, delta_t)]

initial_states = [chunk[0] for chunk in data_chunks]

time_chunks = [project_to_interval([0, 1], chunk)[:, None] for chunk in time_chunks_]

time_chunks_train = list()
data_chunks_train = list()

for i in range(len(time_chunks)):
    
    indices = sorted(np.random.choice(time_chunks[i].shape[0], n_chunk_samples))
    time_chunks_train.append(time_chunks[i][indices])
    data_chunks_train.append(data_chunks[i][indices])

initial_states_train = initial_states

time_chunks_test = [project_to_interval([0, 1], chunk)[:, None] for chunk in time_aux]
data_chunks_test = data_aux
initial_states_test = [chunk[0] for chunk in data_aux]

In [None]:
branch_input_train = np.vstack([np.tile(init[None,:], (time_chunk.shape[0], 1))
                                for init, time_chunk in zip(initial_states_train, time_chunks_train)])

branch_input_test = np.vstack([np.tile(init, (time_chunk.shape[0], 1))
                               for init, time_chunk in zip(initial_states_test, time_chunks_test)])

trunk_input_train = np.vstack(time_chunks_train)
trunk_input_test = np.vstack(time_chunks_test)

output_train = np.vstack(data_chunks_train)
output_test = np.vstack(data_chunks_test)

In [None]:
n_inputs = 1
n_outputs = 2

lambda_1 = 0.0  # Penalty for the L¹ regularization (Lasso)
lambda_2 = 1e-4  # Penalty factor for the L² regularization
n_epochs = 2_000  # Maximum number of iterations for ADAM
lr = 1e-3  # Initial learning rate for the ADAM algorithm
n_latent = 100

# Configuration for the fully-connected trunk network
trunk_config = {
                'layers_units': 3*[100],  # Hidden layers
                'activations': 'relu',
                'input_size': 1,
                'output_size': n_latent*n_outputs,
                'name': 'trunk_net'
               }

# Configuration for the fully-connected branch network
branch_config = {
                'layers_units': 3*[100],  # Hidden layers
                'activations': 'relu',
                'input_size': n_outputs,
                'output_size': n_latent*n_outputs,
                'name': 'branch_net',
                }

In [None]:
# Instantiating and training the surrogate model
trunk_net = DenseNetwork(**trunk_config)
branch_net = DenseNetwork(**branch_config)

In [None]:
optimizer_config = {'lr': lr}

# Maximum derivative magnitudes to be used as loss weights
maximum_values = (1/np.linalg.norm(output_train, 2, axis=0)).tolist()

params = {'lambda_1': lambda_1, 'lambda_2': lambda_2, 'weights': maximum_values}

# It prints a summary of the network features
trunk_net.summary()
branch_net.summary()

input_data = {'input_branch': branch_input_train, 'input_trunk': trunk_input_train}

lorenz_net = DeepONet(trunk_network=trunk_net,
                      branch_network=branch_net,
                      var_dim=n_outputs,
                      model_id='lotka_volterra_net')

In [None]:
optimizer = Optimizer('adam', params=optimizer_config)

optimizer.fit(op=lorenz_net, input_data=input_data, target_data=output_train,
              n_epochs=n_epochs, loss="wrmse", params=params, device='gpu')

In [None]:
approximated_data = lorenz_net.eval(trunk_data=trunk_input_test, branch_data=branch_input_test)

l2_norm = L2Norm()

error = 100*l2_norm(data=approximated_data, reference_data=output_test, relative_norm=True)

for ii in range(n_inputs):

    plt.plot(approximated_data[:, ii], label="Approximated")
    plt.plot(output_test[:, ii], label="Exact")
    plt.legend()
    plt.savefig(f'lorenz_deeponet_time_int_{ii}.png')
    plt.show()

print(f"Approximation error for the variables: {error} %")