In [None]:
import os, subprocess
import numpy as np
import xarray as xr
import glob as glob
import matplotlib as mpl
from zipfile import ZipFile, ZIP_DEFLATED
from typing import Union
from pathlib import Path
from datetime import datetime

import matplotlib.pyplot as plt
from tqdm.notebook import trange
from scipy.interpolate import RBFInterpolator

mpl.rcParams['mathtext.default'] = 'regular'

%matplotlib widget

# Wafer simulation functions

Func 1: Create dataset with coordinates
- Functionality: make a rectangular grid, create dataset with coordinates
- Inputs: shape (circle or square), diameter (use motor units like mm), resolution (beam size, match units of diameter)
- Returns: dataset with coordinates, attributes of shape, diameter, resolution (to be used in other functions)

Func 2: Calculate elemental composition at each point
- Functionality: Set elemental composition at discrete points (manually entered or calculated at distance from center/edge)
- Inputs: 1) dataset (shape, diameter, resolution are attributes than can be accessed!) 2) num_positions (N) 3) elements (E) 4) element_comps (E, N) 5) method can be manual or equidistant, 6) smoothing factor=1.0
    - How do we have method specific arguments? e.g., if manual must provide ((xs,ys), N), if equidistant must provide % from center to edge 
    - Add data variable for elemental_weights with 0s
        - if shape == circle, set outside circle to NaNs
- Returns: dataset with elements and element_weights, with NaNs for weights outside shape (if shape == circle)
    - Additional attributes are: elements, element_comps + positions, smoothing factor

Func 3: Interpolate phase_weights, I(Q) at each point
- Functionality:
- Inputs: dataset, phase diagram dataset (needed for interpolation), 
- Returns: dataset (ground_truth) with phases, phase_weights, I(Q) at each point


In [None]:
def get_repo_root() -> Path:
    # 1) If running inside a Git clone, ask Git
    try:
        root = subprocess.check_output(
            ["git", "rev-parse", "--show-toplevel"], text=True
        ).strip()
        return Path(root)
    except Exception:
        pass
    # 2) Optional: allow an override via env var
    if os.getenv("PROJECT_ROOT"):
        return Path(os.environ["PROJECT_ROOT"]).expanduser().resolve()
    # 3) Fallback: current working directory
    print('Root directory not found, using current working directory')
    return Path.cwd()

In [None]:
def create_dataset_with_coords(shape, diameter, resolution) -> xr.Dataset:
    """
    Parameters
    ----------
    shape : string
        circle or square

    diameter : float
        diameter in motor coord units (typically mm)

    resolution: float
        beam size (same units as diameter)
    """
    # Create coordinate grid
    pts_per_side = int(np.ceil(diameter / resolution)) # number of points per side
    x = np.linspace(-diameter//2, diameter//2, pts_per_side)
    y = np.linspace(-diameter//2, diameter//2, pts_per_side)
    xv, yv = np.meshgrid(x, y) # create grid coordinates from 1D x,y arrays. xv, yv are 2D arays for each point on the grid
    coords = np.column_stack([xv.ravel(), yv.ravel()])

    if shape == 'circle':
        x_coords = coords[:,0]
        y_coords = coords[:,1]
        center = (0,0)
        distances = np.sqrt((x_coords - center[0])**2 + (y_coords - center[1])**2)
        mask = distances >= (diameter / 2)
        points_inside = len([i for i in mask if not i])
        #valid_coords = np.array([coords[i] for i in range(coords.shape[0]) if mask[i] == False])
        #coords_mask = np.reshape(mask, (len(x),len(x),1))

    #else:
    #    valid_coords = coords

    # create dataset functionality below here
    ds = xr.Dataset(
        coords={
            "x": ("points", coords[:,0]),
            "y": ("points",coords[:,1]),
            "xy": (("points", "tuple_index"), coords)
        }
    )

    ds.attrs['description'] = f"Simulated dataset with shape = {shape}, diameter = {diameter}, resolution = {resolution}, and points inside = {points_inside}."
    ds.attrs['shape'] = shape
    ds.attrs['shape_center'] = (0,0)
    ds.attrs['shape_diameter'] = diameter
    ds.attrs['resolution'] = resolution
    ds.attrs['points_inside'] = points_inside
    #ds.attrs['valid_coords'] = valid_coords

    #if shape == 'circle':
    #    ds.attrs['mask_2D'] =  coords_mask

    return ds


def calc_elemental_comps(dataset: xr.Dataset, num_compositions: int, elements_list, discrete_compositions, positions='calculated',
                         calc_dist_scale = 100, deg_rotation=-90,
                         discrete_comp_coords = None,
                         find_on_grid=True,
                         smoothing_factor=1.0) -> xr.Dataset:
    """
    Parameters
    ----------
    dataset : xr.Dataset
        Xarray dataset created from create_dataset_with_coords()

    num_compositions : int
        Number of discrete compositions to be set

    elements_list : list of strings
        element names (e.g., ['Al', 'Li', 'Fe']) matching order of discrete_compositions

    discrete_compositions : array of arrays, where shape = (num_compositions, len(elements_list))
        Elements of arrays are floats between 0 and 1. Each sub-array must sum to 1

    positions : string
        Options are 'calculated' or 'manual'    ### IS THIS THE BEST WAY TO HANDLE THIS? HOW CAN WE SEPARATE EQUIDISTANT CALC + DIST_FROM_CENTER_TO_EDGE_PERCENT AND MANUAL INPUTS?
            If 'calculated', calculates equidistant points around a shape and scales positions from center to edge based on radius - these points are the coordinate positions for discrete_compositions
                Required inputs: equidistant_scale, deg_rotation, find_on_grid
            If 'manual', the positions are manually defined
                Required inputs: discrete_comp_coords, find_on_grid

    calc_dist_scale : float
        Percent distance from shape center to edge for calculating equidistant points where 0 = at center and 100 = at edge

    deg_rotation : float
        Rotating calculated points around a circle - defaulted to -90

    find_on_grid : bool
        If True, finds coordinates nearest calculated or inputted positions. If False, uses exact positions calculated or inputted.

    discrete_comp_coords : array of arrays where shape = (num_compositions, 2).
        Elements (floats) are x,y coordinate pairs corresponding to discrete_compositions

    smoothing_factor : float
        Controls gaussian smoothing between discrete compositions

    Returns
    ----------
    dataset : xr.Dataset
        Xarray dataset with elements and element_weights
    """
    # Lots of error checking is needed - this is not exhaustive but should catch most major errors
    if num_compositions != discrete_compositions.shape[0]:
        raise ValueError(f"num_compositions ({num_compositions}) must be an integer matching discrete_compositions.shape[0] ({discrete_compositions.shape[0]})")

    if discrete_compositions.shape != (num_compositions, len(elements_list)):
        raise ValueError(f"discrete_compositions must have shape of (num_compositions, len(elements_list)) - currently {discrete_compositions.shape}")

    if positions == 'calculated':
        if not (0 <= calc_dist_scale >= 100):
            raise ValueError(f"Positions are 'calculated' so calc_dist_scale must be between 0 and 100 % (currently = {calc_dist_scale})")

    if discrete_comp_coords is not None:
        if positions == 'calculated':
            raise ValueError(f"discrete_comp_coords is not None, but positions are 'calculated' - must have 'manual' positions or set discrete_comp_coords to None")
        elif positions =='manual':
            if discrete_comp_coords.shape != (num_compositions, 2):
                raise ValueError(f"discrete_comp_coords must have shape (num_compositions, 2) - current shape is {discrete_comp_coords.shape}")

    for num, sublist in enumerate(discrete_compositions):
        if np.sum(sublist) != 1:
            raise ValueError(f"Sublist {num} in discrete_compositions does not sum to 1. Ensure sublist elements sum to 1 and try again.\nSublist = {sublist}. Sum = {np.sum(sublist)}")

    # Composition center calculations

    if positions == 'calculated':
        if dataset.attrs['shape'] == 'circle':
            radius = dataset.attrs['shape_diameter'] / 2 # this is true for both square and circle
            center = dataset.attrs['shape_center'] # assuming always centering at origin

            # Calculate angles in radians given a number of compositions - endpoint = False ensures 0=360
            # deg_rotation allows for rotating the points to a desired position (e.g., set 1st point at bottom of shape)
            anglesrad = np.linspace(0, 2 * np.pi, num_compositions, endpoint=False)
            anglesdeg = [np.degrees(angle) + deg_rotation for angle in anglesrad]
            angles = np.radians(anglesdeg)

            # Calculate the x and y coordinates
            x_positions = center[0] + (radius * calc_dist_scale *0.01) * np.cos(angles)
            y_positions = center[1] + (radius * calc_dist_scale * 0.01)  * np.sin(angles)

        elif dataset.attrs['shape'] == 'square':
            raise ValueError("I still have to write equidistant point calculations for a square - come back later")

    elif positions == 'manual':
        # already checked that discrete_comp_coords is defined and the right shape
        x_positions = discrete_comp_coords[:,0]
        y_positions = discrete_comp_coords[:,1]

    #
    # Find on grid vs. using exact coords
    #

    # Find nearest coordinate for the x and y coords
    if find_on_grid is True:
        comp_points = np.zeros((num_compositions,2))
        for i, (x_pos, y_pos) in enumerate(zip(x_positions,y_positions)):
            distances = np.sqrt((dataset.x.values - x_pos)**2 + (dataset.y.values - y_pos)**2)
            nearest_index = int(distances.argmin())
            nearest_point = dataset.xy.values[nearest_index]
            comp_points[i] = nearest_point
    else:
        comp_points = np.column_stack((x_positions,y_positions)) # use calculated or inputted positions directly

    #
    # Calculate composition weights
    #

    x_coords = dataset['x'].values
    y_coords = dataset['y'].values
    coords = np.column_stack((x_coords, y_coords))

    # Number of points and elements
    num_points = coords.shape[0]
    num_elements = len(elements_list)

    # Initialize weights array
    weights = np.zeros((num_points, num_elements))

    # Track points that match composition centers
    fixed_points = np.zeros(num_points, dtype=bool)

    # Identify matching points
    for i, (x_center, y_center) in enumerate(comp_points):
        matching_points = (coords[:, 0] == x_center) & (coords[:, 1] == y_center)
        weights[matching_points] = discrete_compositions[i]
        fixed_points |= matching_points

    # Calculate weights for non-fixed points
    for i, (x_center, y_center) in enumerate(comp_points):
        # Calculate distances
        distances = np.sqrt((coords[:, 0] - x_center) ** 2 + (coords[:, 1] - y_center) ** 2)

        # Apply Gaussian smoothing and reshape
        exp_distances = np.exp(-distances / smoothing_factor)  # shape = num_points
        exp_distances_reshaped = exp_distances.reshape(-1, 1)  # shape = num_points, 1

        # Multiply distances with composition_at_centers[i] for non-fixed points
        weighted_composition_at_center = exp_distances_reshaped * discrete_compositions[i]  # shape = num_points, num_elements

        # Add weighted composition to weights for non-fixed points
        weights[~fixed_points] += weighted_composition_at_center[~fixed_points]

    # Normalize weights to sum to 1 for each point
    weights_sum = weights.sum(axis=1)  # shape = num_points
    weights_normalized = weights / weights_sum[:, np.newaxis]
    weights = weights_normalized

    # Precision issue with weights is annoying for interpolation, so if sum of weights for a row is within 1e-8, divide difference by N elements and subtract/add
    weight_tolerance = 1e-20

    for index, weight in enumerate(weights):
        weight_sum = np.sum(weight)
        if not np.isclose(weight_sum, 1.0, atol=weight_tolerance):
            weight_diff = 1 - weight_sum
            if np.abs(weight_diff) < weight_tolerance:
                if weight_diff > 0:  # weight sum less than 1
                    weights[index, 0] += weight_diff
                else:  # weight sum greater than 1
                    weights[index, 0] -= weight_diff
            else:
                raise ValueError(f'Summed weight for index {index} outside tolerance for adjustment - inspect calculation of weights')

    # Check if all rows sum to 1 within tolerance
    counter = 0
    for index, weight in enumerate(weights):
        if not np.isclose(np.sum(weight), 1.0, atol=weight_tolerance):
            print(index, np.sum(weight), weight)
            counter += 1
    if counter != 0:
        raise ValueError(f"Number of weights (rows) not summing to 1: {counter}")

    # set weights to NaN where points are outside
    if dataset.attrs['shape'] == 'circle':
        diameter = dataset.attrs['shape_diameter']
        center = dataset.attrs['shape_center'] # assuming always centering at origin

        distances = np.sqrt((x_coords - center[0])**2 + (y_coords - center[1])**2)
        mask = distances >= (diameter / 2)

        weights[mask, :] = np.nan

    # Assign coordinates and weights to the dataset
    repeated_elements = np.tile(elements_list, (num_points, 1))

    dataset = dataset.assign_coords(
        elements=(("points", "elements"), repeated_elements),
        element_weights=(("points", "weights"), weights)
    )
    discrete_compositions_str = [str(comp) for comp in discrete_compositions]
    comp_points_str = [str(pos) for pos in comp_points]
    dataset.attrs['composition_centers (coords)'] = comp_points_str
    dataset.attrs['composition_centers (weights)'] = discrete_compositions_str
    dataset.attrs['smoothing_factor'] = smoothing_factor

    return dataset


def interpolate_phase_weights(element_weights_wafer, element_weights_PD, phase_weights_PD):
    # Initialize phase_weights_wafer with NaNs
    phase_weights_wafer = np.full((element_weights_wafer.shape[0], phase_weights_PD.shape[1]), np.nan)

    # Perform the interpolation for each dimension of the phase weights
    for i in trange(phase_weights_PD.shape[1], desc='Phase weights interpolation loop'):
        rbf_interpolator = RBFInterpolator(element_weights_PD, phase_weights_PD[:, i], kernel='linear')
        interpolated_values = rbf_interpolator(element_weights_wafer)

        # Ensure non-negative values
        interpolated_values = np.clip(interpolated_values, 0, np.inf)

        phase_weights_wafer[:, i] = interpolated_values

    # Normalize the rows to sum to 1.0
    row_sums = np.sum(phase_weights_wafer, axis=1, keepdims=True)
    # Avoid division by zero by setting row_sums to 1 where it's 0
    row_sums[row_sums == 0] = 1
    phase_weights_wafer = phase_weights_wafer / row_sums

    return phase_weights_wafer

def interpolate_iq_wafer(phase_weights_wafer, phase_weights_PD, iq_PD_ionly, epsilon=1e-10):
    # Initialize iq_wafer_ionly with NaNs
    iq_wafer_ionly = np.full((phase_weights_wafer.shape[0], iq_PD_ionly.shape[1]), np.nan)

    # Perform the interpolation for each dimension of the phase weights
    for i in trange(iq_PD_ionly.shape[1], desc='Intensity interpolation loop'):
        try:
            # Add a small perturbation to phase_weights_PD to avoid singular matrix issues
            phase_weights_PD_perturbed = phase_weights_PD + epsilon * np.random.randn(*phase_weights_PD.shape)

            rbf_interpolator = RBFInterpolator(phase_weights_PD_perturbed, iq_PD_ionly[:, i], kernel='linear')
            interpolated_values = rbf_interpolator(phase_weights_wafer)

            # Ensure non-negative values
            #interpolated_values = np.clip(interpolated_values, 0, np.inf)

            iq_wafer_ionly[:, i] = interpolated_values
        except Exception as e:
            print(f"Error interpolating dimension {i}: {e}")

    return iq_wafer_ionly

def interpolate_and_addtods(dataset, dataset_DRNets, noise_percentage=0.01):

    def add_noise_by_percentage(signal, noise_percentage):
        """
        Add random noise to a NumPy array based on a percentage of the maximum signal value.

        Parameters:
            signal (numpy.ndarray): The original signal array.
            noise_percentage (float): Percentage of the maximum signal value to use as noise.

        Returns:
            noisy_signal (numpy.ndarray): The signal array with added noise.
        """
        # Find the maximum value in the signal
        max_value = np.max(np.abs(signal))

        # Calculate the noise standard deviation as a percentage of the max signal value
        noise_std = (noise_percentage / 100) * max_value

        # Generate random Gaussian noise with zero mean and calculated standard deviation
        noise = np.random.normal(0, noise_std, signal.shape)

        # Add the noise to the original signal
        noisy_signal = signal + noise

        return noisy_signal

    # Get arrays from ds and ds_DRNets
    element_weights_wafer = dataset['element_weights'].values  # shape (709, 3)
    element_weights_PD = dataset_DRNets['element_weights'].values  # shape (231, 3)

    phase_weights_PD = dataset_DRNets['phase_weights'].values  # shape (231, 6)

    iq_PD = dataset_DRNets['iq'].values
    iq_PD_ionly = iq_PD[:, 1, :] # only grabbing the intensity array, shape (231, 650)
    q_points = iq_PD[0][0]

    phase_weights_wafer = interpolate_phase_weights(element_weights_wafer, element_weights_PD, phase_weights_PD)
    iq_wafer_ionly = interpolate_iq_wafer(phase_weights_wafer, phase_weights_PD, iq_PD_ionly)

    for num, pattern in enumerate(iq_wafer_ionly):
        if not np.isnan(pattern).any():
            noisy_pattern = add_noise_by_percentage(signal=pattern, noise_percentage=noise_percentage)
            iq_wafer_ionly[num] = noisy_pattern  # Explicitly assign back the noisy pattern

    iq_wafer = np.empty((iq_wafer_ionly.shape[0], 2, len(q_points)))

    # Fill the first slice of the second axis with q_points
    iq_wafer[:, 0, :] = q_points

    # Fill the second slice of the second axis with iq_wafer_ionly
    iq_wafer[:, 1, :] = iq_wafer_ionly

    # Get phase names, tile to dataset shape
    phase_names = dataset_DRNets['phase_names'][0].values
    repeated_names = np.tile(phase_names, (phase_weights_wafer.shape[0], 1))

    dataset = dataset.assign_coords(
        phase_names=(("points", "names"), repeated_names),
    )

    # Add phase_weights, I(Q) to dataset
    dataset['phase_weights'] = (('points', 'phase_weights'), phase_weights_wafer)
    dataset['iq'] = (('points', 'tuple_index', 'q_points'), iq_wafer)

    # Add Q array as an attribute (does not need to be stored with intensity)
    dataset.attrs['Q'] = q_points

    return dataset

def reshape_ds(dataset):
    pts_per_side = int(np.ceil(dataset.attrs['shape_diameter'] / dataset.attrs['resolution'])) # number of points per side
    x = np.linspace(-dataset.attrs['shape_diameter']//2, dataset.attrs['shape_diameter']//2, pts_per_side)
    y = np.linspace(-dataset.attrs['shape_diameter']//2, dataset.attrs['shape_diameter']//2, pts_per_side)
    N = len(x)

    arr_ew = np.reshape(dataset.element_weights.data, (N,N,dataset.element_weights.data.shape[1]))
    da_ew = xr.DataArray(arr_ew, dims=["y", "x", "element_weight"], coords=dict(y=y, x=x, element_weight=("element_weight", [str(i) for i in dataset.elements.data[0]])))

    arr_pw = np.reshape(dataset.phase_weights.data, (N,N,dataset.phase_weights.data.shape[1]))
    da_phase = xr.DataArray(arr_pw, dims=["y", "x", "phase_weight"], coords=dict(y=y, x=x, phase_weight=("phase_weight", [str(i) for i in dataset.phase_names.data[0]])))

    arr_iq = np.reshape(dataset.iq.data[:,1,:], (N,N,dataset.iq.data.shape[2]))
    da_iq = xr.DataArray(arr_iq, dims=["y", "x", "intensity"], coords=dict(y=y, x=x))

    reshaped_ds = xr.Dataset({
        'element_weights': da_ew,
        'phase_weights': da_phase,
        'iq': da_iq
        })

    # Copy in some existing attributes, add some new ones
    reshaped_ds.attrs['description'] = dataset.attrs['description']
    reshaped_ds.attrs['elements'] = [str(i) for i in dataset.elements.data[0]]
    reshaped_ds.attrs['phases'] = [str(i) for i in dataset.phase_names.data[0]]
    reshaped_ds.attrs['shape'] = dataset.attrs['shape']
    reshaped_ds.attrs['shape_center'] = dataset.attrs['shape_center']
    reshaped_ds.attrs['shape_width'] = dataset.attrs['shape_diameter']
    reshaped_ds.attrs['resolution'] = dataset.attrs['resolution']
    reshaped_ds.attrs['points_inside'] = dataset.attrs['points_inside']
    reshaped_ds.attrs['composition_centers (coords)'] = dataset.attrs['composition_centers (coords)']
    reshaped_ds.attrs['composition_centers (weights)'] = dataset.attrs['composition_centers (weights)']
    reshaped_ds.attrs['smoothing_factor'] = dataset.attrs['smoothing_factor']
    reshaped_ds.attrs['Q'] = dataset.attrs['Q']

    return reshaped_ds

def reshape_ds_simple(dataset):
    pts_per_side = int(np.ceil(dataset.attrs['shape_diameter'] / dataset.attrs['resolution'])) # number of points per side
    x = np.linspace(-dataset.attrs['shape_diameter']//2, dataset.attrs['shape_diameter']//2, pts_per_side)
    y = np.linspace(-dataset.attrs['shape_diameter']//2, dataset.attrs['shape_diameter']//2, pts_per_side)
    N = len(x)

    arr_ew = np.reshape(dataset.element_weights.data, (N,N,dataset.element_weights.data.shape[1]))
    da_ew = xr.DataArray(arr_ew, dims=["y", "x", "element_weight"], coords=dict(y=y, x=x, element_weight=("element_weight", [str(i) for i in dataset.elements.data[0]])))

    arr_pw = np.reshape(dataset.phase_weights.data, (N,N,dataset.phase_weights.data.shape[1]))
    da_phase = xr.DataArray(arr_pw, dims=["y", "x", "phase_weight"], coords=dict(y=y, x=x, phase_weight=("phase_weight", [str(i) for i in dataset.phase_names.data[0]])))

    arr_iq = np.reshape(dataset.iq.data[:,1,:], (N,N,dataset.iq.data.shape[2]))
    da_iq = xr.DataArray(arr_iq, dims=["y", "x", "intensity"], coords=dict(y=y, x=x))

    reshaped_ds = xr.Dataset({
        'element_weights': da_ew,
        'phase_weights': da_phase,
        'iq': da_iq
        })

    # Copy in some existing attributes, add some new ones
    reshaped_ds.attrs['description'] = dataset.attrs['description']
    reshaped_ds.attrs['elements'] = [str(i) for i in dataset.elements.data[0]]
    reshaped_ds.attrs['phases'] = [str(i) for i in dataset.phase_names.data[0]]
    reshaped_ds.attrs['shape'] = dataset.attrs['shape']
    reshaped_ds.attrs['shape_center'] = dataset.attrs['shape_center']
    reshaped_ds.attrs['shape_width'] = dataset.attrs['shape_diameter']
    reshaped_ds.attrs['resolution'] = dataset.attrs['resolution']
    reshaped_ds.attrs['points_inside'] = dataset.attrs['points_inside']
    reshaped_ds.attrs['Q'] = dataset.attrs['Q']

    return reshaped_ds

def save_ds(dataset: xr.Dataset, path: str, prefix: str, suffix: str, datetimestamp=True, remove_nc=False, drop_element_weights=True):

    if drop_element_weights:
    # for DRNets generated datasets
        if 'element_weight' in dataset.coords:
            dataset = dataset.drop_vars('element_weight')

        if 'element_weights' in dataset.data_vars:
            dataset = dataset.drop_vars('element_weights')

    # file name setup
    if datetimestamp:
        datetimestr = datetime.now(tz=None).strftime("%d%b%Y_%H-%M-%S") # time of saving ds, not creation (creation is multi-step so not timestamping)
        file = f'{prefix}_{datetimestr}.nc'
    else:
        file = f'{prefix}.nc'

    if suffix and type(suffix) == str:
        file = file.replace('.nc', suffix + '.nc')

    # check if file exists and throw an error if so
    if os.path.isfile(os.path.join(path,file)):
        raise NameError(f"file already exists - edit prefix, suffix, or change argument 'datetimestamp' to True")

    #print(f'Dataset filename:\n{file}')

    dataset.to_netcdf(os.path.join(path, file))
    with ZipFile(os.path.join(path,file.replace('.nc','.zip')), 'w', ZIP_DEFLATED) as zObject:
        zObject.write(os.path.join(path,file), arcname=file)

    if remove_nc:
        # save space, keep GitHub happy by removing .nc files
        os.remove(os.path.join(path,file))


def ds_plot2D(dataset: xr.Dataset, dataarray: str, marker='s', marker_size=2, cmap='bone_r',vmin=0,vmax=1, background_color='xkcd:eggshell', save=False):

    if dataarray not in dataset.data_vars:
        raise NameError(f'dataarray inputted ({dataarray}) is not a data variable in dataset inputted ({dataset})\nAvailable data variables: {dataset.data_vars}')

    if dataset[dataarray].shape[2] > 20:
        raise ValueError(f'3rd dimension of array is larger than 10 - this would generate {dataset[dataarray].shape[2]} figures\nThis is not the appropriate plotting tool for the inputted array')

    for i in range(dataset[dataarray].shape[2]):
        plt.figure()
        if background_color:
            ax = plt.axes()
            ax.set_facecolor(background_color)
        da_slice = dataset[dataarray][:, :, i]
        da_slice.plot(cmap=cmap, vmin=vmin, vmax=vmax) # expects y = dim[0], x = dim[1]
        if save:
            if dataarray == 'phase_weights':
                plt.savefig(f'{simwafer_path}/ds_{"".join(dataset.attrs['elements'])}_{dataset.attrs["phases"][i].split('+')[1].split('_')[0]}_weight.png', dpi=300)
            elif dataarray == 'element_weights':
                plt.savefig(f'{simwafer_path}/ds_{"".join(dataset.attrs['elements'])}_{dataset.attrs["elements"][i]}_weight.png', dpi=300)


# simpler approach to setting element weights

# num_compositions is 2, discrete_compositions should be len 2 (L / R of x_boundary)

def set_elemental_comp_LR(dataset: xr.Dataset, num_compositions, elements_list, discrete_compositions, x_boundary='center'):
    # some error checking
    if num_compositions != discrete_compositions.shape[0]:
        raise ValueError(f"num_compositions ({num_compositions}) must be an integer matching discrete_compositions.shape[0] ({discrete_compositions.shape[0]})")

    if discrete_compositions.shape != (num_compositions, len(elements_list)):
        raise ValueError(f"discrete_compositions must have shape of (num_compositions, len(elements_list)) - currently {discrete_compositions.shape}")

    for num, sublist in enumerate(discrete_compositions):
        if np.sum(sublist) != 1:
            raise ValueError(f"Sublist {num} in discrete_compositions does not sum to 1. Ensure sublist elements sum to 1 and try again.\nSublist = {sublist}. Sum = {np.sum(sublist)}")

    # validate x_boundary
    if x_boundary == 'center':
        x_center, _ = dataset.attrs['shape_center']

    else:
        try:
            if dataset.x.data.min() < x_boundary < dataset.x.data.max():
                x_center = x_boundary
            else:
                raise ValueError(f'x_boundary {x_boundary} is out of bounds')
        except TypeError:
            raise TypeError(f"x_boundary must be 'center' or a numeric value within x bounds - input is {x_boundary} (type={type(x_boundary)})")

    # setting weights based on boundary

    # create an array of same shape as coords - iterate over
    num_points = dataset.x.shape[0]

    weights = np.zeros((num_points, len(elements_list)))

    for i in range(num_points):
        if dataset.x.data[i] < x_center:
            weights[i] = discrete_compositions[0]
        elif dataset.x.data[i] > x_center:
            weights[i] = discrete_compositions[1]

    # set weights to NaN where points are outside
    if dataset.attrs['shape'] == 'circle':
        diameter = dataset.attrs['shape_diameter']
        center = dataset.attrs['shape_center'] # assuming always centering at origin

        distances = np.sqrt((dataset.x.data - center[0])**2 + (dataset.y.data - center[1])**2)
        mask = distances >= (diameter / 2)

        weights[mask, :] = np.nan

    # Assign coordinates and weights to the dataset
    repeated_elements = np.tile(elements_list, (num_points, 1))

    dataset = dataset.assign_coords(
        elements=(("points", "elements"), repeated_elements),
        element_weights=(("points", "weights"), weights)
    )
    #discrete_compositions_str = [str(comp) for comp in discrete_compositions]
    #comp_points_str = [str(pos) for pos in comp_points]
    #dataset.attrs['composition_centers (coords)'] = comp_points_str
    #dataset.attrs['composition_centers (weights)'] = discrete_compositions_str
    #dataset.attrs['smoothing_factor'] = smoothing_factor

    return dataset


def generate_symmetric_points_range(num_points, total_range):
    half_points = num_points // 2

    # Adjust the step size slightly if needed
    step_size = total_range / half_points

    if num_points % 2 == 0:
        positive_points = np.linspace(step_size / 2, total_range - step_size / 2, half_points)
        symmetric_points = np.concatenate((-positive_points[::-1], positive_points))
    else:
        positive_points = np.linspace(0, total_range, half_points + 1)
        symmetric_points = np.concatenate((-positive_points[1:][::-1], positive_points))

    if len(symmetric_points) != num_points:
        raise ValueError(f'num_points = {num_points}, but len symmetric_points = {len(symmetric_points)}')

    return symmetric_points

def generate_symmetric_points_step(num_points, step_size):
    half_points = num_points // 2
    total_range = step_size * half_points

    if num_points % 2 == 0:
        positive_points = np.linspace(step_size / 2, total_range - step_size / 2, half_points)
        symmetric_points = np.concatenate((-positive_points[::-1], positive_points))
    else:
        positive_points = np.linspace(0, total_range, half_points + 1)
        symmetric_points = np.concatenate((-positive_points[1:][::-1], positive_points))

    if len(symmetric_points) != num_points:
        raise ValueError(f'num_points = {num_points}, but len symmetric_points = {len(symmetric_points)}')

    return symmetric_points


def ingest_data(filepath,
                    dataset_description: str,
                    dataset_shape: str,
                    dataset_resolution: Union[float, int],
                    dataset_Q_values: list,
                    add_noise_percentage: Union[float, int]):

    def generate_symmetric_points_step(num_points, step_size):
        half_points = num_points // 2
        total_range = step_size * half_points

        if num_points % 2 == 0:
            positive_points = np.linspace(step_size / 2, total_range - step_size / 2, half_points)
            symmetric_points = np.concatenate((-positive_points[::-1], positive_points))
        else:
            positive_points = np.linspace(0, total_range, half_points + 1)
            symmetric_points = np.concatenate((-positive_points[1:][::-1], positive_points))

        if len(symmetric_points) != num_points:
            raise ValueError(f'num_points = {num_points}, but len symmetric_points = {len(symmetric_points)}')

        return symmetric_points

    def add_noise_by_percentage(signal, noise_percentage):
        """
        Add random noise to a NumPy array based on a percentage of the maximum signal value.

        Parameters:
            signal (numpy.ndarray): The original signal array.
            noise_percentage (float): Percentage of the maximum signal value to use as noise.

        Returns:
            noisy_signal (numpy.ndarray): The signal array with added noise.
        """
        # Find the maximum value in the signal
        max_value = np.max(np.abs(signal))

        # Calculate the noise standard deviation as a percentage of the max signal value
        noise_std = (noise_percentage / 100) * max_value

        # Generate random Gaussian noise with zero mean and calculated standard deviation
        noise = np.random.normal(0, noise_std, signal.shape)

        # Add the noise to the original signal
        noisy_signal = signal + noise

        return noisy_signal


    # check that args are as expected
    if dataset_shape not in ['circle','square']:
        raise ValueError(f"dataset_shape must be 'cirlce' or 'square' - input is {dataset_shape}")

    # read the file
    data = np.load(filepath, allow_pickle=True)

    # assumes name, weights, and pattern are dict keys
    names = [i['name'] for i in data]
    weights = [i['weights'] for i in data]

    # add noise if arg given
    if add_noise_percentage:
        patterns = [(add_noise_by_percentage(i['pattern'], add_noise_percentage)) for i in data]
    else:
        patterns = [i['pattern'] for i in data]

    # Error checking:
    # 1) shape of weight maps in weights must be symmetric
    # 2) length of patterns must match the length of dataset_q_values
    for i in range(data.shape[0]):
        if weights[i].shape[0] != weights[i].shape[1]:
            raise ValueError(f"index {i} in weights has a non-symmetric shape - x points = {weights[i].shape[0]}, y points = {weights[i].shape[1]}")

        if len(patterns[i]) != len(dataset_Q_values):
            raise ValueError(f"index {i} in patterns has a length different from dataset_xaxis_values - Q len = {len(dataset_Q_values)}, patterns[{i}] len = {len(patterns[i])}")

    # some dataset assembly from inputs
    pts_per_side = weights[0].shape[0] # number of points per side
    x = generate_symmetric_points_step(pts_per_side, dataset_resolution)
    y = generate_symmetric_points_step(pts_per_side, dataset_resolution)
    N = len(x)

    # make lists into arrays
    weights_arr = np.array(weights)
    patterns_arr = np.array(patterns)

    weights_arr = weights_arr[:, ::-1, :]

    if dataset_shape == 'circle':

        if abs(y.max()) != abs(y.min()):
            raise ValueError(f"abs value of y.max and y.min should be equal but abs y.max = {abs(y.max())}, abs y.min = {abs(y.min())}")

        if abs(x.max()) != abs(x.min()):
            raise ValueError(f"abs value of x.max and x.min should be equal but abs x.max = {abs(x.max())}, abs x.min = {abs(x.min())}")

        if abs(x.max()) != abs(y.max()):
            raise ValueError(f"abs value of x.max and y.max should be equal but abs x.max = {abs(x.max())}, abs y.max = {abs(y.max())}")

        radius = y.max()

        xv, yv = np.meshgrid(x, y) # create grid coordinates from 1D x,y arrays. xv, yv are 2D arays for each point on the grid
        coords = np.column_stack([xv.ravel(), yv.ravel()])

        center = (0,0)
        distances = np.sqrt((coords[:,0] - center[0])**2 + (coords[:,1] - center[1])**2)
        mask = distances >= (radius)
        #points_inside = len([i for i in mask if not i])

        mask_grid = np.reshape(mask, (len(x),len(x)))

        for weight in weights_arr:
            weight[mask_grid] = np.nan

    # create iq_array
    scaled_pattern_list = []
    for i in range(data.shape[0]):
        #pattern_map = weights_arr[i][np.newaxis, :, :] * patterns_arr[i]
        pattern_map = weights_arr[i][np.newaxis, :, :] * patterns_arr[i][:, np.newaxis, np.newaxis]
        scaled_pattern_list.append(pattern_map)

    scaled_pattern_arr = np.array(scaled_pattern_list)
    iq_arr = np.sum(scaled_pattern_arr, axis=0)

    # prepare weights and iq for dataarray
    #weights_arr_T = weights_arr.transpose(2, 1, 0)  # assumes shape of weights_arr is (weights, x, y) and we want (y, x, weights)4    weights_arr_T = weights_arr.transpose(2, 1, 0)  # assumes shape of weights_arr is (weights, x, y) and we want (y, x, weights)
    weights_arr_T = weights_arr.transpose(1, 2, 0)  # assumes shape of weights_arr is (weights, x, y) and we want (y, x, weights)

    if weights_arr_T.shape != (N, N, len(data)):
        raise ValueError(f'weights array has a different shape than expected - expected = ({N, N, len(data)}), actual = ({weights_arr_T.shape})')

    #iq_arr_T = iq_arr.transpose(2, 1, 0) # assumes shape of iq_arr is (iq, x, y) and we want (y, x, iq)
    iq_arr_T = iq_arr.transpose(1, 2, 0) # assumes shape of iq_arr is (iq, x, y) and we want (y, x, iq)

    if iq_arr_T.shape != (N, N, len(dataset_Q_values)):
        raise ValueError(f'iq array has a different shape than expected - expected = ({N, N, len(dataset_Q_values)}), actual = ({iq_arr_T.shape})')

    # Dataarray and dataset creation
    da_phase = xr.DataArray(weights_arr_T, dims=["y", "x", "phase_weight"], coords=dict(y=y, x=x, phase_weight=("phase_weight", names)))
    da_iq = xr.DataArray(iq_arr_T, dims=["y", "x", "intensity"], coords=dict(y=y, x=x))

    dataset = xr.Dataset({
        'phase_weights': da_phase,
        'iq': da_iq
        })

    # add attributes
    dataset.attrs['description'] = dataset_description
    dataset.attrs['phases'] = names
    dataset.attrs['shape'] = dataset_shape
    dataset.attrs['shape_center'] = (0,0)
    dataset.attrs['shape_width'] = float(abs(x.max()) + (abs(x.min())))
    dataset.attrs['resolution'] = dataset_resolution
    dataset.attrs['Q'] = dataset_Q_values

    return dataset


def add_noise_by_percentage(signal, noise_percentage):
    """
    Add random noise to a NumPy array based on a percentage of the maximum signal value.

    Parameters:
        signal (numpy.ndarray): The original signal array.
        noise_percentage (float): Percentage of the maximum signal value to use as noise.

    Returns:
        noisy_signal (numpy.ndarray): The signal array with added noise.
    """
    # Find the maximum value in the signal
    max_value = np.max(np.abs(signal))

    # Calculate the noise standard deviation as a percentage of the max signal value
    noise_std = (noise_percentage / 100) * max_value

    # Generate random Gaussian noise with zero mean and calculated standard deviation
    noise = np.random.normal(0, noise_std, signal.shape)

    # Add the noise to the original signal
    noisy_signal = signal + noise

    return noisy_signal


def add_validcoords_labels(reshaped_ds, weight_rounding=8, weight_cutoff=0.01, num_points_radius_reduced=3, short_legend_names=False):
    if reshaped_ds.attrs['shape'] == 'circle':
        x = reshaped_ds.x.data
        y = reshaped_ds.y.data
        xv, yv = np.meshgrid(x, y)
        coords = np.column_stack([xv.ravel(), yv.ravel()])
        distances = np.sqrt((coords[:,0] - reshaped_ds.attrs['shape_center'][0])**2 + (coords[:,1] - reshaped_ds.attrs['shape_center'][1])**2)
        reduced_radius = reshaped_ds.attrs['shape_width'] / 2 - (reshaped_ds.attrs['resolution'] * num_points_radius_reduced)
        mask = (distances >= reduced_radius)
        coords_valid = np.array([coords[i] for i in range(coords.shape[0]) if mask[i] == False])

    else:
        x = reshaped_ds.x.data
        y = reshaped_ds.y.data
        xv, yv = np.meshgrid(x, y)
        coords = np.column_stack([xv.ravel(), yv.ravel()])

    # all indices (includes nans)
    x_array = coords[:,0]
    y_array = coords[:,1]
    phase_weights = reshaped_ds.phase_weights.data.reshape(reshaped_ds.phase_weights.data.shape[0]*reshaped_ds.phase_weights.data.shape[1],reshaped_ds.phase_weights.data.shape[2])
    phase_names = np.tile(reshaped_ds.phase_weight.data, (phase_weights.shape[0], 1))

    nan_indices = [index for index, value in enumerate(phase_weights) if np.isnan(value).any()]

    x_array = np.array([value for index, value in enumerate(x_array) if index not in nan_indices])
    y_array = np.array([value for index, value in enumerate(y_array) if index not in nan_indices])
    phase_weights = np.array([value for index, value in enumerate(phase_weights) if index not in nan_indices])
    phase_names = np.array([value for index, value in enumerate(phase_names) if index not in nan_indices])
    coords = np.column_stack((x_array,y_array)) # need to redefine this for further masking

    # Create a mask for rows in `coords` that do not exist in `coords_valid`
    mask = np.array([not np.any(np.all(coord == coords_valid, axis=1)) for coord in coords])

    # Get the indices of rows in `coords` that are not in `coords_valid`
    #missing_indices = np.where(mask)[0]

    # Create a mask for the data array (True for the indices we want to keep)
    mask_to_keep = ~mask

    # Apply the mask to filter arrays
    x_array = x_array[mask_to_keep]
    y_array = y_array[mask_to_keep]
    phase_weights = phase_weights[mask_to_keep]
    phase_names = phase_names[mask_to_keep]
    coords = np.column_stack((x_array, y_array)) # need to redefine this for use later

    if short_legend_names:
    # Generate phase_names_short array - only for DRnets names that are not shortened
        phase_names_short = []
        for names in phase_names:
            #print(names)
            phase_names_list = []
            for name in names:
                short_name = str(name).split('+')[1].split('_')[0]
                phase_names_list.append(short_name)
            phase_names_short.append(phase_names_list)

        phase_names_short = np.array(phase_names_short)

    # Generate phase_weights_rd array
    phase_weights_rd = []
    for weights in phase_weights:
        phase_weights_list = []
        for weight in weights:
            weight_rd = round(weight,weight_rounding)
            phase_weights_list.append(weight_rd)
        phase_weights_rd.append(phase_weights_list)
    phase_weights_rd = np.array(phase_weights_rd)

    # get phases_present
    phases_present = []

    if short_legend_names:
        for weights, names in zip(phase_weights_rd, phase_names_short):
            phases_present_list = []
            for weight, name in zip(weights,names):
                if weight > weight_cutoff:
                    phases_present_list.append(name)
            phases_present.append(phases_present_list)
    else:
        for weights, names in zip(phase_weights_rd, phase_names):
            phases_present_list = []
            for weight, name in zip(weights,names):
                if weight > weight_cutoff:
                    phases_present_list.append(name)
            phases_present.append(phases_present_list)

    # Get the unique combinations of strings (order matters)
    unique_combinations = []

    # Collect unique combinations (as lists, preserving order)
    for sublist in phases_present:
        if sublist not in unique_combinations:
            unique_combinations.append(sublist)

    ground_truth_labels = []

    # Iterate over the main list and find matching unique combination index (order-sensitive)
    for sublist in phases_present:
        if sublist in unique_combinations:
            ground_truth_labels.append(unique_combinations.index(sublist))
        else:
            raise ValueError(f'label mapping failed - sublist {sublist} not found in unique_combinations')

    # Create DataArrays without coordinates
    da_labels = xr.DataArray(ground_truth_labels, dims=['dim_1d'])
    da_coords_valid = xr.DataArray(coords_valid, dims=['dim_x', 'dim_y'])

    reshaped_ds['ground_truth_labels'] = da_labels
    reshaped_ds['coords_valid'] = da_coords_valid

    reshaped_ds.attrs['ground_truth_uniquecombs'] = [str(comb) for comb in unique_combinations] # can't store inhomogeneous shapes as attrs or vars
    reshaped_ds.attrs['ground_truth_phasespresent'] = [str(phases) for phases in phases_present] # can't store inhomogeneous shapes as attrs or vars

    return reshaped_ds


# Paths - edit as needed (e.g., if repo root not found)

In [None]:
root_dir = get_repo_root()
phasediagram_path = root_dir / "Data" / "phasediagram_datasets" # where you store Al-Li-Fe_dataset.nc
simwafer_path = root_dir / "Data" / "simulatedwafer_datasets" # where you will save wafers you simulate
print(f"Root directory: {root_dir}\nPhasediagram datasets path: {phasediagram_path}\nSimulated wafer datasets path: {simwafer_path}")

# Example: Al-Li-Fe combinatorial library

In [None]:
# Loading phase diagram dataset
file = 'Al-Li-Fe_dataset.nc'
ds_AlLiFe = xr.open_dataset(os.path.join(phasediagram_path,file))

In [None]:
# Create dataset with coordinates
testds = create_dataset_with_coords(shape='circle', diameter=60, resolution=0.4)

In [None]:
# Calculate elemental weights
testds = calc_elemental_comps(dataset=testds, num_compositions=3, elements_list=['Al','Li','Fe'], discrete_compositions = np.array([[0.2,0.2,0.6], [0.2,0.6,0.2], [0.6,0.2,0.2]]), positions='calculated',
                         calc_dist_scale = 100, deg_rotation=-90,
                         discrete_comp_coords = None,
                         find_on_grid=True,
                         smoothing_factor=5.0)

In [None]:
# Interpolate phase weights and I(Q), add to dataset - default adds noise!
testds = interpolate_and_addtods(dataset=testds, dataset_DRNets=ds_AlLiFe, noise_percentage=0.01)

In [None]:
# Reshape ds onto meshgrid
reshaped_ds = reshape_ds(testds)

In [None]:
# get labels, phasecombs, coords_valid needed in experiment
reshaped_ds = add_validcoords_labels(reshaped_ds=reshaped_ds, weight_rounding=8, weight_cutoff=0.01, num_points_radius_reduced=3, short_legend_names=True)

In [None]:
# Save dataset
save_ds(dataset=reshaped_ds, path=simwafer_path, prefix='ds_AlLiFe_complex', suffix='', datetimestamp=True, remove_nc=False)


In [None]:
# Plot element weights
ds_plot2D(dataset=reshaped_ds, dataarray='element_weights', marker='s', marker_size=2, cmap='bone_r',vmin=0,vmax=1, background_color='ivory') #background_color='xkcd:eggshell')

In [None]:
# Plot Phase weights
ds_plot2D(dataset=reshaped_ds, dataarray='phase_weights', marker='s', marker_size=2, cmap='bone_r',vmin=0,vmax=1,background_color='ivory') #background_color='xkcd:eggshell')

In [None]:
# Plot of summed diffraction pattern intensity
summed_data = reshaped_ds.iq.sum(dim='intensity', keep_attrs=True)
summed_data = summed_data.expand_dims(dim='intensity')
plt.figure()
summed_data.plot(cmap='bone_r')
plt.title('Summed XRD intensity on wafer')

# Creating 3 libraries

## Al-Li-Fe

### Make ds

In [None]:
# Loading phase diagram dataset
#phasediagram_path = 'C:\\GitHub\\ExperimentSimulator\\Data\\phasediagram_datasets' # now defined at top of notebook
file = 'Al-Li-Fe_dataset.nc'
ds_AlLiFe = xr.open_dataset(os.path.join(phasediagram_path,file))
# Create dataset with coordinates
testds = create_dataset_with_coords(shape='circle', diameter=60, resolution=0.4)
# Calculate elemental weights
testds = calc_elemental_comps(dataset=testds, num_compositions=3, elements_list=['Al','Li','Fe'], discrete_compositions = np.array([[0.2,0.2,0.6], [0.2,0.6,0.2], [0.6,0.2,0.2]]), positions='calculated',
                         calc_dist_scale = 100, deg_rotation=-90,
                         discrete_comp_coords = None,
                         find_on_grid=True,
                         smoothing_factor=5.0)
# Interpolate phase weights and I(Q), add to dataset - default adds noise!
testds = interpolate_and_addtods(dataset=testds, dataset_DRNets=ds_AlLiFe, noise_percentage=0.01)
# Reshape ds onto meshgrid
reshaped_ds = reshape_ds(testds)
# get labels, phasecombs, coords_valid needed in experiment
reshaped_ds = add_validcoords_labels(reshaped_ds=reshaped_ds, weight_rounding=8, weight_cutoff=0.01, num_points_radius_reduced=3, short_legend_names=True)


In [None]:
# Save dataset -this breaks if the file already exists and is open - possibly because of how the file contents are interacted with on disk
save_ds(dataset=reshaped_ds, path=simwafer_path, prefix='ds_AlLiFe_complex', suffix='', datetimestamp=True, remove_nc=False, drop_element_weights=False)

### Vis

In [None]:
ds_plot2D(dataset=reshaped_ds, dataarray='element_weights', marker='s', marker_size=2, cmap='bone_r',vmin=0,vmax=1, background_color='ivory', save=True) #background_color='xkcd:eggshell')

In [None]:
ds_plot2D(dataset=reshaped_ds, dataarray='phase_weights', marker='s', marker_size=2, cmap='bone_r',vmin=0,vmax=1,background_color='ivory', save=True) #background_color='xkcd:eggshell')

## Bi-Cu-V

In [None]:
# Loading phase diagram dataset
#phasediagram_path = 'C:\\GitHub\\ExperimentSimulator\\Data\\phasediagram_datasets' # now defined at top of notebook
file = 'Bi-Cu-V_dataset.nc'
ds_BiCuV = xr.open_dataset(os.path.join(phasediagram_path,file))
# Create dataset with coordinates
testds = create_dataset_with_coords(shape='circle', diameter=60, resolution=0.4)
# Calculate elemental weights
testds = calc_elemental_comps(dataset=testds, num_compositions=3, elements_list=['Bi','Cu','V'], discrete_compositions = np.array([[0.2,0.2,0.6], [0.2,0.6,0.2], [0.6,0.2,0.2]]), positions='calculated',
                         calc_dist_scale = 100, deg_rotation=-90,
                         discrete_comp_coords = None,
                         find_on_grid=True,
                         smoothing_factor=5.0)
# Interpolate phase weights and I(Q), add to dataset - default adds noise!
testds = interpolate_and_addtods(dataset=testds, dataset_DRNets=ds_BiCuV, noise_percentage=0.01)
# Reshape ds onto meshgrid
reshaped_ds = reshape_ds(testds)
# get labels, phasecombs, coords_valid needed in experiment
reshaped_ds = add_validcoords_labels(reshaped_ds=reshaped_ds, weight_rounding=8, weight_cutoff=0.01, num_points_radius_reduced=3, short_legend_names=True)

# Save dataset -this breaks if the file already exists and is open - possibly because of how the file contents are interacted with on disk
save_ds(dataset=reshaped_ds, path=simwafer_path, prefix='ds_BiCuV_complex', suffix='', datetimestamp=True, remove_nc=False, drop_element_weights=False)


In [None]:
ds_plot2D(dataset=reshaped_ds, dataarray='element_weights', marker='s', marker_size=2, cmap='bone_r',vmin=0,vmax=1, background_color='ivory', save=True) #background_color='xkcd:eggshell')
ds_plot2D(dataset=reshaped_ds, dataarray='phase_weights', marker='s', marker_size=2, cmap='bone_r',vmin=0,vmax=1,background_color='ivory', save=True) #background_color='xkcd:eggshell')

## Li-Sr-Al

In [None]:
# Loading phase diagram dataset
#phasediagram_path = 'C:\\GitHub\\ExperimentSimulator\\Data\\phasediagram_datasets' # now defined at top of notebook
file = 'Li-Sr-Al_dataset.nc'
ds_LiSrAl = xr.open_dataset(os.path.join(phasediagram_path,file))
# Create dataset with coordinates
testds = create_dataset_with_coords(shape='circle', diameter=60, resolution=0.4)
# Calculate elemental weights
testds = calc_elemental_comps(dataset=testds, num_compositions=3, elements_list=['Li','Sr','Al'], discrete_compositions = np.array([[0.2,0.2,0.6], [0.2,0.6,0.2], [0.6,0.2,0.2]]), positions='calculated',
                         calc_dist_scale = 100, deg_rotation=-90,
                         discrete_comp_coords = None,
                         find_on_grid=True,
                         smoothing_factor=5.0)
# Interpolate phase weights and I(Q), add to dataset - default adds noise!
testds = interpolate_and_addtods(dataset=testds, dataset_DRNets=ds_LiSrAl, noise_percentage=0.01)
# Reshape ds onto meshgrid
reshaped_ds = reshape_ds(testds)
# get labels, phasecombs, coords_valid needed in experiment
reshaped_ds = add_validcoords_labels(reshaped_ds=reshaped_ds, weight_rounding=8, weight_cutoff=0.01, num_points_radius_reduced=3, short_legend_names=True)

# Save dataset -this breaks if the file already exists and is open - possibly because of how the file contents are interacted with on disk
save_ds(dataset=reshaped_ds, path=simwafer_path, prefix='ds_LiSrAl_complex', suffix='', datetimestamp=True, remove_nc=False, drop_element_weights=False)


### Vis

In [None]:
ds_plot2D(dataset=reshaped_ds, dataarray='element_weights', marker='s', marker_size=2, cmap='bone_r',vmin=0,vmax=1, background_color='ivory', save=True) #background_color='xkcd:eggshell')
ds_plot2D(dataset=reshaped_ds, dataarray='phase_weights', marker='s', marker_size=2, cmap='bone_r',vmin=0,vmax=1,background_color='ivory', save=True) #background_color='xkcd:eggshell')