In [None]:
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn.functional as F
from rkan.torch import JacobiRKAN, PadeRKAN
from sklearn.model_selection import train_test_split
from torch import nn, optim

In [None]:
# Closure function
def closure():
    optimizer.zero_grad()  # Zero the gradients
    outputs = mlp(X_train)  # Forward pass
    loss = criterion(outputs, y_train)  # Compute the loss
    loss.backward(retain_graph=True)  # Backward pass
    return loss

In [None]:
Fs = [
    lambda x: x / (1 + x**2),
    lambda x: 1 / (1 + x**2),
    lambda x: torch.exp(-(x**2)),
]
Activations = [
    lambda q: JacobiRKAN(q),
    lambda q: PadeRKAN(q, 2),
    lambda q: PadeRKAN(q, 3),
    lambda q: PadeRKAN(q, 4),
    lambda q: PadeRKAN(q, 5),
    lambda q: PadeRKAN(q, 6),
    lambda q: fJNB(q),
]

In [None]:
X = torch.rand(300, requires_grad=True).reshape(-1, 1)
a, b = -10, 10
X = a + (b - a) * X

In [None]:
criterion = nn.MSELoss()

In [None]:
for trial in range(5):
    for fn, f in enumerate(Fs):
        y = f(X)
        X_train, X_test, y_train, y_test = train_test_split(
            X, y, test_size=0.33, random_state=0
        )

        for q in range(2,7):
            for act, Activation in enumerate(Activations):

                mlp = nn.Sequential(
                    nn.Linear(1, 10),
                    Activation(q),
                    nn.Linear(10, 10),
                    Activation(q),
                    nn.Linear(10, 1),
                )
                optimizer = optim.LBFGS(list(mlp.parameters()), lr=0.001)
                for i in range(1, 10):
                    optimizer.step(closure)
                train_pred = mlp(X_train)

                train_loss = criterion(train_pred, y_train).detach().numpy()

                test_pred = mlp(X_test)
                test_loss = criterion(test_pred, y_test).detach().numpy()
                print(
                    "%d,%d,%d,%d,%.3e,%.3e"
                    % (trial, q, act, fn, train_loss, test_loss),
                    file=open("activation_comparison.csv", "a"),
                )
                print(
                    "%d,%d,%d,%d,%.3e,%.3e" % (trial, q, act, fn, train_loss, test_loss)
                )