## **Import libraries**

In [None]:
import numpy as np
import tinycudann as tcnn
from utils.Deep_Models import *
from Parameters import *
from utils.Forward import *
from utils.Simulate_Data_Process import*
from utils.Training_models import *
import h5py
import commentjson as json
import time

## **Load ground truth object and probe**

In [None]:
def recenter_probe(probe):

    center=int(probe.shape[0]/2)
    max_index=(probe.abs()==torch.max(probe.abs())).nonzero()
    shift=(center-max_index[0,0],center-max_index[0,1])
    recentered_probe = torch.roll(probe, shifts=tuple(shift), dims=(0, 1))

    return recentered_probe

In [None]:
crystal=np.load("data/crystal.npy")[303:,303:]

In [None]:
probe=np.load("data/probe.npy")

In [None]:
plt.imshow(np.abs(probe))

In [None]:
plt.imshow(np.angle(probe))

In [None]:
probe=torch.tensor(probe)

In [None]:
probe=probe.numpy()

## **crop the central part of the real probe, make it to be 64*64**

In [None]:
probe=probe[24:-24,24:-24]

In [None]:
probe=torch.tensor(probe).cuda()

In [None]:
np.abs(probe.cpu()).max()

## **Overlap ratio and noise level define**

In [None]:
overlap=0.95             #choose between 0.95 0.9 0.8   0.6 0.3
noise_tag="poisson_high"         #choose between clean, gaussian low, gaussian high, poisson low, poisson high, combined
parameters["probe_known"]=False
probe_tag="probe_unknown"
parameters["probe_shape"]="circ"      # choose between circ and rect

In [None]:
parameters["overlap_ratio"]=overlap
step_size=round((1-parameters["overlap_ratio"])*probe.shape[0])

pad_number=step_size-(241-probe.shape[0])%step_size

if pad_number!=step_size:
    pad=nn.ZeroPad2d((0, pad_number, 0, pad_number))
    case_obj=torch.tensor(crystal)
    case_obj=pad(case_obj)
    case_obj=case_obj.numpy()
else:
    case_obj=crystal
    
if ((case_obj.shape[0]-probe.shape[0])/step_size+1)%2!=0:
    case_obj=torch.tensor(case_obj)
    pad=nn.ZeroPad2d((0, step_size, 0, step_size))
    case_obj=pad(case_obj)
    case_obj=case_obj.numpy()

print("test object shape: ",case_obj.shape)
print("overlap ratio: ",parameters["overlap_ratio"])
parameters["obj_size"]=case_obj.shape[0]

In [None]:
parameters

## **Generate diffraction patterns**

In [None]:
case_obj=crystal

In [None]:
h5=SCAN_Iterative_Methods_Process(amplitude_gt=torch.tensor(np.abs(case_obj)),phase_gt=torch.tensor(np.angle(case_obj))
                                   ,overlap_ratio=overlap,probe=probe,
                       parameters=parameters,noise=noise_tag)

In [None]:
noise_tag

## **Iterative methods**

### 1) ePIE

In [None]:
parameters["tag"]="crystal_"+noise_tag+"_ePIE_"+str(overlap)+"_"+probe_tag
parameters["total_steps"] = 300
parameters["a"]=1
parameters["b"]=1
parameters["probe_update_start"]=0

In [None]:
train_model(parameters,h5,probe,model_name="ePIE",trained=False)

### 2) DM

In [None]:
parameters["tag"]="crystal_"+noise_tag+"_DM_"+str(overlap)+"_"+probe_tag
parameters["total_steps"] = 300
parameters["DM_alpha"]=0.1
parameters["fourier_relax_factor"]=0.05
parameters["probe_update_start"]=0

In [None]:
train_model(parameters,h5,probe,model_name="DM",trained=False)

## **Deep methods**

## 1) PtyINR

In [None]:
parameters["tag"]="crystal_"+noise_tag+"_Pty_INR_"+str(overlap)+"_"+probe_tag
parameters["total_steps"] = 3000
parameters["diffraction_scale"]=1600
parameters["batches"]=[9800,9800]
parameters["LR"]=1e-4            # for object amplitude       
parameters["LR2"]=1e-4         # for object phase 
parameters["LR3"]=5e-5           # for probe amplitude
parameters["LR4"]=5e-5         # for probe phase
parameters["regularized_loss_weight"]=1e-2        #1
parameters["regularized_steps"]=50
parameters["show_every"]=50
parameters["first_omega"]=60
parameters["loss"]="SmoothL1"
parameters["beta_for_smoothl1"]=1e-5

train_model(parameters,h5,probe,model_name="Pty_INR",trained=False)