In [None]:
from simu_PSF_polar import *
from tqdm import tqdm
import os
import time
import copy
import matplotlib.pyplot as plt
plt.rcParams['figure.figsize'] = [5,3]

In [None]:
if torch.cuda.is_available():
    device = torch.device('cuda')
    print("Using GPU")
else:   
    device = torch.device('cpu')
    print("Using CPU")

In [None]:
def limit(x, lim, slope, upper=True):
    '''
    if upper:
       return torch.sum(torch.tensor(1/(1+torch.exp(-slope*(x-lim))), requires_grad=True, device=device))
    else:
        return torch.sum(torch.tensor(1/(1+torch.exp(slope*(x-lim))), requires_grad=True, device=device))
    '''
    if upper:
        return torch.sum(torch.exp((x-lim)*slope))
    else:
        return torch.sum(torch.exp(-1*(x-lim)*slope))
        

In [None]:
N_photons= torch.tensor(5000., device=device)
N=torch.tensor(80, device=device)
l_pixel=torch.tensor(16, device=device)
NA=torch.tensor(1.4, device=device)
mag=torch.tensor(100, device=device)
lambd=torch.tensor(617, device=device)
f_tube=torch.tensor(200, device=device)
MAG=torch.tensor(200/150, device=device)

In [None]:
xp = torch.tensor(0.3*(2*np.random.rand(100)-1), device=device)
yp = torch.tensor(0.3*(2*np.random.rand(100)-1), device=device)
z = torch.tensor((1.+np.random.rand(100))*0.8, device=device) #position of dipole in lambda units
d_ = torch.tensor([-1.5 for k in range(100)], device=device) #defocus of dipole in lambda units 4.4
rho = torch.tensor(10.+160.*np.random.rand(100), requires_grad=True, device=device)
eta = torch.tensor(30.+150.*np.random.rand(100), requires_grad=True, device=device)# 70
delta = torch.tensor(70.+80.*np.random.rand(100), requires_grad=True, device=device)#50

In [None]:
x, y, th1, phi, [Ex0, Ex1, Ex2], [Ey0, Ey1, Ey2], r, r_cut, k_, f_o = vectorial_BFP_perfect_focus(N, NA=NA, mag=mag, lambd=lambd, f_tube=f_tube, device=device)

In [None]:
second_plane = torch.tensor([0.35, 0, -0.35], device=device)
polar_projections = torch.tensor([0., 45., 0.], device=device)

In [None]:
u, v, M = compute_M(xp=xp, yp=yp, zp=z, d=d_, x=x, y=y, th1=th1, phi=phi, Ex0=Ex0, Ex1=Ex1, Ex2=Ex2
                    , Ey0=Ey0, Ey1=Ey1, Ey2=Ey2, r=r, r_cut=r_cut, k=k_, f_o=f_o, second_plane=second_plane, polar_projections=polar_projections, N=N, l_pixel=l_pixel, NA=NA, mag=mag, lambd=lambd, f_tube=f_tube, MAG=MAG,
                   device=device, aberrations=False, defocus_coef=1e-5, spherical_coef=-1.5)

In [None]:
psf = PSF(rho=rho, eta=eta, delta=delta, M=M, N_photons=N_photons)

In [None]:
nn = 12
print('z = ', z[nn].cpu().detach().numpy())

vmax=torch.max(psf[nn]).cpu().detach().numpy()
vmin=torch.min(psf[nn]).cpu().detach().numpy()

plt.rcParams['figure.figsize'] = [15,9]
fig, ax = plt.subplots(2,3)
mesh1 = ax[0,0].pcolormesh(u.cpu(), v.cpu(), psf[nn,0,0].cpu().detach().numpy(), cmap='gray', vmin=vmin, vmax=vmax)
mesh2 = ax[1,0].pcolormesh(u.cpu(), v.cpu(), psf[nn,0,1].cpu().detach().numpy(), cmap='gray', vmin=vmin, vmax=vmax)
mesh3 = ax[0,1].pcolormesh(u.cpu(), v.cpu(), psf[nn,1,0].cpu().detach().numpy(), cmap='gray', vmin=vmin, vmax=vmax)
mesh4 = ax[1,1].pcolormesh(u.cpu(), v.cpu(), psf[nn,1,1].cpu().detach().numpy(), cmap='gray', vmin=vmin, vmax=vmax)
mesh5 = ax[0,2].pcolormesh(u.cpu(), v.cpu(), psf[nn,2,0].cpu().detach().numpy(), cmap='gray', vmin=vmin, vmax=vmax)
mesh6 = ax[1,2].pcolormesh(u.cpu(), v.cpu(), psf[nn,2,1].cpu().detach().numpy(), cmap='gray', vmin=vmin, vmax=vmax)
ax[0,0].set_xlim((-150,150))
ax[0,0].set_ylim((-150,150))
cb = plt.colorbar(mesh1, pad=0.15, label='Photon number')
ax[0,1].set_xlim((-150,150))
ax[0,1].set_ylim((-150,150))
cb = plt.colorbar(mesh2, pad=0.15, label='Photon number')
ax[1,0].set_xlim((-150,150))
ax[1,0].set_ylim((-150,150))
cb = plt.colorbar(mesh3, pad=0.15, label='Photon number')
ax[1,1].set_xlim((-150,150))
ax[1,1].set_ylim((-150,150))
cb = plt.colorbar(mesh4, pad=0.15, label='Photon number')
ax[0,2].set_xlim((-150,150))
ax[0,2].set_ylim((-150,150))
cb = plt.colorbar(mesh5, pad=0.15, label='Photon number')
ax[1,2].set_xlim((-150,150))
ax[1,2].set_ylim((-150,150))
cb = plt.colorbar(mesh6, pad=0.15, label='Photon number')

In [None]:
def signal(psf):
    return torch.mean(torch.max(torch.flatten(psf, start_dim=1), dim=1)[0])

In [None]:
def loss_pos(xp, yp, zp, rho, eta, delta, N_photons, data, second_plane, background, sigma, dim_simu):
    u, v, M_ = compute_M(xp=xp, yp=yp, zp=zp, d=d_, x=xx, y=yy, th1=th1, phi=phi, Ex0=Ex0, Ex1=Ex1, Ex2=Ex2
                    , Ey0=Ey0, Ey1=Ey1, Ey2=Ey2, r=r, r_cut=r_cut, k=k_, f_o=f_o, second_plane=second_plane, polar_projections=polar_projections, N=N,
                    l_pixel=l_pixel, NA=NA, mag=mag, lambd=lambd, f_tube=f_tube, MAG=MAG, device=device, aberrations=False)
    dim_data = 6
    h = PSF(rho=rho, eta=eta, delta=delta, M=M_, N_photons=N_photons)[:,:,:,dim_simu-dim_data:dim_simu+dim_data+1,dim_simu-dim_data:dim_simu+dim_data+1]
    loss = torch.sum(torch.add(h, -(data+sigma**2)*torch.log(h+background+sigma**2)))
    #loss = torch.sum(torch.pow(torch.sum(torch.add(h, -data), dim=(2,)), 2)) 
    x_bound = limit(xp, 5*0.12, 100, upper=True) + limit(xp, -5*0.12, 100, upper=False)
    y_bound = limit(yp, 5*0.12, 100, upper=True) + limit(yp, -5*0.12, 100, upper=False)
    z_bound = limit(zp, 5., 100, upper=True) + limit(zp, 0, 100, upper=False)
    N_bound = limit(N_photons, 10000., 0.01, upper=True)
    return loss +x_bound+y_bound+z_bound#+N_bound

In [None]:
def loss_N(xp, yp, zp, rho, eta, delta, N_photons, data, second_plane, background, sigma, dim_simu):
    u, v, M_ = compute_M(xp=xp, yp=yp, zp=zp, d=d_, x=xx, y=yy, th1=th1, phi=phi, Ex0=Ex0, Ex1=Ex1, Ex2=Ex2
                    , Ey0=Ey0, Ey1=Ey1, Ey2=Ey2, r=r, r_cut=r_cut, k=k_, f_o=f_o, second_plane=second_plane, polar_projections=polar_projections, N=N,
                    l_pixel=l_pixel, NA=NA, mag=mag, lambd=lambd, f_tube=f_tube, MAG=MAG, device=device, aberrations=False)
    dim_data = 6
    h = PSF(rho=rho, eta=eta, delta=delta, M=M_, N_photons=N_photons)[:,:,:,dim_simu-dim_data:dim_simu+dim_data+1,dim_simu-dim_data:dim_simu+dim_data+1]
    #loss = torch.sum(torch.add(h, -(data+sigma**2)*torch.log(h+background+sigma**2)))
    loss = torch.sum(torch.sqrt(torch.pow(torch.sum(torch.add(h, -data), dim=(2,)), 2))) 
    x_bound = limit(xp, 5*0.12, 100, upper=True) + limit(xp, -5*0.12, 100, upper=False)
    y_bound = limit(yp, 5*0.12, 100, upper=True) + limit(yp, -5*0.12, 100, upper=False)
    z_bound = limit(zp, 5., 100, upper=True) + limit(zp, 0, 100, upper=False)
    N_bound = limit(N_photons, 10000., 0.01, upper=True)
    return loss

In [None]:
background = 11.
sig_b = 6.
read=1.5
bias=11.
psf_n = noise(psf, QE=1, EM=1, b=background, sigma_b=sig_b, sigma_r=read, bias=bias)

In [None]:
dim_simu = int(psf_n.shape[-1]//2)
dim_data = 6
psf_n = psf_n[:,:,:,dim_simu-dim_data:dim_simu+dim_data+1,dim_simu-dim_data:dim_simu+dim_data+1]

In [None]:
print('z = ', z[nn].cpu().detach().numpy())

vmax=torch.max(psf_n[nn]).cpu().detach().numpy()
vmin=torch.min(psf_n[nn]).cpu().detach().numpy()

plt.rcParams['figure.figsize'] = [15,9]
fig, ax = plt.subplots(2,3)
mesh1 = ax[0,0].pcolormesh(psf_n[nn,0,0].cpu().detach().numpy(), cmap='gray', vmin=vmin, vmax=vmax)
mesh2 = ax[1,0].pcolormesh(psf_n[nn,0,1].cpu().detach().numpy(), cmap='gray', vmin=vmin, vmax=vmax)
mesh3 = ax[0,1].pcolormesh(psf_n[nn,1,0].cpu().detach().numpy(), cmap='gray', vmin=vmin, vmax=vmax)
mesh4 = ax[1,1].pcolormesh(psf_n[nn,1,1].cpu().detach().numpy(), cmap='gray', vmin=vmin, vmax=vmax)
mesh5 = ax[0,2].pcolormesh(psf_n[nn,2,0].cpu().detach().numpy(), cmap='gray', vmin=vmin, vmax=vmax)
mesh6 = ax[1,2].pcolormesh(psf_n[nn,2,1].cpu().detach().numpy(), cmap='gray', vmin=vmin, vmax=vmax)
cb = plt.colorbar(mesh1, pad=0.15, label='Photon number')
cb = plt.colorbar(mesh2, pad=0.15, label='Photon number')
cb = plt.colorbar(mesh3, pad=0.15, label='Photon number')
cb = plt.colorbar(mesh4, pad=0.15, label='Photon number')
cb = plt.colorbar(mesh5, pad=0.15, label='Photon number')
cb = plt.colorbar(mesh6, pad=0.15, label='Photon number')

In [None]:
u, v, M = compute_M(xp=xp, yp=yp, zp=z, d=d_, x=x, y=y, th1=th1, phi=phi, Ex0=Ex0, Ex1=Ex1, Ex2=Ex2
                    , Ey0=Ey0, Ey1=Ey1, Ey2=Ey2, r=r, r_cut=r_cut, k=k_, f_o=f_o, second_plane=second_plane, polar_projections=polar_projections, N=N, l_pixel=l_pixel, NA=NA, mag=mag, lambd=lambd, f_tube=f_tube, MAG=MAG,
                   device=device, aberrations=False, defocus_coef=1e-5, spherical_coef=-1.5)

In [None]:
xx, yy, th1, phi, [Ex0, Ex1, Ex2], [Ey0, Ey1, Ey2], r, r_cut, k_, f_o = vectorial_BFP_perfect_focus(N, NA=NA, mag=mag, lambd=lambd, f_tube=f_tube, device=device)

In [None]:
N_p_list = np.linspace(800, 8000, 100)

In [None]:
for Np in N_p_list:
    psf = PSF(rho=rho, eta=eta, delta=delta, M=M, N_photons=Np)
    psf_n = noise(psf, QE=1, EM=1, b=background, sigma_b=sig_b, sigma_r=read, bias=bias)
    psf_n = psf_n[:,:,:,dim_simu-dim_data:dim_simu+dim_data+1,dim_simu-dim_data:dim_simu+dim_data+1]

    B = background + bias
    sig_r = np.sqrt(read**2 + sig_b**2)
    N_start = torch.tensor([3000 for i in range(100)], requires_grad=False, device=device)
    x_start = torch.tensor([0. for i in range(100)], requires_grad=False, device=device)
    y_start = torch.tensor([0. for i in range(100)], requires_grad=False, device=device)
    z_start =  z + torch.tensor(np.random.rand(100)-0.5, requires_grad=False, device=device)

    rho_ = rho + torch.tensor(20*(2*np.random.rand(100)-1), requires_grad=False, device=device)
    delta_ = delta + torch.tensor(20*(2*np.random.rand(100)-1), requires_grad=False, device=device)
    eta_ = torch.tensor([45. for i in range(100)], requires_grad=False, device=device)

    params = torch.cat((x_start, y_start, z_start, N_start/1000))
    params.requires_grad=True

    # Use Stochastic Gradient Descent (SGD) to optimize params
    optimizer = torch.optim.Adam([params], lr=0.1)  # Learning rate = 0.01

    num_epochs_max = 100
    loss_ = []
    x_ = []
    y_ = []
    z_ = []
    N_ = []

    for i in tqdm(range(num_epochs_max)):
        optimizer.zero_grad()  # Reset gradients
        loss = loss_pos(params[0:100], params[100:200], params[200:300], rho_, eta_, delta_, params[300:]*1000, psf_n.detach(), second_plane, B, sig_r, dim_simu)
        loss_.append(loss.cpu().detach().numpy())

        x_.append(params[:100].cpu().detach().numpy())
        y_.append(params[100:200].cpu().detach().numpy())
        z_.append(params[200:300].cpu().detach().numpy())
        N_.append(params[300:400].cpu().detach().numpy()*1000)

        loss.backward()  # Backpropagation
        optimizer.step()  # Update parameters

    signal_ = signal(psf.cpu().detach())
    np.savez_compressed('D:/AMAURY\data_simu/stage/2025_07_25_xy_loc_precision/'+str(Np)+'.npz', signal=signal_.detach().numpy(), std=sig_r, floor=B, x=np.array(x_-xp.cpu().clone().detach().numpy())[-1,:], y=np.array(y_-yp.cpu().clone().detach().numpy())[-1,:], z=np.array(z_-z.cpu().clone().detach().numpy())[-1,:], Nphotons_retreived=np.array(N_)[-1,:], Nphotons_retreived2=np.array(N_2)[-1,:])

In [None]:
folder = 'D:/AMAURY/data_simu/stage/2025_07_25_xy_loc_precision'
x = []
y = []
z = []
N = []
floors = []
stds = []
signals = []
for file in os.listdir(folder):
    if float(file[:-4]) in N_p_list:
        data = np.load(folder+'/'+file)
        x.append(data['x'])
        y.append(data['y'])
        z.append(data['z'])
        N.append(data['Nphotons_retreived'])
        stds.append(data['std'])
        signals.append(data['signal'])
        floors.append(data['floor'])


In [None]:
x = np.array(x)
y = np.array(y)
z = np.array(z)
stds = np.array(stds)
N = np.array(N)
signals = np.array(signals)
floors = np.array(floors)

In [None]:
kjhj= plt.hist(x.flatten(), bins=1000)
#plt.xlim((-0.2, 0.2))

In [None]:
mask = (x**2>0.2**2) 
x[mask] = np.nan
mask = (y**2>0.2**2) 
y[mask] = np.nan

In [None]:
kjhj= plt.hist(z.flatten(), bins=1000)
#plt.xlim((-1, 1))

In [None]:
mask = (z**2>1.2**2) 
z[mask] = np.nan

In [None]:
plt.rcParams.update({'font.size': 16})
plt.rcParams['figure.figsize'] = [3,2]
fig, ax = plt.subplots(1)

ax.scatter(N_p_list, 1000*np.nanmean(x,axis=1), c='r', label='x localization', s=8)
ax.scatter(N_p_list, 1000*np.nanmean(y,axis=1), c='b', label='y localization', s=8)
ax.grid()
ax.set_xlim((800,7000))
ax.set_ylim((-10,10))
ax.set_xlabel('$ N_{\\mathrm{PHOTONS}}$')
ax.set_ylabel('Mean error (nm)')

plt.show()

fig, ax = plt.subplots(1)
ax.scatter(N_p_list, 1000*np.nanmean(z,axis=1), c='g', s=8)
ax.grid()
ax.set_xlim((800,7000))
ax.set_ylim((-100,100))
ax.set_xlabel('$ N_{\\mathrm{PHOTONS}}$')
ax.set_ylabel('Mean error (nm)')
plt.show()

In [None]:
plt.rcParams.update({'font.size': 10})
plt.rcParams['figure.figsize'] = [18,3.]
fig, ax = plt.subplots(1,3)

#ax[0].scatter(N_p_list, 1000*(np.percentile(x, 75, axis=1)-np.percentile(x, 25, axis=1)), c='r', label='x localization', s=8)
#ax[0].scatter(N_p_list, 1000*(np.percentile(y, 75, axis=1)-np.percentile(y, 25, axis=1)), c='b', label='y localization', s=8)
ax[0].scatter(N_p_list, 1000*np.nanstd(x,axis=1), c='r', label='x localization', s=8)
ax[0].scatter(N_p_list, 1000*np.nanstd(y,axis=1), c='b', label='y localization', s=8)
'''for iter in range(x.shape[0]):
    ax[0].scatter([N_p_list[iter] for i in range(x.shape[1])], 1000*x[iter], c='r', s=2)
ax[0].fill_between(N_p_list, np.mean(1000*x, axis=1)-np.std(1000*x, axis=1)/2, np.mean(1000*x, axis=1)+np.std(1000*x, axis=1)/2, color='r', alpha=0.3)
for iter in range(z.shape[0]):
    ax[0].scatter([N_p_list[iter] for i in range(y.shape[1])], 1000*y[iter], c='r', s=2)
ax[0].fill_between(N_p_list, np.mean(1000*y, axis=1)-np.std(1000*y, axis=1)/2, np.mean(1000*y, axis=1)+np.std(1000*y, axis=1)/2, color='b', alpha=0.3)'''
ax[0].legend(loc='lower left')
ax[0].grid()
ax[0].set_yscale('log')
ax[0].set_xlim((800,7000))
ax[0].set_ylim((4,200))
ax[0].set_xlabel('$ N_{\\mathrm{PHOTONS}}$')
ax[0].set_ylabel('$ \sigma_{\\mathrm{lat}}$ (nm)')


'''for iter in range(z.shape[0]):
    ax[1].scatter([N_p_list[iter] for i in range(z.shape[1])], 1000*z[iter], c='r', s=2)
ax[1].fill_between(N_p_list, np.mean(1000*z, axis=1)-np.std(1000*z, axis=1)/2, np.mean(1000*z, axis=1)+np.std(1000*z, axis=1)/2, color='r', alpha=0.3)'''
ax[1].scatter(N_p_list, 1000*np.std(z,axis=1), c='g', s=8)
ax[1].grid()
ax[1].set_yscale('log')
ax[1].set_xlim((800,7000))
ax[1].set_xlabel('$ N_{\\mathrm{PHOTONS}}$')
ax[1].set_ylabel('$ \sigma_{\\mathrm{z}}$ (nm)')

for iter in range(N.shape[0]):
    ax[2].scatter([N_p_list[iter] for i in range(N.shape[1])], N[iter], c='r', s=2)
ax[2].fill_between(N_p_list, np.mean(N, axis=1)-np.std(N, axis=1)/2, np.mean(N, axis=1)+np.std(N, axis=1)/2, color='y', alpha=0.8)
ax[2].plot(N_p_list, N_p_list)
ax[2].grid()
ax[2].set_xlim((800,7000))
ax[2].set_xlabel('$ N_{\\mathrm{PHOTONS}}$')
ax[2].set_ylabel('$ N_{\\mathrm{RETRIEVED}}$')