<font size = "5"> **[Correcting Image Distortions](0_Correcting_Image_Distortions.ipynb)** </font>

<hr style="height:1px;border-top:4px solid #FF8200" />

by 

Gerd Duscher

Materials Science & Engineering<br>
Joint Institute of Advanced Materials<br>
The University of Tennessee, Knoxville

and 

Matthew. F. Chisholm

Center of Nanophase Materials<br>
Oak Ridge National Laboratory

# Undistortion and Registration of Image-Stacks

We use this notebook **only** for a stacks of images for which we have a different stack where the distortion matrix is already determined.


## First we import the usual libraries
Please visit the  [Introductory notebook](0_Correcting_Image_Distortions.ipynb) to install ``pyTEMlib``and for more information on the used packages.

You'll need at least pyTEMlib version 0.05.2020.0 .

Run the code cell below to load all the python packages that we use here.

In [None]:
# import matplotlib and numpy
#                       use "inline" instead of "notebook" for non-interactive plots
%pylab --no-import-all notebook


# Import libraries from the book

import pyTEMlib
import file_tools  as ft     # File input/ output library
import image_tools as it

# For archiving reasons it is a good idea to print the version numbers out at this point
print('pyTEM version: ',pyTEMlib.__version__)

__notebook__ = 'Undistorted_Registration'
__notebook_version__ = '2020_06_09'


## Load an image stack :

Please, load an image stack. <br>

A stack of images is used to reduce noise, but for an added image the images have to be aligned to compensate for drift and other microscope instabilities.

Note that the **open file dialog** might not apear in the foreground!

In [None]:
Open_Nion_Directory = True

try: ### Close any h5_file that may be left open
    h5_file.close()
except:
    pass

### Open file Dialog
import importlib
importlib.reload(ft)

if Open_Nion_Directory:
    nion_selection = ft.nion_directory()
else:
    %gui qt5


### Plot Image Stack

In [None]:
### Load file
try:
    h5_file.close()
except: 
    pass
if Open_Nion_Directory:
    try:
        h5_file = nion_selection.h5_file
    except:
        pass
else:
    h5_file = ft.h5_open_file()#os.path.join(current_directory,filename))

current_channel = h5_file['Measurement_000/Channel_000']
current_dataset = current_channel['nDim_Data']

if current_channel['data_type'][()] != 'image_stack':
    print(f"Please load an image stack for this notebook, this is an {current_channel['data_type'][()]}")
    

print('Previous analysis of ',current_dataset.attrs['title'])        
for key in current_channel:
    if 'Log' in key:
        if 'analysis' in current_channel[key].keys():
            print(f"{key} includes analysis: {current_channel[key]['analysis'][()]}")
            
view = ft.h5_plot(current_dataset)  # note this needs a view reference for interaction


## Load an image with Distortion Matrix

Here we need a file file for which a distortion matrix has been determined. 
We use that distortion matrix to remove the distortion from the current file. 

Please note that the scale will change to an **absolute scale**.


In [None]:
fp = open(ft.config_path+'\path.txt','r')
path = fp.read()
fp.close()
distortion_selection = ft.nion_directory(path, extension=['hf5'])


In [None]:
### Load file

distort_file  = distortion_selection.h5_file
current_distort_channel = distort_file['Measurement_000/Channel_000']

found_distortion_matrix = False  

for key in current_distort_channel:
    if 'Log' in key:
        if 'analysis' in current_distort_channel[key]:
            if  'Distortion' in current_distort_channel[key]['analysis'][()]:
                distortion_tags  = current_distort_channel[key]
                found_distortion_matrix = True
                d_scaleX = current_distort_channel[key]['scale_x'][()]
                d_scaleY = current_distort_channel[key]['scale_y'][()]
                distortion_matrix = current_distort_channel[key]['distortion_matrix'][()]
                if 'distortion_crop' in current_distort_channel[key]: 
                    distortion_crop = current_distort_channel[key]['distortion_crop'][()]
            if 'Rigid Registration' == current_distort_channel[key]['analysis'][()]:
                distortion_crop = current_distort_channel[key]['Rigid_registration_crop'][()]\
            
if found_distortion_matrix:
    print('found distortion matrix')
    name = 'Distortion_Matrix'
    distortion_matrix_original = distortion_matrix.copy()
else:    
    print('No distortion matrix found!! We need a file with a distortion matrix to proceed!!!')
distortion_filename =  distort_file.filename   
distort_file.close()
print(distortion_filename)

In [None]:
#print(current_distort_channel[key]['analysis'][()])
print(distortion_filename)

### Visualize Distortion Matrix
This is not necessary but allows for a quick check on the data set

In [None]:

## Visualize Distortion Matrix ( check whether you got all pixels)
difference = distortion_matrix[:,:2]-distortion_matrix[:,2:]
distance_image = np.zeros(current_dataset.shape[1:])
angle_image = np.zeros(current_dataset.shape[1:])

distance_image[(distortion_matrix[:,0].astype(int),distortion_matrix[:,1].astype(int))] = np.linalg.norm(difference, axis =1)
angle_image[(distortion_matrix[:,0].astype(int),distortion_matrix[:,1].astype(int))] = np.degrees(np.arctan2(difference[:,1],difference[:,0]))
             
fig, ax = plt.subplots(1, 2, figsize=(8, 4), sharex=True, sharey=True)
fig.suptitle('Distortion Matrix in Polar Coordinates')

ax[0].set_title('norm')
norm_fig = ax[0].imshow(distance_image*d_scaleX*1000)
fig.colorbar(norm_fig, ax=ax[0] , label = 'distance [pm]')

ax[1].set_title('angle')
angle_fig = ax[1].imshow(angle_image,cmap = 'twilight')
fig.colorbar(angle_fig, ax=ax[1], label = 'angle [$^o$]');

###  Input stack data
We put all the input data into a dictionary named tags

In [None]:
 ## spatial data
d_crop = distortion_crop
tags = {}
tags['sizeX'] = current_dataset.shape[1]
tags['sizeY'] = current_dataset.shape[2]
tags['scaleX'] = d_scaleX
tags['scaleY'] = d_scaleY
tags['crop'] = distortion_crop
tags['extent'] = it.make_extent(current_dataset.shape[1:],tags['scaleX'],tags['scaleY'])
    
d_crop = distortion_crop
data_cube = np.array(current_dataset[:,d_crop[0]:d_crop[1],d_crop[2]:d_crop[3]])
print(data_cube.shape, current_dataset.shape)

## reduce distortion matrix to same croped image
distortion_matrix = distortion_matrix[np.where(distortion_matrix[:,0]<data_cube.shape[1])]
distortion_matrix = distortion_matrix[np.where(distortion_matrix[:,1]<data_cube.shape[2])]
distortion_matrix = distortion_matrix[np.where(distortion_matrix[:,0]>0)]
distortion_matrix = distortion_matrix[np.where(distortion_matrix[:,1]>0)]


### Undistort Stack

In [None]:
import importlib
importlib.reload(it)

interpolated = it.undistort_stack(distortion_matrix, data_cube)

### Crop interpolated stack

#### Choose Area 
Select the area that contains the atoms without rim

In [None]:
from matplotlib.widgets import  RectangleSelector
plt.figure(figsize=(8, 6))
plt.imshow(interpolated[0].T,origin = 'upper')
selector = RectangleSelector(plt.gca(), None,interactive=True , drawtype='box')  # gca get current axis (plot)

selector.to_draw.set_visible(True)
radius = interpolated.shape[1]/2.1
center = np.array(interpolated[0].shape)/2

selector.extents = (center[0]-radius,center[0]+radius,center[1]-radius,center[1]+radius)

#### Crop

In [None]:
xmin, xmax, ymin, ymax = selector.extents
print(selector.extents)
reduced_stack= interpolated[:,int(xmin): int(xmax), int(ymin): int(ymax)]
print(interpolated.min())
reduced_stack[np.isnan(reduced_stack)]=0.
print(reduced_stack.shape)
extent = (xmin, xmax, ymax, ymin )
plt.figure()
plt.imshow(reduced_stack[0].T, interpolation='nearest',cmap='gray', extent = extent, origin = 'upper');


## Do Complete Registration

In [None]:

## Log Undistorted Stack
out_tags = {}
out_tags['notebook']= __notebook__ 
out_tags['notebook_version']= __notebook_version__
out_tags['interpolation_crop'] = [xmin, xmax, ymin, ymax]
out_tags['data'] = reduced_stack
out_tags['data_type'] = 'image_stack'
out_tags['name'] = 'Remove Distortion'
out_tags['distortion_filename'] = distortion_filename
out_tags['scale_x'] = d_scaleX ## needs to be specfied because different from loaded spectrum
out_tags['scale_y'] = d_scaleY ## is now the one from distortion matrix

### scale NOT RIGHT
stack_group = ft.add_registration(current_channel, out_tags)

## Do all of registration
notebook_tags ={}
notebook_tags['notebook']= __notebook__ 
notebook_tags['notebook_version']= __notebook_version__
notebook_tags['scale_x'] = d_scaleX
notebook_tags['scale_y'] = d_scaleY
current_dataset = stack_group['nDim_Data']
stack_group = it.complete_registration(current_dataset, current_channel, notebook_tags)

h5_file.flush()

for key in current_channel:
    if 'Log' in key:
        if 'analysis' in current_channel[key]:
            print(f"{key} includes analysis: {current_channel[key]['analysis'][()]}")
plot = ft.h5_plot(stack_group['nDim_Data'])
print(h5_file.filename)

## Check Drift

In [None]:
for key in current_channel:
    if 'Log' in key:
        if 'analysis' in current_channel[key]:
            if 'Rigid Registration' == current_channel[key]['analysis'][()]:
                drift_channel = current_channel[key]
                
drift = drift_channel['Rigid_registration_drift']
polynom_degree = 2 # 1 is linear fit, 2 is parabolic fit, ...

x = np.linspace(0,drift.shape[0]-1,drift.shape[0])

line_fit_x = np.polyfit(x, drift[:,0], polynom_degree)
poly_x = np.poly1d(line_fit_x)
line_fit_y = np.polyfit(x, drift[:,1], polynom_degree)
poly_y = np.poly1d(line_fit_y)

plt.figure()
plt.axhline(color = 'gray')
plt.plot(x, drift[:,0], label = 'drift x')
plt.plot(x, drift[:,1], label = 'drift y')
plt.plot(x, poly_x(x),  label = 'fit_drift_x')
plt.plot(x, poly_y(x),  label = 'fit_drift_y')

plt.legend();
ax_pixels = plt.gca()
ax_pixels.step(1, 1)
for dim in current_dataset.dims:
    if dim.label == 'x': scaleX = (dim[0][1]-dim[0][0])*1000.  #in pm

ax_pm = ax_pixels.twinx()
x_1, x_2 = ax_pixels.get_ylim()

ax_pm.set_ylim(x_1*scaleX, x_2*scaleX)

ax_pixels.set_ylabel('drift [pixels]')
ax_pm.set_ylabel('drift [pm]')
ax_pixels.set_xlabel('image number');
plt.tight_layout()


In [None]:
print(h5_file.filename)
h5_file.close()


You can open this file now in the [Strain Analysis](5_strain_analysis.ipynb) notebook.  

Or if something went wrong go through the notebook step by step

## Step by Step

If this is an image stack we need to register and add the images. 

If this is not an image stack, we just take whatever image you opened.


### Rigid Registration


In [None]:
import scipy as sp

stack = np.transpose(reduced_stack, axes=(0,1,2) )

RigReg ,drift = it.Rigid_Registration(stack)

RigReg_crop,crop  = it.crop_image_stack(RigReg, drift)
    

RigReg_image = np.sum(RigReg_crop, axis=2)
    
im = RigReg_image
plt.figure()
#plt.title(current_channel['title'][()] )
plt.imshow(im.T,extent = tags['extent'], origin = 'upper');


### Determine Quality of Rigid Registration

First we fit a polynom of degree **polynom_degree** onto the drift of x and y separately.

The fit helps to discriminate outlayers (which then could be excluded, fixed, ... we did not had this problem for a while)

In a cell below, the second fit can be used to correct the drift of the outlayer images in the stack.

In [None]:
polynom_degree = 2 # 1 is linear fit, 2 is parabolic fit, ...

max_drift = 4 ## pixels

x = np.linspace(0,drift.shape[0]-1,drift.shape[0])

line_fit_x = np.polyfit(x, drift[:,0], polynom_degree)
poly_x = np.poly1d(line_fit_x)
line_fit_y = np.polyfit(x, drift[:,1], polynom_degree)
poly_y = np.poly1d(line_fit_y)

difference_drift = np.zeros(drift.shape)
difference_drift[:,0] = np.abs(drift[:,0] - poly_x(x))
difference_drift[:,1] = np.abs(drift[:,1] - poly_y(x))

stable_images = np.all(difference_drift<4  ,axis=1) 

line_fit_x = np.polyfit(x*stable_images, drift[:,0]*stable_images, polynom_degree)
poly_x = np.poly1d(line_fit_x)
line_fit_y = np.polyfit(x*stable_images, drift[:,1]*stable_images, polynom_degree)
poly_y = np.poly1d(line_fit_y)


plt.figure()

plt.axhline(color = 'gray')
plt.plot(x, drift[:,0], label = 'drift x')
plt.plot(x, drift[:,1], label = 'drift y')

#plt.plot(x, np.rint(drift[:,0]), label = 'pixel drift x')
#plt.plot(x, np.rint(drift[:,1]), label = 'pixel drift y')


plt.plot(x, poly_x(x),  label = 'fit_drift_x')
plt.plot(x, poly_y(x),  label = 'fit_drift_y')

outlayer = []
remove =[]
for i in range(len(stable_images)):
    if not stable_images[i]:
        plt.scatter(x[i],0, color='red')#, label='outlayers')
        print(f'image {x[i]}: estimated drift x: {poly_x(x[i]):.2f}, y: {poly_y(x[i]):.2f}')
        outlayer.append([i,poly_x(x[i]),poly_y(x[i])])
        remove.append(i)

plt.legend();
ax_pixels = plt.gca()
ax_pixels.step(1, 1)

ax_pm = ax_pixels.twinx()
x_1, x_2 = ax_pixels.get_ylim()
scaleX = tags['scaleX']*1000
ax_pm.set_ylim(x_1*scaleX, x_2*scaleX)

ax_pixels.set_ylabel('drift [pixels]')
ax_pm.set_ylabel('drift [pm]')
ax_pixels.set_xlabel('image number');



### Log Rigid Registration

please note that the last used of the two above options is stored.
Also we crop the stack and the summed image now.

In [None]:
out_tags = {}
out_tags['analysis'] = 'Rigid Registration Undistorted'
out_tags['notebook']= __notebook__ 
out_tags['notebook_version']= __notebook_version__

out_tags['data_type'] = 'image_stack'
out_tags['data'] = RigReg_crop

out_tags['Rigid_registration_drift']=drift
out_tags['Rigid_registration_crop'] = crop

out_tags['spatial_origin_x'] = 0.
out_tags['spatial_origin_y'] = 0.
out_tags['spatial_scale_x'] = tags['scaleX']
out_tags['spatial_scale_y'] = tags['scaleY']
out_tags['spatial_size_x'] = RigReg_image.shape[0]
out_tags['spatial_size_y'] = RigReg_image.shape[1]
out_tags['spatial_units'] = 'nm'


## Log data
out_tags['name'] = 'rigid_registration_undistorted'
out_tags['title'] = out_tags['name']
stack_channel = ft.log_results(current_channel, out_tags)


current_dataset = stack_channel['nDim_Data']
view = ft.h5_plot(current_dataset)  # note this needs a view reference for interaction


In [None]:
h5_file.flush()
print(h5_file.filename)

### Non-Rigid Registration

Here we use the **Diffeomorphic Demon Non-Rigid Registration** as provided by **simpleITK**.  

Please Cite: 
* [simpleITK](http://www.simpleitk.org/SimpleITK/project/parti.html)
    
    and
    
* [T. Vercauteren, X. Pennec, A. Perchant and N. Ayache *Diffeomorphic Demons Using ITK\'s Finite Difference Solver Hierarchy* The Insight Journal, 2007](http://hdl.handle.net/1926/510)

Please check Cite this article

[Yankovich, A., Berkels, B., Dahmen, W. et al. Picometre-precision analysis of scanning transmission electron microscopy images of platinum nanocatalysts. Nat Commun 5, 4155 (2014)] (https://doi.org/10.1038/ncomms5155)


This can take a few minutes

In [None]:
non_rigid_registered = it.DemonReg(current_dataset)

DemReg_image = np.sum(non_rigid_registered, axis=2)

plt.figure()
plt.imshow(DemReg_image.T);

### Log Non-Rigid Registration

please note that you can always delete a **Log** group with:<br>
*del current_channel['Log_001']*  or whatever LOG number you want to get rid off.

In [None]:
out_tags={}

out_tags['analysis']= 'Non-Rigid Registration Undistorted'
out_tags['notebook']= __notebook__ 
out_tags['notebook_version']= __notebook_version__

out_tags['data'] = non_rigid_registered
out_tags['data_type'] = 'image_stack'
    
out_tags['spatial_origin_x'] = 0.
out_tags['spatial_origin_y'] = 0.
out_tags['spatial_scale_x'] = tags['scaleX']
out_tags['spatial_scale_y'] = tags['scaleY']
out_tags['spatial_size_x'] = DemReg_image.shape[0]
out_tags['spatial_size_y'] = DemReg_image.shape[1]
out_tags['spatial_units'] = 'nm'


out_tags['name'] = 'non-rigid_registration_undistorted'
out_tags['title'] = out_tags['name']

stack_channel = ft.log_results(current_channel, out_tags)

for key in current_channel:
    if 'Log' in key:
        if 'analysis' in current_channel[key]:
            print(f"{key} includes analysis: {current_channel[key]['analysis'][()]}")

plot = ft.h5_plot(stack_channel['nDim_Data'])

## Close File



In [None]:
h5_file.close()

You can open this file now in the [Strain Analysis](5_strain_analysis.ipynb) notebook.  

In [None]:
plt.close('all')