Takes a folder of .ovf files (of skyrmion) simulated from mumax and does two things: 
1. Create ground truth labels 
2. Creats simulated LTEM images

The data structure should be as the following
- training_set_XXX (main directory)
    - image NAME and number 00XXX will be same across all .mx3, .ovf, truth image, and LTEM training image
    - mx3s
        - "NAME_00XXX_simdetails.mx3"
    - magnetizations (containing the .ovf files output from mumax)
        - "NAME_00XXX_simdetails.ovf"
        - simdetails includes DMI value, B value, etc. 
    - training_images (LTEM images will be placed here)
        - "NAME_00XXX_LTEM_imdetails.tif"
        - imdetails will include defocus, Tx, Ty, etc. 
    - training_labels (ground truth labels will be placed here)
        - "NAME_00XXX_LABEL_labdetails.tif"


Notes: 
* images should all be saved as inidividuals, can be batched later. 
    - if saving them as stacks would cause problems as they'll get too big. 

In [3]:
# PyLorentz2 environment 
%matplotlib widget
%load_ext autoreload
%autoreload 2

from pathlib import Path 
import sys
sys.path.append("/home/bendera/Lorentz_folder/AlecBender/mumax_training_files/SkyrmNet-main/hipl-main")
sys.path.append("/home/bendera/Lorentz_folder/AlecBender/PyLorentz/SimLTEM")
sys.path.append("/home/bendera/Lorentz_folder/AlecBender/PyLorentz/PyTIE")

from image_helpers import * 
from sim_helper import load_ovf
import tifffile
import matplotlib.pyplot as plt
from skimage import data
from skimage.filters import threshold_triangle

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [4]:
wd = Path("/home/bendera/Lorentz_folder/AlecBender/mumax_training_files/presentation_files"
          ).expanduser().resolve()
magdir = wd / "magnetizations"
labeldir = wd / "label_images"
imdir = wd / "training_images"
print(magdir.exists(), labeldir.exists(), imdir.exists())

True True False


# Creating the ground truths 
should be pretty straightforward. Determine a cutoff of the magnetization (e.g. 0.95), threshold each Mz, save all images as .tif

In [3]:
def truth_from_ovf(path, thresh, background=None, show=False):
    erosion = True
    thresholding = False
    watershed = False
    #assert (erosion == True) and (thresholding == True), "Only use a single truth filter"
    #assert (erosion == False) and (thresholding == False), "Use a truth filter"
    
    mx, my, mz, del_px, _zscale = load_ovf(path, sim='norm', v=0)
    mz = np.mean(mz, axis=0)
    
    if background is None: 
        # get background from net mz magnetization
        if mz.sum() > 0:
            background = 'pos'
        elif mz.sum() < 0: 
            background = 'neg'
            mz *= -1
        else:
            print("Background is zero... dafuq?")
            return 
        print("Background is: ", background)
    #from skimage.filters import try_all_threshold

    #fig, ax = try_all_threshold(mz, figsize=(10, 8), verbose=False)
    #plt.show()
    
    if erosion:
        truth = np.where(mz < thresh, 1, 0)
        # for some reason int8 throws an error with tifffile
        truth = ndi.binary_erosion(truth).astype('uint8')
        if show: 
            show_im(mz, f'original', simple=True)
            show_im(truth, f'truth, {thresh}', simple=True)
            show_2D(np.mean(mx, axis=0), np.mean(my, axis=0), mz, color=True, a=0)
        
        if watershed:
            watershed_segmentation(truth)
        
        print(truth, del_px)
        return truth, del_px 

    elif thresholding:

        thresh_triangle = threshold_triangle(mz)
        truth = np.where(mz < thresh_triangle, 1, 0)
        truth = (truth).astype('uint8')

        '''
        fig, axes = plt.subplots(ncols=3, figsize=(8, 2.5))
        ax = axes.ravel()
        ax[0] = plt.subplot(1, 3, 1)
        ax[1] = plt.subplot(1, 3, 2)
        ax[2] = plt.subplot(1, 3, 3, sharex=ax[0], sharey=ax[0])

        ax[0].imshow(mz, cmap=plt.cm.gray)
        ax[0].set_title('original')
        ax[0].axis('off')

        ax[1].hist(mz.ravel(), bins=256)
        ax[1].set_title('Histogram')
        ax[1].axvline(thresh, color='r')

        ax[2].imshow(truth, cmap=plt.cm.gray)
        ax[2].set_title('truth')
        ax[2].axis('off')

        plt.show()
        '''

        if watershed:
            watershed_segmentation(truth)
        
        #print(truth, del_px)
        return truth, del_px 
    


Watershed and random walker for segmentation
============================================

This example compares two segmentation methods in order to separate two
connected disks: the watershed algorithm, and the random walker algorithm.

Both segmentation methods require seeds, that are pixels belonging
unambigusouly to a reagion. Here, local maxima of the distance map to the
background are used as seeds.


In [4]:
def watershed_segmentation (image):
    import numpy as np
    from skimage.morphology import watershed
    from skimage.feature import peak_local_max
    from skimage import measure
    from skimage.segmentation import random_walker
    import matplotlib.pyplot as plt
    from scipy import ndimage
    
    # Now we want to separate the two objects in image
    # Generate the markers as local maxima of the distance
    # to the background
    distance = ndimage.distance_transform_edt(image)
    local_maxi = peak_local_max(
        distance, indices=False, footprint=np.ones((3, 3)), labels=image)
    markers = measure.label(local_maxi)
    labels_ws = watershed(-distance, markers, mask=image)

    markers[~image] = -1
    labels_rw = random_walker(image, markers)

    plt.figure(figsize=(12, 3.5))
    plt.subplot(141)
    plt.imshow(image, cmap='gray', interpolation='nearest')
    plt.axis('off')
    plt.title('image')
    plt.subplot(142)
    plt.imshow(-distance, interpolation='nearest')
    plt.axis('off')
    plt.title('distance map')
    plt.subplot(143)
    plt.imshow(labels_ws, cmap='nipy_spectral', interpolation='nearest')
    plt.axis('off')
    plt.title('watershed segmentation')
    plt.subplot(144)
    plt.imshow(labels_rw, cmap='nipy_spectral', interpolation='nearest')
    plt.axis('off')
    plt.title('random walker segmentation')

    plt.tight_layout()
    plt.show()

In [None]:
# thresh 0.99 too much, do 0.98, e.g. j=510

In [6]:
plt.close('all')
single_check = False
#j = 20
thresh = 0.98
ovf_files = list(magdir.glob("*.ovf"))
ovf_files.sort()
# sparse skyrmions i=100, dense i=0, stripes i=18

i = 1
tot = len(ovf_files)
for ovf in ovf_files: 
    if single_check: 
        ovf = ovf_files[j]
        print(ovf.stem)
    if i % 100 == 0: 
        print(f"{i}/{tot}", end="\r")
    ovfname = ovf.stem 
    deets = ovfname.split('_')
    set_name = deets[0] + deets[1]
    ovf_num = deets[2] 
    rest_name = deets[3:]
    tifname = f"{set_name}_{ovf_num}_LABEL_th{thresh}.tif"
    
    label, del_px = truth_from_ovf(ovf, thresh, background='pos',
                                   show=True if single_check else False)
    # save 
    res = 1/del_px
    tifffile.imsave(str(labeldir / tifname),
                    label,
                    imagej = True,
                    resolution = (res, res),
                    metadata={'unit': 'nm'}
                    )
    i += 1
    if single_check:
        break
print("===============\nDone\n================")

Done


# Creating the LTEM images

In [7]:
from comp_phase import mansPhi
from scipy.spatial.transform import Rotation as R
from scipy.constants import mu_0
from microscopes import Microscope

In [8]:
# simulate microscope-like defocus
def sim_im(phi, pscope, defocus): 
    obj_wave = np.cos(phi) + 1j * np.sin(phi)
    dy, dx = phi.shape
    qq = dist(dy, dx, shift=True)
    pscope.defocus = defocus
    im_def = pscope.getImage(obj_wave, qq, del_px)
    return norm_image(im_def)

### single example

In [9]:
path = ovf_files[0]
mx, my, mz, del_px, zscale = load_ovf(path, sim='norm', v=0)
zdim = mz.shape[0]
deets = path.stem.split('_')
Ms = 1.45e+05
b0 = Ms * mu_0

# code below can be used for changing
'''
for d in deets: 
    if d.startswith("Ms"):
        Ms = float(d[2:])
        b0 = Ms * mu_0
'''
    
set_name = deets[0]
ovf_num = deets[1] 
# show_2D(np.mean(mx, axis=0), np.mean(my, axis=0), np.mean(mz, axis=0), color=True, a=0)
mx, my, mz = mx.sum(axis=0), my.sum(axis=0), mz.sum(axis=0)


In [None]:
theta_x = 20 # degrees, tilt around x axis
theta_y = 0 

Tx = R.from_rotvec(np.deg2rad(theta_x) * np.array([1,0,0]))
Ty = R.from_rotvec(np.deg2rad(theta_y) * np.array([0,1,0]))

beam_z = [0,0,1]
beam_dir = np.around(Tx.apply(Ty.apply(beam_z)),5)
print('beam direction: ', beam_dir)
print(f"angle from normal: {np.rad2deg(np.arctan2(beam_dir[1],beam_dir[2])):.1f}")

phi0 = 2.07e7 #Gauss*nm^2 flux quantum
pre_B = 2*np.pi*b0*zscale*del_px/(zdim*phi0)

mphi = mansPhi(mx, my, mz, beam=beam_dir) * pre_B
# show_im(mz, 'mz', simple=True)
show_im(mphi, 'phase shift')


In [None]:
ALTEM = Microscope(E=200e3,Cs = 200.0e3, theta_c = 0.01e-3, def_spr = 80.0)
im_un = norm_image(sim_im(mphi, ALTEM, -1_000_000))
show_im(im_un)

## LTEM images for all

In [17]:
single_check = False
# j = 7
ovf_files = list(magdir.glob("*.ovf"))
ovf_files.sort()
# sparse skyrmions i=100, dense i=0, stripes i=18

theta_x = 20 # degrees, tilt around x axis
theta_y = 0 
ALTEM = Microscope(E=200e3,Cs = 200.0e3, theta_c = 0.01e-3, def_spr = 80.0)
defocus = -2_000_000

# applies theta tilt for each axis
Tx = R.from_rotvec(np.deg2rad(theta_x) * np.array([1,0,0]))
Ty = R.from_rotvec(np.deg2rad(theta_y) * np.array([0,1,0]))

beam_z = [0,0,1]
beam_dir = np.around(Tx.apply(Ty.apply(beam_z)),5)

# print('beam direction: ', beam_dir)
# print(f"angle from normal: {np.rad2deg(np.arctan2(beam_dir[1],beam_dir[2])):.1f}")

phi0 = 2.07e7 #Gauss*nm^2 flux quantum

i = 0
tot = len(ovf_files)
for ovf in ovf_files: 
    if i % 100 == 0: 
        print(f"{i}/{tot}", end="\r")

    if single_check:
        ovf = ovf_files[j]

    ovf_name = ovf.stem 
    deets = ovf_name.split('_')
    set_name = deets[0] + deets[1]
    ovf_num = deets[2] 
    rest_name = deets[3:]
    tif_name = f"{set_name}_{ovf_num}_LTEM_Tx{theta_x}_Ty{theta_y}_df{defocus*1e-6}mm.tif"

    # initialize Msat for b0
    for d in deets: 
        if d.startswith("AvgMs"): # (Msat from regions) | (Msat from uniform)
            Ms = float(d[5:])
            b0 = Ms * mu_0
            break 
        elif d.startswith("Ms"):
            Ms = float(d[2:])
            b0 = Ms * mu_0
            break 

    # create LTEM image
    mx, my, mz, del_px, zscale = load_ovf(ovf, sim='norm', v=0)
    zdim = mz.shape[0]
    mx, my, mz = mx.sum(axis=0), my.sum(axis=0), mz.sum(axis=0)
    pre_B = 2*np.pi*b0*zscale*del_px/(zdim*phi0)
    mphi = mansPhi(mx, my, mz, beam=beam_dir) * pre_B # scale intensity of the phase shift
    im_def = norm_image(sim_im(mphi, ALTEM, defocus)).astype('float32') # doesnt support float 16
    
    # save 
    res = 1/del_px
    tifffile.imsave(str(imdir / tif_name),
                    im_def,
                    imagej = True,
                    resolution = (res, res),
                    metadata={'unit': 'nm'}
                    )
    i += 1
    if single_check:
        show_2D(mx, my, mz, color=True, a=0)
        show_im(im_def, simple=True)
        break
print("================\nDone\n================")

Done
