In [None]:
from nanorsm_parallel import *
db_key = 'new' # for scans done with old flyscan, use 'old'. Deault is 'new' 

In [None]:
# load elemental image data, align it with pystackreg and generate a tranform matrix
%matplotlib qt
scan = '346726-346736'
interval = 2
parent_path = '//nsls2//data//hxn//legacy//users//2025Q2//Marschilok_2025Q2//'
data_path = f"{parent_path}xrf//"
output_path = f"{parent_path}nanoRSM//{scan}_{interval}//"
if not os.path.exists(output_path):
    os.makedirs(output_path)
    print(f"Created directory: {output_path}")


sid_list = get_sid_list([scan],interval)
elem = 'Ni_K'
file_list = [
    f"{data_path}output_tiff_scan2D_{sid}//detsum_{elem}_norm.tiff"
    for sid in sid_list
]
im_stack = load_ims(file_list)
num_frame,im_row,im_col = np.shape(im_stack)
im_stack_aligned, trans_matrix = align_im_stack_v1(im_stack,'AFFINE') # use pystackreg

im_stack_test = interp_sub_pix_v1(im_stack,trans_matrix) # verify the alignment is done correctly
slider_view(im_stack_test)


In [None]:
# if there are more elements to be aligned and stacked
elem_list = ['Ni_K', 'Mn_K','Co_K']
stack = []
for i, elem in enumerate(elem_list):
    file_list = [
        f"{data_path}output_tiff_scan2D_{sid}//detsum_{elem}_norm.tiff"
        for sid in sid_list
    ]
    im_stack = load_ims(file_list)
    im_stack_test = interp_sub_pix_v1(im_stack, trans_matrix)
    imp = np.sum(im_stack_test, axis=0)
    stack.append(imp[np.newaxis, ...])
    tifffile.imwrite(f"{output_path}{elem}.tiff",im_stack_test.astype(np.float32),imagej=True)
    print(f"save {elem} to {output_path}{elem}.tiff")
stack = np.concatenate(stack, axis=0)                  

In [None]:
# sum all diffraction patterns from local files to determine roi
det_name = 'merlin1'

tot = sum_all_h5_data_db_parallel(sid_list, det = det_name)
roi = select_roi(np.log(tot+1))

In [None]:
# save all parameters to a json file for future reference

scan_row = im_row
scan_col = im_col
mon_name = 'sclr1_ch4'
threshold = None
roi_offset = [0,0]
data_store = 'reduced' # this will reduce the data stored. If use 'full' it can be over 100G
microscope = 'mll' # choose 'mll' or 'zp'

params_db = read_params_db(sid_list,microscope=microscope,det=det_name)
params_user = {"scan ids": sid_list,
               "fluorescence data path": data_path,
               "output path": output_path,
               "element list": elem_list,
               "element for alignment": elem,
               "alignment matrix": trans_matrix,
               "scan dimensions": [scan_row,scan_col],
               "detector name": det_name,
               "detector roi": roi,
               "threshold": threshold,
               "monitor": mon_name,
               "roi offset": roi_offset,
               "data store": data_store
              }
params = params_db | params_user

if not os.path.exists(output_path):
    os.makedirs(output_path)
    print(f"Created directory: {output_path}")
    
params = convert_numpy(params)

with open(f"{output_path}parameters.json", "w") as f:
    json.dump(params, f, indent=4)
    print(f"write parameters to {output_path}parameters.json")


In [None]:
# load diffraction patterns and align them

diff_data = load_h5_data_db_parallel(sid_list,det=det_name,mon=mon_name,roi=roi,threshold=threshold,max_workers = 1)
# diff_data = nanorsm.load_h5_files_parallel(files,roi)
sz = diff_data.shape
diff_data = np.reshape(diff_data,(sz[0],scan_row,scan_col,sz[2],sz[3]))
diff_data = interp_sub_pix_v1(diff_data,trans_matrix)


In [None]:
### transform to cartesian crystal coordinates (z along hkl and x is the rocking direction)

energy = params['energy']
delta = params['delta']
gamma = params['gamma']
num_angle = params['number of angles']
th_step = params['angle step']
pix = params['pixel size']
det_dist = params['detector distance']
offset = np.asarray(params['roi offset'])
data_store = params['data store']  

method = {'fit_type': 'com',# fitting method: center of mass, 'com' or 'peak'
         'shape': 'gaussian', # peak shape: 'gaussian', 'lorentzian', or 'voigt'
         'n_peaks': [1,1,1], # number of peaks in each direction along qx, qy, and qz
          'mask': 'tot', # reference image used for reference
          'mask threshold': 0.1 # pixels with values below maximun*thresh are set to zero
         }  


# generate an object of the RSM class
rsm = RSM(diff_data,energy,delta,gamma,num_angle,th_step,pix,det_dist,offset,stack,elem_list)
# transform from detector coordinates to crystal coordinates
rsm.calcRSM('cryst',data_store,desired_workers=5)
# calculate strain
# 'com', center of mass, is a simple algorithm to calculate the strain. Note: There is an abitrary offset
rsm.calcSTRAIN(method) 
# show results
rsm.disp()
# save results
rsm.save(output_path)
# also save the entire object
# save_file = f"{output_path}all_data.obj"
# pickle.dump(rsm, open(save_file,'wb'),protocol = 4)

In [None]:
rsm.run_interactive(scale='log')

In [None]:
desc = {
       'title':'Movie',
       'artist': 'hyan',
       'comment': 'test',
       'save_file': f"{output_path}movie.mp4",
       'fps': 15,
       'dpi': 100
      }

rsm.generate_movie(desc,scale='log')

In [None]:
def create_movie(desc, names,im_stack,label,data_4D,path,cmap='jet',color='white',clim=None):
    # desc: a dictionary. Example,
    # desc ={
    #    'title':'Movie',
    #    'artist': 'hyan',
    #    'comment': 'Blanket film',
    #    'save_file': 'movie_blanket_film.mp4',
    #    'fps': 15,
    #    'dpi': 100
    # }
    # names: names of the individual im in im_stack
    # label: name of the 4D dataset
    # data_4D: the 4D dataset, for example, [row, col, qx, qz]
    # path: sampled positions for the movie, a list of [row, col]
    # cmap: color scheme of the plot
    # color: color of the marker
    
    n = min(len(names), im_stack.shape[0])
    m = min(len(label),len(data_4D))

    total = n + m
    ncols, nrows = 3, int(np.ceil(total / 3))

    fig, axs = plt.subplots(
        nrows, ncols,
        figsize=(3 * ncols, 3 * nrows),
        gridspec_kw={'wspace': 0.1, 'hspace': 0.2}
    )
    axs = axs.ravel()

    # initial plotting
    
    for ax, name, img in zip(axs, names[:n], im_stack[:n]):
        ax.imshow(img, cmap=cmap, aspect='auto')
        ax.set_title(name, fontsize=9)
        ax.axis('off')
    for i in range(m):
        diff_ax = axs[n+i]
        im = diff_ax.imshow(data_4D[i][0, 0], cmap=cmap, clim=clim, aspect='auto')
        diff_ax.set_title(label, fontsize=9)
        diff_ax.axis('off')
        plt.show(block=False)

    for ax in axs[total:]:
        ax.axis('off')

    cbar = fig.colorbar(im, ax=diff_ax, fraction=0.046, pad=0.02)
    cbar.ax.tick_params(labelsize=8)
    
    fig.tight_layout()
    
    def update_fig(row,col,cmap=cmap,color=color):
        
        plt.cla()
        for i, ax in enumerate(axs[:total]):
            ax.clear()
            if i < n:
                ax.imshow(im_stack[i], cmap=cmap, aspect='auto')
                ax.plot(col, row, 'o', color=marker_color, ms=5)
                ax.set_title(names[i], fontsize=9)
            else:
                j = i - n
                ax.imshow(data_4D[j][row, col], cmap=cmap, clim=clim, aspect='auto')
                ax.set_title(label[j], fontsize=9)
            ax.axis('off')
        fig.tight_layout()
        #fig.canvas.draw_idle()
        return 
     
    FFMpegWriter = manimation.writers['ffmpeg']
    metadata = dict(title=desc['title'], artist=desc['artist'],
                comment=desc['comment'])
    writer = FFMpegWriter(fps=desc['fps'], metadata=metadata)
     
    with writer.saving(fig, desc['save_file'], dpi=desc['dpi']):
        writer.grab_frame()
         
        for j in tqdm(range(len(path)),desc='Progress'):       
            update_fig(path[j][0],path[j][1],cmap=cmap,color=color)
            writer.grab_frame()
    writer.finish()