In [None]:
!pip install numpy matplotlib torch pandas pykan jax scikit-learn polars

In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
from kan import MultKAN
import os

lib = ['x','x^2','x^3','x^4','exp','log','sqrt','tanh','sin','abs']
EPOCHS = 150
LEARNING_RATE = 0.008
GRIDS = 100
SPLINE = 3

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


# Helper function to print comparison table
def print_comparison_table(t_eval_np, y_pred, y_exact_eval, example_number, output_dir, meanSe, meanAe):
    table_lines = []
    table_lines.append(f"\nComparison Table for Example {example_number}:")
    table_lines.append(f"{'t':<10}{'Network Output':<20}{'Exact Solution':<20}{'MAE':<10}{'MSE':<10}")
    table_lines.append("-" * 90)
    for t, pred, exact, mae, mse in zip(t_eval_np.flatten(), y_pred.flatten(), y_exact_eval.flatten(), np.abs(y_pred.flatten() - y_exact_eval.flatten()), (y_pred.flatten() - y_exact_eval.flatten())**2):
        table_lines.append(f"{t:<10.4f}{pred:<20.6f}{exact:<20.6f}{mae}{mse}")
    table_lines.append(f"MSE:{meanSe}")
    table_lines.append(f"MAE:{meanAe}")


    # Write table to file
    file_path = os.path.join(output_dir, f"comparison_table_example_{example_number}.txt")
    with open(file_path, "a") as f:
        f.write("\n".join(table_lines))
    print(f"Comparison table for Example {example_number} saved to {file_path}")

def save_plot(t_eval_np, y_pred, y_exact_eval, example_number, output_dir):
    plt.figure()
    plt.plot(t_eval_np, y_exact_eval, label='Exact Solution', color='g', marker='s', markersize=6, markerfacecolor='none', linestyle='--', linewidth=1)
    plt.plot(t_eval_np, y_pred, label='KAN Solution', color='b', marker='o', markersize=4, linestyle='-', linewidth=1)
    plt.legend()

    # Modify font style to italicize the title, labels, and legend
    plt.title(' ', fontstyle='italic')
    plt.xlabel('t', fontstyle='italic')
    plt.ylabel('y(t)', fontstyle='italic')

    plt.grid(True)

    # Save the figure
    file_path = os.path.join(output_dir, f"example_{example_number}_solution.png")
    plt.savefig(file_path, dpi=1800)
    plt.close()
    print(f"Plot for Example {example_number} saved to {file_path}")



def train_model_with_physics(model, epochs, optimizer, t_train, ode_loss, patience=100):
    # Initialize values
    best_loss = float('inf')
    counter = 0
    losses = []

    for epoch in range(epochs):
        def closure():
            optimizer.zero_grad()

            # Get model prediction (z(t))
            z_pred = model(t_train)

            # Calculate loss using ode_loss(t, z) function
            total_loss = ode_loss(t_train, z_pred)
            total_loss.backward()
            return total_loss

        loss = optimizer.step(closure)

        if loss.item() < best_loss:
            best_loss = loss.item()
            counter = 0
        else:
            counter += 1
        if counter >= patience:
            print(f"Early stopping at epoch {epoch}")
            break

        # Log and store loss
        if epoch % 10 == 0:
            print(f"Epoch {epoch}/{epochs}, Loss: {loss.item()}")
        losses.append(loss.item())

    # # Plot loss curve
    # plt.plot(losses)
    # plt.xlabel('Epochs')
    # plt.ylabel('Loss')
    # plt.title('Training Loss Curve')
    # plt.grid(True)
    # plt.show()
    model.plot(scale=1)
    plt.plot()
    plt.show()
    return model

def solve_example1(architecture, output_dir):
    # Initial condition values (y(a) = A | y(0) = 1)
    A = 1
    a = 0

    # Define model
    model = MultKAN(
        width=architecture,
        grid=GRIDS,
        k=SPLINE,
        device=device,
        mult_arity = 2,
        noise_scale = 0.1,
        scale_base_mu = 0,
        scale_base_sigma = 1,
        base_fun = torch.nn.SiLU(),
        symbolic_enabled = False, # for efficiency
        affine_trainable = True, # update (sub)node_scale, (sub)node_bias
        grid_eps = 1, # 1 - uniform grid, 0 - percentiles
        grid_range = [-5, 5],
        seed = 1107,
        sparse_init = False,
        first_init = True,
        )

    # Function for calculating loss
    # dy(t)/dt = f(t,y); => yt(t) = A + (t-a)z(t); z(t) - model output
    # Error = ((dyt(t)/dt - f(t,y))**2)/2
    def ode_loss(t, z):
        # Compute trial solution y
        y = A + (t - a) * z

        # Calculate dz(t)/dt using autograd
        dz_dt = torch.autograd.grad(z, t, grad_outputs=torch.ones_like(z), create_graph=True)[0]

        # Substitute into dy/dt = z(t) + (t - a) * dz(t)/dt
        dy_dt = z + (t - a) * dz_dt

        # Formulate the residual with dy/dt_actual
        f_term = (t + ((1 + 3 * t**2) / (1 + t + t**3))) * y
        g_term = 2 * t + t**3 + t**2 * ((1 + 3 * t**2) / (1 + t + t**3))
        return torch.mean(((dy_dt - (g_term - f_term))**2))/2

    # Training data
    t_train = torch.linspace(0, 1, 100, device=device).reshape(-1, 1).requires_grad_()
    # # Modify data for 2-input KAN (worse performance, do not pursue)
    # t_train = torch.cat([t_train, t_train], dim=1)
    # print(t_train)

    # Define optimizer
    optimizer = torch.optim.LBFGS(model.parameters(), lr=LEARNING_RATE, max_iter=20, tolerance_grad=1e-7, tolerance_change=1e-9, history_size=50)

    # Train model
    model = train_model_with_physics(model, epochs=EPOCHS, optimizer=optimizer, t_train=t_train, ode_loss=ode_loss)

    # Evaluate and plot
    t_eval = torch.linspace(0, 1, 21, device=device).reshape(-1, 1).requires_grad_()
    y_pred = A + (t_eval - a) * model(t_eval).detach()
    y_pred = y_pred.detach().to('cpu').numpy()
    t_eval_np = t_eval.detach().to('cpu').numpy()
    y_exact_eval = (np.exp(-t_eval_np**2 / 2) + t_eval_np**5 + t_eval_np**3 + t_eval_np**2) / (t_eval_np**3 + t_eval_np + 1)

    # Display Results and Comparison
    MSE = (torch.sum((torch.tensor(y_exact_eval) - y_pred)**2))/200
    MAE = (torch.sum((torch.abs(torch.tensor(y_exact_eval) - y_pred))))/200
    print(f"MSE for example 1: {MSE}")
    print(f"MAE for example 1: {MAE}")
    print_comparison_table(t_eval_np, y_pred, y_exact_eval, 1, output_dir, MSE, MAE)


    # # Extract the symbolic formula for z(t)
    # print(f"Symbolic Expression for KAN: {symbolic_expr}")
    # symbolic_expr = model.symbolic_formula()


    return t_eval_np, y_pred, y_exact_eval

def solve_example2(architecture, output_dir):
    # Initial condition values (y(a) = A | y(0) = 3)
    A = 3
    a = 0

    # Define model
    model = MultKAN(
        width=architecture,
        grid=GRIDS,
        k=SPLINE,
        device=device,
        mult_arity = 2,
        noise_scale = 0.1,
        scale_base_mu = 0,
        scale_base_sigma = 1,
        base_fun = torch.nn.SiLU(),
        symbolic_enabled = False, # for efficiency
        affine_trainable = True, # update (sub)node_scale, (sub)node_bias
        grid_eps = 1, # 1 - uniform grid, 0 - percentiles
        grid_range = [-5, 5],
        seed = 1107,
        sparse_init = False,
        first_init = True,
        )

    # Function for calculating loss
    # dy(t)/dt = f(t,y); => yt(t) = A + (t-a)z(t); z(t) - model output
    # Error = ((dyt(t)/dt - f(t,y))**2)/2
    def ode_loss(t, z):
        # Compute trial solution y
        y = A + (t - a) * z

        # Calculate dz(t)/dt using autograd
        dz_dt = torch.autograd.grad(z, t, grad_outputs=torch.ones_like(z), create_graph=True)[0]

        # Substitute into dy/dt = z(t) + (t - a) * dz(t)/dt
        dy_dt = z + (t - a) * dz_dt

        # Formulate the residual with dy/dt_actual
        f_term = 2 * y
        g_term = torch.cos(4 * t)
        return torch.mean(((dy_dt - (g_term - f_term))**2))/2

    # Training data
    t_train = torch.linspace(0, 3, 150, device=device).reshape(-1, 1).requires_grad_()

    # Define optimizer
    optimizer = torch.optim.LBFGS(model.parameters(), lr=LEARNING_RATE, max_iter=20, tolerance_grad=1e-7, tolerance_change=1e-9, history_size=50)

    # Train model
    model = train_model_with_physics(model, epochs=EPOCHS, optimizer=optimizer, t_train=t_train, ode_loss=ode_loss)

    # Evaluate and plot
    t_eval = torch.linspace(0, 3, 21, device=device).reshape(-1, 1).requires_grad_()
    y_pred = A + (t_eval - a) * model(t_eval).detach()
    y_pred = y_pred.detach().to('cpu').numpy()
    t_eval_np = t_eval.detach().to('cpu').numpy()
    y_exact_eval = (np.sin(4 * t_eval_np) / 5 + np.cos(4 * t_eval_np) / 10 + 2.9 * np.exp(-2 * t_eval_np))

    # Display Results and Comparison
    MSE = (torch.sum((torch.tensor(y_exact_eval) - y_pred)**2))/200
    MAE = (torch.sum((torch.abs(torch.tensor(y_exact_eval) - y_pred))))/200
    print(f"MSE for example 2: {MSE}")
    print(f"MAE for example 2: {MAE}")
    print_comparison_table(t_eval_np, y_pred, y_exact_eval, 2, output_dir, MSE, MAE)


    # # Extract the symbolic formula for z(t)
    # print(f"Symbolic Expression for KAN: {symbolic_expr}")
    # symbolic_expr = model.symbolic_formula()

    return t_eval_np, y_pred, y_exact_eval

def solve_example3(architecture, output_dir):
    # Initial condition values (y(a) = A | y(0) = 0.5)
    A = 0.5
    a = 0

    # Define model
    model = MultKAN(
        width=architecture,
        grid=GRIDS,
        k=SPLINE,
        device=device,
        mult_arity = 2,
        noise_scale = 0.1,
        scale_base_mu = 0,
        scale_base_sigma = 1,
        base_fun = torch.nn.SiLU(),
        symbolic_enabled = False, # for efficiency
        affine_trainable = True, # update (sub)node_scale, (sub)node_bias
        grid_eps = 1, # 1 - uniform grid, 0 - percentiles
        grid_range = [-5, 5],
        seed = 1107,
        sparse_init = False,
        first_init = True,
        )

    # Function for calculating loss
    # dy(t)/dt = f(t,y); => yt(t) = A + (t-a)z(t); z(t) - model output
    # Error = ((dyt(t)/dt - f(t,y))**2)/2
    def ode_loss(t, z):
        # Compute trial solution y
        y = A + (t - a) * z

        # Calculate dz(t)/dt using autograd
        dz_dt = torch.autograd.grad(z, t, grad_outputs=torch.ones_like(z), create_graph=True)[0]

        # Substitute into dy/dt = z(t) + (t - a) * dz(t)/dt
        dy_dt = z + (t - a) * dz_dt

        # Formulate the residual with dy/dt_actual
        g_term = y - t**2 + 1
        return torch.mean(((dy_dt - g_term)**2))/2

    # Training data
    t_train = torch.linspace(0, 2, 100, device=device).reshape(-1, 1).requires_grad_()

    # Define optimizer
    optimizer = torch.optim.LBFGS(model.parameters(), lr=LEARNING_RATE, max_iter=20, tolerance_grad=1e-7, tolerance_change=1e-9, history_size=50)

    # Train model
    model = train_model_with_physics(model, epochs=EPOCHS, optimizer=optimizer, t_train=t_train, ode_loss=ode_loss)

    # Evaluate and plot
    t_eval = torch.linspace(0, 2, 21, device=device).reshape(-1, 1).requires_grad_()
    y_pred = A + (t_eval - a) * model(t_eval).detach()
    y_pred = y_pred.detach().to('cpu').numpy()
    t_eval_np = t_eval.detach().to('cpu').numpy()
    y_exact_eval = (t_eval_np + 1)**2 - 0.5 * np.exp(t_eval_np)

    # Display Results and Comparison
    MSE = (torch.sum((torch.tensor(y_exact_eval) - y_pred)**2))/200
    MAE = (torch.sum((torch.abs(torch.tensor(y_exact_eval) - y_pred))))/200
    print(f"MSE for example 3: {MSE}")
    print(f"MAE for example 3: {MAE}")
    print_comparison_table(t_eval_np, y_pred, y_exact_eval, 3, output_dir, MSE, MAE)

    # # Extract the symbolic formula for z(t)
    # print(f"Symbolic Expression for KAN: {symbolic_expr}")
    # symbolic_expr = model.symbolic_formula()

    return t_eval_np, y_pred, y_exact_eval

def plot_results(t1, y1, y1_exact, t2, y2, y2_exact, t3, y3, y3_exact, architecture):
    print("Results for architecture: ", architecture)
    plt.figure(figsize=(18, 5))

    # Plot Example 1
    plt.subplot(1, 3, 1)
    plt.plot(t1, y1, 'b-', label='KAN Solution')
    plt.scatter(t1, y1, color='blue', s=15, label='KAN Points')  # Add points
    plt.plot(t1, y1_exact, 'g--', label='Exact Solution')
    plt.scatter(t1, y1_exact, color='green', s=15, label='Exact Points')  # Add points
    plt.title('Example 1: Solution')
    plt.xlabel('t')
    plt.ylabel('y(t)')
    plt.grid(True)
    plt.legend()

    # Plot Example 2
    plt.subplot(1, 3, 2)
    plt.plot(t2, y2, 'r-', label='KAN Solution')
    plt.scatter(t2, y2, color='red', s=15, label='KAN Points')  # Add points
    plt.plot(t2, y2_exact, 'g--', label='Exact Solution')
    plt.scatter(t2, y2_exact, color='green', s=15, label='Exact Points')  # Add points
    plt.title('Example 2: Solution')
    plt.xlabel('t')
    plt.ylabel('y(t)')
    plt.grid(True)
    plt.legend()

    # Plot Example 3
    plt.subplot(1, 3, 3)
    plt.plot(t3, y3, 'm-', label='KAN Solution')
    plt.scatter(t3, y3, color='magenta', s=15, label='KAN Points')  # Add points
    plt.plot(t3, y3_exact, 'c--', label='Exact Solution')
    plt.scatter(t3, y3_exact, color='cyan', s=15, label='Exact Points')  # Add points
    plt.title('Example 3: Solution')
    plt.xlabel('t')
    plt.ylabel('y(t)')
    plt.grid(True)
    plt.legend()

    plt.tight_layout()

    # Save the figure
    architecture_str = "_".join(map(str, architecture)).replace("[", "").replace("]", "").replace(",", "-").replace(" ", "")
    filename = f"results_architecture_{architecture_str}.png"
    plt.savefig(filename, dpi=300)
    print(f"Figure saved as {filename}")

    plt.show()


def main():

    output_dir = "results"
    os.makedirs(output_dir, exist_ok=True)

    # architectures = [
    #     [1, 50, 20, 1],
    #     [1, 64, 32, 1],
    #     [1, [50, 20], 10, 1],
    #     [1, [50, 20], [20, 8], 1],
    #     ]
    architecture = [1, 14, 8, 1]

    for i in range(200):

        print("Solving Example 1...")
        t1, y1, y1_exact = solve_example1(architecture, output_dir)
        save_plot(t1, y1, y1_exact, 1, output_dir)

        print("Solving Example 2...")
        t2, y2, y2_exact = solve_example2(architecture, output_dir)
        save_plot(t2, y2, y2_exact, 2, output_dir)

        print("Solving Example 3...")
        t3, y3, y3_exact = solve_example3(architecture, output_dir)
        save_plot(t3, y3, y3_exact, 3, output_dir)

        # print("Plotting results...")
        # plot_results(t1, y1, y1_exact, t2, y2, y2_exact, t3, y3, y3_exact, architecture)


if __name__ == "__main__":
    main()
