In [None]:
import matplotlib.pyplot as plt
import numpy as np
import torch
import pickle
import yaml

import sys
sys.path.append('../')
import datamodule
import don
from plane import Plane
from propagator import PropagatorFactory

In [None]:
config = yaml.load(open('../config.yaml', 'r'), Loader = yaml.FullLoader)
config['batch_size'] = 1

In [None]:
checkpoint_path = '/home/mblgh6/Documents/diffractive_optical_model/results/my_models/early_testing/epoch=4-step=6250-v3.ckpt'
checkpoint = torch.load(checkpoint_path)

In [None]:
model = don.DON(config)
data = datamodule.select_data(config)
data.prepare_data()
data.setup(stage='fit')
dataloader = data.train_dataloader()

In [None]:
state_dict = checkpoint['state_dict']
assert (state_dict['layers.0.propagator.H'].cpu() == model.layers[0].propagator.H).all()
assert (state_dict['layers.1.propagator.H'].cpu() == model.layers[1].propagator.H).all()

In [None]:
model.load_state_dict(checkpoint['state_dict'])

In [None]:
phases = []
amplitudes = []
for block in model.layers:
    phases.append(block.modulator.get_phase(with_grad=False))
    amplitudes.append(block.modulator.get_amplitude(with_grad=False))
for phase in phases:
    print((phase % (torch.pi * 2)).min())

In [None]:
num_blocks = len(model.layers)
fig,ax = plt.subplots(num_blocks, 2, figsize=(10, 5 * num_blocks))

for i,(a,p) in enumerate(zip(amplitudes, phases)):
    ax[i][0].imshow(a.squeeze())
    ax[i][1].imshow(p.squeeze() % (torch.pi * 2), vmin=0, vmax=2*torch.pi)

    ax[i][0].set_title("Block {} Amplitude".format(i))
    ax[i][1].set_title("Block {} Phase".format(i))

In [None]:
image, slm_sample, target = next(iter(dataloader))
image = image.squeeze()
slm_sample = slm_sample.squeeze()
target = target.squeeze()

In [None]:
fig,ax = plt.subplots(3,2,figsize=(10,15))
ax[0][0].imshow(image.abs())
ax[0][1].imshow(image.angle())

ax[1][0].imshow(slm_sample.abs())
ax[1][1].imshow(slm_sample.angle())

ax[2][0].imshow(target.abs())
ax[2][1].imshow(target.angle())

In [None]:
outputs, target = model.shared_step((image, slm_sample, target), 0)
print(outputs.keys())
for k in outputs:
    outputs[k] = outputs[k].detach().squeeze()

In [None]:
fig,ax = plt.subplots(len(outputs), 2, figsize=(10, 5 * len(outputs)))
ax[0][0].imshow(outputs['output_wavefronts'].abs())
ax[0][1].imshow(outputs['output_wavefronts'].angle())

In [None]:
test = model.layers[0].forward(image)

In [None]:
plt.imshow(test.abs())

In [None]:
test2 = model.layers[1].forward(test)

In [None]:
plt.imshow(test2.abs().detach().squeeze())

In [None]:
test

In [None]:
input_plane_params = {
    'name':'input',
    'size': torch.tensor([8.96e-3, 8.96e-3]),
    'Nx':1080,
    'Ny':1080,
    'normal': torch.tensor([0,0,1]),
    'center': torch.tensor([0,0,0])}

output_plane_params = {
    'name':'input',
    'size': torch.tensor([8.96e-3, 8.96e-3]),
    'Nx':1080,
    'Ny':1080,
    'normal': torch.tensor([0,0,1]),
    'center': torch.tensor([0,0,0])}

propagator_params = {'wavelength':torch.tensor(1.55e-6)}

input_plane = Plane(input_plane_params)
output_plane = Plane(output_plane_params)

propagator = PropagatorFactory()(input_plane, output_plane, propagator_params)

In [None]:
test2 = propagator(test * model.layers[1].modulator.get_transmissivity(with_grad=False))

In [None]:
plt.imshow(test2.abs().detach().squeeze())

In [None]:
Lx = 8.96e-3
Nx = 1080
delta_x = Lx/Nx
wavelength = 1.55e-6

first = 2*Lx * delta_x / wavelength
second = first * np.sqrt(1 - (wavelength / ( 2 * delta_x))**2)
print(second)