## **Import libraries**

In [None]:
import numpy as np
import os
import tinycudann as tcnn
from utils.Deep_Models import *
from Parameters import *
from utils.Forward import *
from utils.Simulate_Data_Process import*
from utils.Training_models import *
import h5py
import commentjson as json
import time

## **Load ground truth object and probe**

In [None]:
def recenter_probe(probe):

    center=int(probe.shape[0]/2)
    max_index=(probe.abs()==torch.max(probe.abs())).nonzero()
    shift=(center-max_index[0,0],center-max_index[0,1])
    recentered_probe = torch.roll(probe, shifts=tuple(shift), dims=(0, 1))

    return recentered_probe

In [None]:
crystal=np.load("data/crystal.npy")[303:,303:]

In [None]:
crystal.shape

In [None]:
probe=np.load("data/probe_for_sim.npy")

In [None]:
plt.imshow(np.abs(probe))

In [None]:
np.abs(probe).flatten()

In [None]:
plt.imshow(np.angle(probe))

## **Crop the central part of the real probe, make it to be 64*64**

In [None]:
probe=torch.tensor(probe).cuda()[24:-24,24:-24]

## **Overlap ratio and noise level define**

In [None]:
overlap=0.95           #choose between 0.95 0.9 0.8   0.6 0.3
noise_tag="clean"         #choose between clean, gaussian, poisson, combined
parameters["probe_known"]=False
probe_tag="probe_unknown"
parameters["probe_shape"]="circ"      # choose between circ and rect

In [None]:
parameters["overlap_ratio"]=overlap
step_size=round((1-parameters["overlap_ratio"])*probe.shape[0])

pad_number=step_size-(241-probe.shape[0])%step_size

if pad_number!=step_size:
    pad=nn.ZeroPad2d((0, pad_number, 0, pad_number))
    case_obj=torch.tensor(crystal)
    case_obj=pad(case_obj)
    case_obj=case_obj.numpy()
else:
    case_obj=crystal
    
if ((case_obj.shape[0]-probe.shape[0])/step_size+1)%2!=0:
    case_obj=torch.tensor(case_obj)
    pad=nn.ZeroPad2d((0, step_size, 0, step_size))
    case_obj=pad(case_obj)
    case_obj=case_obj.numpy()

print("test object shape: ",case_obj.shape)
print("overlap ratio: ",parameters["overlap_ratio"])
parameters["obj_size"]=case_obj.shape[0]

## **Generate diffraction patterns**

In [None]:
h5=diffraction_pattern_generate(amplitude_gt=torch.tensor(np.abs(case_obj)),phase_gt=torch.tensor(np.angle(case_obj))
                                   ,overlap_ratio=overlap,probe=probe,
                       parameters=parameters,noise=noise_tag)

In [None]:
plt.imshow(h5["diffamp"][()][10],cmap="grey")

## Training
You may also change hyperparameters from Parameters.py

In [None]:
parameters["tag"]="crystal_"+noise_tag+"_PtyINR_"+str(overlap)+"_"+probe_tag
parameters["total_steps"] = 5000
parameters["diffraction_scale"]=1200

parameters["train_method"]="mini_batch"   
parameters["batches"]=3600
parameters["LR"]=6e-5           # for object amplitude       
parameters["LR2"]=6e-5         # for object phase 
parameters["LR3"]=5e-5           # for probe amplitude
parameters["LR4"]=5e-5         # for probe phase
parameters["regularized_loss_weight"]=1e-1     #1
parameters["regularized_steps"]=50
parameters["show_every"]=100
parameters["first_omega"]=30

parameters["loss"]="SmoothL1"
parameters["beta_for_smoothl1"]=1e-3

In [None]:
train_model(parameters,h5,probe)