# Example 003 - use band replacement to include optical sectionining as part of 2D-SIM reconstruction
  
In real samples, there is often background due to fluorescence from above / below the focal plane. This lowers modulation contrast to noise ratio (MCNR) and limits the reconstruction quality for 2D-SIM. Here, we introduce a variant on existing methods by Shaw and O'Halleran to achieve optical sectioning due 2D-SIM reconstruction.

We utilize a z-stack of 2D-SIM acquisitions through a tetrahymena sample, immunofluorescence labeled for basal bodies, as our dataset. Sample provided by Dr. Nick Galati.

### Import libraries

In [None]:
import os
import pickle
from pathlib import Path
import mcsim.analysis.sim_reconstruction as sim
import localize_psf.fit_psf as psf
import localize_psf.affine as affine
import localize_psf.rois as rois
import napari
import tifffile
import numpy as np
from numpy import fft
from mcsim.analysis import analysis_tools
import matplotlib.pyplot as plt
from matplotlib.colors import PowerNorm

### Load experimental data containing all z stage positions and patterns.

In [None]:
ncolors = 1
nangles = 3
nphases = 3
nz = 35
nx = 2048
ny = 2048
imgs = tifffile.imread(Path("data", "example_003","raw_data","tetrahymena_basal_bodies.tif")).reshape([ncolors, nangles, nphases, nz, ny, nx])
imgs = np.swapaxes(imgs,1,2)

### Define experimental metadata

In [None]:
na = 1.3
pixel_size = 0.065
dz = 0.250
excitation_wavelengths = 0.465
emission_wavelengths = 0.509

### View raw DMD-SIM data

In [None]:
viewer = napari.view_image(imgs,name='Raw DMD-SIM',scale=(1,1,1,dz,pixel_size,pixel_size))

### Define region of interest containing data

In [None]:
# [cy, cx]
roi = rois.get_centered_roi([1230, 996], [731, 631])
nx_roi = roi[3] - roi[2]
ny_roi = roi[1] - roi[0]

### Load experimental optical transfer function

In [None]:
otf_data_path = Path("data", "example_003", "calibration", "2020_05_19_otf_fit_blue.pkl")

with open(otf_data_path, 'rb') as f:
    otf_data = pickle.load(f)
otf_p = otf_data['fit_params']

otf_fn = lambda f, fmax: 1 / (1 + (f / fmax * otf_p[0]) ** 2) * psf.circ_aperture_otf(f, 0, na, 2 * na / fmax)

### Load pre-calculated affine transforms for qi2lab DMD-SIM.

In [None]:
affine_fnames = [Path("data", "example_003", "calibration", "2021-04-13_12;49;32_affine_xform_blue_z=0.pkl")]

affine_xforms = []
for p in affine_fnames:
    with open(p, 'rb') as f:
        affine_xforms.append(pickle.load(f)['affine_xform'])

### Load qi2lab DMD-SIM patterns. Estimate frequency and phase using extracted affine transforms.

In [None]:
dmd_pattern_data_fpath = [Path("data", "example_003", "calibration", "period=6.0_nangles=3", "wavelength=473nm", "sim_patterns_period=6.01_nangles=3.pkl")]

frqs_dmd = np.zeros((2, 3, 2))
phases_dmd = np.zeros((ncolors, nangles, nphases))

for kk in range(ncolors):
    ppath = dmd_pattern_data_fpath[kk]
    xform = affine_xforms[kk]

    with open(ppath, 'rb') as f:
        pattern_data = pickle.load(f)

    # DMD intensity frequency and phase (twice electric field frq/phase)
    frqs_dmd[kk] = 2 * pattern_data['frqs']
    phases_dmd[kk] = 2 * pattern_data['phases']
    dmd_nx = pattern_data['nx']
    dmd_ny = pattern_data['ny']

for kk in range(ncolors):
    # otf matrix
    fmax = 1 / (0.5 * emission_wavelengths / na)
    fx = fft.fftshift(fft.fftfreq(nx_roi, pixel_size))
    fy = fft.fftshift(fft.fftfreq(ny_roi, pixel_size))
    ff = np.sqrt(fx[None, :] ** 2 + fy[:, None] ** 2)
    otf = otf_fn(ff, fmax)
    otf[ff >= fmax] = 0

    # guess frequencies/phases
    frqs_guess = np.zeros((nangles, 2))
    phases_guess = np.zeros((nangles, nphases))
    for ii in range(nangles):
        for jj in range(nphases):
            # estimate frequencies based on affine_xform
            frqs_guess[ii, 0], frqs_guess[ii, 1], phases_guess[ii, jj] = \
                affine.xform_sinusoid_params_roi(frqs_dmd[kk, ii, 0], frqs_dmd[kk, ii, 1],
                                                 phases_dmd[kk, ii, jj], [dmd_ny, dmd_nx], roi, xform)

    # convert from 1/mirrors to 1/um
    frqs_guess = frqs_guess / pixel_size

### Perform SIM reconstruction

In [None]:
%matplotlib inline
save_path = Path("data","example_004","reconstruction")
for ch_idx in range(ncolors):
    for z_idx in range(14,15):
        imgset = sim.SimImageSet({"pixel_size": pixel_size, "na": na, "wavelength": emission_wavelengths},
                                imgs[0, :, :, z_idx, roi[0]:roi[1], roi[2]:roi[3]],
                                frq_estimation_mode="band-correlation",
                                frq_guess=frqs_guess,
                                phases_guess=phases_guess,
                                phase_estimation_mode="wicker-iterative",
                                combine_bands_mode="fairSIM",
                                fmax_exclude_band0=0.5,
                                normalize_histograms=True,
                                otf=otf,
                                wiener_parameter=0.4,
                                background=100, 
                                gain=2, 
                                min_p2nr=0.25, 
                                max_phase_err=15,
                                save_dir=Path("data","example_003","reconstruction"),
                                save_suffix="_z"+str(z_idx).zfill(3)+"_ch"+str(ch_idx).zfill(3),
                                interactive_plotting=False, 
                                figsize=(20, 13))

        imgset.reconstruct()
        imgset.plot_figs()
        imgset.save_imgs()
        imgset.log_file.close()

        if ch_idx == 0 and z_idx == 0:
            wf_images = np.zeros((ncolors,nz,imgset.widefield.shape[0],imgset.widefield.shape[1]),dtype=np.float32)
            SR_images = np.zeros((ncolors,nz,imgset.sim_sr.shape[0],imgset.sim_sr.shape[1]),dtype=np.float32)
        wf_images[ch_idx,z_idx,:]=imgset.widefield
        SR_images[ch_idx,z_idx,:]=imgset.sim_sr

        del imgset

### Display results

In [None]:
colormaps = ['bop blue', 'bop orange']

viewer = napari.view_image(wf_images[ch_idx,:],name='Widefield CH '+str(ch_idx).zfill(2),scale=(dz,pixel_size,pixel_size),colormap = colormaps[0],blending='additive',contrast_limits=[0,2**16-1])
viewer.add_image(SR_images[ch_idx,:],name='SR CH '+str(ch_idx).zfill(2),scale=(dz,pixel_size/2,pixel_size/2),colormap = colormaps[1],blending='additive',contrast_limits=[0,2**16-1])

viewer.scale_bar.unit = 'um'
viewer.scale_bar.visible = True

### Calculate FFT of raw data to find frequency guesses

In [None]:
%matplotlib widget
for z_idx in range(14,15):

    image_set = imgs[0, :, :, z_idx, :, :]

    dx = pixel_size
    fxs = analysis_tools.get_fft_frqs(nx, dx)
    df = fxs[1] - fxs[0]
    fys = analysis_tools.get_fft_frqs(ny, dx)
    ff = np.sqrt(np.expand_dims(fxs, axis=0)**2 + np.expand_dims(fys, axis=1)**2)
    for ii in range(image_set.shape[0]):
        ft = fft.fftshift(fft.fft2(fft.ifftshift(np.squeeze(image_set[ii, 0, :]))))

        figh = plt.figure()
        plt.title('Z='+str(ch_idx)+", Angle="+str(ii))
        plt.imshow(np.abs(ft), norm=PowerNorm(gamma=0.1,vmin=65,vmax=1.5e7),
                    extent=[fxs[0] - 0.5 * df, fxs[-1] + 0.5 * df, fys[-1] + 0.5 * df, fys[0] - 0.5 * df])
        plt.show()