In [1]:
import matplotlib as mpl
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
import numpy as np
import scienceplots
import torch
import torch.nn as nn
import torch.nn.functional as F
from scipy.special import gamma
from torch import optim
from torch.autograd import Variable
from fkan.torch import FractionalJacobiNeuralBlock as fJNB

In [2]:
plt.style.use("science")
mpl.use("pgf")

plt.rcParams.update(
    {"text.usetex": True, "pgf.preamble": r"\usepackage{amssymb} \usepackage{amsmath}"}
)

In [3]:
def dy_dx(y, x):
    return torch.autograd.grad(
        y, x, grad_outputs=torch.ones_like(y), create_graph=True
    )[0]


def fracmatrix(N, b, alpha):
    A = torch.zeros((N, N))
    t = torch.linspace(0, b, N)
    dt = t[1] - t[0]

    for i in range(1, N):
        A[i, : i + 1] = fracweights(1 + i, dt, alpha)

    return t, A


def fracweights(n, dt, alpha):
    start = float(n - 1)
    mu = torch.zeros(n + 1)
    p = (torch.arange(start, -1.0, -1) * dt) ** (1 - alpha)
    mu[1:n] = (p[:-1] - p[1:]) / (gamma(2 - alpha) * dt)
    w = mu[:-1] - mu[1:]
    return w


def closure():
    loss = get_loss(x)
    optimizer.zero_grad()
    loss.backward()
    return loss

In [5]:
class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()

        self.jacobies = []
        self.hiddens = []
        self.denses = []

        n_hidden = 10
        for i in range(1, 7):
            act = fJNB(i)
            # act = nn.Tanh()
            self.jacobies.append(act)
            self.hiddens.append(nn.Linear(1, n_hidden))

        self.aggregate = nn.Linear(n_hidden * len(self.jacobies), n_hidden)

        for i in range(3):
            self.denses.append(nn.Linear(n_hidden, n_hidden))

        self.output = nn.Linear(n_hidden, 1)

    def forward(self, x):
        acts = []
        for hidden, jacobi in zip(self.hiddens, self.jacobies):
            q = hidden(x)
            acts.append(jacobi(q))
        h = torch.cat(acts, dim=1)
        agg = self.aggregate(h)
        for i in range(len(self.denses)):
            agg = self.denses[i](agg)
        output = self.output(agg)
        return output

In [6]:
alpha = 0.3

domain = 0, 1

n_discretization = 300 * (domain[1] - domain[0])

x, FM = fracmatrix(n_discretization, domain[1], alpha)
FM = FM.clone()
x = Variable(x.detach().clone(), requires_grad=True).reshape(-1, 1)

mlp = Model()

optimizer = optim.LBFGS(list(mlp.parameters()), lr=0.05)

In [7]:
def get_loss(x, ret_res=False):
    y = mlp(x)
    t = x - 1
    y_t = mlp(t)
    fracdiff = FM @ y
    residual = (
        fracdiff
        - y_t
        + y
        - 1
        + 3 * x
        - 3 * x**2
        - ((2000 * x**2.7) / (1071 * gamma(0.7)))
    )
    boundary1 = y[0] - 0

    loss = 1e4 * ((residual**2).mean() + boundary1**2)
    return (loss, residual) if ret_res else loss

In [None]:
losses = []
for i in range(30):
    loss = get_loss(x)
    optimizer.step(closure)
    losses.append(loss.detach().numpy())
    if i % 2 == 0:
        print("Epoch %3d: Current loss: %.2e" % (i, losses[-1]))

Epoch   0: Current loss: 1.26e+04
Epoch   2: Current loss: 1.23e+03
Epoch   4: Current loss: 2.19e+01
Epoch   6: Current loss: 4.66e+00
Epoch   8: Current loss: 4.28e+00
Epoch  10: Current loss: 3.38e+00
Epoch  12: Current loss: 1.50e+00
Epoch  14: Current loss: 1.39e+00
Epoch  16: Current loss: 1.39e+00


In [None]:
fig, axs = plt.subplots(1, 3, figsize=(15, 4.5))

x_test = torch.linspace(0, 1, x.shape[0] + 1)[1:].reshape(-1, 1)
loss, res = get_loss(x_test, ret_res=True)
predict = mlp.forward(x_test).detach().numpy().flatten()
exact = (x_test**3).detach().numpy().flatten()

axs[0].plot(x_test, exact, "g-", label="exact", lw=2)
axs[0].plot(
    [*x_test[::19], x_test[-1]],
    [*predict[::19], predict[-1]],
    "r*",
    label="Prediction",
    lw=2,
)

axs[1].plot(x_test, predict - exact, "c", label="Loss")
axs[2].plot(x_test, res.detach().numpy(), "c", label="Res")
titles = [r"$\chi(\tau)$", r"$\mathfrak{R}(\tau)$", r"Network $\mathfrak{R}(\tau)$"]
axs[0].legend()

for i in range(1, 3):
    axs[i].yaxis.set_major_formatter(ticker.ScalarFormatter(useMathText=True))
    axs[i].yaxis.get_major_formatter().set_scientific(True)
    axs[i].yaxis.get_major_formatter().set_powerlimits((0, 0))


for ax, title in zip(axs, titles):
    ax.set_ylabel(title)
    ax.set_xlabel(r"$\tau$")


fig.savefig(
    "fractional-delay.pdf",
    bbox_inches="tight",
    pad_inches=0,
)