In [None]:
%matplotlib widget
import matplotlib.pyplot as plt
import numpy as np
import os
import ipywidgets as widgets
from IPython.display import display, clear_output
import pandas as pd
from ipywidgets import fixed, Layout, Button, Box
import io
import visualization
import nemo.analysis
import nemo.tools
path = os.path.join(os.getcwd(),"dashstyle.mplstyle")
plt.style.use([path])
###################
#Preliminary setup#
###################


###Buttons to be used####
#file manager
dropdown = widgets.FileUpload(
    accept='.lx',  # Accepted file extension e.g. '.txt', '.pdf', 'image/*', 'image/*,.pdf'
    multiple=True  # True to accept multiple files upload else False
)

In [None]:
#####################################
# widgets are included through two functions:
# XXX_widget : recieves the variables from the interface's buttons and treat it to become suitable for analysis
# XXX_func   : carries out the analysis and visualization
#####################################

def diag_plot(**kwargs):
    eps = kwargs['eps']
    nr  = kwargs['nr']
    dielec = (eps,nr)
    cutoff = kwargs['cutoff']
    files = kwargs['files']
    names = kwargs['names']
    axs = kwargs['ax']
    for ax in axs:
        ax.clear()
        ax.axis('off')
        ax.set_xticklabels([])
        for t in ax.texts:
            t.remove()          
    
    for j in range(len(kwargs['molecs'])):
            axs[j].set_title(kwargs['molecs'][j])
            try:
                for file in files:
                    if names[file.name] == kwargs['molecs'][j]:
                        state = file['ensemble'][0]
                        data, emi = nemo.analysis.rates(state,dielec,data=file,ensemble_average=kwargs['ens'])
                        data.rename(columns=lambda x: x.split('(')[0], inplace=True)
                        data_show = data[(data.Prob > 100*cutoff) | data.Transition.str.contains('S0', regex=False) & (data.Prob != np.nan)]
                        visualization.plot_transitions(data,axs[j],cutoff)    
                        display(widgets.HTML(value = f'<p style="font-size:18px;text-align:left"><b>{names[file.name]} {state} Ensemble</b></p>'))
                        display(data_show.style.hide(axis='index').background_gradient().format({'Rate':'{:.2e} s-1','Error':'{:.2e} s-1','Prob':'{:.2f}%','AvgDE+L':'{:.3f} eV','AvgSOC':'{:.3f} meV','AvgSigma':'{:.3f} eV','AvgConc':'{:.1f}%'}))    
                visualization.write_energies(axs[j])
            except:
                pass
    for ax in axs:
        ax.relim()    
    top = [ax.get_ylim()[1] for ax in axs]
    bot = [ax.get_ylim()[0] for ax in axs]
    for ax in axs:
        ax.set_ylim([min(bot),1.1*max(top)])
        if kwargs['legend']:
            ax.legend(title=f'$\epsilon ={dielec[0]:.3f}$\n$n={dielec[1]:.3f}$',title_fontsize=12, fontsize=12, loc='best',frameon=False)
        else:
            ax.legend(handles=[],title=f'$\epsilon ={dielec[0]:.3f}$\n$n={dielec[1]:.3f}$',title_fontsize=12, fontsize=12, loc='best',frameon=False)
    clear_output(wait=True)                    
###################################

def diag_widget(files,names):
    kw = {}
    eps, nr = visualization.eps_nr()

    molecs = []
    for i in names.keys():
        mol = names[i]
        if mol not in molecs:
            molecs.append(mol)
    
    ensemble = widgets.Checkbox(
    value=False,
    description='Ensemble Avg',
    disabled=False,
    indent=False
    )

    cutoff = widgets.FloatSlider(
    value=0.1,
        min=0,
        max=1,
        step=0.05,
        description='Cutoff:',
        disabled=False,
        continuous_update=False,
        orientation='horizontal',
        readout=True,
        readout_format='.2f',
    )
    legend = widgets.Checkbox(
    value=False,
    description='Display rates',
    disabled=False,
    indent=False
    )
    kw['legend'] = legend

    vbox = widgets.VBox([eps,nr,cutoff])
    vbox2= widgets.VBox([ensemble,legend])
    hbox = widgets.HBox([vbox,vbox2])
    display(hbox)

    fig, ax = plt.subplots(1,len(molecs))
    kw['ens'] = ensemble
    if len(molecs) == 1:
        kw['ax']  = fixed([ax])
    else:
        kw['ax']  = fixed(ax)
    kw['eps'] = eps
    kw['nr']  = nr
    kw['cutoff'] = cutoff
    kw['files']  = fixed(files)
    kw['names']  = fixed(names)
    kw['molecs'] = fixed(molecs)
    wid = widgets.interactive_output(diag_plot,kw)
    display(wid)


    

def spec_plot(**kwargs):
    ax = kwargs['ax']
    names = kwargs['names']
    ax.clear()    
    dielec = (kwargs['eps'],kwargs['nr'])
    STATS  = []
    for file in kwargs['files']:
        state = file['ensemble'][0]
        for tipo in kwargs[file.name]:
            if tipo == 'Emission':
                _, emi = nemo.analysis.rates(state,dielec,data=file)
                x = emi['Energy'].values
                y = emi['Diffrate'].values
            elif tipo == 'Absorption':
                abspec = nemo.analysis.absorption(state,dielec,data=file,save=False)
                x = abspec[:,0]
                y = abspec[:,1]
            y = y/max(y)
            if kwargs['wave']:
                ax.plot(1239.8/x,y,label=f'{names[file.name]} {tipo[:3]} {state}')
            else:
                ax.plot(x,y,label=f'{names[file.name]} {tipo[:3]} {state}')
            peak = visualization.get_peak(y,x)
            STATS.append([f'{tipo[:3]} {names[file.name]}',state,peak,1239.8/peak])    
    stats = pd.DataFrame(STATS,columns=['Spectrum','State','Peak (eV)','Peak (nm)'])
    display(stats.style.hide(axis='index').background_gradient().format({'Peak (eV)':'{:.2f}','Peak (nm)':'{:.0f}'}))
    ax.set_ylim(bottom=0) 
    ax.set_ylabel('Normalized Intensity')
    if kwargs['wave']:
        ax.set_xlabel('Wavelength (nm)')
    else:    
        ax.set_xlabel('Energy (eV)')
    ax.set_ylim(bottom=0)
    title = f'$\epsilon ={dielec[0]:.3f}$\n$n={dielec[1]:.3f}$'
    ax.legend(title=title)    
    clear_output(wait=True)

def spec_widget(files,names):
    WIDS = []
    kw = {}
    for file in files:
        if file['ensemble'][0] == 'S0':
            options = ['Absorption']
        else:    
            options = ['Absorption', 'Emission']
        select = widgets.SelectMultiple(
            options=options,
            value=[options[-1]],
            #rows=10,
            description='Spectra',
            disabled=False,
        )
        WIDS.append(select)
        kw[file.name] = select
    eps, nr = visualization.eps_nr()
    
    wave = widgets.Checkbox(
    value=False,
    description='Wavelength (nm)',
    disabled=False,
    indent=False
    )
    kw['wave'] = wave
    tab  = widgets.Tab()
    tab.children = WIDS
    #setting up the titles of the table
    for i in range(len(files)):
        tab.set_title(i,names[files[i].name] +' ' + files[i]['ensemble'][0])
    vbox = widgets.VBox([eps,nr,wave])    
    box  = widgets.HBox([tab,vbox])
    display(box)
    fig, ax = plt.subplots(figsize=(11,4))
    kw['files'] = fixed(files)
    kw['eps']   = eps  
    kw['nr']    = nr   
    kw['ax']    = fixed(ax)
    kw['names'] = fixed(names)
    wid = widgets.interactive_output(spec_plot,kw)
    display(wid)

def corr_plot(**kwargs):
    ax = kwargs['ax']
    ax.clear()
    gran = 10**kwargs['gran']
    names = kwargs['names']
    for file in kwargs['files']:
        opt1 = file['nr'][0]**2
        alpha = (nemo.tools.get_alpha(kwargs['eps']) - nemo.tools.get_alpha(kwargs['nr']**2))*(1/nemo.tools.get_alpha(opt1))
        for header in kwargs[file.name]:
            ds = file['d_'+header.lower()].to_numpy()
            #hist,bins = visualization.spectrum(ds*alpha,gran)
            hist,bins = visualization.spectrum(ds,gran)
            ax.plot(bins,hist,label=f'Ens: {names[file.name]} - {header}')
    ax.set_xlim(left=0)
    ax.set_ylim(bottom=0)
    ax.set_xlabel('State Specific Correction (eV)')#('Solvent Reorganization Energy (eV)')#('State Specific Correction (eV)')
    ax.set_ylim(bottom=0)
    ax.legend(loc='best')
    clear_output(wait=True)

def corr_widget(files,names):
    WIDS = []
    kw = {}
    for file in files:
        states = [i.split('_')[1].upper() for i in file.columns if 'd_' in i]
        select = widgets.SelectMultiple(
            options=states,
            value=[states[0]],
            #rows=10,
            description='States',
            disabled=False,
        )
        WIDS.append(select)
        kw[file.name] = select
    gran_slider = widgets.FloatSlider(
    value=-2,
    max=-1,
    min=-3,
    step=1,
    description='Bin $10^x$ (eV)',
    disabled=False,
    continuous_update=False,
    orientation='horizontal',
    readout=True,
    readout_format='.0f'
    )    
    eps, nr = visualization.eps_nr(eps0=files[0]['eps'][0],nr0=files[0]['nr'][0])
    kw['eps']  = eps
    kw['nr']   = nr 
    kw['gran'] = gran_slider
    tab  = widgets.Tab()
    tab.children = WIDS
    #setting up the titles of the table
    for i in range(len(files)):
        tab.set_title(i,names[files[i].name] +' ' + files[i]['ensemble'][0])
    vbox = widgets.VBox([gran_slider,eps,nr])
    box  = widgets.HBox([tab,vbox])
    display(box)
    fig, ax = plt.subplots()
    kw['files'] = fixed(files)
    kw['ax']    = fixed(ax)
    kw['names'] = fixed(names)
    wid = widgets.interactive_output(corr_plot,kw)
    display(wid)


def body(**kwargs):
    if kwargs['run']:
        files   = kwargs['files']
        datas = []
        names = {}
        for file in files:
            names[file.name] = kwargs[file.name] #+' ' +file['ensemble'][0]
            if file['ensemble'][0] != 'S0':
                datas.append(file)

        w_diag = widgets.interactive(diag_widget,files=fixed(datas),names=fixed(names))
        w_spec = widgets.interactive(spec_widget,files=fixed(files),names=fixed(names))
        w_corr = widgets.interactive(corr_widget,files=fixed(files),names=fixed(names))    


        accordion = widgets.Accordion(children=[w_diag, w_spec,w_corr], selected_index=0)
        accordion.set_title(0, 'DIAGRAM')
        accordion.set_title(1, 'SPECTRA')
        accordion.set_title(2, 'DIABATIZATION')
        display(accordion)

#core function
def main(file_name):
    names, kw = [], {}
    input_list = list(file_name.keys())
    if len(input_list) > 0:
        datas, norates  = [], []
        for file in input_list:
            data = file_name[file]['content']
            data = io.StringIO(data.decode('utf-8'))
            data = pd.read_csv(data)
            data.name = file.split('.')[0]
            wid = widgets.Text(
            value=data.name,
            placeholder=file,
            description='Molecule:',
            disabled=False,
            continuous_update=False
            )
            kw[data.name] = wid
            names.append(wid)
            norates.append(data)    
        
        run_but = widgets.ToggleButton(
        value=False,
        description='Read File',
        disabled=False,
        button_style='success', # 'success', 'info', 'warning', 'danger' or ''
        tooltip='Description',
        icon='check'
        )
        h   = widgets.GridBox(names, layout=widgets.Layout(grid_template_columns="repeat(3, 350px)"))
        h1  = widgets.HBox([h,run_but])
        tab = widgets.Tab()
        tab.children = (h1,)
        tab.set_title(0,'NAMES')
        kw['run'] = run_but
        
        w_body = widgets.interactive_output(body,{'files':fixed(norates), **kw})
        display(tab,w_body)

        
    else:
        pass 

###################################################
#Initializing main function and displaying widgets
###################################################
display(widgets.HTML(value = r'<p style="font-size:24px"><b>NEMO VISUALIZATION</b></p>'))
i = widgets.interactive(main, file_name=dropdown);#,run_button=run_but);
v = widgets.VBox([i.children[0],i.children[1]])
display(v)    