# Strain Mapping 

This notebook goes through the steps to calculate the strain using pyxem.  This data was taken from the paper:

```
Microstructure and microchemistry study of irradiation-induced precipitates in proton irradiated ZrNb alloys
Yu, Zefeng; Zhang, Chenyu; Voyles, Paul M.; He, Lingfeng; Liu, Xiang; Nygren, Kelly; Couet, Adrien
10.18126/2nj3-gyd8 
```

It shows a percipitate which arises from irradiation in the ZrNb sample.  The dataset shows the strain for one of these precipitates.  The results in this notebook are slightly different than those published as the paper only uses two diffraction spots to calculate strain.  Here we define a `basis` set of diffraction spots from an unstrained region of the sample and then use that basis set of spots to refine the diffraction spots found in the rest of the dataset.

Then a gradient tensor which maps each set of found points at (x,y) is calculated such that the tensor maps the points onto the basis.

Transforming that gradient tensor we can plot the percent strain in the E11 E22 and E33 directions as well as a Theta displacement. 

In this sample you can see that there is mostly compressive stress on the percipite as well as shear stress. Hot spots on the edge of the theta map suggest the presence of dislocations as well. 
<center><img src="Images/Strain Mapping.png" alt="StrainMapping" height="1000" width="1000"></center>

In [1]:
import pyxem as pxm
import hyperspy.api as hs
print(pxm.__version__)

0.21.0


In [2]:
# Load the data
s = pxm.data.zrnb_precipitate(allow_download=True, lazy=True)

In [3]:
s.axes_manager

Navigation axis name,size,index,offset,scale,units
x,60,0,0.5,0.9,nm
y,40,0,0.5,0.9,nm

Signal axis name,size,Unnamed: 2,offset,scale,units
kx,256,,-6.564102564102564,0.0512820512820512,nm^-1
ky,256,,-6.564102564102564,0.0512820512820512,nm^-1


In [4]:
s.calibration.scale= 0.051

In [5]:
# set axis labels
s.axes_manager.signal_axes[0].name="kx"
s.axes_manager.signal_axes[1].name="ky"

In [6]:
# Set the figure dpi so that things show up nicely side by side (This is different for every monitor. 
import matplotlib.pyplot as plt
plt.rcParams['figure.dpi'] = 60

In [7]:
%matplotlib qt
s.plot()

In [8]:
# center the direct beam
beam_shift = s.get_direct_beam_position(method="blur", sigma=15, half_square_width= 30)

2025-06-05 00:32:45.972 python[22610:439067] +[IMKClient subclass]: chose IMKClient_Modern
2025-06-05 00:32:45.972 python[22610:439067] +[IMKInputSession subclass]: chose IMKInputSession_Modern


In [9]:
# compute the beam shift
beam_shift.compute()

  0%|          | 0/121 [00:00<?, ?it/s]

In [10]:
# plot the beam shift
beam_shift.plot()

[<Axes: title={'center': 'x-shift'}, xlabel='x axis (nm)', ylabel='y axis (nm)'>,
 <Axes: title={'center': 'y-shift'}, xlabel='x axis (nm)', ylabel='y axis (nm)'>]

In [11]:
# make the beam shift linear
linear_beam_shift = beam_shift.get_linear_plane(fit_corners=0.5)

In [12]:
# plot the linear beam shift
linear_beam_shift.plot()

[<Axes: title={'center': 'x-shift'}, xlabel='x axis (nm)', ylabel='y axis (nm)'>,
 <Axes: title={'center': 'y-shift'}, xlabel='x axis (nm)', ylabel='y axis (nm)'>]

In [13]:
# center the direct beam
centered_s = s.center_direct_beam(shifts=beam_shift, inplace=False)

In [14]:
centered_s

Title:,Unnamed: 1_level_0,Unnamed: 2_level_0
SignalType:,electron_diffraction,Unnamed: 2_level_1
Unnamed: 0_level_2,Array,Chunk
Navigation Axes,Signal Axes,Unnamed: 2_level_3
Bytes,600.00 MiB,16.00 MiB
Shape,"(60, 40|256, 256)","(8,8|256,256)"
Count,162 Tasks,40 Chunks
Type,float32,numpy.ndarray
60  40,256  256,
"Title: SignalType: electron_diffraction Array Chunk Bytes 600.00 MiB 16.00 MiB Shape (60, 40|256, 256) (8,8|256,256) Count 162 Tasks 40 Chunks Type float32 numpy.ndarray",Navigation Axes Signal Axes 60  40  256  256,

Title:,Unnamed: 1_level_0,Unnamed: 2_level_0
SignalType:,electron_diffraction,Unnamed: 2_level_1
Unnamed: 0_level_2,Array,Chunk
Bytes,600.00 MiB,16.00 MiB
Shape,"(60, 40|256, 256)","(8,8|256,256)"
Count,162 Tasks,40 Chunks
Type,float32,numpy.ndarray

Navigation Axes,Signal Axes
60  40,256  256


## Lets Average the Signal a Bit...

Sometimes we see some varible intensity in the disks.  We can actually correct some of this by just gaussian filtering the data in real space.  
For strain mapping we can make this very local (sigma = 1.,1.0,0.0,0.0) and we don't lose much spatial resolution. Of course you can also just take longer exposures to get more signal to noise ratio!  

There are other things you can do.  

In [13]:
s.plot()

In [14]:
from dask_image.ndfilters import gaussian_filter # For lazy signals
#from scipy.ndimage import gaussian_filter

In [15]:
s = pxm.data.zrnb_precipitate(allow_download=True, lazy=True)

In [16]:
#s.compute()

In [17]:
# filter the dataset using dask_image
filtered = s.filter(gaussian_filter, sigma=(1,1,0,0)) # in pixels 

In [18]:
# plot the filtered dataset (Lets compare the two!) 

hs.plot.plot_signals([s, filtered], navigator=s.navigator,
                     vmax="99th", cmap='magma')
plt.close("all")

In [19]:
plt.close("all")

## Aside:  I want to try a Hough Transform!

Okay I personally am I little bit hisitent about the Hough Transform.  Most of the time the reason that people see better performance from the Hough transform over Template matching is that they are not doing a cross-correlation which properly normalizes for noise.  The Hough transform also requires very high signal to noise, requires you to mask your data in some way (which also explains better performance).  Template matching will handle fluctuations in the disks as long as you normalize correctly (which pyxem currently does quite well in my opinion)

You can find the peaks from the circular_hough signal in the same way we normally find peaks. This functionality isn't currently baked into `pyxem` but if people are interested in this you can raise an issue and it is an easy addition. 

In [20]:
from skimage.transform import hough_circle
from skimage.feature import canny
import numpy as np

def hough_circle_single_rad(img, radius,sigma=3, **kwargs):
    return hough_circle(img, radius, **kwargs)[0] # Otherwise the returned 

In [21]:
canny_img = filtered.map(canny, sigma=3, low_threshold=.6, high_threshold=.8, inplace=False, use_quantiles=True)

In [22]:
canny_img.plot()

  0%|          | 0/407 [00:00<?, ?it/s]

In [23]:
circular_hough = canny_img.map(hough_circle_single_rad, radius=11, inplace=False)

In [24]:
canny_img.plot()
circular_hough.plot(axes_manager=canny_img.axes_manager)

  0%|          | 0/407 [00:00<?, ?it/s]

## Filtering with a Disk Template Matching

Then we can use template matching before finding the diffraction vectors in the dataset.  I like to do this lazily and then adjust the parameters.  The disk_r can be read from the size of the direct beam but it is also good to view the template result to make sure that things worked correctly.  If your disk_r is too small you might end up with a valley at the center of the disk and if your radius is too large you end up with a platau at the center.

This is shown below where the ideal radius is around ~11 pixels

In [25]:
plt.close("all")

In [26]:
s.plot()

In [27]:
# lets just try to see what the effect of different disk radii is:
import matplotlib.pyplot as plt
fig = plt.figure(figsize=(15, 5))
one_pattern=filtered.inav[5,5]
hs.plot.plot_images([one_pattern.template_match_disk(r) for r in [5,10,15]], 
                    axes_decor="off",
                    scalebar="all",
                    label=["Radius=5pix","Radius=10pix","Radius=15pix"], fig=fig)
plt.show()

In [28]:
# template matching using a disk.  
temp = filtered.template_match_disk(disk_r=11, subtract_min=False)

In [29]:
# pass navigator and plot
temp.navigator= s.navigator
temp.plot()

## Peak Finding 

Now we can see what a good value for peak finding is.  We can either use the interactive peak finding in hyperspy but I tend to just play with the vmin value with plotting until I get a reasonable min value.

In [30]:
# lets look at the histgram for just a couple of points
temp.inav[3:10, 3:10].get_histogram().plot()



  0%|          | 0/2 [00:00<?, ?it/s]

In [31]:
# Lets plot the data with an adjusted vmin 
temp.plot(vmin=.4)

In [32]:
# get the diffraction vectors
vect = temp.get_diffraction_vectors(threshold_abs=0.4,
                        distance=10, get_intensity=False)



In [33]:
#display the diffraction vectors
vect

Title:,Unnamed: 1_level_0,Unnamed: 2_level_0
SignalType:,diffraction_vectors,Unnamed: 2_level_1
Unnamed: 0_level_2,Array,Chunk
Navigation Axes,Signal Axes,Unnamed: 2_level_3
Bytes,18.75 kiB,512 B
Shape,"(60, 40|ragged)","(8,8|)"
Count,567 Tasks,40 Chunks
Type,object,numpy.ndarray
60  40,,
"Title: SignalType: diffraction_vectors Array Chunk Bytes 18.75 kiB 512 B Shape (60, 40|ragged) (8,8|) Count 567 Tasks 40 Chunks Type object numpy.ndarray",Navigation Axes Signal Axes 60  40,

Title:,Unnamed: 1_level_0,Unnamed: 2_level_0
SignalType:,diffraction_vectors,Unnamed: 2_level_1
Unnamed: 0_level_2,Array,Chunk
Bytes,18.75 kiB,512 B
Shape,"(60, 40|ragged)","(8,8|)"
Count,567 Tasks,40 Chunks
Type,object,numpy.ndarray

Navigation Axes,Signal Axes
60  40,


In [34]:
#plot the diffraction vectors
m = vect.to_markers( facecolors="none",
                     edgecolor="w", 
                     sizes=[30,])
s.plot()
s.add_marker(m,)


In [35]:
plt.close("all")

In [36]:
# lets subpixel refine the vectors and compare!
vect_sub = vect.subpixel_refine(filtered,"cross-correlation",
                     disk_r=11,
                     upsample_factor=2, square_size=26)

m = vect.to_markers( facecolors="none",
                     edgecolor="w", 
                     sizes=[30,])

m2 = vect_sub.to_markers( facecolors="none",
                     edgecolor="g", 
                     sizes=[30,])
s.plot()
s.add_marker(m,)
s.add_marker(m2,)

## Setting a Basis and Determining Strain

First we filter the vactors based on their magnitude. This gets rid of the zero beam and weaker peaks farther out!

In [37]:
# filter the magnitude of the vectors
vect_filtered = vect_sub.filter_magnitude(min_magnitude=3.5, max_magnitude=4.5)

In [38]:
# display the vectors
#lazy_vect = vect_filtered.as_lazy()
vect_filtered.compute()

  0%|          | 0/487 [00:00<?, ?it/s]

Defining the basis vector far from the region of interest

In [39]:
# get a basis from the unstrained region
basis = vect_filtered.inav[4,4]

 Filter the vectors to only those vectors around a basis vector.  Distance defines the max distance from a basis point for some point to be associated with that point.
 
 If multiple points are found the closest point is used.  If no points are found `np.nan` is returned and it is ignored.

In [41]:
# filter based on the basis
filtered_data = vect_filtered.filter_basis(basis)

  0%|          | 0/109 [00:00<?, ?it/s]

In [42]:
# display filtered data
filtered_data

0,1,2
"DiffractionVectors2D, title: , dimensions: (60, 40|2, 10)","DiffractionVectors2D, title: , dimensions: (60, 40|2, 10)","DiffractionVectors2D, title: , dimensions: (60, 40|2, 10)"
"Current Index:(0, 0)","Current Index:(0, 0)","Current Index:(0, 0)"
,,
column_names:,ky,kx
0,,
1,-2.1794871794871797,-3.333333333333333
2,1.9230769230769234,-3.333333333333333
3,,
4,-3.1538461538461537,-1.5897435897435894
5,-4.0256410256410255,0.07692307692307665


In [43]:
# white for current data, red for basis.
m = filtered_data.to_markers(edgecolor="w",facecolor="none", sizes=(40), lw=4)
basis_markers = hs.plot.markers.Points(basis.data[0][:,::-1], edgecolor="r",facecolor="none", sizes=(40), lw=3 )

filtered.plot(vmax="99th")
filtered.add_marker((m, basis_markers))

  0%|          | 0/109 [00:00<?, ?it/s]

  0%|          | 0/407 [00:00<?, ?it/s]

## Fitting an Ellipse for the Strain

Let's get a tensor strain map.  Basically we can determine the best elliptical transfromation to map from the basis set of vectors to the strained vectors.  We can also determine the residual and use that to improve the fits or indentify area where we are less confident about the fit.


In [44]:
basis

0,1,2
"DiffractionVectors, title: , dimensions: (|ragged)","DiffractionVectors, title: , dimensions: (|ragged)","DiffractionVectors, title: , dimensions: (|ragged)"
Current Index:(),Current Index:(),Current Index:()
,,
column_names:,ky,kx
0,0.2564102564102564,-4.256410256410256
1,-2.1794871794871797,-3.333333333333333
2,1.8974358974358978,-3.358974358974359
3,3.384615384615385,-2.4871794871794872
4,-3.1538461538461537,-1.6410256410256405
5,-4.076923076923077,0.0512820512820511


In [45]:
from pyxem.generators.displacement_gradient_tensor_generator import get_DisplacementGradientMap

# Lets get a tensor strain map.
strain_map, residual = get_DisplacementGradientMap(filtered_data, basis.data[0], return_residuals=True)

  0%|          | 0/109 [00:00<?, ?it/s]

  0%|          | 0/109 [00:00<?, ?it/s]

In [46]:
# get the magnitude of the residual
std_err = (residual**2).sum(axis=-1)**0.5
std_err.set_signal_type()

In [47]:
#plot the error.
std_err.plot()

In [48]:
# visualize the error and determine how to better fit the data.
m = filtered_data.to_markers(edgecolor="w",facecolor="none", sizes=(40))
basis_markers = hs.plot.markers.Points(basis.data[0][:,::-1], edgecolor="r",facecolor="none", sizes=(40) )

filtered.plot(navigator=std_err, vmax="99th")
filtered.add_marker((m, basis_markers))

  0%|          | 0/109 [00:00<?, ?it/s]

In [51]:
# get the strain maps.
maps = strain_map.get_strain_maps()

  0%|          | 0/109 [00:00<?, ?it/s]

  0%|          | 0/109 [00:00<?, ?it/s]

  0%|          | 0/9 [00:00<?, ?it/s]

## Visualizing the Strain

In [52]:
# use plot images to show the images.
import matplotlib.pyplot as plt
f= plt.figure(figsize=(7,7))
hs.plot.plot_images(maps,per_row=2,fig=f,
                    label=["e11","e22", "e12", "theta"],
                    tight_layout=True, cmap="hot", axes_decor="off", scalebar="all", scalebar_color="black")
plt.show()
# save the figure
fig.savefig("strainmaps.png", dpi=300)

In [53]:
m = filtered_data.to_markers(edgecolor="w",facecolor="none", sizes=(45), linewidth=4)
basis_markers = hs.plot.markers.Points(basis.data[0][:,::-1], edgecolor="r",facecolor="none", sizes=(45), linewidth=4 )

filtered.plot(navigator=std_err,vmax="99th", navigator_kwds =dict(cmap="hot"))
filtered.add_marker((m, basis_markers))


  0%|          | 0/109 [00:00<?, ?it/s]