In [None]:
from pinntorch import *
from functools import partial
import matplotlib
from matplotlib.transforms import Bbox
from matplotlib.ticker import FuncFormatter

In [None]:
K = 5.0

In [None]:
def exact_solution_log(x):
    return 1/(1+torch.exp(-torch.Tensor(K*x)))

# exact solution in NumPy: This one is needed for the loss function becasue somehow the tensor form does not work as of now.
def exact_solution_log_np(x):
    return 1/(1+np.exp(-K*x))

def create_noisy_data(x, std_dev, noise_seed = 42):
    exact = exact_solution_log(x)
    torch.manual_seed(noise_seed)
    
    return exact + torch.randn(exact.size())*std_dev 

def data_loss(model: PINN, data: torch.Tensor = None, x: torch.Tensor = None) -> torch.float:
    return (f(model, x) - data).pow(2).mean()

def physics_loss(model: PINN, x: torch.Tensor = None) -> torch.float:
    # define PDE loss
    pde_loss_pre = df(model, x) - K*f(model, x)*(1 - f(model, x))
    pde_loss = pde_loss_pre.pow(2).mean()
    
    # define conditional losses (initial + boundary)
    boundary_loss_right_pre = (f(model, at(+1.0)) - exact_solution_log_np(+1)) 
    boundary_loss_right = boundary_loss_right_pre.pow(2).mean()

    # combine all losses
    final_loss = pde_loss + boundary_loss_right
    
    return final_loss


def total_loss(model: PINN, data: torch.Tensor = None, x: torch.Tensor= None, alpha: torch.float= None) -> list:

    """adds the physics and the data loss with coefficients alpha and (1-alpha) respectively"""

    loss_data = data_loss(model,data, x)

    loss_physics = physics_loss(model, x)

    return alpha*loss_data + (1 - alpha)* loss_physics

def custom_color_normalize(value):
    return value**80
    

def find_last_threshold_index(physics, validation, threshold = 0.0):
    over = (validation - physics) > threshold
    reversed_arr = np.flipud(over)
    index = len(over) - np.where(reversed_arr == False)[0][0] - 1
    return index

def get_best_indices(physics, validation, threshold = 0.0):
    best_indices = []
    for i in range(len(physics)):
        best_indices.append(find_last_threshold_index(physics[i], validation[i], threshold))
    return best_indices

In [None]:
def scalar_training(settings, input_data, train_points, val_points):

    L_p = []
    L_d = []
    L_VAL = []
    LR = []
    
    models_trained = []
    for i, alpha in enumerate(settings['alphas']):
        print("i:", i, "alpha:", alpha)
        loss_fn = partial(total_loss,data = input_data, x=train_points, alpha = alpha)  # For each alpha we need a loss function with different alpha. 
             
        torch.manual_seed(settings['model_seed'])
        model = PINN(1, 3, 9, 1)


        callbacks = [AllDataMonitor(
            partial(data_loss, data=input_data, x=train_points), 
            partial(physics_loss, x=train_points), 
            partial(physics_loss, x=val_points))]


        trained_model = train_model(
            model = model, 
            loss_fn=loss_fn,
            max_epochs = settings['epochs'],
            lr_decay=2e-2,
            optimizer_fn = partial(torch.optim.Adam, lr=settings['start_learning_rate']),
            epoch_callbacks = callbacks
        )

        L_p.append(np.array(callbacks[0].physics_history))
        L_d.append(np.array(callbacks[0].data_history))
        LR.append(np.array(callbacks[0].lr_history))
        L_VAL.append(np.array(callbacks[0].val_history))
        models_trained.append(trained_model)

    return L_p, L_d, LR, L_VAL, models_trained

In [None]:
settings = {}

# create alpha list
alphas = 1-torch.logspace(start=-2, end=0.0, steps=20, base=80)


settings['n_train_points'] = 20
settings['n_val_points'] = 39
settings['noise_level'] = 0.1
settings['epochs'] = 10000
settings['alphas'] = alphas.cpu()
settings['noise_seed'] = 123
settings['model_seed'] = 111
settings['start_learning_rate'] = 0.01

training_points = generate_grid((settings['n_train_points']), (-1.0,1.0))
validation_points = generate_grid((settings['n_val_points']), (-1.0,1.0))

data_noise = create_noisy_data(training_points, settings['noise_level'], noise_seed=settings['noise_seed'])

In [None]:
Loss_physics, Loss_data, LR_evolution, Loss_val, models_trained = scalar_training(settings, data_noise, training_points, validation_points)

In [None]:
def tick_formatter(val, pos):
    skip_indices = [1,2,3,5,15,16,17,18,19,20]  # alpha indices to skip showing
    if pos in skip_indices:
        return ''
    elif pos == 1:
        return '('+str(1.0)+')'
    return alphas[pos]

def plot_pareto_front(L_D, L_P, data_color, xtick_rotation=0, file_name='log_pareto', bbox_bounds = (0.1, -0.1, 5.2, 3.8)):
    cmap = matplotlib.colormaps['jet']
    plt.figure(figsize=(5,4))
    low_noise_points = plt.scatter(L_D, L_P, c=custom_color_normalize(data_color), cmap=cmap)  # Use 'viridis' colormap, but you can choose any other
    # Add colorbar for the z values
    cbar_lown = plt.colorbar(low_noise_points, ticks=custom_color_normalize(np.concatenate(([1.0], data_color))), format=FuncFormatter(tick_formatter))

    cbar_lown.set_label(r' $α$ (log scaled)', fontsize=13)
    
    plt.ylabel(r"$\mathcal{L}_\mathrm{PHYSICS}$", loc='center', fontsize=13)
    plt.xlabel(r"$\mathcal{L}_\mathrm{DATA}$", loc='center', fontsize=13)

    plt.grid()

    bbox_instance = Bbox.from_bounds(*bbox_bounds)
    plt.xticks(rotation=xtick_rotation)

    plt.savefig(file_name+'.png', dpi=600, bbox_inches=bbox_instance)
    plt.show()

In [None]:
alphas = np.array(alphas)

log_data = np.array(Loss_data) 
log_physics = np.array(Loss_physics) 
log_val = np.array(Loss_val)

In [None]:
# determine the converged/not overfitted indices for every model via threshold between physics and validation
best_indices = get_best_indices(log_physics, log_val, 0.001)

# extract data loss and physics loss according to the indices
log_best_data = [log_data[i][best_indices[i]] for i in range(len(log_data))]
log_best_physics = [log_physics[i][best_indices[i]] for i in range(len(log_physics))]

In [None]:
best_indices

In [None]:
plot_pareto_front(log_best_data, log_best_physics, alphas, 0, 'log_best', (0.1, -0.1, 5.2, 3.8))