In [None]:
%matplotlib widget
%matplotlib widget
import matplotlib.pyplot as plt
import numpy as np
import os
import ipywidgets as widgets
from IPython.display import display, clear_output, HTML
import pandas as pd
from ipywidgets import fixed, Layout, Button, Box
import io
import random
import functools
from susc import visualization
import susc.widgets
import warnings
# Suppress all runtime warnings
warnings.filterwarnings("ignore", category=RuntimeWarning)
pd.options.mode.chained_assignment = None 
path = os.path.join(os.getcwd(),"dashstyle.mplstyle")
plt.style.use([path])
href = os.path.join(os.getcwd(),'figs','favicon.ico')
favicon_link = f'<link rel="icon" type="image/x-icon" href="{href}">'
HTML(favicon_link)
###################
#Preliminary setup#
###################
#getting the path file to dump txts after
try:
    with open('pathfile.txt','r') as p:
        for line in p:
            path_bash = line
            path_bash = path_bash.split('\n')[0]
except:
    path_bash = '.'


def download_button():
    dump_but = widgets.Button(
        description='Download',
        disabled=False,
        button_style='info', # 'success', 'info', 'warning', 'danger' or ''
        tooltip='Download the current plot as a high res png file',
        icon='download' # (FontAwesome names without the `fa-` prefix)
    )
    return dump_but

###Buttons to be used####
#file manager
dropdown = widgets.FileUpload(
    accept='.csv',  # Accepted file extension e.g. '.txt', '.pdf', 'image/*', 'image/*,.pdf'
    multiple=True,  # True to accept multiple files upload else False
    tooltip='Upload your ensemble files here'
)

def display_side_by_side(combined):
    """Display tables side by side to save vertical space
    """
    output = ""
    for caption, df in combined.items():
        output += df.set_table_styles([{'selector': 'caption','props': [('font-size', '18px')]}])._repr_html_()
        output += "\xa0\xa0\xa0"
    display(HTML(output))

In [None]:
#####################################
# widgets are included through two functions:
# XXX_widget : recieves the variables from the interface's buttons and treat it to become suitable for analysis
# XXX_func   : carries out the analysis and visualization
#####################################

def on_button_clicked(b, fig=None, name=None):
    filename = visualization.naming(name, path_bash)
    fig.savefig(os.path.join(path_bash,filename), dpi=600)



def corr_plot(**kwargs):
    ax = kwargs['ax']
    for a in ax:
        fontsize = visualization.set_fontsize(a)
        a.clear()  
    combined, fits  = {}, {}
    combined['stats'] = []        
    STATS, INFER = [], [] 
    yticks = []
    for data in kwargs['datas']:
        molecules = [col for col in data.columns if col not in ['Solvent','epsilon','nr', 'solvent']]
        for molecule in molecules:
            data_mol = data[data['Solvent'].isin(kwargs[f'solvents_{molecule}'])]
            epsilons = data_mol['epsilon'].to_numpy()
            if len(epsilons) < 3:
                continue
            nr = data_mol['nr'].to_numpy()
            alphas_st = (epsilons - 1)/(epsilons + 1)
            alphas_opt = (nr**2 - 1)/(nr**2 + 1)
            emission = data_mol[molecule].to_numpy()
            opt, cov = visualization.linear_fit((alphas_st,alphas_opt),emission)
            chi, e_vac = opt
            
            #display cov as dataframe
            cov = pd.DataFrame(cov,columns=['chi','e_vac'],index=['chi','e_vac'])
            fits[molecule] = [opt,cov]
            error = np.sqrt(np.diag(cov))
            function = visualization.model((alphas_st,alphas_opt), chi, e_vac)
            x = 2 * alphas_st - alphas_opt
            ax[0].plot(x, function ,label=molecule)
            ax[0].scatter(x,emission)
            
            #visualization.plot_confidence_ellipse([opt,cov],ax[1])
            ax[1].scatter(x, emission - function , marker='x', s=100, linewidth=2)
            
            chi = visualization.format_number(chi, error[0],'') 
            e_vac = visualization.format_number(e_vac, error[1],'') 
            STATS.append([molecule,e_vac,chi]) 
            yticks.append(molecule)   
        # check if there is nan in epsilon column
        if np.isnan(epsilons).any():
            #filter data to include only cases where epsilon is nan
            inference = data[data['epsilon'].isna()]	
            for molecule in molecules:
                INFER = []	
                for film in inference['Solvent'].unique():
                    #get the median and the lower and upper bounds
                    emi = inference[inference['Solvent'] == film][molecule].to_numpy()
                    nrs = inference[inference['Solvent'] == film]['nr'].to_numpy()
                    median, lower, upper = visualization.get_dielectric(emi, fits[molecule], nr=nrs)
                    INFER.append([film, emi[0],f'{1240/emi[0]:.0f}',median,f'[{lower:.2f} , {upper:.2f}]'])
                infer = pd.DataFrame(INFER,columns=['Film','Emission (eV)','Emission (nm)','\u03B5','Interval'])
                #sort by Median from lowest to highest
                infer = infer.sort_values(by='\u03B5',ascending=True)
                combined[f'infer_{molecule}'] = infer.style.hide(axis='index').set_table_attributes("style='display:inline'").set_caption(f'{molecule}').format(lambda x: "{:.2f}".format(x) if isinstance(x, float) else x)
                
    stats = pd.DataFrame(STATS,columns=['Molecule','&lt;E_vac&gt; (eV)','<\u03C7> (eV)'])
    #sort by chi
    stats = stats.sort_values(by='<\u03C7> (eV)',ascending=True)
    combined['stats'] = stats.style.hide(axis='index').set_table_attributes("style='display:inline'").set_caption('Properties')
    ax[0].set_xlabel(r'$2\alpha_{st} - \alpha_{opt}$') 
    ax[0].set_ylabel('Energy (eV)') 
    ax[0].legend(loc='best')
    ax[1].set_xlabel(r'$2\alpha_{st} - \alpha_{opt}$') 
    ax[1].set_ylabel('Residuals (eV)')
    # set title to the left
    ax[0].set_title('a)',loc='left')
    ax[1].set_title('b)',loc='left')
    display_side_by_side(combined)
    
    
def corr_widget(datas):
    WIDS = []
    kw = {}
    all_molecules = []
    for data in datas:
        # get all columns not named 'Solvent' and 'epsilon' and 'nr'
        molecules = [col for col in data.columns if col not in ['Solvent','epsilon','nr', 'solvent']]
        all_molecules.extend(molecules)
        solvents = data['Solvent'].to_numpy()
        for m in molecules:
            select2 = widgets.SelectMultiple(
                options=solvents,
                value=list(solvents),
                description='Solvents',
                tooltip='Select states to plot susceptibility for',
                disabled=False,
            )
            WIDS.append(select2)
            kw[f'solvents_{m}'] = select2  
    tab  = widgets.Tab()
    tab.children = WIDS
    for i in range(len(WIDS)):
        tab.set_title(i,all_molecules[i])   
    dump = download_button()
    box  = widgets.HBox([tab,dump])
    display(box)
    fig, ax = plt.subplots(1,2,figsize=(11,4))
    kw['molecules'] = fixed(all_molecules)
    kw['datas'] = fixed(datas)
    kw['ax']    = fixed(ax)
    wid = widgets.interactive_output(corr_plot,kw)
    fig.show()
    display(wid)
    dump.on_click(functools.partial(on_button_clicked,fig=fig,name='susceptibilities.png'))




#core function
def main(file_name):
    file_name = {file_name[i]['name']:file_name[i] for i in range(len(file_name))}
    input_list = list(file_name.keys())
    if len(file_name) > 0:
        datas = []
        for file in input_list:
            data = file_name[file]['content']
            data = io.BytesIO(data)
            data = pd.read_csv(data)
            # all numerical values above 100 are considered have to be converted to eV
            for col in data.columns:
                if col != 'Solvent' and col != 'epsilon' and col != 'nr':
                    data[col] = data[col].apply(lambda x: 1240/float(x) if float(x) > 100 else x)
            data.name = file.split('.')[0]
            datas.append(data)
        
        w_corr = widgets.interactive(corr_widget,datas=fixed(datas))    
        
        accordion = widgets.Accordion(children=[w_corr], selected_index=0)
        accordion.set_title(0, 'SUSCEPTIBILITY')
        display(accordion)
        
    else:
        pass 

###################################################
#Initializing main function and displaying widgets
###################################################
#logo = widgets.Image(value=open(os.path.join(os.getcwd(),'figs','nemoview.png'), 'rb').read(), format='png')
# set the size of the image
#ratio = 529/1134
#width = 300
#logo.layout.width = str(width)+'px'
#logo.layout.height = str(int(ratio*width))+'px'
#n1 = widgets.HTML(value = f'<b>NEMO: {nemo_version.__version__} NEMOview: {nemoview_version.__version__}<b>')
#n2 =widgets.HTML(value = f'<b>NEMOview: {nemoview_version.__version__}<b>')
#n12 = widgets.VBox([n1,n2])
i = widgets.interactive(main, file_name=dropdown);
v = widgets.VBox([i.children[0],i.children[1]])
display(v)