In [None]:
from scipy.special import spherical_jn, spherical_yn

from scipy import special

import torch
from torch.autograd import Function
from torch.nn.modules.module import Module
from torch.nn.parameter import Parameter

import matplotlib.pyplot as plt
import numpy as np


Here we define the torch functions for the spherical bessel functions both forward and backward (its gradient), as the equations for the scattering coeffents invole the derivative of the spherical bessel functions we also need the second deriviative of them. These can be found starting from the reccurce relations,

$$
f_n^\prime(z) = f_{n-1}(z) - \frac{n+1}{z}f_n(z),
$$
and,
$$
f_n^\prime(z) = - f_{n+1}(z) + \frac{n}{z}f_n(z),
$$

where $f_n(z)$ is any spherical bessel function. First for later use substitie $n = n+1$ into the first equation to get,

$$
f_{n+1}^\prime(z) = f_{n}(z) - \frac{n+2}{z}f_{n+1}(z)
$$





Taking the derivative of the secound equation,

$$
\frac{d^2}{dz^2}f_n(z) = - \frac{d}{dz}f_{n+1}(z) + n\frac{d}{dz} \left( \frac{f_n(z)}{z} \right),
$$

$$
f_n^{\prime\prime}(z) = - f_{n+1}^\prime(z) + n \left( \frac{f_n^\prime(z)z-f_n(z)}{z^2} \right).
$$
Reagrange this to 

$$
z^2f_n^{\prime\prime}(z) = -z^2f_{n+1}^\prime(z) + nzf_{n}^\prime - nf_n(z),
$$
and then substituie the modified first equation and the second equation to get,
$$
z^2f_n^{\prime\prime}(z) = -z^2\left(f_{n}(z) - \frac{n+2}{z}f_{n+1}(z)\right) + nz \left( - f_{n+1}(z) + \frac{n}{z}f_n(z) \right) - nf_n(z).
$$
Reagrangr this to get the equation for $f_n^{\prime\prime}(z)$,

$$
z^2f_n^{\prime\prime}(z) = f_n(z) \left ( -z^2 + n^2 - n\right) + f_{n+1}(z) \left ( z(n+2) -nz \right),
$$

$$
f_n^{\prime\prime}(z) = \frac{1}{z^2} \left [ (n^2 - n - z^2)f_n(z) + 2z f_{n+1}(z)     \right ].
$$









In [None]:
class torch_jn(Function):
    @staticmethod
    def forward(ctx, input, n):
        input = input.detach()
        result = torch.from_numpy(spherical_jn(n, input.numpy()))
        ctx.save_for_backward(result)
        ctx.n = n #n is not learnable so can just save in ctx

        return torch.as_tensor(result, dtype=torch.complex64)

    @staticmethod
    def backward(ctx, grad_output):
        grad_output = grad_output.detach()
        input = ctx.saved_tensors
        grad_output = grad_output.numpy()
        grad_input = torch.from_numpy(spherical_jn(ctx.n, grad_output, derivative = True))

        return torch.as_tensor(grad_input, dtype=torch.complex64), None
    

class torch_jn_der(Function):
    @staticmethod
    def forward(ctx, input, n):
        input = input.detach()
        result = torch.from_numpy(spherical_jn(n, input.numpy(), derivative = True))
        ctx.save_for_backward(result)
        ctx.n = n #n is not learnable so can just save in ctx

        return torch.as_tensor(result, dtype=torch.complex64)

    @staticmethod
    def backward(ctx, grad_output):
        grad_output = grad_output.detach()
        input = ctx.saved_tensors
        grad_output = grad_output.numpy()
        grad_input = torch.from_numpy( (1/grad_output**2)*((ctx.n**2 - ctx.n - grad_output**2)*spherical_jn(ctx.n, grad_output) + 2*grad_output*spherical_jn(ctx.n + 1, grad_output)) )

        return torch.as_tensor(grad_input, dtype=torch.complex64), None


class torch_yn(Function):
    @staticmethod
    def forward(ctx, input, n):
        input = input.detach()
        result = torch.from_numpy(spherical_yn(n, input.numpy()))
        ctx.save_for_backward(result)
        ctx.n = n #n is not learnable so can just save in ctx

        return torch.as_tensor(result, dtype=torch.complex64)

    @staticmethod
    def backward(ctx, grad_output):
        grad_output = grad_output.detach()
        input = ctx.saved_tensors
        grad_output = grad_output.numpy()
        grad_input = torch.from_numpy(spherical_yn(ctx.n, grad_output, derivative = True))

        return torch.as_tensor(grad_input, dtype=torch.complex64), None
    

class torch_yn_der(Function):
    @staticmethod
    def forward(ctx, input, n):
        input = input.detach()
        result = torch.from_numpy(spherical_yn(n, input.numpy(), derivative = True))
        ctx.save_for_backward(result)
        ctx.n = n #n is not learnable so can just save in ctx

        return torch.as_tensor(result, dtype=torch.complex64)

    @staticmethod
    def backward(ctx, grad_output):
        grad_output = grad_output.detach()
        input = ctx.saved_tensors
        grad_output = grad_output.numpy()
        grad_input = torch.from_numpy( (1/grad_output**2)*((ctx.n**2 - ctx.n - grad_output**2)*spherical_yn(ctx.n, grad_output) + 2*grad_output*spherical_yn(ctx.n + 1, grad_output)) )

        return torch.as_tensor(grad_input, dtype=torch.complex64), None

Using these functions we define the Riccati-Bessel Functions:

In [None]:
sph_jn = torch_jn.apply
sph_yn = torch_yn.apply

sph_jn_der = torch_jn_der.apply
sph_yn_der = torch_yn_der.apply

def sph_h1n(z, n):
    return sph_jn(z, n) + 1j*sph_yn(z, n)

def sph_h1n_der(z, n):
    return sph_jn_der(z, n) + 1j*sph_yn_der(z, n)

def psi(z, n):
    return z*sph_jn(z,n)

def chi(z, n):
    return -z*sph_yn(z, n)

def xi(z, n):
    return z*sph_h1n(z, n)

def psi_der(z, n):
    return sph_jn(z,n) + z*sph_jn_der(z,n)

def chi_der(z, n):
    return -sph_yn(z,n) - z*sph_yn_der(z,n)

def xi_der(z, n):
    return sph_h1n(z,n) + z*sph_h1n_der(z,n)

And from these we define the scattering coeffients:

In [None]:
def An(x, n, m1, m2):
    return (m2*psi(m2*x, n)*psi_der(m1*x, n) - m1*psi_der(m2*x, n)*psi(m1*x, n))/(m2*chi(m2*x, n)*psi_der(m1*x, n) - m1*chi_der(m2*x, n)*psi(m1*x, n))

def Bn(x, n, m1, m2):
    return (m2*psi(m1*x, n)*psi_der(m2*x, n) - m1*psi(m2*x, n)*psi_der(m1*x, n))/(m2*chi_der(m2*x, n)*psi(m1*x, n) - m1*psi_der(m1*x, n)*chi(m2*x, n))

def an(x, y, n, m1, m2):
    return (psi(y, n)*(psi_der(m2*y, n) - An(x, n, m1, m2)*chi_der(m2*y, n)) - m2*psi_der(y, n)*(psi(m2*y, n) - An(x, n, m1, m2)*chi(m2*y, n)))/(xi(y, n)*(psi_der(m2*y, n) - An(x, n, m1, m2)*chi_der(m2*y, n)) - m2*xi_der(y, n)*(psi(m2*y, n) - An(x, n, m1, m2)*chi(m2*y, n)))

def bn(x, y, n, m1, m2):
    return (m2*psi(y, n)*(psi_der(m2*y, n) - Bn(x, n, m1, m2)*chi_der(m2*y, n)) - psi_der(y, n)*(psi(m2*y, n) - Bn(x, n, m1, m2)*chi(m2*y, n)))/(m2*xi(y, n)*(psi_der(m2*y, n) - Bn(x, n, m1, m2)*chi_der(m2*y, n)) - xi_der(y, n)*(psi(m2*y, n) - Bn(x, n, m1, m2)*chi(m2*y, n)))

In [None]:
dtype = torch.complex64
device = torch.device("cpu")

#x = ka where a is the core raduis
#y = kb where b is the raduis of the shell

x = torch.linspace(1, 2, 3, device=device, dtype=dtype, requires_grad=True)
y = torch.linspace(1, 2, 3, device=device, dtype=dtype, requires_grad=True)

#m1 = torch.tensor(3.0, requires_grad=True)
#m2 = torch.tensor(2.0, requires_grad=True)

m1 = torch.full(x.shape, 3.0, device=device, dtype=dtype, requires_grad=True)
m2 = torch.full(x.shape, 2.0, device=device, dtype=dtype, requires_grad=True)

n = 3

output = an(x, y, n, m1, m2)  #Testing with an to begin with

print(output, output.shape)
print(x, x.shape)
if True:
    #Is this the correct way to find the gradient of an wrt x, y, m1 and m2? This works with no error.
    #Reading these:
    # https://discuss.pytorch.org/t/how-to-calculate-2nd-derivative-of-a-likelihood-function/15085/7
    # https://discuss.pytorch.org/t/second-order-derivatives-of-loss-function/71797/3
    #It seems like this may be wrong
    
    # you must call backward only once!
    output.backward([x, y, m1, m2], retain_graph = True, create_graph = True)
    
    # calling backward again will sum the gradients again, so it will change (falsify) previously calculted gradients
    # output.backward(y, retain_graph = True, create_graph = True)
    # output.backward(m1, retain_graph = True, create_graph = True)
    # output.backward(m2, retain_graph = True, create_graph = True)

    print("ddX", x.grad)
    print("ddY", y.grad)
    print("ddm1", m1.grad)
    print("ddm2", m2.grad)
    
    
    # if you want to use the autograd interface, you need to tell him explicitly the shape and dtype of the output
    # (see https://discuss.pytorch.org/t/what-is-the-difference-between-autograd-backward-and-autograd-grad/74663/3)
    x_grad = torch.autograd.grad(output, x, create_graph=True, grad_outputs=torch.ones(3, dtype=dtype))
    print("ddX via autograd:", x_grad)
    
    # then, if you created a graph, you can autograd this again (it should also work with the "backward" API...)
    # I'm not entirely sure about this syntax though, but anyways we need to verify everything with numerical differentiation.
    x_grad_grad = torch.autograd.grad(x_grad, x, grad_outputs=torch.ones(3, dtype=dtype))
    print("d2 dX2:", x_grad_grad)