In [1]:
import torch
import numpy as np
from globalbiopak.linop import *
from globalbiopak.phaseretrieval import ForwardPhaseRetrieval
import matplotlib.pyplot as plt

%load_ext autoreload
%autoreload 2

# Random model

In [2]:
d = 200
n = 1000
x = torch.randn(d, dtype=torch.complex64)
A = torch.randn(n, d, dtype=torch.complex64)
op = LinOpMatrix(A)
forward = ForwardPhaseRetrieval(op)
y = forward.apply(x)

In [4]:
xest = forward.spectralinit(y)
xest.requires_grad = True
n_iter = 100
step = 1e-3
optimizer = torch.optim.SGD({xest}, lr=step, momentum=0.1, nesterov=True)

for i_iter in range(n_iter):
    optimizer.zero_grad()
    yest = forward.apply(xest)
    loss = torch.norm(yest-y)
    loss.backward()
    optimizer.step()

with torch.no_grad():
    print("Final correlation (close to 1 means successful recovery):")
    print(torch.abs(xest.T.conj() @ x) / \
        torch.sqrt(torch.sum(torch.abs(x)**2)) / torch.sqrt(torch.sum(torch.abs(xest)**2)))

Final correlation (close to 1 means successful recovery):
tensor(0.9974)


# Coded Diffraction Imaging

In [11]:
d = 200
n_masks = 10
x = torch.randn(d, dtype=torch.complex64)
masks = torch.randn(n_masks, d, dtype=torch.complex64)
op = LinOpComposition(LinOpFFT(), LinOpMul(masks[0, :]))
for i in range(1, n_masks):
    op = StackLinOp(op, LinOpComposition(LinOpFFT(), LinOpMul(masks[i, :])))
forward = ForwardPhaseRetrieval(op)
y = forward.apply(x)

In [12]:
xest = forward.spectralinit(y)
xest.requires_grad = True
n_iter = 100
step = 1e-3
optimizer = torch.optim.SGD({xest}, lr=step, momentum=0.1, nesterov=True)

for i_iter in range(n_iter):
    optimizer.zero_grad()
    yest = forward.apply(xest)
    loss = torch.norm(yest-y)
    loss.backward()
    optimizer.step()

with torch.no_grad():
    print("Final correlation (close to 1 means successful recovery):")
    print(torch.abs(xest.T.conj() @ x) / \
        torch.sqrt(torch.sum(torch.abs(x)**2)) / torch.sqrt(torch.sum(torch.abs(xest)**2)))

Final correlation (close to 1 means successful recovery):
tensor(0.9991)


# Ptychography

In [40]:
d = 500
n_img = 10
ptycho_radius = 250
left_origin = ptycho_radius / 2
right_origin = d - ptycho_radius / 2
step_size = int((right_origin - left_origin) / n_img)
overlap = 1 - step_size / ptycho_radius
print(f"The overlap is {overlap}")

sampling_grid = torch.linspace(0, d-1, d)
probe = torch.randn(d, dtype=torch.complex64) * (sampling_grid < ptycho_radius)

x = torch.randn(d, dtype=torch.complex64)
op = LinOpComposition(LinOpFFT(), LinOpMul(probe))
for i in range(1, n_masks):
    op = StackLinOp(op, LinOpComposition(LinOpFFT(), LinOpMul(torch.roll(probe, i*step_size))))
forward = ForwardPhaseRetrieval(op)
y = forward.apply(x)

The overlap is 0.9


In [42]:
xest = forward.spectralinit(y)
xest.requires_grad = True
n_iter = 1000
step = 1e-3
optimizer = torch.optim.SGD({xest}, lr=step, momentum=0.1, nesterov=True)

for i_iter in range(n_iter):
    optimizer.zero_grad()
    yest = forward.apply(xest)
    loss = torch.norm(yest-y)
    loss.backward()
    optimizer.step()

with torch.no_grad():
    print("Final correlation (close to 1 means successful recovery):")
    print(torch.abs(xest.T.conj() @ x) / \
        torch.sqrt(torch.sum(torch.abs(x)**2)) / torch.sqrt(torch.sum(torch.abs(xest)**2)))

Final correlation (close to 1 means successful recovery):
tensor(0.9751)
