In [None]:
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
import utils
from data_parameters import data_param
import json

Start by getting the SI preprocessed (normalized, zero padded, alligned, ...) and save it to a TFRecord so that we dont have to repeat the operations every time. We use a SI from Chen, B., Gauquelin, N., Strkalj, N. et al. Signatures of enhanced out-of-plane polarization in asymmetric BaTiO3 superlattices integrated on silicon. Nat Commun 13, 265 (2022). https://doi.org/10.1038/s41467-021-27898-x


In [None]:
%run IO_help/SI_to_TFRecord.py --name=SI_Chen --extension=dm3 --offset=20
filename = 'SI_data/SI_Chen'

Choose a NN, here pick the ensemble since it is the most robust

In [None]:
model = tf.keras.models.load_model(f'trained element identification models/2ViT_3UNet_ensemble',custom_objects={"custom_loss": utils.custom_loss}) 
threshold = 0.75

Choose some selection criterion to surpress likely false positives, for example only consider elements that occur in 1% of the spectra
. You can just put this to zero to see all positive identifications

In [None]:
concentration_threshold = 0.01

Data reader

In [None]:
per_sys = data_param['per_sys']
N_elem = len(per_sys)
spectrum_length = data_param['spectrum_length']
batch_size = 4072
def parse_tfr_element(element):
    length = spectrum_length
    #use the same structure as in writing; 
    data = {
        'raw_spec' : tf.io.FixedLenFeature([], tf.string),
        'idx_0': tf.io.FixedLenFeature([], tf.int64),
        'idx_1': tf.io.FixedLenFeature([], tf.int64),
        'sum': tf.io.FixedLenFeature([], tf.float32),
           }

      
    content = tf.io.parse_single_example(element, data)
    
    raw_spec = content['raw_spec']
    idx_0 = content['idx_0']
    idx_1 = content['idx_1']
    I = content['sum']
  
    feature = tf.io.parse_tensor(raw_spec, out_type=tf.float64)
    feature = tf.reshape(feature, shape=[length])
    return (feature,idx_0,idx_1,I)

def get_dataset(filename):
    #create the dataset
    dataset = tf.data.TFRecordDataset(filename)

    #pass every single feature through our mapping function
    dataset = dataset.map(
        parse_tfr_element
    )
      
    return dataset

Make a series of maps (one for each element) with binary predictions

In [None]:
with open(f'{filename}_metadata.json') as f:
    metadata = json.load(f)
idx0_range = metadata['idx_0_range']
idx1_range = metadata['idx_1_range']
idx0_step = metadata['idx_0_step']
idx1_step = metadata['idx_1_step']
units = metadata['idx_0_units']

Im = np.zeros(shape=(idx0_range,idx1_range))
for spec, idx0,idx1,i in get_dataset(f'{filename}.tfrecords').take(-1):
    idx0,idx1 = idx0.numpy(),idx1.numpy()
    Im[idx0,idx1] = i.numpy()


coord = get_dataset(f'{filename}.tfrecords').batch(batch_size,drop_remainder=False)
map = np.zeros(shape=(idx0_range,idx1_range,N_elem), dtype=np.float64)

for spec,IDX0,IDX1,i in coord.take(-1):
    pred = model.predict(spec,verbose=0, workers=10,use_multiprocessing=True)
    IDX0 = IDX0.numpy()
    IDX1 = IDX1.numpy()
    for N in range(int(tf.shape(spec)[0])):
        idx0,idx1 = IDX0[N],IDX1[N]
        idx_found = np.where(pred[N] > threshold)
        map[idx0,idx1,idx_found] = 1


c = []
for i in range(N_elem):
    if np.sum(map[:,:,i])/(idx0_range*idx1_range) > concentration_threshold:
        c.append(i)

Visualize the predictions. Be sure to check the pdf result because the pop-up image might not properly display some details

In [None]:
elements_per_row = 5
s = int(np.ceil(len(c)/elements_per_row))

fig, ax = plt.subplots(s,elements_per_row,figsize=(elements_per_row,s*5))

n = 0
if s==1:
    for j in range(elements_per_row):
        if n < len(c):
            k = c[n]
            heatmap = map[:,:,k]
            heatmap = np.where(heatmap == 0., np.nan, heatmap)
            ax[j].imshow(Im.T/np.max(Im),cmap = "gray")
            im = ax[j].imshow(heatmap.T,cmap = 'Reds',alpha=0.7)
            cb = plt.colorbar(im)
            ax[j].set_title(per_sys[k])
            n +=1
            cb.remove() 
        ax[j].axis('off')
else:
    for i in range(s):
        for j in range(elements_per_row):
            if n < len(c):
                k = c[n]
                heatmap = map[:,:,k]
                heatmap = np.where(heatmap == 0., np.nan, heatmap)
                ax[i,j].imshow(Im.T/np.max(Im),cmap = "gray")
                im = ax[i,j].imshow(heatmap.T,cmap = 'Reds',alpha=0.7)
                cb = plt.colorbar(im)
                ax[i,j].set_title(per_sys[k])
                n +=1
                cb.remove() 
            ax[i,j].axis('off')

plt.tight_layout()
plt.savefig('mapping_result.pdf',dpi=600)