In [None]:
import h5py, sys, os, time
import numpy as np
import scipy.ndimage as ndimage
import matplotlib.pyplot as plt
import matplotlib.colors as colors
import matplotlib.cm as cm

In [None]:
import spimage

In [None]:
import sys; sys.path.append("../../offline/")
import sparse, geom, utils

In [None]:
sys.path.append('/mnt/cbis/home/benedikt/.local/dragonfly/utils/py_src')

In [None]:
import writeemc
import detector
import reademc

In [None]:
path_to_data = "../../data/sparse/" # scratch/sparse on Maxwell
path_to_aux = "../../data/aux/" # scratch/benedikt/aux on Maxwell
path_to_geometry = "../../geometry/"
path_to_recons = "../../data/recons/"

In [None]:
emc_folder   = path_to_recons + "sucrose_0000/"
emc_output   = emc_folder + "data/r0/output_060.h5"
emc_photons  = emc_folder + 'photons.txt'
emc_detector = path_to_recons + "det/det_2145_lowq5.h5"

In [None]:
photons_list = emc_folder + str(np.loadtxt(emc_photons, dtype=str))

## Load EMC reconstruction

In [None]:
with h5py.File(emc_output,'r') as f:
    occupancies = f['occupancies'][:]
    likelihood = f['likelihood'][:]
    orientations = f['orientations'][:]
    images = f['intens'][:]
    scale = f['scale'][:]

## Determine modes

In [None]:
nr_rot = 180
nr_static = 0
nr_modes = occupancies.shape[1] - nr_static
modes = orientations % (nr_modes)
rotind = orientations // (nr_modes)
modes[rotind >= nr_rot] = orientations[rotind >= nr_rot] - nr_modes * (nr_rot - 1)
blacklisted = (occupancies == 0).all(axis=1)
ndata = len(modes) - blacklisted.sum()

## Inspecting 2D EMC classes

In [None]:
N = occupancies.shape[1]
ncols = 5
nrows = N//ncols
fig, axes = plt.subplots(nrows=nrows, ncols=ncols, figsize=(ncols*4,nrows*4), dpi=200)
for i in range(N):
    c,r = i//ncols, i%ncols
    axes[c,r].axis('off')
    nclass = ((modes==i) & (~blacklisted)).sum()
    axes[c,r].set_title("%d/%d [%.1f%%]" %(nclass,ndata, nclass/ndata*100.))
    axes[c,r].imshow(images[i], norm=colors.LogNorm(vmin=0.0001, vmax=3), cmap='cividis')
    axes[c,r].text(0,0, "%d" %i, color='r')
plt.savefig("../../plots/2dclasses.png", bbox_inches='tight')
plt.show()

## Select non-spherical classes

In [None]:
remove = [0,5]
blacklist = np.zeros(occupancies.shape[0], dtype=int)
for m in remove:
    blacklist[np.where(modes == m)[0]] = 1
ndata = blacklist.shape[0] - blacklist.sum()

In [None]:
np.savetxt(emc_folder + "blacklist_%d_%d.dat" %(0,ndata), blacklist, fmt='%d')

## Save templates 

In [None]:
# Parameters
photon_energy = 6.01 # keV
wavelength = 1240. / photon_energy / 1e3 * 1e-9 # m
distance = 0.705 # m
pixelsize = 200e-6 # m
rmax = 100
material = 'sucrose'
intensity0 = 1e7

In [None]:
# Detector sampling
boxsize = (181,181)
Y,X = spimage.grid(boxsize, (0,0))
Xc = spimage.x_to_qx(X,pixelsize,distance)
Yc = spimage.y_to_qy(Y,pixelsize,distance)
q = np.sqrt(Xc**2 + Yc**2)

In [None]:
# Spherical template
def template(diameter):
    A = spimage.sphere_model_convert_intensity_to_scaling(intensity0, diameter, wavelength, pixelsize, 
                                                      distance, material=material)
    s = spimage.sphere_model_convert_diameter_to_size(diameter, wavelength, pixelsize, distance)
    return spimage.I_sphere_diffraction(A,q,s).astype(np.float64)

In [None]:
N = 100
dmin = 50
dstep = 2
dsamples = np.arange(dmin,dmin+N*dstep,dstep) *1e-9

In [None]:
templates = np.array([template(d) for d in dsamples])

In [None]:
with h5py.File(emc_folder + "data/r2/init.h5", "w") as f:
    f['intens'] = templates.astype("<f8")
    f['scale'] = np.ones(occupancies.shape[0], dtype=np.int32)
    f['dsamples'] = dsamples
    f['intensity'] = intensity0
    f['wavelength'] = wavelength
    f['pixelsize'] = pixelsize
    f['distance'] = distance
    f['material'] = material

## Pick one sucrose 2D class

In [None]:
m = 14
det = detector.Detector(emc_detector, mask_flag=True)
emc = reademc.EMCReader(photons_list, det)
view_index = np.where((modes == m) )[0]
views_assembled = np.array([emc.get_frame(i) for i in view_index])
views_modules = np.array([emc.get_frame(i, raw=True) for i in view_index])
views_modules = views_modules.reshape((-1,4,128,128))

In [None]:
rot = rotind[modes == m]
sca = scale[modes == m]
x,y = geom.pixel_maps_from_geometry_file("../../geometry/b2_lowq.geom")

In [None]:
def test_geometry(dx,dy,N=10, rotation=True):
    x0,y0 = x.copy().reshape(4,128,128), y.copy().reshape(4,128,128)
    for i in range(len(dx)):
        x0[i] += dx[i]
    for i in range(len(dy)):
        y0[i] += dy[i]
    x0 = x0.reshape(512,128)
    y0 = y0.reshape(512,128)
    avg = 0.
    weights = 0.
    mask1 = geom.apply_geom_ij_yx((y0, x0), det.raw_mask.reshape(4,128,128))[::-1,::-1]
    mask2 = geom.apply_geom_ij_yx((y0, x0), np.ones((4,128,128)))[::-1,::-1]
    mask = ((mask1 == 0) & (mask2 == 1)).astype(np.float)
    for j in range(N):
        img = geom.apply_geom_ij_yx((y0, x0), views_modules[j])[::-1,::-1]
        if rotation:
            img = ndimage.rotate(img, rot[j], reshape=False)
            msk = ndimage.rotate(mask, rot[j], reshape=False)
            avg += (img + ndimage.rotate(img, 180., reshape=False))
            weights += (msk + ndimage.rotate(msk, 180., reshape=False))
        else:
            avg += img
            weights += mask
    avg[weights>=1] /= weights[weights>=1]
    avg[weights<=1] = 0.
    return avg, (weights >= 1)

In [None]:
def radial_averages(dx,dy,N=10):
    x0,y0 = x.copy().reshape(4,128,128), y.copy().reshape(4,128,128)
    for i in range(len(dx)):
        x0[i] += dx[i]
    for i in range(len(dy)):
        y0[i] += dy[i]
    x0 = x0.reshape(512,128)
    y0 = y0.reshape(512,128)
    mask1 = geom.apply_geom_ij_yx((y0, x0), det.raw_mask.reshape(4,128,128))[::-1,::-1]
    mask2 = geom.apply_geom_ij_yx((y0, x0), np.ones((4,128,128)))[::-1,::-1]
    mask = ((mask1 == 0) & (mask2 == 1)).astype(np.float)
    centers = []
    radials = []
    for j in range(N):
        img = geom.apply_geom_ij_yx((y0, x0), views_modules[j])[::-1,::-1]
        center, radial = spimage.radialMeanImage(img, msk=mask, cx=img.shape[1]//2, cy=img.shape[0]//2, output_r=True)
        centers.append(center)
        radials.append(radial)
    return np.array(centers), np.array(radials)

In [None]:
t0 = time.time()
test0, mask0 = test_geometry([0,0,0,0], [0,0,0,0], N=rot.shape[0], rotation=True)
test1, mask1 = test_geometry([0,0,0,0], [0,0,0,0], N=rot.shape[0], rotation=False)
print(time.time() - t0)

In [None]:
mask0 *= ~spimage.rmask(mask0.shape, 19)

In [None]:
plt.figure(figsize=(10,10))
plt.imshow(test0*mask0, norm=colors.LogNorm())
plt.colorbar()
plt.show()

In [None]:
centers0, radial0 = spimage.radialMeanImage(test0, msk=mask0, 
                                            cx=test0.shape[1]//2, cy=test0.shape[0]//2, output_r=True)
centers1, radial1 = spimage.radialMeanImage(test1, msk=mask1, 
                                            cx=test1.shape[1]//2, cy=test1.shape[0]//2, output_r=True)

## Parameters

In [None]:
photon_energy = 6.01 # keV
wavelength = 1240. / photon_energy / 1e3 * 1e-9 # m
distance = 0.732 # m
pixelsize = 200e-6 # m
rmax = 100
material = 'sucrose'

In [None]:
photons = test0.copy()
mask = mask0.copy()

## Fit diameter

In [None]:
diameter, info = spimage.fit_sphere_diameter(photons, mask, 140e-9, 0.5e-9, 
                                             wavelength, pixelsize, distance,
                                             method='pearson', full_output=True, x0=0, y0=0, 
                                             detector_adu_photon=1., detector_quantum_efficiency=1., 
                                             material=material, rmax=rmax, downsampling=1, 
                                             do_brute_evals=100, do_photon_counting=True, 
                                             maxfev=1000, brute_dmax=200e-9)

In [None]:
diameter*1e9

In [None]:
#diameter = 65e-9

In [None]:
intensity, info = spimage.fit_sphere_intensity(photons, mask, diameter, 1e-9, 
                                               wavelength, pixelsize, distance, 
                                               method='nrphotons', full_output=True, x0=0, y0=0, 
                                               detector_adu_photon=1., detector_quantum_efficiency=1., 
                                               material=material, rmax=rmax, downsampling=1, do_photon_counting=True)

In [None]:
intensity

In [None]:
dx, dy, diameter, intensity, info = spimage.fit_full_sphere_model(photons, mask, diameter, intensity,
                                                                  wavelength, pixelsize, distance, 
                                                                  full_output=True, x0=0, y0=0, 
                                                                  detector_adu_photon=1., detector_quantum_efficiency=1., 
                                                                  material=material, rmax=rmax, 
                                                                  downsampling=1, do_photon_counting=False, deltab=0.5, n=2)

In [None]:
diameter, intensity, dx,dy

In [None]:
#diameter *= 2
#intensity = 0.8e7

In [None]:
# Detector sampling
Y,X = spimage.grid(photons.shape, (0,0))
Xc = spimage.x_to_qx(X - dx,pixelsize,distance)
Yc = spimage.y_to_qy(Y - dy,pixelsize,distance)
q = np.sqrt(Xc**2 + Yc**2)
qr = spimage.x_to_qx(np.arange(0,photons.shape[0]/2.), pixelsize, distance)

In [None]:
# Radial average
centers, photons_r = spimage.radialMeanImage(photons, msk=mask, 
                                             cx=photons.shape[1]/2+dx, cy=photons.shape[0]/2+dy, output_r=True)
photons_qr = spimage.x_to_qx(centers, pixelsize, distance)[:photons.shape[0]//2]
photons_r  = photons_r[:photons.shape[0]//2]

In [None]:
# Fitted diffraction pattern
A = spimage.sphere_model_convert_intensity_to_scaling(intensity, diameter, wavelength, pixelsize, 
                                                      distance, material=material)
s = spimage.sphere_model_convert_diameter_to_size(diameter, wavelength, pixelsize, distance)
fit_2d = spimage.I_sphere_diffraction(A,q,s).astype(np.float64)
fit_1d = spimage.I_sphere_diffraction(A,qr,s)

In [None]:
# customized colors
cmap = cm.get_cmap('magma')
cmap.set_bad('green')
cmap.set_under('black')

In [None]:
img = np.copy(test1)
img[:,:test1.shape[1]//2] = fit_2d[:,:test1.shape[1]//2]
img[img == 0] = 1e-10
img[~(mask1)] = np.nan
cmap.set_bad('0.9')
cmap.set_under('white')
fig, axes = plt.subplots(ncols=2, figsize=(12,5), dpi=100)
axes[0].axis('off')
axes[0].set_title('left side: data / right side: fit')
axes[0].imshow(img*mask1, norm=colors.LogNorm(vmin=1e-5, vmax=5), cmap=cmap)
axes[1].plot(photons_qr, photons_r, label='data')
axes[1].plot(qr, fit_1d, label='fit')
axes[1].semilogy()
axes[1].set_ylim(1e-3,100)
axes[1].set_title("Mode = %02d" %m)
axes[1].legend(frameon=False)
plt.savefig("../../plots/sucrose/fit_%02d.png" %m, bbox_inches='tight')
plt.show()

In [None]:
plt.figure()
plt.plot(centers0, radial0/radial0.sum(), label='rot. averaged')
plt.plot(centers1, radial1/radial1.sum(), label='stat. averaged')
plt.semilogy()
plt.legend()
plt.show()

In [None]:
C,R = radial_averages([0,0,0,0], [0,0,0,0], N=sca.shape[0])

In [None]:
Rfit = []
for j in range(sca.shape[0]):
    A = spimage.sphere_model_convert_intensity_to_scaling(intensity*sca[j], diameter, wavelength, pixelsize, 
                                                          distance, material=material)
    s = spimage.sphere_model_convert_diameter_to_size(diameter, wavelength, pixelsize, distance)
    Rfit.append(spimage.I_sphere_diffraction(A,qr[18:],s))
Rfit = np.array(Rfit)
Rfit = Rfit[np.argsort(sca[:])][::-1]

In [None]:
Rdata = R[np.argsort(sca[:]),:159][::-1]
Rdata[Rdata == 0] = 1e-3
plt.imshow(Rdata, norm=colors.LogNorm(vmin=1e-3, vmax=2), aspect='auto')
plt.colorbar()
plt.show()

In [None]:
plt.figure()
plt.imshow(Rfit, norm=colors.LogNorm(vmin=1e-3, vmax=2), aspect='auto')
plt.colorbar()
plt.show()

In [None]:
plt.figure()
plt.imshow(Rfit, vmin=0, vmax=2, aspect='auto')
plt.colorbar()
plt.show()

In [None]:
plt.figure()
plt.plot(Rfit[400])
plt.plot(Rdata[400])
plt.semilogy()
plt.ylim(1e-3,1000)
plt.show()

In [None]:
plt.figure()
plt.plot(np.sort(intensity*sca*1e-6)[::-1])
plt.show()

In [None]:
plt.hist(intensity*sca*1e-6, bins=50)
plt.show()

In [None]:
with h5py.File(path_to_aux + '/sucrose/results_%02d.h5' %(m), "w") as f:
    f['diameter'] = diameter
    f['intensity'] = intensity
    f['scale'] = np.sort(sca)[::-1]
    f['photons'] = photons
    f['mask'] = mask
    f['rdata'] = Rdata
    f['rfit'] = Rfit
    f['qr'] = qr[18:]
    f['wavelength'] = wavelength
    f['distance'] = distance
    f['pixelsize'] = pixelsize