In [None]:
%matplotlib widget
%matplotlib widget
import matplotlib.pyplot as plt
import numpy as np
import os
import ipywidgets as widgets
from IPython.display import display, clear_output, HTML
import pandas as pd
from ipywidgets import fixed, Layout, Button, Box
import io
import random
import functools
from susc import visualization
import susc.widgets
import warnings
# Suppress all runtime warnings
warnings.filterwarnings("ignore", category=RuntimeWarning)
pd.options.mode.chained_assignment = None 
path = os.path.join(os.getcwd(),"dashstyle.mplstyle")
plt.style.use([path])
href = os.path.join(os.getcwd(),'figs','favicon.ico')
favicon_link = f'<link rel="icon" type="image/x-icon" href="{href}">'
HTML(favicon_link)
###################
#Preliminary setup#
###################
#getting the path file to dump txts after
try:
    with open('pathfile.txt','r') as p:
        for line in p:
            path_bash = line
            path_bash = path_bash.split('\n')[0]
except:
    path_bash = '.'


def download_button():
    dump_but = widgets.Button(
        description='Download',
        disabled=False,
        button_style='info', # 'success', 'info', 'warning', 'danger' or ''
        tooltip='Download the current plot as a high res png file',
        icon='download' # (FontAwesome names without the `fa-` prefix)
    )
    return dump_but

###Buttons to be used####
#file manager
dropdown = widgets.FileUpload(
    accept='.csv',  # Accepted file extension e.g. '.txt', '.pdf', 'image/*', 'image/*,.pdf'
    multiple=True,  # True to accept multiple files upload else False
    tooltip='Upload your ensemble files here'
)

def display_side_by_side(combined):
    """Display tables side by side to save vertical space
    """
    output = ""
    for caption, df in combined.items():
        output += df.set_table_styles([{'selector': 'caption','props': [('font-size', '18px')]}])._repr_html_()
        output += "\xa0\xa0\xa0"
    display(HTML(output))

In [None]:
#####################################
# widgets are included through two functions:
# XXX_widget : recieves the variables from the interface's buttons and treat it to become suitable for analysis
# XXX_func   : carries out the analysis and visualization
#####################################

def on_button_clicked(b, fig=None, name=None):
    filename = visualization.naming(name, path_bash)
    fig.savefig(os.path.join(path_bash,filename), dpi=600)

def coupling_plot(**kwargs):
    eps = kwargs['eps']
    alpha = (eps - 1)/(eps + 1)
    nr = kwargs['nr']
    emi = kwargs['emi']
    stats = kwargs['stats']
    vacuum = stats['Vacuum'].to_numpy()
    susc = stats['Susc'].to_numpy()
    coupling = 0.5*(emi -(vacuum - susc*alpha))
    #start new data frame
    data = pd.DataFrame({'Molecule':stats['Molecule'],'Coupling':coupling})
    #add column Aggregation. If coupling is positive, it is H-aggregated, if negative, it is J-aggregated
    data['Aggregation'] = data['Coupling'].apply(lambda x: 'H' if x > 0 else 'J')
    display(data)


def corr_plot(**kwargs):
    ax = kwargs['ax']
    ax2 = kwargs['ax2']
    ax2.clear()
    for a in ax:
        fontsize = visualization.set_fontsize(a)
        a.clear()  
    combined, fits  = {}, {}
    combined['stats'] = []        
    STATS, INFER = [], [] 
    yticks = []

    for data in kwargs['datas']:
        # get rows where solvent is in kwargs['solvents']
        data = data[data['Solvent'].isin(kwargs['solvents'])]
        epsilons = data['epsilon'].to_numpy()
        alphas = (epsilons - 1)/(epsilons + 1)
        molecules = list(kwargs['select'])
        for molecule in molecules:
            emission = data[molecule].to_numpy()
            opt, cov = visualization.linear_fit(alphas,emission)
            m, b = opt
            fits[molecule] = [opt,cov]
            error = np.sqrt(np.diag(cov))
            ax[0].plot(alphas, -1*m*alphas+b,label=molecule)
            ax[0].scatter(alphas,emission)
            line = -np.sort(-np.random.multivariate_normal([m,b], cov, size=1000)[:,0])
            try:
                heat = np.vstack((heat,line[np.newaxis,:]))
            except:
                heat = line[np.newaxis,:]
            chi = f'{m:.3f} ± {error[0]:.3f}'
            vac = f'{b:.3f} ± {error[1]:.3f}'    
            STATS.append([molecule,vac,chi]) 
            yticks.append(molecule)   
        # check if there is nan in epsilon column
        if np.isnan(epsilons).any():
            #filter data to include only cases where epsilon is nan
            inference = data[data['epsilon'].isna()]	
            for molecule in molecules:
                INFER = []	
                for film in inference['Solvent'].unique():
                    #get the median and the lower and upper bounds
                    emi = inference[inference['Solvent'] == film][molecule].to_numpy()
                    median, lower, upper = visualization.get_dielectric(emi, fits[molecule])
                    INFER.append([film, emi[0],f'{1240/emi[0]:.0f}',median,f'[{lower:.2f} , {upper:.2f}]'])
                infer = pd.DataFrame(INFER,columns=['Film','Emission (eV)','Emission (nm)','\u03B5','Interval'])
                #sort by Median from lowest to highest
                infer = infer.sort_values(by='\u03B5',ascending=True)
                combined[f'infer_{molecule}'] = infer.style.hide(axis='index').set_table_attributes("style='display:inline'").set_caption(f'{molecule}').format(lambda x: "{:.2f}".format(x) if isinstance(x, float) else x)
                
    stats = pd.DataFrame(STATS,columns=['Molecule','E_vac (eV)','\u03C7 (eV)'])
    #sort by chi
    stats = stats.sort_values(by='\u03C7 (eV)',ascending=True)
    combined['stats'] = stats.style.hide(axis='index').set_table_attributes("style='display:inline'").set_caption('Properties')
    heatmap = ax[1].imshow(heat.T, cmap='coolwarm')
    heatmap.set_interpolation('none')
    #keep aspect ratio
    ax[1].set_aspect('auto')      
    # use labels as yticks
    ax[1].set_xticks(np.arange(len(yticks)))
    ax[1].set_xticklabels(yticks) 
    #remove xticks
    ax[1].set_yticks([]) 
    # set heatmap min and max to 0 and 1
    heatmap.set_clim(0,kwargs['maxsusc'])
    ax[2].imshow(np.linspace(kwargs['maxsusc'],0.0,100)[:,np.newaxis], cmap='coolwarm')
    ax[2].set_aspect('auto')
    # set 2 yticks
    ax[2].set_yticks([0,99])
    # set yticklabels to 0 and 1
    ax[2].set_yticklabels([f"{kwargs['maxsusc']:.1f}",'0'])
    ax2.set_ylim(ax[2].get_ylim())
    ax2.set_yticks([99])
    ax2.set_yticklabels(['LE'])
    ax2.set_ylabel('CT character', labelpad=-20)
    ax2.yaxis.set_label_position('right')
    ax[2].set_xticks([])
    #ax[2].set_ylabel('Susceptibility (eV)', labelpad=-20)
    #remove grid
    ax[2].grid(False)
    #ax[0].set_xlim(left=0)
    #ax[0].set_ylim(bottom=0)
    ax[0].set_xlabel(r'$\alpha$')
    ax[0].set_ylabel('Energy (eV)') 
    ax[1].set_ylabel('Electronic Character')
    ax[0].legend(loc='best')
    # set title to the left
    ax[0].set_title('a)',loc='left')
    ax[1].set_title('b)',loc='left')
    #display(stats)
    display_side_by_side(combined)
    #kw = {}
    #emi = widgets.BoundedFloatText(
    #    value=1,
    #    step=0.01,
    #    min=0.01,
    #    max=10.0,
    #    description="emi",
    #    tooltip="Emission (eV)",
    #    disabled=False,
    #)
    #kw['emi'] = emi
    #eps, nr, kw = susc.widgets.eps_nr(kw)
    #kw['stats'] = fixed(stats)
    #box  = widgets.HBox([eps,nr, emi])
    #display(box)
    #wid = widgets.interactive_output(coupling_plot,kw)
    #display(wid)
    
    
def corr_widget(datas):
    WIDS = []
    kw = {}
    for data in datas:
        molecules = data.columns[2:]
        select = widgets.SelectMultiple(
            options=molecules,
            value=[molecules[0]],
            #rows=10,
            description='Molecules',
            tooltip='Select states to plot susceptibility for',
            disabled=False,
        )
        solvents = data['Solvent'].to_numpy()
        select2 = widgets.SelectMultiple(
            options=solvents,
            value=list(solvents),
            #rows=10,
            description='Solvents',
            tooltip='Select states to plot susceptibility for',
            disabled=False,
        )
        WIDS.append(select)
        WIDS.append(select2)
        kw['select'] = select
        kw['solvents'] = select2  
    tab  = widgets.Tab()
    tab.children = WIDS
    for i in range(len(datas)):
        tab.set_title(i,datas[i].name)
        if i == len(datas)-1:
            tab.set_title(i+1,'Solvents')    
    maxsusc, kw = susc.widgets.maxsusc(kw)
    dump = download_button()
    box  = widgets.HBox([tab,maxsusc,dump])
    display(box)
    fig, ax = plt.subplots(1,3,figsize=(11,4),gridspec_kw={'width_ratios': [1, 1,0.05]})
    ax2 = ax[2].twinx()
    kw['ax2']   = fixed(ax2)
    kw['datas'] = fixed(datas)
    kw['ax']    = fixed(ax)
    wid = widgets.interactive_output(corr_plot,kw)
    fig.show()
    display(wid)
    dump.on_click(functools.partial(on_button_clicked,fig=fig,name='susceptibilities.png'))




#core function
def main(file_name):
    file_name = {file_name[i]['name']:file_name[i] for i in range(len(file_name))}
    input_list = list(file_name.keys())
    if len(file_name) > 0:
        datas = []
        for file in input_list:
            data = file_name[file]['content']
            data = io.BytesIO(data)
            data = pd.read_csv(data)
            # all numerical values above 100 are considered have to be converted to eV
            for col in data.columns:
                if col != 'Solvent' and col != 'epsilon':
                    data[col] = data[col].apply(lambda x: 1240/float(x) if float(x) > 100 else x)
            data.name = file.split('.')[0]
            datas.append(data)
        
        w_corr = widgets.interactive(corr_widget,datas=fixed(datas))    
        
        accordion = widgets.Accordion(children=[w_corr], selected_index=0)
        accordion.set_title(0, 'SUSCEPTIBILITY')
        display(accordion)
        
    else:
        pass 

###################################################
#Initializing main function and displaying widgets
###################################################
#logo = widgets.Image(value=open(os.path.join(os.getcwd(),'figs','nemoview.png'), 'rb').read(), format='png')
# set the size of the image
#ratio = 529/1134
#width = 300
#logo.layout.width = str(width)+'px'
#logo.layout.height = str(int(ratio*width))+'px'
#n1 = widgets.HTML(value = f'<b>NEMO: {nemo_version.__version__} NEMOview: {nemoview_version.__version__}<b>')
#n2 =widgets.HTML(value = f'<b>NEMOview: {nemoview_version.__version__}<b>')
#n12 = widgets.VBox([n1,n2])
i = widgets.interactive(main, file_name=dropdown);
v = widgets.VBox([i.children[0],i.children[1]])
display(v)