## Imports and dependencies 
The template matching import is optional, as it would give a more realistic placement of the diffraction spots. However, it is unknown how much that matters for the training of the network

In [None]:
#import template_matching
%matplotlib qt
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt

In the case of not importing template matching, this function generates random positions for the diffraction spots. 

In [None]:
def create_fake_lib(number_of_spots, size = (256, 256)):
    return [(np.random.randint(size[0]),np.random.randint(size[0])) for _ in range(number_of_spots)]

In [None]:
def gauss(x, y, px, py, A, sigma):
    return A*np.exp(-(x-px)**2/sigma)*np.exp(-(y-py)**2/sigma)

In [None]:
def sim_diff_pat(N, real_diffpat = False):
    #N is the number of images and corresponding masks that are generated 
    #real_diffpat is a boolean parameter determining if the generated diffraction patterns are
    #from a simulated set of diffraction patterns or randomly generated coordinates 
    TM = None
    if real_diffpat:
        elements = ["muscovite", "quartz"]
        TM = template_matching.Template_matching(elements = elements)
        
        TM.create_lib(1.0, deny_new = False, minimum_intensity=1E-20, 
                      max_excitation_error=78E-4, force_new = False, camera_length=129, 
                      half_radius = 128, reciprocal_radius=1.3072443077127025, 
                      accelerating_voltage = 80, diffraction_calibration = 0.010212846154005488,
                      precession_angle=1.0)

    
    
    images = np.zeros((N,256,256))
    masks = np.zeros((N,256,256))
    x, y = np.meshgrid(np.arange(256), np.arange(256))

    for n in tqdm(range(N)):
        if real_diffpat:
            scatters_lib = TM.library[elements[np.random.randint(2)]]["pixel_coords"]
            scatters = scatters_lib[np.random.randint(len(scatters_lib))]
        else:
            scatters = create_fake_lib(np.random.randint(30))
        #Generates an image with random nomlized noise. 
        images[n] = np.random.normal(1 + 9 * np.random.rand(),1 + 8 * np.random.rand(), (256,256))

        #Random numbers used to generate the amplitudes and standard deviation for the diffraction spots
        A = np.random.rand(len(scatters))
        sigma = np.random.rand(len(scatters))
        for i, scatter in enumerate(scatters):
            images[n] += gauss(x,y,scatter[0], scatter[1], A[i]*10 + 5,sigma[i]*20 + 3)

        #Multipling the image with a gaussian to fade away the edges of the image
        images[n] *= gauss(x,y,128,128, 1, 10000 * np.random.rand() + 3000)

    
        for i, scatter in enumerate(scatters):
            masks[n] += gauss(x,y,scatter[0], scatter[1], A[i]*10 + 5,sigma[i]*20 + 3)
    
        masks[n] = np.where(masks[n] < 3, 0, 1)

        
    masks = np.array(masks, dtype = np.int8)
    #np.save(..., images) # Save the masks and images for training of the network
    #np.save(..., masks)
    return images, masks
    


In [None]:
images, masks = sim_diff_pat(5, False)

In [None]:
for i, img in enumerate(images):
    plt.figure()
    plt.imshow(img)
    plt.figure()
    plt.imshow(masks[i])
    