In [None]:
from simu_PSF_polar import *
from tqdm import tqdm
import time
import matplotlib.pyplot as plt
plt.rcParams['figure.figsize'] = [5,3]

In [None]:
if torch.cuda.is_available():
    device = torch.device('cuda')
    print("Using GPU")
else:   
    device = torch.device('cpu')
    print("Using CPU")

In [None]:
def limit(x, lim, slope, upper=True):
    '''
    if upper:
       return torch.sum(torch.tensor(1/(1+torch.exp(-slope*(x-lim))), requires_grad=True, device=device))
    else:
        return torch.sum(torch.tensor(1/(1+torch.exp(slope*(x-lim))), requires_grad=True, device=device))
    '''
    if upper:
        return torch.sum(torch.exp((x-lim)*slope))
    else:
        return torch.sum(torch.exp(-1*(x-lim)*slope))
        

In [None]:
N_photons= torch.tensor(5000., device=device)
N=torch.tensor(80, device=device)
l_pixel=torch.tensor(16, device=device)
NA=torch.tensor(1.4, device=device)
mag=torch.tensor(100, device=device)
lambd=torch.tensor(617, device=device)
f_tube=torch.tensor(200, device=device)
MAG=torch.tensor(200/150, device=device)

In [None]:
xp = torch.tensor([0. for k in range(100)], device=device)
yp = torch.tensor([0. for k in range(100)], device=device)
z = torch.tensor((1.+np.random.rand(100))*0.8, device=device) #position of dipole in lambda units
d_ = torch.tensor([-1.5 for k in range(100)], device=device) #defocus of dipole in lambda units 4.4
rho = torch.tensor(10.+160.*np.random.rand(100), requires_grad=True, device=device)
eta = torch.tensor([50. for k in range(100)], requires_grad=True, device=device)# 70
delta = torch.tensor([80. for k in range(100)], requires_grad=True, device=device)#50

In [None]:
x, y, th1, phi, [Ex0, Ex1, Ex2], [Ey0, Ey1, Ey2], r, r_cut, k, f_o = vectorial_BFP_perfect_focus(N, NA=NA, mag=mag, lambd=lambd, f_tube=f_tube, device=device)

In [None]:
second_plane = torch.tensor([0.35, 0, -0.35], device=device)
polar_projections = torch.tensor([0., 45., 0.], device=device)

In [None]:
u, v, M = compute_M(xp=xp, yp=yp, zp=z, d=d_, x=x, y=y, th1=th1, phi=phi, Ex0=Ex0, Ex1=Ex1, Ex2=Ex2
                    , Ey0=Ey0, Ey1=Ey1, Ey2=Ey2, r=r, r_cut=r_cut, k=k, f_o=f_o, second_plane=second_plane, polar_projections=polar_projections, N=N, l_pixel=l_pixel, NA=NA, mag=mag, lambd=lambd, f_tube=f_tube, MAG=MAG,
                   device=device, aberrations=False, defocus_coef=1e-5, spherical_coef=-1.5)

In [None]:
psf = PSF(rho=rho, eta=eta, delta=delta, M=M, N_photons=N_photons)

In [None]:
nn = 12
print('z = ', z[nn].cpu().detach().numpy())

vmax=torch.max(psf[nn]).cpu().detach().numpy()
vmin=torch.min(psf[nn]).cpu().detach().numpy()

plt.rcParams['figure.figsize'] = [15,9]
fig, ax = plt.subplots(2,3)
mesh1 = ax[0,0].pcolormesh(u.cpu(), v.cpu(), psf[nn,0,0].cpu().detach().numpy(), cmap='gray', vmin=vmin, vmax=vmax)
mesh2 = ax[1,0].pcolormesh(u.cpu(), v.cpu(), psf[nn,0,1].cpu().detach().numpy(), cmap='gray', vmin=vmin, vmax=vmax)
mesh3 = ax[0,1].pcolormesh(u.cpu(), v.cpu(), psf[nn,1,0].cpu().detach().numpy(), cmap='gray', vmin=vmin, vmax=vmax)
mesh4 = ax[1,1].pcolormesh(u.cpu(), v.cpu(), psf[nn,1,1].cpu().detach().numpy(), cmap='gray', vmin=vmin, vmax=vmax)
mesh5 = ax[0,2].pcolormesh(u.cpu(), v.cpu(), psf[nn,2,0].cpu().detach().numpy(), cmap='gray', vmin=vmin, vmax=vmax)
mesh6 = ax[1,2].pcolormesh(u.cpu(), v.cpu(), psf[nn,2,1].cpu().detach().numpy(), cmap='gray', vmin=vmin, vmax=vmax)
ax[0,0].set_xlim((-150,150))
ax[0,0].set_ylim((-150,150))
cb = plt.colorbar(mesh1, pad=0.15, label='Photon number')
ax[0,1].set_xlim((-150,150))
ax[0,1].set_ylim((-150,150))
cb = plt.colorbar(mesh2, pad=0.15, label='Photon number')
ax[1,0].set_xlim((-150,150))
ax[1,0].set_ylim((-150,150))
cb = plt.colorbar(mesh3, pad=0.15, label='Photon number')
ax[1,1].set_xlim((-150,150))
ax[1,1].set_ylim((-150,150))
cb = plt.colorbar(mesh4, pad=0.15, label='Photon number')
ax[0,2].set_xlim((-150,150))
ax[0,2].set_ylim((-150,150))
cb = plt.colorbar(mesh5, pad=0.15, label='Photon number')
ax[1,2].set_xlim((-150,150))
ax[1,2].set_ylim((-150,150))
cb = plt.colorbar(mesh6, pad=0.15, label='Photon number')

In [None]:
def signal(psf):
    return torch.mean(torch.max(torch.flatten(psf, start_dim=1), dim=1)[0])

In [None]:
def general_loss(M, rho, eta, delta, N_photons, data):
    h = PSF(rho=rho, eta=eta, delta=delta, M=M, N_photons=N_photons)
    #
    # loss = torch.sum(torch.add(h, -data).pow(2))
    loss = torch.sum(torch.add(h, -(data+sig_r**2)*torch.log(h+B+sig_r**2)))
    #rho_bound = limit(rho, 180, 100, upper=True) + limit(rho, 0, 100, upper=False)
    delta_bound = limit(delta, 180, 100, upper=True) + limit(delta, 1, 100, upper=False)
    #eta_bound = limit(eta, 180, 100, upper=True) + limit(eta, 0, 100, upper=False)
    return loss + 1000.*(delta_bound) #+ 0.1*torch.sum(h.pow(2))

In [None]:
background = 11.
sig_b = 6.
read=1.5
bias=11.
psf_n = noise(psf, QE=1, EM=1, b=background, sigma_b=sig_b, sigma_r=read, bias=bias)

In [None]:
print('z = ', z[nn].cpu().detach().numpy())

vmax=torch.max(psf_n[nn]).cpu().detach().numpy()
vmin=torch.min(psf_n[nn]).cpu().detach().numpy()

plt.rcParams['figure.figsize'] = [15,9]
fig, ax = plt.subplots(2,3)
mesh1 = ax[0,0].pcolormesh(u.cpu(), v.cpu(), psf_n[nn,0,0].cpu().detach().numpy(), cmap='gray', vmin=vmin, vmax=vmax)
mesh2 = ax[1,0].pcolormesh(u.cpu(), v.cpu(), psf_n[nn,0,1].cpu().detach().numpy(), cmap='gray', vmin=vmin, vmax=vmax)
mesh3 = ax[0,1].pcolormesh(u.cpu(), v.cpu(), psf_n[nn,1,0].cpu().detach().numpy(), cmap='gray', vmin=vmin, vmax=vmax)
mesh4 = ax[1,1].pcolormesh(u.cpu(), v.cpu(), psf_n[nn,1,1].cpu().detach().numpy(), cmap='gray', vmin=vmin, vmax=vmax)
mesh5 = ax[0,2].pcolormesh(u.cpu(), v.cpu(), psf_n[nn,2,0].cpu().detach().numpy(), cmap='gray', vmin=vmin, vmax=vmax)
mesh6 = ax[1,2].pcolormesh(u.cpu(), v.cpu(), psf_n[nn,2,1].cpu().detach().numpy(), cmap='gray', vmin=vmin, vmax=vmax)
ax[0,0].set_xlim((-150,150))
ax[0,0].set_ylim((-150,150))
cb = plt.colorbar(mesh1, pad=0.15, label='Photon number')
ax[0,1].set_xlim((-150,150))
ax[0,1].set_ylim((-150,150))
cb = plt.colorbar(mesh2, pad=0.15, label='Photon number')
ax[1,0].set_xlim((-150,150))
ax[1,0].set_ylim((-150,150))
cb = plt.colorbar(mesh3, pad=0.15, label='Photon number')
ax[1,1].set_xlim((-150,150))
ax[1,1].set_ylim((-150,150))
cb = plt.colorbar(mesh4, pad=0.15, label='Photon number')
ax[0,2].set_xlim((-150,150))
ax[0,2].set_ylim((-150,150))
cb = plt.colorbar(mesh5, pad=0.15, label='Photon number')
ax[1,2].set_xlim((-150,150))
ax[1,2].set_ylim((-150,150))
cb = plt.colorbar(mesh6, pad=0.15, label='Photon number')

In [None]:
B = background + bias
sig_r = np.sqrt(read**2 + sig_b**2)
''' 
xp_error = (xp.cpu().clone().detach() + 0.02*(np.random.rand(100)-0.5)).to(device)
yp_error = (yp.cpu().clone().detach() + 0.02*(np.random.rand(100)-0.5)).to(device)
z_error = (z.cpu().clone().detach() + 0.04*(np.random.rand(100)-0.5)).to(device)

u, v, M = compute_M(xp=xp_error, yp=yp_error, zp=z_error, d=d_, x=x, y=y, th1=th1, phi=phi, Ex0=Ex0, Ex1=Ex1, Ex2=Ex2
                    , Ey0=Ey0, Ey1=Ey1, Ey2=Ey2, r=r, r_cut=r_cut, k=k, f_o=f_o, second_plane=0.4, N=N, l_pixel=l_pixel, NA=NA, mag=mag, lambd=lambd, f_tube=f_tube, MAG=MAG, device=device)
'''
#DELTA = (delta.cpu().clone().detach()+40.*(np.random.rand(100)-0.5)).to(device=device)

params = torch.tensor(np.concatenate((rho.cpu().clone().detach()+130.*(np.random.rand(100)-0.5), #rho
                                     np.array([90. for i in range(100)]), #eta
                                     delta.cpu().clone().detach()+80.*(np.random.rand(100)-0.5))),
                                     requires_grad=True, device=device)
params_prev = params.clone().detach() - torch.tensor([2. for k in range(len(params))], device=device)
'''
thresh_rho = 10.
thresh_delta = 10.
rho_guess = params[0:100].clone().detach()
delta_guess = params[200:300].clone().detach()

lambd = 1.
'''

# Use Stochastic Gradient Descent (SGD) to optimize params
optimizer = torch.optim.Adam([params], lr=0.9)  # Learning rate = 0.01

num_epochs_max = 250
loss_ = []
rho_ = []
eta_ = []
delta_ = []
N_p = []

tour=0

#for u in range(5000):
for it in range(num_epochs_max):
    params_prev = params.clone().detach()
    optimizer.zero_grad()  # Reset gradients
    #print(params[0:3], params[3:6], params[6:])
    loss = general_loss(M.detach(), params[0:100], params[100:200], params[200:], N_photons.detach(), psf_n.detach())
    #loss = general_loss(M.detach(), params[0], params[1], params[2], N_photons.detach(), psf.detach())
    loss_.append(loss.cpu().detach().numpy())

    rho___ = params[0:100].cpu().clone().detach().numpy()%180
    rho___[rho___-rho.cpu().clone().detach().numpy() > 90] = rho___[rho___-rho.cpu().clone().detach().numpy() > 90] - 180.
    rho___[rho___-rho.cpu().clone().detach().numpy() < -90] = rho___[rho___-rho.cpu().clone().detach().numpy() < -90] + 180.
    eta___ = params[100:200].cpu().clone().detach().numpy()%180
    eta___[eta___-eta.cpu().clone().detach().numpy() > 90] = eta___[eta___-eta.cpu().clone().detach().numpy() > 90] - 180.
    eta___[eta___-eta.cpu().clone().detach().numpy() < -90] = eta___[eta___-eta.cpu().clone().detach().numpy() < -90] + 180.
    rho_.append(rho___)
    eta_.append(eta___)
    delta_.append(params[200:].cpu().clone().detach().numpy())

    loss.backward()  # Backpropagation
    optimizer.step()  # Update parameters
    tour+=1

In [None]:
signal_ = signal(psf.cpu().detach())
print(signal_/sig_r)

In [None]:
plt.rcParams['figure.figsize'] = [25,3.5]
plt.rcParams.update({'font.size': 16})
fig, ax = plt.subplots(1,4)
ax[0].plot(loss_)    
ax[0].set_xlabel('Iteration')
ax[0].set_ylabel('Loss', labelpad=0.1)
#ax[0].set_ylim((np.min(loss_), np.max(loss_)))
ax[0].grid()
ax[1].plot(np.array(rho_)-rho.cpu().clone().detach().numpy())
ax[1].set_xlabel('Iteration')
ax[1].set_ylabel('$\\Delta \\rho$ ($\\degree$)', labelpad=0.01)
ax[1].set_title('std = '+str(int(1000*np.nanstd((np.array(rho_)-rho.cpu().clone().detach().numpy())[:][-1]))/1000)+'\n mean = '+str(int(1000*np.nanmean((np.array(rho_)-rho.cpu().clone().detach().numpy())[:][-1]))/1000))
ax[1].grid()
ax[2].plot(eta_)
ax[2].axhline(eta[0].cpu().detach().numpy(), c='r', label='Ground truth')
ax[2].legend()
ax[2].set_xlabel('Iteration')
ax[2].set_ylabel('$\\eta$ ($\\degree$)', labelpad=0.1)
ax[2].set_title('std = '+str(int(1000*np.nanstd(eta_[:][-1]))/1000)+'\n mean = '+str(int(1000*np.nanmean(eta_[:][-1]))/1000))
ax[2].grid()
ax[3].plot(delta_)
ax[3].axhline(delta[0].cpu().detach().numpy(), c='r', label='Ground truth')
ax[3].legend()
ax[3].set_xlabel('Iteration')
ax[3].set_ylabel('$\\delta$ ($\\degree$)', labelpad=0.1)
ax[3].set_title('std = '+str(int(1000*np.nanstd(delta_[:][-1]))/1000)+'\n mean = '+str(int(1000*np.nanmean(delta_[:][-1]))/1000))
ax[3].grid()

In [None]:
N_photons

In [None]:
np.savez_compressed('D:/AMAURY/data_simu/stage/2025_07_10_SNR_20_70/1000_photons.npz', signal=signal_.detach().numpy(), std=sig_r, floor=B, rho=np.array(rho_-rho.cpu().clone().detach().numpy())[-1,:], eta=np.array(eta_)[-1,:], delta=np.array(delta_)[-1,:])

In [None]:
folder = 'D:/AMAURY/data_simu/stage/2025_07_10_SNR_70_120'
rhos = []
etas = []
deltas = []
floors = []
stds = []
signals = []
Nphotons = [1000, 1500, 2000, 3000, 4000, 5000]
for file in ['1000_photons', '1500_photons', '2000_photons', '3000_photons', '4000_photons', '5000_photons']:
    data = np.load(folder+'/'+file+'.npz')
    rhos.append(data['rho'])
    etas.append(data['eta'])
    deltas.append(data['delta'])
    stds.append(data['std'])
    signals.append(data['signal'])
    floors.append(data['floor'])

eta_true = 70.
delta_true = 120.

In [None]:
folder = 'D:/AMAURY/data_simu/stage/2025_07_10_SNR_45_120'
rhos = []
etas = []
deltas = []
floors = []
stds = []
signals = []
Nphotons = [1000, 1500, 2000, 3000, 4000, 5000]
for file in ['1000_photons', '1500_photons', '2000_photons', '3000_photons', '4000_photons', '5000_photons']:
    data = np.load(folder+'/'+file+'.npz')
    rhos.append(data['rho'])
    etas.append(data['eta'])
    deltas.append(data['delta'])
    stds.append(data['std'])
    signals.append(data['signal'])
    floors.append(data['floor'])

eta_true = 45.
delta_true = 120.

In [None]:
folder = 'D:/AMAURY/data_simu/stage/2025_07_10_SNR_20_120'
rhos = []
etas = []
deltas = []
floors = []
stds = []
signals = []
Nphotons = [1000, 1500, 2000, 3000, 4000, 5000]
for file in ['1000_photons', '1500_photons', '2000_photons', '3000_photons', '4000_photons', '5000_photons']:
    data = np.load(folder+'/'+file+'.npz')
    rhos.append(data['rho'])
    etas.append(data['eta'])
    deltas.append(data['delta'])
    stds.append(data['std'])
    signals.append(data['signal'])
    floors.append(data['floor'])

eta_true = 20.
delta_true = 120.

In [None]:
folder = 'D:/AMAURY/data_simu/stage/2025_07_10_SNR_70_70'
rhos = []
etas = []
deltas = []
floors = []
stds = []
signals = []
Nphotons = [1000, 1500, 2000, 3000, 4000, 5000]
for file in ['1000_photons', '1500_photons', '2000_photons', '3000_photons', '4000_photons', '5000_photons']:
    data = np.load(folder+'/'+file+'.npz')
    rhos.append(data['rho'])
    etas.append(data['eta'])
    deltas.append(data['delta'])
    stds.append(data['std'])
    signals.append(data['signal'])
    floors.append(data['floor'])

eta_true = 70.
delta_true = 70.

In [None]:
folder = 'D:/AMAURY/data_simu/stage/2025_07_10_SNR_45_70'
rhos = []
etas = []
deltas = []
floors = []
stds = []
signals = []
Nphotons = [1000, 1500, 2000, 3000, 4000, 5000]
for file in ['1000_photons', '1500_photons', '2000_photons', '3000_photons', '4000_photons', '5000_photons']:
    data = np.load(folder+'/'+file+'.npz')
    rhos.append(data['rho'])
    etas.append(data['eta'])
    deltas.append(data['delta'])
    stds.append(data['std'])
    signals.append(data['signal'])
    floors.append(data['floor'])

eta_true = 45.
delta_true = 70.

In [None]:
folder = 'D:/AMAURY/data_simu/stage/2025_07_10_SNR_20_70'
rhos = []
etas = []
deltas = []
floors = []
stds = []
signals = []
Nphotons = [1000, 1500, 2000, 3000, 4000, 5000]
for file in ['1000_photons', '1500_photons', '2000_photons', '3000_photons', '4000_photons', '5000_photons']:
    data = np.load(folder+'/'+file+'.npz')
    rhos.append(data['rho'])
    etas.append(data['eta'])
    deltas.append(data['delta'])
    stds.append(data['std'])
    signals.append(data['signal'])
    floors.append(data['floor'])

eta_true = 20.
delta_true = 70.

In [None]:
plt.rcParams['figure.figsize'] = [17,3.5]
plt.rcParams.update({'font.size': 7})
fig, ax = plt.subplots(1,3)

ax[0].violinplot(rhos, positions=np.array(Nphotons)/1000, showextrema=True, showmedians=True, widths=0.5)
ax[0].grid()
ax[0].set_xticks(np.array(Nphotons)/1000)
ax[0].set_xticklabels((np.array(Nphotons)/1000).astype(str))
ax02 = ax[0].twiny()
ax02.set_xlim(ax[0].get_xlim())
ax02.set_xticks(np.array(Nphotons)/1000)
ax02.set_xticklabels((np.array(signals)/np.array(stds)).astype(int).astype(str))
ax02.set_xlabel('SNR')
ax[0].set_ylabel('$\\Delta \\rho$ ($^{\\circ}$)')
ax[0].set_xlabel('Number of photon ($\\times$1000)')
ax[0].set_ylim((-90,90))

ax[1].violinplot(etas, positions=np.array(Nphotons)/1000, showextrema=True, showmedians=True, widths=0.5)
ax[1].grid()
ax[1].axhline(eta_true, c='r', label='ground truth')
ax[1].set_xticks(np.array(Nphotons)/1000)
ax[1].set_xticklabels((np.array(Nphotons)/1000).astype(str))
ax12 = ax[1].twiny()
ax12.set_xlim(ax[1].get_xlim())
ax12.set_xticks(np.array(Nphotons)/1000)
ax12.set_xticklabels((np.array(signals)/np.array(stds)).astype(int).astype(str))
ax12.set_xlabel('SNR')
ax[1].set_ylabel('$\\eta$ ($^{\\circ}$)')
ax[1].set_xlabel('Number of photon ($\\times$1000)')
ax[1].legend()
ax[1].set_ylim((0,180))

ax[2].violinplot(deltas, positions=np.array(Nphotons)/1000, showextrema=True, showmedians=True, widths=0.5)
ax[2].grid()
ax[2].axhline(delta_true, c='r', label='ground truth')
ax[2].set_xticks(np.array(Nphotons)/1000)
ax[2].set_xticklabels((np.array(Nphotons)/1000).astype(str))
ax22 = ax[2].twiny()
ax22.set_xlim(ax[2].get_xlim())
ax22.set_xticks(np.array(Nphotons)/1000)
ax22.set_xticklabels((np.array(signals)/np.array(stds)).astype(int).astype(str))
ax22.set_xlabel('SNR')
ax[2].set_ylabel('$\\delta$ ($^{\\circ}$)')
ax[2].set_xlabel('Number of photon ($\\times$1000)')
ax[2].legend()
ax[2].set_ylim((0,180))