In [1]:
from scipy.special import spherical_jn, spherical_yn

from scipy import special

import torch
from torch.autograd import Function
from torch.nn.modules.module import Module
from torch.nn.parameter import Parameter

import matplotlib.pyplot as plt
import numpy as np


In [2]:
class torch_jn(Function):
    @staticmethod
    def forward(ctx, input, n):
        input = input.detach()
        result = torch.from_numpy(spherical_jn(n, input.numpy()))
        ctx.save_for_backward(result)
        ctx.n = n #n is not learnable so can just save in ctx

        return result
        #return torch.as_tensor(result, dtype=torch.complex64)

    @staticmethod
    def backward(ctx, grad_output):
        grad_output = grad_output.detach()
        input = ctx.saved_tensors
        grad_output = grad_output.numpy()
        grad_input = torch.from_numpy(spherical_jn(ctx.n, grad_output, derivative = True))

        return grad_input, None
        #return torch.as_tensor(grad_input, dtype=torch.complex64), None


class torch_jn_der(Function):
    @staticmethod
    def forward(ctx, input, n):
        input = input.detach()
        result = torch.from_numpy(spherical_jn(n, input.numpy(), derivative = True))
        ctx.save_for_backward(result)
        ctx.n = n #n is not learnable so can just save in ctx

        return result
        #return torch.as_tensor(result, dtype=torch.complex64)

    @staticmethod
    def backward(ctx, grad_output):
        grad_output = grad_output.detach()
        input = ctx.saved_tensors
        grad_output = grad_output.numpy()
        grad_input = torch.from_numpy( (1/grad_output**2)*((ctx.n.detach().numpy()**2 - ctx.n.detach().numpy() - grad_output**2)*spherical_jn(ctx.n.detach().numpy(), grad_output) + 2*grad_output*spherical_jn(ctx.n.detach().numpy() + 1, grad_output)) )

        return grad_input, None
        #return torch.as_tensor(grad_input, dtype=torch.complex64), None


class torch_yn(Function):
    @staticmethod
    def forward(ctx, input, n):
        input = input.detach()
        result = torch.from_numpy(spherical_yn(n, input.numpy()))
        ctx.save_for_backward(result)
        ctx.n = n #n is not learnable so can just save in ctx

        return result
        #return torch.as_tensor(result, dtype=torch.complex64)

    @staticmethod
    def backward(ctx, grad_output):
        grad_output = grad_output.detach()
        input = ctx.saved_tensors
        grad_output = grad_output.numpy()
        grad_input = torch.from_numpy(spherical_yn(ctx.n, grad_output, derivative = True))

        return grad_input, None
        #return torch.as_tensor(grad_input, dtype=torch.complex64), None


class torch_yn_der(Function):
    @staticmethod
    def forward(ctx, input, n):
        input = input.detach()
        result = torch.from_numpy(spherical_yn(n, input.numpy(), derivative = True))
        ctx.save_for_backward(result)
        ctx.n = n #n is not learnable so can just save in ctx

        return result
        #return torch.as_tensor(result, dtype=torch.complex64)

    @staticmethod
    def backward(ctx, grad_output):
        grad_output = grad_output.detach()
        input = ctx.saved_tensors
        grad_output = grad_output.numpy()
        grad_input = torch.from_numpy( (1/grad_output**2)*((ctx.n.detach().numpy()**2 - ctx.n.detach().numpy() - grad_output**2)*spherical_yn(ctx.n.detach().numpy(), grad_output) + 2*grad_output*spherical_yn(ctx.n.detach().numpy() + 1, grad_output)) )
        return grad_input, None
        #return torch.as_tensor(grad_input, dtype=torch.complex64), None

In [45]:
# Testing shapes
n = np.array([1,2,3,4,5])
n = np.reshape(n,(5,1))

z = np.array([0.1,0.2,0.3])
print(spherical_yn(n,z))


n = np.array([1,2,3,4,5])


z = np.array([0.1,0.2,0.3])
z = np.reshape(z,(3,1))
print(spherical_yn(n,z).T)

[[-1.00498751e+02 -2.54950111e+01 -1.15999172e+01]
 [-3.00501248e+03 -3.77524834e+02 -1.12814717e+02]
 [-1.50150125e+05 -9.41262583e+03 -1.86864537e+03]
 [-1.05075038e+07 -3.29064379e+05 -4.34889106e+04]
 [-9.45525188e+08 -1.47984844e+07 -1.30279867e+06]]
[[-1.00498751e+02 -2.54950111e+01 -1.15999172e+01]
 [-3.00501248e+03 -3.77524834e+02 -1.12814717e+02]
 [-1.50150125e+05 -9.41262583e+03 -1.86864537e+03]
 [-1.05075038e+07 -3.29064379e+05 -4.34889106e+04]
 [-9.45525188e+08 -1.47984844e+07 -1.30279867e+06]]


In [3]:
sph_jn = torch_jn.apply
sph_yn = torch_yn.apply

sph_jn_der = torch_jn_der.apply
sph_yn_der = torch_yn_der.apply

def sph_h1n(z, n):
    return sph_jn(z, n) + 1j*sph_yn(z, n)

def sph_h1n_der(z, n):
    return sph_jn_der(z, n) + 1j*sph_yn_der(z, n)

def psi(z, n):
    return z*sph_jn(z,n)

def chi(z, n):
    return -z*sph_yn(z, n)

def xi(z, n):
    return z*sph_h1n(z, n)

def psi_der(z, n):
    return sph_jn(z,n) + z*sph_jn_der(z,n)

def chi_der(z, n):
    return -sph_yn(z,n) - z*sph_yn_der(z,n)

def xi_der(z, n):
    return sph_h1n(z,n) + z*sph_h1n_der(z,n)

In [5]:
def An(k, a, b, n, m1, m2):
    return (m2*psi(m2*k*a, n)*psi_der(m1*k*a, n) - m1*psi_der(m2*k*a, n)*psi(m1*k*a, n))/(m2*chi(m2*k*a, n)*psi_der(m1*k*a, n) - m1*chi_der(m2*k*a, n)*psi(m1*k*a, n))

def Bn(k, a, n, m1, m2):
    return (m2*psi(m1*k*a, n)*psi_der(m2*k*a, n) - m1*psi(m2*k*a, n)*psi_der(m1*k*a, n))/(m2*chi_der(m2*k*a, n)*psi(m1*x, n) - m1*psi_der(m1*k*a, n)*chi(m2*k*a, n))

def an(k, a, b, n, m1, m2):
    return (psi(k*b, n)*(psi_der(m2*k*b, n) - An(k*a, n, m1, m2)*chi_der(m2*k*b, n)) - m2*psi_der(k*b, n)*(psi(m2*y, n) - An(k*a, n, m1, m2)*chi(m2*k*b, n)))/(xi(y, n)*(psi_der(m2*k*b, n) - An(k*a, n, m1, m2)*chi_der(m2*k*b, n)) - m2*xi_der(k*b, n)*(psi(m2*k*b, n) - An(k*a, n, m1, m2)*chi(m2*k*b, n)))

def bn(k, a, b, n, m1, m2):
    return (m2*psi(y, n)*(psi_der(m2*k*b, n) - Bn(k*a, n, m1, m2)*chi_der(m2*k*b, n)) - psi_der(k*b, n)*(psi(m2*k*b, n) - Bn(k*a, n, m1, m2)*chi(m2*k*b, n)))/(m2*xi(k*b, n)*(psi_der(m2*k*b, n) - Bn(k*a, n, m1, m2)*chi_der(m2*k*b, n)) - xi_der(k*b, n)*(psi(m2*k*b, n) - Bn(k*a, n, m1, m2)*chi(m2*k*b, n)))

In [4]:
def An(x, n, m1, m2):
    return (m2*psi(m2*x, n)*psi_der(m1*x, n) - m1*psi_der(m2*x, n)*psi(m1*x, n))/(m2*chi(m2*x, n)*psi_der(m1*x, n) - m1*chi_der(m2*x, n)*psi(m1*x, n))

def Bn(x, n, m1, m2):
    return (m2*psi(m1*x, n)*psi_der(m2*x, n) - m1*psi(m2*x, n)*psi_der(m1*x, n))/(m2*chi_der(m2*x, n)*psi(m1*x, n) - m1*psi_der(m1*x, n)*chi(m2*x, n))

def an(x, y, n, m1, m2):
    return (psi(y, n)*(psi_der(m2*y, n) - An(x, n, m1, m2)*chi_der(m2*y, n)) - m2*psi_der(y, n)*(psi(m2*y, n) - An(x, n, m1, m2)*chi(m2*y, n)))/(xi(y, n)*(psi_der(m2*y, n) - An(x, n, m1, m2)*chi_der(m2*y, n)) - m2*xi_der(y, n)*(psi(m2*y, n) - An(x, n, m1, m2)*chi(m2*y, n)))

def bn(x, y, n, m1, m2):
    return (m2*psi(y, n)*(psi_der(m2*y, n) - Bn(x, n, m1, m2)*chi_der(m2*y, n)) - psi_der(y, n)*(psi(m2*y, n) - Bn(x, n, m1, m2)*chi(m2*y, n)))/(m2*xi(y, n)*(psi_der(m2*y, n) - Bn(x, n, m1, m2)*chi_der(m2*y, n)) - xi_der(y, n)*(psi(m2*y, n) - Bn(x, n, m1, m2)*chi(m2*y, n)))

#def x(k, a):
#    return k*a

#def y(k, b):
#    return k*b

In [5]:
wlRes = 10

wl = np.linspace(500,550, wlRes)#500.0  # wavelength in nm

wl = np.reshape(wl,newshape=(wlRes,1))


r_core = 80.0
r_shell = r_core + 100.0

n_env = 1
n_core = 4
n_shell = 0.1   + .7j

mu_env = 1
mu_core = 1
mu_shell = 1

dtype = torch.complex64
dtype2 = torch.float64
device = torch.device("cpu")

n_max = 5
k = 2 * np.pi / (wl / n_env)

#print(k)

m1 = n_core / n_env
m2 = n_shell / n_env

K = torch.tensor(k, requires_grad=False, dtype=dtype2)

#a = torch.tensor(r_core, requires_grad=True)
#b = torch.tensor(r_shell, requires_grad=True)
#a = torch.full(K.shape, r_core, requires_grad=True, dtype=dtype)
#b = torch.full(K.shape, r_shell, requires_grad=True, dtype=dtype)
#m1 = torch.tensor(m1, requires_grad=True, dtype=dtype)
#m2 = torch.tensor(m2, requires_grad=True, dtype=dtype)
m1 = torch.full(K.shape, m1, requires_grad=True, dtype=dtype)
m2 = torch.full(K.shape, m2, requires_grad=True, dtype=dtype)

x = K * r_core
y = K * r_shell

x = torch.tensor(x, requires_grad=True)
y = torch.tensor(y, requires_grad=True)

n = torch.tensor([1,2,3,4,5], dtype=dtype)

#print(n, n.shape)

#x_ = x(K, a)
#y_ = y(K, b)

A = an(x, y, n, m1, m2)
B = bn(x, y, n, m1, m2)

print(K.shape)

print("here",torch.sum((2 * n + 1) * (A.real + B.real),dim=1).shape)

qext = torch.mul(K, torch.unsqueeze(torch.sum((2 * n + 1) * (A.real + B.real), dim=1),1))
qsca = K * torch.unsqueeze(torch.sum((2 * n + 1) * (A.real**2 + A.imag**2 + B.real**2 + B.imag**2), dim=1),1)
qabs = qext - qsca

print(qext, qext.shape)
print(qsca, qsca.shape)
print(qabs, qabs.shape)



torch.Size([10, 1])
here torch.Size([10])
tensor([[0.0698+0.j],
        [0.0674+0.j],
        [0.0650+0.j],
        [0.0628+0.j],
        [0.0607+0.j],
        [0.0586+0.j],
        [0.0567+0.j],
        [0.0549+0.j],
        [0.0531+0.j],
        [0.0515+0.j]], dtype=torch.complex128, grad_fn=<MulBackward0>) torch.Size([10, 1])
tensor([[0.0611+0.j],
        [0.0589+0.j],
        [0.0568+0.j],
        [0.0549+0.j],
        [0.0530+0.j],
        [0.0512+0.j],
        [0.0495+0.j],
        [0.0478+0.j],
        [0.0463+0.j],
        [0.0448+0.j]], dtype=torch.complex128, grad_fn=<MulBackward0>) torch.Size([10, 1])
tensor([[0.0087+0.j],
        [0.0085+0.j],
        [0.0082+0.j],
        [0.0080+0.j],
        [0.0077+0.j],
        [0.0075+0.j],
        [0.0073+0.j],
        [0.0070+0.j],
        [0.0068+0.j],
        [0.0066+0.j]], dtype=torch.complex128, grad_fn=<SubBackward0>) torch.Size([10, 1])


  x = torch.tensor(x, requires_grad=True)
  y = torch.tensor(y, requires_grad=True)
  return self.numpy().astype(dtype, copy=False)


In [54]:
qext.backward([x, y, m1, m2], retain_graph = True, create_graph = True)

print("ddX", x.grad)
print("ddY", y.grad)
print("ddm1", m1.grad)
print("ddm2", m2.grad)

ddX tensor([[nan+nanj],
        [nan+nanj],
        [nan+nanj],
        [nan+nanj],
        [nan+nanj],
        [nan+nanj],
        [nan+nanj],
        [nan+nanj],
        [nan+nanj],
        [nan+nanj]], grad_fn=<CopyBackwards>)
ddY tensor([[nan+nanj],
        [nan+nanj],
        [nan+nanj],
        [nan+nanj],
        [nan+nanj],
        [nan+nanj],
        [nan+nanj],
        [nan+nanj],
        [nan+nanj],
        [nan+nanj]], grad_fn=<CopyBackwards>)
ddm1 tensor([[12.8661-8.4589e-04j],
        [12.7248-4.2461e-04j],
        [12.5865-7.3176e-05j],
        [12.4511+2.3557e-04j],
        [12.3185+5.2130e-04j],
        [12.1888+7.9940e-04j],
        [12.0616+1.0820e-03j],
        [11.9371+1.3816e-03j],
        [11.8150+1.7101e-03j],
        [11.6953+2.0816e-03j]], grad_fn=<CopyBackwards>)
ddm2 tensor([[nan+nanj],
        [nan+nanj],
        [nan+nanj],
        [nan+nanj],
        [nan+nanj],
        [nan+nanj],
        [nan+nanj],
        [nan+nanj],
        [nan+nanj],
        [nan+n

  n = np.asarray(n, dtype=np.dtype("long"))
  n = np.asarray(n, dtype=np.dtype("long"))
  grad_input = torch.from_numpy( (1/grad_output**2)*((ctx.n.detach().numpy()**2 - ctx.n.detach().numpy() - grad_output**2)*spherical_yn(ctx.n.detach().numpy(), grad_output) + 2*grad_output*spherical_yn(ctx.n.detach().numpy() + 1, grad_output)) )
  grad_input = torch.from_numpy( (1/grad_output**2)*((ctx.n.detach().numpy()**2 - ctx.n.detach().numpy() - grad_output**2)*spherical_yn(ctx.n.detach().numpy(), grad_output) + 2*grad_output*spherical_yn(ctx.n.detach().numpy() + 1, grad_output)) )
