# Example 003 - use band replacement to include optical sectionining as part of 2D-SIM reconstruction
  
In real samples, there is often background due to fluorescence from above / below the focal plane. This lowers modulation contrast to noise ratio (MCNR) and limits the reconstruction quality for 2D-SIM. Here, we introduce a variant on existing methods by Shaw and O'Halleran to achieve optical sectioning due 2D-SIM reconstruction.

We utilize a z-stack of 2D-SIM acquisitions through a *Tetrahymena* sample, immunofluorescence labeled for basal bodies, as our dataset. Sample provided by Dr. Nick Galati.

### Import libraries

In [None]:
from pathlib import Path
import mcsim.analysis.sim_reconstruction as sim
import localize_psf.fit_psf as psf
import localize_psf.affine as affine
import localize_psf.rois as rois
import napari
import tifffile
import json
import numpy as np
from numpy import fft
from mcsim.analysis import analysis_tools
import matplotlib.pyplot as plt
from matplotlib.colors import PowerNorm

### Load experimental data containing all z stage positions and patterns.

In [None]:
ncolors = 1 # number of channels
nangles = 3 # number of angles
nphases = 3 # number of phases
nz = 35 # number of z planes
nx = 2048 # number of x pixels on camera
ny = 2048 # number of y pixels on camera

# read data from disk
imgs = tifffile.imread(Path("data", "example_003","raw_data","tetrahymena_basal_bodies.tif")).reshape([ncolors, nangles, nphases, nz, ny, nx])

### Define experimental metadata

In [None]:
na = 1.3
pixel_size = 0.065 # um
dz = 0.250 # um
excitation_wavelengths = 0.465 # um
emission_wavelengths = 0.509 # um

### View raw DMD-SIM data

In [None]:
# add images to napari viewer with scale information
viewer = napari.view_image(np.squeeze(imgs),name='Raw DMD-SIM',scale=(1,1,dz,pixel_size,pixel_size))

# label sliders in napari viewer
viewer.dims.axis_labels = ['a', 'p', 'z','y','x']

### Define region of interest containing data

In [None]:
# [cy, cx]
roi = rois.get_centered_roi([1230, 996], [731, 631])
nx_roi = roi[3] - roi[2]
ny_roi = roi[1] - roi[0]

### Load experimental optical transfer function

In [None]:
otf_data_path = Path("data", "example_002", "calibration", "2020_05_19_otf_fit_blue.json")

# load optical transfer function data
with open(otf_data_path, 'rb') as f:
    otf_data = json.load(f)
otf_p = np.asarray(otf_data['fit_params'])

# define function to return optical transfer function
otf_fn = lambda f, fmax: 1 / (1 + (f / fmax * otf_p[0]) ** 2) * psf.circ_aperture_otf(f, 0, na, 2 * na / fmax)

### Load pre-calculated affine transforms for qi2lab DMD-SIM.

In [None]:
affine_fnames = [Path("data", "example_003", "calibration", "2021-04-13_12;49;32_affine_xform_blue_z=0.json")]

# load affine transforms for each channel
affine_xforms = []
for p in affine_fnames:
    with open(p, 'rb') as f:
        affine_xforms.append(np.asarray(json.load(f)['affine_xform']))

### Load qi2lab DMD-SIM patterns. Estimate frequency and phase using extracted affine transforms.

In [None]:
dmd_pattern_data_fpath = [Path("data", "example_003", "calibration", "period=6.0_nangles=3", "wavelength=473nm", "sim_patterns_period=6.01_nangles=3.json")]

# variables to store information on frequences and phases of DMD patterns
frqs_dmd = np.zeros((2, 3, 2))
phases_dmd = np.zeros((ncolors, nangles, nphases))

# loop over all channels
for kk in range(ncolors):

    # load patterns displayed on DMD
    ppath = dmd_pattern_data_fpath[kk]
    with open(ppath, 'rb') as f:
        pattern_data = json.load(f)

    # load affine transform for this channel
    xform = affine_xforms[kk]

    # DMD intensity frequency and phase (twice electric field frq/phase)
    frqs_dmd[kk] = 2 * np.asarray(pattern_data['frqs'])
    phases_dmd[kk] = 2 * np.asarray(pattern_data['phases'])
    dmd_nx = int(pattern_data['nx'])
    dmd_ny = int(pattern_data['ny'])

# loop over all channels
for kk in range(ncolors):

    # calculate optical transfer function matrix
    fmax = 1 / (0.5 * emission_wavelengths / na)
    fx = fft.fftshift(fft.fftfreq(nx_roi, pixel_size))
    fy = fft.fftshift(fft.fftfreq(ny_roi, pixel_size))
    ff = np.sqrt(fx[None, :] ** 2 + fy[:, None] ** 2)
    otf = otf_fn(ff, fmax)
    otf[ff >= fmax] = 0

    # guess frequencies/phases using OTF, affine transform, and DMD patterns
    frqs_guess = np.zeros((nangles, 2))
    phases_guess = np.zeros((nangles, nphases))
    for ii in range(nangles):
        for jj in range(nphases):
            # estimate frequencies based on affine_xform
            frqs_guess[ii, 0], frqs_guess[ii, 1], phases_guess[ii, jj] = \
                affine.xform_sinusoid_params_roi(frqs_dmd[kk, ii, 0], frqs_dmd[kk, ii, 1],
                                                 phases_dmd[kk, ii, jj], [dmd_ny, dmd_nx], roi, xform)

    # convert from 1/mirrors to 1/um
    frqs_guess = frqs_guess / pixel_size

### Perform SIM reconstruction. Loop over multiple OS strength values to see effect on reconstruction.

In [None]:
%matplotlib inline
# path to save data
save_path = Path("data","example_004","reconstruction")

# range of z planes to reconstruct
min_z=9
max_z=12

# loop over all channels and z planes.
for ch_idx in range(ncolors):
    for z_idx in range(min_z,max_z):
        # define list of fmax_exclude_band0 parameters to test for optical sectionining
        b_idx = 0
        band_exclusion=[0.1,0.2,0.3,0.4,0.5]

        # loop over all fmax_exclude_band0 parameters
        for fmax_exclude_band0 in band_exclusion:

            # create mcSIM reconstruction object. See docstring for details on each parameter.
            imgset = sim.SimImageSet({"pixel_size": pixel_size, "na": na, "wavelength": emission_wavelengths},
                                    imgs[0, :, :, z_idx, roi[0]:roi[1], roi[2]:roi[3]],
                                    frq_estimation_mode="band-correlation",
                                    frq_guess=frqs_guess,
                                    phases_guess=phases_guess,
                                    phase_estimation_mode="wicker-iterative",
                                    combine_bands_mode="fairSIM",
                                    fmax_exclude_band0=fmax_exclude_band0,
                                    normalize_histograms=True,
                                    otf=otf,
                                    wiener_parameter=0.2,
                                    background=100, 
                                    gain=1, 
                                    min_p2nr=0.25, 
                                    max_phase_err=15*np.pi/180,
                                    save_dir=Path("data","example_003","reconstruction"),
                                    save_suffix="_b"+str(b_idx).zfill(3)+"_z"+str(z_idx).zfill(3)+"_ch"+str(ch_idx).zfill(3),
                                    interactive_plotting=False, 
                                    figsize=(20, 13))

            # perform reconstruction, plot figures, save, and clean up log file
            imgset.reconstruct()
            imgset.plot_figs()
            imgset.save_imgs()
            imgset.save_result()

            # create variables to hold widefield and SIM SR images
            if ch_idx == 0 and z_idx == min_z and b_idx==0:
                wf_images = np.zeros((len(band_exclusion),ncolors,nz,imgset.widefield.shape[0],imgset.widefield.shape[1]),dtype=np.float32)
                SR_images = np.zeros((len(band_exclusion),ncolors,nz,imgset.sim_sr.shape[0],imgset.sim_sr.shape[1]),dtype=np.float32)
            
            # store widefield and SIM SR images for display
            wf_images[b_idx,ch_idx,z_idx,:]=imgset.widefield
            SR_images[b_idx,ch_idx,z_idx,:]=imgset.sim_sr
            b_idx=b_idx+1

            # clean up mcSIM reconstruction object
            del imgset

### Display results

In [None]:
# define colormaps
colormaps = ['bop blue', 'bop orange']

# add images to napari viewer with scale information, colormaps, and additive blending
viewer = napari.view_image(np.squeeze(wf_images),name='Widefield CH '+str(ch_idx).zfill(2),scale=(1,dz,pixel_size,pixel_size),colormap = colormaps[0],blending='additive',contrast_limits=[0,2**16-1])
viewer.add_image(np.squeeze(SR_images),name='SR CH '+str(ch_idx).zfill(2),scale=(1,dz,pixel_size/2,pixel_size/2),colormap = colormaps[1],blending='additive',contrast_limits=[0,2**16-1])

# activate scale bar in physical units
viewer.scale_bar.unit = 'um'
viewer.scale_bar.visible = True

# label sliders in napari viewer
viewer.dims.axis_labels =['b_idx', 'z','y','x']

## Importance of SLM calibration

### Calculate FFT of raw data to find frequency guesses

In [None]:
%matplotlib widget

# range of z planes to view FFTs
min_z = 9
max_z = 10

for z_idx in range(min_z,max_z):

    # extract angle/phase data for current z plane
    image_set = imgs[0, :, :, z_idx, :, :]

    # create grid containing correct spatial frequencies given image sampling
    dx = pixel_size
    fxs = analysis_tools.get_fft_frqs(nx, dx)
    df = fxs[1] - fxs[0]
    fys = analysis_tools.get_fft_frqs(ny, dx)
    ff = np.sqrt(np.expand_dims(fxs, axis=0)**2 + np.expand_dims(fys, axis=1)**2)

    # loop over all angle/phases and plot absolute value of 2D fourier transform
    for ii in range(image_set.shape[0]):
        ft = fft.fftshift(fft.fft2(fft.ifftshift(np.squeeze(image_set[ii, 0, :]))))

        figh = plt.figure()
        plt.title('Z='+str(ch_idx)+", Angle="+str(ii))
        plt.imshow(np.abs(ft), norm=PowerNorm(gamma=0.1,vmin=65,vmax=1.5e7),
                    extent=[fxs[0] - 0.5 * df, fxs[-1] + 0.5 * df, fys[-1] + 0.5 * df, fys[0] - 0.5 * df])
        plt.show()