In [None]:
import numpy as np
import matplotlib.pyplot as plt
import sys
import yaml
from tqdm import tqdm
import torch
sys.path.append('../')
from plane import Plane
from propagator import PropagatorFactory

In [None]:
config = yaml.load(open('../config.yaml', 'r'), Loader=yaml.FullLoader)

In [None]:
plane1_params = {
    'name':'input_plane',
    'size': (5.e-3, 5.e-3),
    'Nx':166,
    'Ny':166,
    'center': (0,0,0),
    'normal': (0,0,1)
}
plane2_params = {
    'name':'output_plane',
    'size': (5.e-3, 5.e-3),
    'Nx':166,
    'Ny':166,
    'center': (0,0,9.e-2),
    'normal': (0,0,1)
}

In [None]:
plane1 = Plane(plane1_params)
plane2 = Plane(plane2_params)

In [None]:
U = torch.ones((plane1.Nx, plane1.Ny))
mask = np.sqrt(plane1.xx**2 + plane1.yy**2) < 0.15e-3
U = U * mask
plt.imshow(U)

In [None]:
pf = PropagatorFactory()
prop = pf(plane1, plane2, config)

In [None]:
output = prop(U)

In [None]:
plt.imshow(output.abs())

In [None]:
distance = plane2.center[-1] - plane1.center[-1]
wavelength = 1.55e-6
k = torch.pi * 2 / wavelength
shape = U.size()
padding = (shape[0]//2, shape[0]//2, shape[1]//2, shape[1]//2)
U = torch.nn.functional.pad(U,padding,mode="constant")
output_field = U.new_empty(U.size(), dtype=torch.complex64)

for i,x in enumerate(tqdm(plane1.x)):
    for j,y in enumerate(plane1.y):
        r = torch.sqrt((plane2.xx-x)**2 + (plane2.yy-y)**2 + distance**2)
        chirp = torch.exp(1j * k * r)
        scalar1 = distance / r
        scalar2 = (( 1 / r) - 1j*k)
        combined = U * chirp * scalar1 * scalar2
        output_field[i,j] = combined.sum()