## 3D Thin Lens ray equations
In order to find how a thin lens bends an arbitrary incoming ray, one can consider the following situation:

A ray $\{R,s_0\}$ (originating from point $R$, in the direction of $s_0$) passes through a thin lens plane $\{L,\hat{n}\}$ at point $P$ and gets bent into the direction $s_1$. A parallel ray $\{L,s_0\}$ passes through the lens center $L$, and intersects the original ray at point $Q$ at the focal plane. The focal plane $\{F,\hat{n}\}$ is at a distance $f$ from the thin lens plane.

![3D Lens](img/3D-lens.svg)

To find the path length shift $\Delta_{TL}$ introduced by the thin lens, consider a ray originating from the front focal point $F_1$, hitting the lens at point $P$ on the thin lens plane, and continuing parallel to the optical axis. The path length shift $\Delta_{TL}$ is such that any ray originating from $F_1$ has the same total path length, thus we can write: $\Delta_{TL} = f - \sqrt{}$

![3D Lens](img/3D-lens-phase.svg)

We mean to find the new ray $\{P,s_1\}$ and the $\Delta_{TL}$, both as functions of the incoming ray $\{R,s_0\}$, thin lens plane $\{L,\hat{n}\}$ and focal distance $f$.

In [1]:
import torch

if torch.cuda.is_available():
    dev = torch.device("cuda:0")
    print('Running on CUDA')
else:
    dev = torch.device("cpu")
    print('Running on CPU')
    

# Define vector space functions
def vector3(v):
    """Create 3D (NxMx3) vector array."""
    return torch.tensor(v, dtype=torch.float32, device=dev).view(1,1,3)

def scalar(s):
    """Create (NxMx1) scalar array."""
    return torch.tensor(s, dtype=torch.float32, device=dev).view(1,1,1)

def inner(v, w):
    """Vector inner product for NxMxD vector arrays, where D is vector dimension."""
    return torch.sum(v*w, dim=2, keepdim=True)

def normsq(v):
    """L2-norm squared for NxMxD vector arrays, where D is vector dimension."""
    return inner(v,v)

def norm(v):
    """L2-norm for NxMxD vector arrays, where D is vector dimension."""
    return torch.norm(v, dim=2, keepdim=True)

def unit(v):
    """Unit vectors for NxMxD vector arrays, where D is vector dimension."""
    return v / norm(v)


# Define ray operations
def dist_to_plane(ray_pos, ray_dir, plane_pos, plane_dir):
    """Scalar distance along ray to plane."""
    return inner((plane_pos - ray_pos), (plane_dir)) / inner(ray_dir, plane_dir)
    
def ray_plane_intersect(ray_pos, ray_dir, plane_pos, plane_dir):
    """Vector point of ray intersection with plane."""
    t = dist_to_plane(ray_pos, ray_dir, plane_pos, plane_dir)
    return ray_pos + ray_dir*t
    
    
#### Implement directional vector / pos-dir classes, representing rays, or infinite planes; ray.dir, ray.pos, plane.dir, plane.pos
    
def thinlens(ray_pos, ray_dir, lens_pos, lens_dir, f):
    """New ray position, direction and additional path length for thin lens."""
    L = lens_pos
    F = L - f*lens_dir                                       # Back focal plane center
    P = ray_plane_intersect(ray_pos, ray_dir, L, lens_dir)   # Ray intersection with lens plane
    Q = ray_plane_intersect(L, ray_dir, F, lens_dir)         # Ray intersection with back focal plane
    Delta = f - torch.sqrt(f*f + normsq(L-P))                # Lens induced phase shift term
    return P, unit(Q-P), Delta


R  = vector3((0,0,0))
s0 = unit(vector3((0,1,2)))
L  = vector3((0,0,2))
nL = vector3((0,0,-1))
f  = scalar(2)

thinlens(R, s0, L, nL, f)


### Implement function for plotting an array of points


Running on CUDA


(tensor([[[0., 1., 2.]]], device='cuda:0'),
 tensor([[[0., 0., 1.]]], device='cuda:0'),
 tensor([[[-0.2361]]], device='cuda:0'))