In [1]:
import numpy as np

import torch
import torch.nn.init as init
from torch.autograd import grad

from scipy.integrate import solve_ivp

from modules.utils import FeedForwardNetwork, plot_ode, plot_losses, rmse
from modules.problems import DampedHarmonicOscillator

In [None]:
def train(
    problem,
    model,
    alpha, beta, N_F,
    num_iters, lr,
    print_every=1000, collect_every=1000
):
    collocation_t = torch.linspace(0, problem.T, N_F, requires_grad=True).reshape(-1, 1)
    test_points = torch.linspace(0, problem.T, 128).reshape(-1, 1)
    
    losses = []
    
    optim = torch.optim.Adam(model.parameters(), lr=lr)
    
    for i in range(0, num_iters + 1):
        
        optim.zero_grad()
        
        L_I = problem.loss_initial(model)
        L_F = problem.loss_physical(model, collocation_t)
    
        L = alpha * L_I + beta * L_F
        
        L.backward()
        optim.step()
        
        if i % print_every == 0 and i > 0:
            predicts = model(test_points).flatten().detach().numpy()
            print(f'Iteration {i} --- Loss {L.item()} --- RMSE {rmse(predicts, problem.numerical_solution)}')
            
        if collect_every > 0 and i % collect_every == 0:
            predicts = model(test_points).flatten().detach().numpy()
            losses.append(np.array([L.item(), rmse(predicts, problem.numerical_solution)]))
    
    return np.array(losses)

In [None]:
T = 10
zeta, omega = 0.2, 2.0
x_0, v_0 = 5.0, 7.0
problem = DampedHarmonicOscillator(T, (zeta, omega), (x_0, v_0))

L, W = 2, 64
model = FeedForwardNetwork(2, 64)

alpha, beta = 1.0, 0.5
N_F = 256
num_iters, lr = 2500, 1e-3

losses, errors = train(
    problem=problem,
    model=model,
    alpha=alpha, beta=beta, N_F=N_F,
    num_iters=num_iters, lr=lr,
    print_every=2500, collect_every=500
)

In [None]:
predictions = model(problem.t.reshape(-1, 1)).detach().flatten().numpy()
plot_ode(
    problem.t.numpy(), 
    predicted=[(predictions, 'Neural Network')], 
    solutions=[(problem.numerical_solution, 'Numerical Solution')], 
    size=(5, 3)
    )

In [None]:
plot_losses(
    t=np.arange(0, num_iters+1, 500),
    losses=[(losses, 'Loss Value')],
    errors=[(errors, 'RMSE')]
    )