In [7]:
import torch
import numpy as np
import pandas as pd
import json
from ipywidgets import interact, FloatSlider, Button, HBox, VBox, Output
from IPython.display import clear_output
from joblib import load
from modules.cGAN_samplers import noise_sampler
from modules.functions import decode, calculate_entropy_mixing
import matplotlib.pyplot as plt
from joblib import load
import plotly.graph_objects as go

In [2]:
model = torch.jit.load('saved_cGAN/generator.pt',map_location='cpu')
with open('saved_cGAN/gan_hyperparameters.json','r') as fid:
    gan_hyperparameters = json.load(fid)
selected_props = gan_hyperparameters['selected_props']
latent_dim = gan_hyperparameters['latent_dim']
prop_dim = gan_hyperparameters['prop_dim']
elem_list = gan_hyperparameters['elem_list']

scaler_pipe = load('saved_cGAN/scale_kde_pipe.joblib')
scaler_minmax = load('saved_cGAN/min_max_scaler.joblib')

In [3]:
def generate_alloy(prop_arr):
    input_data = np.array(prop_arr).reshape(1,-1).astype('float32')
    scaled_input = torch.from_numpy(scaler_pipe['Scaler'].transform(input_data))
    noise = torch.from_numpy(noise_sampler(1,latent_dim))
    with torch.no_grad():
        generated = model(noise,scaled_input).numpy()
    composition = decode(generated[0], elem_list)
    dict_generated = {'Composition':composition.reduced_formula,
            'composition_arr':generated[0]}
    for i,prop in enumerate(selected_props):
        if prop == 'delta_S':
            dict_generated[prop] = calculate_entropy_mixing(composition).item()
        else:
            model_surrogate = load(f'saved_surrogates/surrogate_{prop}.joblib')
            dict_generated[prop] = model_surrogate.predict(generated).item()
    return dict_generated


In [8]:
demo_df = pd.read_csv('dataset/demo_dataset_true.csv',index_col=0)
df = pd.DataFrame(columns=demo_df.columns)
dict_of_props = {}



for p in selected_props:
    vals = demo_df.loc[:,p].values
    dict_of_props[p.lower()] = {'min':np.round(np.min(vals),2),
                        'max':np.round(np.max(vals),2),
                        'value':np.round((np.min(vals)+np.max(vals))/2,2),
                        'step':0.01,
                        'description':p.upper()
                        }
class record():
    def __init__(self):
        self.objects = [FloatSlider(**params) for params in dict_of_props.values()]
        self.reset_button = Button(description='Reset',disabled=False,button_style='',tooltip='Reset the slider values')
        self.reset_button.on_click(self.on_reset_button_clicked)
        self.generate_button = Button(description='Generate',disabled=False,button_style='',tooltip='Generate HEA composition with selected conditions')
        self.generate_button.on_click(self.on_generate_button_clicked)
        self.output = Output()
        self.prop_output = Output()
        controls_one = VBox([*self.objects[:3]])
        controls_two = VBox([*self.objects[3:]])
        controls = HBox([controls_one, controls_two])
        buttons = HBox([self.reset_button, self.generate_button])
        outputs = HBox([self.output,self.prop_output])
        display(VBox([controls,buttons, outputs]))
        self.fig = go.Figure()
        self.prop_fig = go.Figure()
    

    def on_generate_button_clicked(self,b):
        arr = np.zeros(len(self.objects))
        for i,sl in enumerate(self.objects):
            arr[i] = sl.value
        data = generate_alloy(arr)
        df.loc[len(df)]=data

        #composition
        with self.output:
            
            clear_output(True)
            for i in range(df.shape[0]):
                
                self.fig.add_trace(go.Scatterpolar(
                r=data['composition_arr'],
                theta=elem_list,
                fill='toself',
                name=df['Composition'][i],
                showlegend = True
                ))
                
            self.fig.update_layout(
            title={
        'text': "Composition",
        'y':0.9,
        'x':0.5,
        'xanchor': 'center',
        'yanchor': 'top'})
            names = set()
            self.fig.for_each_trace(lambda trace:trace.update(showlegend=False) if (trace.name in names) else names.add(trace.name))
            self.fig.show()

        #properties
        with self.prop_output:
            clear_output(True)
            vals_scaled = scaler_minmax.transform(df.iloc[:,1:].values.reshape(-1, len(selected_props)))
            for i in range(df.shape[0]):
                
                self.prop_fig.add_trace(go.Scatterpolar(
                r=vals_scaled[i],
                theta=selected_props,
                fill='toself',
                name=df['Composition'][i],
                showlegend = True
                ))
                
            self.prop_fig.update_layout(
            title={
        'text': "Property (Scaled)",
        'y':0.9,
        'x':0.5,
        'xanchor': 'center',
        'yanchor': 'top',
        })
            names = set()
            self.prop_fig.for_each_trace(lambda trace:trace.update(showlegend=False) if (trace.name in names) else names.add(trace.name))
            self.prop_fig.show()

    
    def on_reset_button_clicked(self,b):
        for slider in self.objects:
            slider.value = dict_of_props[slider.description.lower()]['value']
        df.drop(df.index, inplace=True)
        with self.output:
            clear_output()
        with self.prop_output:
            clear_output()
        self.fig.data = []
        self.prop_fig.data = []

        #df = pd.DataFrame(columns=demo_df.columns)
        #print(df)

In [9]:
record()

VBox(children=(HBox(children=(VBox(children=(FloatSlider(value=-1.03, description='DELTA_S', max=-0.28, min=-1…

<__main__.record at 0x7fe0c2a20640>