In [1]:
import torch
import numpy as np
import pandas as pd
import json
from ipywidgets import FloatSlider, Button, HBox, VBox, Output
from IPython.display import clear_output
from heagan.tools.cGAN_samplers import noise_sampler
from heagan.tools.functions import decode, calculate_entropy_mixing
from joblib import load
import plotly.graph_objects as go
import plotly.express as px
from pymatgen.core import Composition
from importlib import resources
import onnxruntime

In [2]:
with resources.files('heagan.saved_cGAN').joinpath('generator.pt') as fname:
    model = torch.jit.load(fname,map_location='cpu')
with resources.files('heagan.saved_cGAN').joinpath('gan_hyperparameters.json').open('r') as f:
    gan_hyperparameters = json.load(f)

selected_props = gan_hyperparameters['selected_props']
latent_dim = gan_hyperparameters['latent_dim']
prop_dim = gan_hyperparameters['prop_dim']
elem_list = gan_hyperparameters['elem_list']

In [3]:
with resources.files('heagan.saved_cGAN').joinpath('scale_kde_pipe.joblib') as fname:
    scaler_pipe = load(fname)
#with resources.files('heagan.saved_cGAN').joinpath('min_max_scaler.joblib') as fname:
#    scaler_minmax = load(fname)

In [4]:
with resources.files('heagan.dataset').joinpath('demo_dataset_true.csv') as fname:
    demo_df = pd.read_csv(fname,index_col=0)

dict_of_props = {}

units_dict = {'delta_S' : '- DELTA_S/R',
              'd_param' : 'D PARAMETER',
              'price' : 'PRICE ($/kg)',
              'FT' : 'FT (MPa/m^0.5)',
              'density' : 'DENSITY (g/cm^3)',
              'hardness' : 'HARDNESS (GPa)',
              'uts1200C' : 'UTS@1200C (GPa)'}

style = {'description_width': 'initial', 'font_variant':"small-caps"}

for p in selected_props:
    vals = demo_df.loc[:,p].values
    dict_of_props[units_dict[p]] = {'min':np.round(np.min(vals),2),
                        'max':np.round(np.max(vals),2),
                        'value':np.round((np.min(vals)+np.max(vals))/2,2),
                        'step':0.01,
                        'description':units_dict[p],
                        'style' : style
                        }

colors = px.colors.qualitative.Alphabet

In [8]:
def generate_alloy(prop_arr):
    input_data = np.array(prop_arr).reshape(1,-1).astype('float32')
    scaled_input = torch.from_numpy(scaler_pipe['Scaler'].transform(input_data))
    noise = torch.from_numpy(noise_sampler(1,latent_dim))
    with torch.no_grad():
        generated = model(noise,scaled_input).numpy()
    composition = decode(generated[0], elem_list)
    dict_generated = {
        'Composition':Composition({
            el: frac*100 for el, frac in composition.as_dict().items()}).__str__(),
            'composition_arr':generated[0]}
    for i,prop in enumerate(selected_props):
        if prop == 'delta_S':
            dict_generated[prop] = calculate_entropy_mixing(composition).item()
        else:
            model_surrogate = load(f'heagan/saved_surrogates/surrogate_{prop}.joblib')
            dict_generated[prop] = model_surrogate.predict(generated).item()
    return dict_generated


In [9]:
class Demonstrator:
    def __init__(self):
        self.objects = [FloatSlider(**params) for params in dict_of_props.values()]
        self.reset_button = Button(description='Reset',disabled=False,button_style='',tooltip='Reset the slider values')
        self.reset_button.on_click(self.on_reset_button_clicked)
        self.generate_button = Button(description='Generate',disabled=False,button_style='',tooltip='Generate HEA composition with selected conditions')
        self.generate_button.on_click(self.on_generate_button_clicked)
        self.output = Output()
        self.prop_output = Output()
        controls_one = VBox([*self.objects[:3]])
        controls_two = VBox([*self.objects[3:]])
        controls = HBox([controls_one, controls_two])
        buttons = HBox([self.reset_button, self.generate_button])
        outputs = HBox([self.output,self.prop_output])
        display(VBox([controls,buttons, outputs]))
        self.fig = go.Figure()
        self.prop_fig = go.Figure()
        self.df = pd.DataFrame(columns=demo_df.columns)
        print('initialized')


    def on_generate_button_clicked(self,b):
        print('generating')
        arr = np.zeros(len(self.objects))
        for i,sl in enumerate(self.objects):
            arr[i] = sl.value
        data = generate_alloy(arr)
        self.df.loc[len(self.df)]=data

        compositionString = ' '.join([f'{el}{round(Composition(data["Composition"]).as_dict()[el]):<2}' if el in Composition(data['Composition']).as_dict() else '    ' for el in elem_list])
        #composition
        with self.output:

            clear_output(True)
            self.fig.add_trace(go.Scatterpolar(
            r=data['composition_arr'],
            theta=elem_list,
            name=compositionString,
            fill='toself',
            opacity=0.75,
            showlegend = True
                ))

            self.fig.update_layout(
                width = 500,
                height = 600,
                legend=dict(
                  orientation = "h",
                  yanchor="bottom",
                  y=-1.1,
                  xanchor="left",
                  x=0),
                font=dict(
                family="Courier New, monospace",
                size=14,
                color="Black"
            ),
            title={
        'text': "Composition",
        'y':0.9,
        'x':0.5,
        'xanchor': 'center',
        'yanchor': 'top'})
            names = set()
            self.fig.for_each_trace(lambda trace:trace.update(showlegend=False) if (trace.name in names) else names.add(trace.name))
            self.fig.show()

        #properties
        with self.prop_output:
            clear_output(True)
            vals_scaled = self.df.iloc[:,1:].values.reshape(-1, len(selected_props))
            self.prop_fig.data = []
            for i in range(self.df.shape[0]):

                self.prop_fig.add_trace(go.Scatterpolar(
                r=vals_scaled[i],
                theta=selected_props,
                fill='toself',
                opacity=0.75,
                showlegend = False,
                hovertext=[f'Value: {str(np.round(x,2))}' for x in self.df.iloc[i,1:].values]
                ))

            self.prop_fig.update_layout(
            width = 400,
            height = 400,
            legend=dict(
                  orientation = "h",
                  yanchor="bottom",
                  y=-1.1,
                  xanchor="left",
                  x=0),
                font=dict(
                family="Courier New, monospace",
                size=14,
                color="Black"
            ),
            title={
                'text': "Property (Scaled)",
                'y':0.85,
                'x':0.5,
                'xanchor': 'center',
                'yanchor': 'top',
                })
            names = set()
            self.prop_fig.for_each_trace(lambda trace:trace.update(showlegend=False) if (trace.name in names) else names.add(trace.name))
            self.prop_fig.show()

    def on_reset_button_clicked(self,b):
        print('resetting')
        for slider in self.objects:
            slider.value = dict_of_props[slider.description]['value']
        self.df.drop(self.df.index, inplace=True)
        with self.output:
            clear_output()
        with self.prop_output:
            clear_output()
        self.fig.data = []
        self.prop_fig.data = []

In [10]:
Demonstrator()

VBox(children=(HBox(children=(VBox(children=(FloatSlider(value=-1.03, description='- DELTA_S/R', max=-0.28, mi…

initialized


<__main__.Demonstrator at 0x284848520>

generating


KeyError: 224