In [None]:
from pinntorch import *
from functools import partial

# use GPU for faster training
torch.set_default_tensor_type('torch.cuda.FloatTensor')

In [None]:
k = 1.0
M = 5.0
T = 12.5

X_DOMAIN = (0.0, M)
T_DOMAIN = (0.0, T)

def exact_solution(x, t):
    """returns the exact solution given the IC and the BC"""
    return torch.sin(torch.pi*x/M)*torch.exp(-k*((torch.pi**2)/(M**2))*t)

### Let us first define some book keeping plotting functions

In [None]:
def plot_data(x, t, data, grid_shape):
    """
    Takes the domain points and the result on them to plot a 3D plot.
    """
    z = data
    color_map = cm.winter
    x = x.cpu().detach().numpy().reshape(grid_shape)
    t = t.cpu().detach().numpy().reshape(grid_shape)
    z = z.cpu().detach().numpy().reshape(grid_shape)

    # Set up plot
    fig, ax = plt.subplots(subplot_kw=dict(projection='3d'))
    ax.set(xlabel='x (location)', ylabel='t (time)', zlabel='function value')

    ls = LightSource(270, 45)

    surf = ax.plot_surface(x, t, z, cmap=color_map, linewidth=0, antialiased=False, shade=False)
    fig.colorbar(surf, shrink=0.5, aspect=5)

def plot_heatmap(x, t, data, grid_shape):
    """
    Takes the domain points and the result on them to plot a heat-map.
    """
    z = data
    color_map = plt.cm.winter
    x = x.cpu().detach().numpy().reshape(grid_shape)
    t = t.cpu().detach().numpy().reshape(grid_shape)
    z = z.cpu().detach().numpy().reshape(grid_shape)
    print(z)
    print((x.min(), x.max(), t.min(), t.max()))

    plt.imshow(z, cmap=color_map, aspect='auto', origin='upper', extent=(t.min(), t.max(), x.min(), x.max()))
    plt.colorbar(label='Function Value')
    plt.xlabel('t (time)')
    plt.ylabel('x (location)')
    plt.title('Heat Map of Function')
    plt.show()

def plot_pareto_front(L_D, L_P, data_color, cmap = 'virdis'):
    
    plt.scatter(L_D, L_P, c=data_color.detach().numpy(), cmap='viridis')  # Use 'viridis' colormap, but you can choose any other
    # Add colorbar for the z values
    cbar = plt.colorbar()
    cbar.set_label(r' $α$ Values')

    plt.xlabel(r"data loss ($L_d$)")
    plt.ylabel(r"physics loss ($L_p$)")
    plt.title("Multi-objective optimization L = $α.L_d + (1-α)L_p$")

    #plt.savefig("heat_pareto_unstructured.png")
    plt.show()

In [None]:

def data_loss(model: nn.Module, data: torch.Tensor = None, x: torch.Tensor = None, t:torch.Tensor = None) -> torch.float:
    """"Caculates the data loss"""
    u_n = f(model, x, t) # evaluating the model
    # MSE loss 
    diff = u_n - data    # u_exact + gaussian noise 
    
    loss = diff.pow(2).mean()

    return loss

def physics_loss(
    model: nn.Module, x: torch.Tensor = None, t: torch.Tensor = None
) -> torch.float:

    pde_loss_pre = df(model, x, t, wrt=1, order=1) - k*df(model, x, t, wrt=0, order=2)
    pde_loss = pde_loss_pre.pow(2).mean()
    

    t_raw = unique_excluding(t, 0.)
    x_left_boundary = fill_like(t_raw, 0.)
    x_right_boundary = fill_like(t_raw, M)
    x_raw = unique_excluding(x)
    t_zero = fill_like(x_raw, 0.)

    # dirichlet boundary conditions.  
    boundary_left = f(model, x_left_boundary, t_raw).pow(2).mean()
    boundary_right = f(model, x_right_boundary, t_raw).pow(2).mean()
    boundary_loss = boundary_left + boundary_right

    # initial
    initial_loss_pre = f(model, x_raw, t_zero) - torch.sin(np.pi/M * x_raw).reshape(-1, 1)
    initial_loss = initial_loss_pre.pow(2).mean()
    
    # together
    conditional_loss = boundary_loss + initial_loss
    
    final_loss = pde_loss + conditional_loss
    return final_loss

def val_loss(
    model: nn.Module, x: torch.Tensor, t: torch.Tensor) -> torch.float:

    pde_loss_pre = df(model, x, t, wrt=1, order=1) - k*df(model, x, t, wrt=0, order=2)
    pde_loss = pde_loss_pre.pow(2).mean()
    
    return pde_loss

def total_loss(model: nn.Module, data: torch.Tensor, x_data: torch.Tensor, t_data: torch.Tensor, x_physics: torch.Tensor, t_physics: torch.Tensor, alpha: torch.float) -> torch.float:

    loss_data = data_loss(model, data, x_data, t_data)

    loss_physics = physics_loss(model, x_physics, t_physics)

    total = alpha*loss_data + (1 - alpha)* loss_physics 
    
    return total  




In [None]:
def create_noisy_data(std_dev, exact_soln):

    """adds gaussian noise to the data"""

    return exact_soln + torch.randn(exact_soln.shape)*std_dev 

## Let us plot the exact solution

In [None]:
x_train_d, t_train_d = generate_grid((20,20), domain=(X_DOMAIN, T_DOMAIN))
x_train_p, t_train_p = generate_grid((20,20), domain=(X_DOMAIN, T_DOMAIN))
x_val, t_val = generate_grid((39,39), domain=(X_DOMAIN, T_DOMAIN))
x_plot, t_plot = generate_grid((200,200), domain=(X_DOMAIN, T_DOMAIN))

exact_soln = exact_solution(x_train_d, t_train_d)

plot_solution = exact_solution(x_plot, t_plot)
plot_heatmap(x_plot, t_plot, plot_solution, grid_shape=(200, 200))

In [None]:
settings = {}
settings['seed'] = 11373
torch.manual_seed(settings['seed'])
settings['n_train_points'] = 20
settings['n_val_points'] = 80

settings['noise_level'] = 0.1

input_data = create_noisy_data(settings['noise_level'], exact_soln)

plot_data(x_train_d, t_train_d, input_data, (20, 20))

settings['start_learning_rate'] = 0.003
learning_rate = settings['start_learning_rate']
epochs = 20_000

def custom_color_normalize(value):
    return value**80

alphas = 1-torch.logspace(start=-2, end=0.0, steps=20, base=80)

settings['alphas'] = alphas.cpu()

In [None]:
def scalar_training(alphas, input_data, x_train_d, t_train_d, x_train_p, t_train_p, x_val, t_val):
    L_p = []
    L_d = []
    L_VAL = []
    LR = []
    
    models_trained = []
    for i, alpha in enumerate(alphas):
        print("i:", i, "alpha:", alpha)
        
        loss_fn = partial(total_loss,data = input_data, x_data=x_train_d, t_data=t_train_d, x_physics=x_train_p, t_physics=t_train_p, alpha = alpha)
        
        torch.manual_seed(7245)
        model = PINN(2, 4, 50, 1, activation="fourier")
        
 #   def train_model(
 #   model: nn.Module,
 #   loss_fn: Callable,
 #   mo_method: Callable = None,
 #   optimizer_fn=torch.optim.Adam,
 #   max_epochs: int = 1_000,
 #   live_logging: bool = True,
 #   log_interval: int = 1_000,
 #   lr_decay = 0.0,
 #   parameter_groups: dict = None,
 #   epoch_callbacks: list = [],
 
 
# callbacks = [ValLRMonitor(training_points=train_points, validation_points=val_points, data_values=input_data)]
#        trained_model = train_model(
#            model = model, 
#            loss_fn=loss_fn,
#            max_epochs = settings['epochs'],
#            lr_decay=2e-2,
#            optimizer_fn = partial(torch.optim.Adam, lr=settings['start_learning_rate']),
#            epoch_callbacks = callbacks
#        )
    loss_data = data_loss(model, data, x_data, t_data)
    data_loss(model: nn.Module, data: torch.Tensor = None, x: torch.Tensor = None, t:torch.Tensor = None)

    loss_physics = physics_loss(model, x_physics, t_physics)
        callbacks = [AllDataMonitor(
            partial(data_loss, data=input_data, x=x_train_d, t=t_train_d), 
            partial(physics_loss, x=train_points), 
            partial(val_loss, x=x_val, t=t_val))]
        
        trained_model = train_model(
        model = model, 
        loss_fn=loss_fn,
        optimizer_fn=partial(torch.optim.Adam, lr=settings['start_learning_rate']),
        max_epochs = settings['epochs'],
        lr_decay = 2e-2,
        epoch_callbacks = callbacks)


        L_p.append(np.array(l_p))
        L_d.append(np.array(l_d))
        LR.append(np.array(lr))
        L_VAL.append(np.array(l_val))
        models_trained.append(trained_model)

    return L_p, L_d, LR, L_VAL, models_trained

### Training with exact solution

In [None]:
Loss_p, Loss_d, LR, Loss_VAL, models_trained = scalar_training(alphas, input_data, x_train_d, t_train_d, x_train_p, t_train_p, x_val, t_val)

In [None]:
run_name = 'heat_L1_k25_test'

result_dict = {
    "settings" : settings,
    "input_data": input_data.detach().cpu().numpy(),
    "loss_data": Loss_d,
    "loss_physics": Loss_p,
    "LR": LR,
    "loss_val": Loss_VAL
}

path = create_run_folder(run_name)
save_dictionary(path, run_name, result_dict)
save_models(path, models_trained)

In [None]:
save_dir = "heat_equation/trained_models"  
os.makedirs(save_dir, exist_ok=True) # TODO: Add a conditional statement to make sure that we are not overwriting the directory. 
def save_model(model, filename):
    filepath = os.path.join(save_dir, filename)
    torch.save(model, filepath)


In [None]:
#callbacks = [TrainLossMonitor(), TrueErrorMonitor(test_points, logistic_fn), SolutionMonitor(plot_points, training_points, store_every=20)]
def weighted_training(alphas, input_data):

    #TODO: output error_true_evolution 
    #TODO: 

    L_p = torch.zeros_like(alphas) #tensor to save physics loss for each \alpha 
    L_d = torch.zeros_like(alphas) #tensor to save data loss for each \alpha
    
    loss_train_evolution = np.zeros((len(alphas), epochs))  #tensor to save training evolution for each \alpha 
    #error_true_evolution = np.zeros((len(alphas), epochs))  #tensor to save error evolution for each \alpha

    for i, alpha in enumerate(alphas):
        
        loss_fn = partial(total_loss, data = input_data, x=x_train, t = t_train,  alpha = alpha)  # For each alpha we need a loss function with different alpha. 
        callbacks = [TrainLossMonitor()]  # TODO: Include True Error Monitor here (for now it is only written for one dimension but here we have two dimensions)
        model = PINN(2, 4, 6, 1)
        trained_model = train_model(
        model = model, 
        loss_fn = loss_fn, 
        learning_rate = learning_rate, 
        max_epochs = epochs, 
        optimizer_fn = torch.optim.Adam,
        epoch_callbacks = callbacks)

        train_loss_evolution = callbacks[0].train_loss_history
        #true_error_evolution = callbacks[1].mae_history

        loss_train_evolution[i, :] = train_loss_evolution
        #error_true_evolution[i, :] = true_error_evolution

        #calculating the physics losses and data losses from trained model with a loss function dependent on α
        l_p = physics_loss(trained_model, x_train, t_train)
        l_d = data_loss(trained_model, input_data, x_train, t_train)


        L_p[i] = l_p
        L_d[i] = l_d

        model_filename = f"model_heat{i}.pth"  #TODO: Saving the ith model does not give info about α. We need to save by α. So far, it is causing trouble.
        save_model(trained_model.state_dict(), model_filename)

    return L_p, L_d, loss_train_evolution


In [None]:
alphas1, alphas2 = torch.linspace(0, 0.95, 20), torch.linspace(0.95, 1, 50)
alphas_cat3 = torch.cat((alphas1, alphas2))

In [None]:
loss_physics_cat4, loss_data_cat4, loss_training_evolution_cat4 = weighted_training(alphas_cat3, data.reshape(-1, 1))

## Loading the saved models. 
By now, we have saved the models. Now is the time to load them again. First we need to initialize the model and then we can load the models that we  have saved. After loading, we want to try the following. 
* Plot the Eucledian norms of the models correcponding to each α. 
* Maybe we can do BGD starting from a particular α and see if we can minimize the loss further. 

In [None]:
#current_dir = os.getcwd()
#print(current_dir+"/heat_equation/trained_models")

In [None]:
def euclidean_norm(parameters1, parameters2):
    squared_diff_sum = 0.0
    for name in parameters1:
        diff = parameters1[name] - parameters2[name]
        squared_diff_sum += torch.sum(diff ** 2)
    return torch.sqrt(squared_diff_sum)

In [None]:
save_dir = "heat_equation/trained_models"  # directory containing trained models  

norm_list = []
for i in range(len(alphas_cat3)-1):
    model = PINN(2, 4, 6, 1)
    model_next = PINN(2, 4, 6, 1)
    
    model.load_state_dict(torch.load(os.getcwd()+"/"+save_dir+f"/model_heat{i}.pth"))
    model_next.load_state_dict(torch.load(os.getcwd()+"/"+save_dir+f"/model_heat{i+1}.pth"))

    w = model.state_dict()
    w_next = model_next.state_dict()

    norm_list.append(euclidean_norm(w, w_next))
    

In [None]:
len(norm_list)

In [None]:
plt.plot(alphas_cat3[:-1], norm_list)
plt.xlabel(r"$\alpha \in [0, 1]$")
plt.ylabel(r"$||w_i - w_{i+1}||_2$")
plt.title(r"$||w_i - w_{i+1}||_2$ Vs α")
plt.savefig("heat_norms.png")

In [None]:
plt.plot(alphas_cat3[30:-1], norm_list[30:])
plt.xlabel(r"$\alpha \in [0, 1]$")
plt.ylabel(r"$||w_i - w_{i+1}||_2$")
plt.title(r"$||w_i - w_{i+1}||_2$ Vs α")
plt.savefig("heat_norms1.png")