In [None]:
%matplotlib widget
%matplotlib widget
import matplotlib.pyplot as plt
import numpy as np
import os
import ipywidgets as widgets
from IPython.display import display, clear_output, HTML
import pandas as pd
from ipywidgets import fixed, Layout, Button, Box
import io
import nemoview.visualization as visualization
import nemoview.widgets
import nemo.analysis
import nemo.tools
import random
import functools
from nemo import __version__ as nemo_version
from nemoview import __version__ as nemoview_version
pd.options.mode.chained_assignment = None 
path = os.path.join(os.getcwd(),"dashstyle.mplstyle")
plt.style.use([path])
href = os.path.join(os.getcwd(),'figs','favicon.ico')
favicon_link = f'<link rel="icon" type="image/x-icon" href="{href}">'
HTML(favicon_link)
###################
#Preliminary setup#
###################
#getting the path file to dump txts after
try:
    with open('pathfile.txt','r') as p:
        for line in p:
            path_bash = line
            path_bash = path_bash.split('\n')[0]
except:
    path_bash = '.'


def download_button():
    dump_but = widgets.Button(
        description='Download',
        disabled=False,
        button_style='info', # 'success', 'info', 'warning', 'danger' or ''
        tooltip='Download the current plot as a high res png file',
        icon='download' # (FontAwesome names without the `fa-` prefix)
    )
    return dump_but

###Buttons to be used####
#file manager
dropdown = widgets.FileUpload(
    accept='.lx',  # Accepted file extension e.g. '.txt', '.pdf', 'image/*', 'image/*,.pdf'
    multiple=True,  # True to accept multiple files upload else False
    tooltip='Upload your ensemble files here'
)

def display_side_by_side(combined):
    """Display tables side by side to save vertical space
    """
    output = ""
    for caption, df in combined.items():
        output += df.set_table_styles([{'selector': 'caption','props': [('font-size', '18px')]}])._repr_html_()
        output += "\xa0\xa0\xa0"
    display(HTML(output))

In [None]:
#####################################
# widgets are included through two functions:
# XXX_widget : recieves the variables from the interface's buttons and treat it to become suitable for analysis
# XXX_func   : carries out the analysis and visualization
#####################################

def on_button_clicked(b, fig=None, name=None):
    filename = visualization.naming(name, path_bash)
    fig.savefig(os.path.join(path_bash,filename), dpi=600)

def check_dielec(dielec):
    if dielec[1]**2 > dielec[0]:
        #display message in html
        display(HTML(f'<div style="color: red; font-size: 20px;">Warning: n<sub>r</sub><sup>2</sup> must be &le; &epsilon;</div>'))
        return True
    return False

def diag_plot(**kwargs):
    eps = kwargs['eps']
    nr  = kwargs['nr']
    dielec = (eps,nr)
    if check_dielec(dielec):
        return
    cutoff = kwargs['cutoff']
    files = kwargs['files']
    names = kwargs['names']
    axs = kwargs['ax']
    combined = {}
    for ax in axs:
        fontsize = visualization.set_fontsize(ax)
        ax.clear()
        ax.axis('off')
        ax.set_xticklabels([])
        for t in ax.texts:
            t.remove()
            
    for j in range(len(kwargs['molecs'])):
            axs[j].set_title(kwargs['molecs'][j],loc='left',fontsize=fontsize)
            try:
                for file in files:
                    if names[file.name] == kwargs['molecs'][j]:
                        state = file['ensemble'][0]
                        data, _ = nemo.analysis.rates(state,dielec,data=file,ensemble_average=kwargs['ens'])
                        data.rename(columns=lambda x: x.split('(')[0], inplace=True)
                        data_show = data[(data['Prob'] > 100*cutoff) | data.Transition.str.contains('S0', regex=False) & (data['Prob'] != np.nan)]
                        visualization.plot_transitions(data,axs[j],cutoff)    
                        new_column_names = ['Transition', 'Rate(1/s)', 'Error(1/s)','Prob(%)','&lt;Gap&gt;(eV)','&lt;SOC&gt;(meV)','&lt;&sigma;&gt;(eV)','&lt;Conc&gt;(%)']
                        data_show.rename(columns=dict(zip(data_show.columns, new_column_names)), inplace=True)
                        # change numbers in column Conc > 100 to 100
                        data_show['&lt;Conc&gt;(%)'] = data_show['&lt;Conc&gt;(%)'].apply(lambda x: 100 if x > 100 else x)    
                        fmts = ['{:6}','{:.2e}','{:.2e}','{:.2f}','{:.3f}','{:.3f}','{:.3f}','{:.1f}']
                        formatter = {data_show.columns[i]: fmts[i] for i in range(len(data_show.columns))}
                        combined[f'{names[file.name]} {state}'] = data_show.style.hide(axis='index').set_table_attributes("style='display:inline'").set_caption(f'{names[file.name]} {state} Ensemble').background_gradient().format(formatter)   
                axs[j].set_ylim(bottom=-0.15)
                visualization.write_energies(axs[j])
            except:
                pass
    for ax in axs:
        ax.relim()     
    top = [ax.get_ylim()[1] for ax in axs]
    bot = [ax.get_ylim()[0] for ax in axs]
    if kwargs['legend']:
        for ax in axs:
            ax.set_ylim([min(bot),1.1*max(top)])
            ax.legend(fontsize=fontsize, loc='upper right',frameon=False,bbox_to_anchor=(1.0,1.3))
    axs[-1].text(1,0, f'$\epsilon ={dielec[0]:.3f}$\n$n={dielec[1]:.3f}$', transform=axs[-1].transAxes, fontsize=fontsize, verticalalignment='top', horizontalalignment='right')
    display_side_by_side(combined)
    clear_output(wait=True)                    
###################################

def diag_widget(files,names):
    kw = {}

    molecs = []
    for i in names.keys():
        mol = names[i]
        if mol not in molecs:
            molecs.append(mol)

    eps, nr, kw = nemoview.widgets.eps_nr(kw)
    ensemble, kw = nemoview.widgets.ensemble(kw)
    cutoff, kw = nemoview.widgets.cutoff(kw)
    legend, kw = nemoview.widgets.legend(kw)
    dump = download_button()
    vbox = widgets.VBox([eps,nr,cutoff])
    vbox2= widgets.VBox([ensemble,legend])
    hbox = widgets.HBox([vbox,vbox2,dump])
    display(hbox)

    fig, ax = plt.subplots(1,len(molecs),figsize=(11,4))
    if len(molecs) == 1:
        kw['ax']  = fixed([ax])
    else:
        kw['ax']  = fixed(ax)
    kw['files']  = fixed(files)
    kw['names']  = fixed(names)
    kw['molecs'] = fixed(molecs)
    wid = widgets.interactive_output(diag_plot,kw)
    fig.show()
    display(wid)
    dump.on_click(functools.partial(on_button_clicked,fig=fig,name='diagram.png'))
    
def spec_plot(**kwargs):
    ax = kwargs['ax']
    fontsize = visualization.set_fontsize(ax)
    ax2 = kwargs['ax2']
    names = kwargs['names']
    ax.clear()    
    ax2.clear()
    ax2.set_xlim([0,kwargs['maxsusc']])
    dielec = (kwargs['eps'],kwargs['nr'])
    if check_dielec(dielec):
        return
    STATS, ABS, EMI, RF, LIFE  = [], [], [], [], [] 
    combined, combined2 = {}, {}
    lw = fontsize/5#3
    for file in kwargs['files']:
        state = file['ensemble'][0]
        for tipo in kwargs[file.name]:
            if tipo == 'Emission':
                res, emi, breakdown = nemo.analysis.rates(state,dielec,data=file,detailed=True)
                emi2 = emi.copy()
                emi2.name = f'{names[file.name]} {state}'
                EMI.append(emi2)
                x = emi['Energy'].values
                y = emi['Diffrate'].values
                err = emi['Error'].values
                label = f'{names[file.name]}: {state[0]}$_{state[1]}\\: \\to \\:$S$_0$'
                rate = res[res.Transition == state.upper()+'->S0']['Rate(s^-1)'][0]
                error = res[res.Transition == state.upper()+'->S0']['Error(s^-1)'][0]
                LIFE.append([names[file.name],state,1/rate,(1/rate)*(error/rate)])
                if kwargs['net']:
                    # normalize each column
                    showdown = breakdown[[state.upper()+'->S0']].apply(lambda x: 100*x/x.sum(), axis=0)
                    suscs = breakdown[['chi_'+state.lower()]].to_numpy()
                    #append susceptibility to the end of the dataframe
                    showdown['chi_'+state.lower()] = suscs
                    # use file['geometry'] as index
                    showdown.index = file['geometry'].astype(int)
                    #reorder rows from high to low
                    showdown = showdown.reindex(showdown.sum(axis=1).sort_values(ascending=False).index)
                    num = min(5,len(showdown))
                    combined2[f'{names[file.name]} {state}'] = showdown.iloc[:num,:].style.set_table_attributes("style='display:inline'").set_caption(f'{names[file.name]} {state}').background_gradient().format({state.upper()+'->S0':'{:.2f}%', 'chi_'+state.lower():'{:.2f}'})
            elif tipo == 'Absorption':
                nstates = int(kwargs['nstates'])
                if nstates == 0:
                    nstates = -1
                abs_spec, breakdown = nemo.analysis.absorption(state,dielec,data=file,save=False,detailed=True,nstates=nstates)
                abs2 = abs_spec.copy()
                abs2.name = f'{names[file.name]} {state}'
                ABS.append(abs2)
                abs_spec = abs_spec.to_numpy()
                x = abs_spec[:,0]
                y = abs_spec[:,1:]
                err = abs_spec[:,-1]
                dec = y[:,:-2]/np.max(y)
                y = y[:,-2]
                label = f'{names[file.name]}: {state[0]}$_{state[1]}\\: \\to \\: \sum_n^{{{str(dec.shape[1]+int(state[1]))}}}$ {state[0]}$_n$'
                
            err /= np.max(y)
            y = y/np.max(y)
            x,y,err = visualization.relevant(x,y,err,kwargs['miny'])
            if kwargs['wave']:
                ax.plot(1239.8/x,y,label=label,lw=lw)
                ax.fill_between(1239.8/x,y-err,y+err,alpha=0.5)
            else:
                ax.plot(x,y,label=label,lw=lw)
                ax.fill_between(x,y-err,y+err,alpha=0.5)
            if kwargs['net']:
                visualization.network_spectrum(breakdown,(ax,ax2),state,tipo[:3].lower(),kwargs['wave'])


            if tipo == 'Absorption' and kwargs['decomp']:
                cmap = plt.get_cmap('magma')
                nums = np.linspace(0,0.99,dec.shape[1])
                cmaplist = [cmap(i) for i in nums]
                dec_label = [f'{names[file.name]}: {state[0]}$_{state[1:]}\\: \\to \\:${state[0]}$_{{{i}}}$' for i in range(int(state[1:])+1,dec.shape[1]+int(state[1:])+1)]
                for i in range(dec.shape[1]):
                    if kwargs['wave']:
                        ax.plot(1239.8/x,dec[:,i],color=cmaplist[i],label=dec_label[i])
                    else:    
                        ax.plot(x,dec[:,i],color=cmaplist[i],label=dec_label[i])
                    peak = visualization.get_peak(dec[:,i],x)   
                    STATS.append([f'{tipo[:3]} {names[file.name]} {state} to {state[0]}{i+1}',state,peak,1239.8/peak])         
            peak = visualization.get_peak(y,x)
            STATS.append([f'{tipo[:3]} {names[file.name]}',state,peak,1239.8/peak])    
    stats = pd.DataFrame(STATS,columns=['Spectrum','State','Peak (eV)','Peak (nm)'])
    combined['stats'] = stats.style.hide(axis='index').set_table_attributes("style='display:inline'").set_caption('Spectral Peaks').background_gradient().format({'Peak (eV)':'{:.2f}','Peak (nm)':'{:.0f}'}) 
    for abso in ABS:
        for emis in EMI:
            radius, dradius = visualization.radius(abso,emis,kwargs['kappa'])
            RF.append([emis.name,abso.name,radius,dradius])
    life = pd.DataFrame(LIFE,columns=['Molecule','State','Lifetime (s)','Error (s)'])
    combined['life'] = life.style.hide(axis='index').set_table_attributes("style='display:inline'").set_caption('Radiative Lifetimes').background_gradient().format({'Lifetime (s)':'{:.2e}','Error (s)':'{:.2e}'})
    rf = pd.DataFrame(RF,columns=['Donor','Acceptor','Radius (&#8491;)','Error (&#8491;)'])
    # sort by Donor
    rf = rf.sort_values(by=['Donor','Acceptor'])
    combined['radii'] = rf.style.hide(axis='index').set_table_attributes("style='display:inline'").set_caption(f'Förster Radii (&kappa;<sup>2</sup>={kwargs["kappa"]:.2f})').background_gradient().format({'Radius (&#8491;)':'{:.1f}','Error (&#8491;)':'{:.1f}'})
    ax.set_ylabel('Normalized Intensity')
    if kwargs['wave']:
        ax.set_xlabel('Wavelength (nm)',fontsize=fontsize)
    else:    
        ax.set_xlabel('Energy (eV)',fontsize=fontsize)
    ax.set_ylim(bottom=0)
    title = f'$\epsilon ={dielec[0]:.3f}$\n$n={dielec[1]:.3f}$'    
    if kwargs['net']:
        ax2.set_visible(True)
        ax2.set_ylim(0.0,1.5)
        ax2.grid(False)
        ax2.set_xlabel('Susceptibility (eV)',fontsize=fontsize)
        ax.set_ylim(0,1.5)
        display_side_by_side(combined2)
    else:
        #hide ax2
        ax2.set_visible(False)
        ax.set_ylim(bottom=0)    
    #set legend outside of plot
    ax.legend(title=title,bbox_to_anchor=(1.01, 1), loc='upper left', borderaxespad=0.,fontsize=fontsize)
    ax.set_zorder(ax2.get_zorder()+1)
    ax.patch.set_visible(False)
    display_side_by_side(combined)
    
def spec_widget(files,names):
    WIDS = []
    kw = {}
    for file in files:
        if file['ensemble'][0] == 'S0':
            options = ['Absorption']
        else:    
            options = ['Absorption', 'Emission']
        select = widgets.SelectMultiple(
            options=options,
            value=[options[-1]],
            description='Spectra',
            tooltip='Select spectra to plot',
            disabled=False,
        )
        WIDS.append(select)
        kw[file.name] = select
    eps, nr, kw = nemoview.widgets.eps_nr(kw)
    kappa, kw = nemoview.widgets.kappa(kw)
    maxsusc, kw = nemoview.widgets.maxsusc(kw)    
    wave, kw = nemoview.widgets.wave(kw)
    net, kw = nemoview.widgets.net(kw)
    decomp, kw = nemoview.widgets.decomp(kw)
    nstates, kw = nemoview.widgets.nstates(kw)
    miny, kw = nemoview.widgets.miny(kw)
    tab  = widgets.Tab()
    tab.children = WIDS
    dump = download_button()
    #setting up the titles of the table
    for i in range(len(files)):
        tab.set_title(i,names[files[i].name] +' ' + files[i]['ensemble'][0])
    vbox = widgets.VBox([eps,nr,nstates,miny]) 
    vbox2 = widgets.VBox([kappa,maxsusc,net,wave,decomp])  
    vbox3 = widgets.VBox([dump]) 
    box  = widgets.HBox([tab,vbox,vbox2,vbox3])
    display(box)
    fig, ax = plt.subplots(figsize=(11,5))
    ax2 = ax.twiny()
    kw['ax2'] = fixed(ax2)
    kw['files'] = fixed(files)
    kw['ax']    = fixed(ax)
    kw['names'] = fixed(names)
    kw['miny'] = miny
    wid = widgets.interactive_output(spec_plot,kw)
    fig.show()
    display(wid)
    dump.on_click(functools.partial(on_button_clicked,fig=fig,name='spectra.png'))


def corr_plot(**kwargs):
    ax = kwargs['ax']
    ax2 = kwargs['ax2']
    ax2.clear()
    for a in ax:
        fontsize = visualization.set_fontsize(a)
        a.clear()       
    dielec = (kwargs['eps'],kwargs['nr'])
    if check_dielec(dielec):
        return     
    gran = 10**kwargs['gran']
    names = kwargs['names']
    alphast2   = nemo.tools.get_alpha(kwargs['eps'])  
    alphaopt2  = nemo.tools.get_alpha(kwargs['nr']**2)
    yticks = []
    for file in kwargs['files']:
        state = file['ensemble'][0]
        alphaopt1  = nemo.tools.get_alpha(file['nr'][0]**2)
        Singlets    =  file[[i for i in file.columns.values if 'e_s' in i and 'osc' not in i]].to_numpy()
        Triplets    =  file[[i for i in file.columns.values if 'e_t' in i and 'osc' not in i]].to_numpy()
        Ss_s        =  file[[i for i in file.columns.values if 'd_s' in i]].to_numpy()
        Ss_t        =  file[[i for i in file.columns.values if 'd_t' in i]].to_numpy()
        
        deltaS = Singlets - (alphaopt2/alphaopt1)*Ss_s
        deltaT = Triplets - (alphaopt2/alphaopt1)*Ss_t
        argsortS = np.argsort(deltaS)
        argsortT = np.argsort(deltaT)
        dss = np.take_along_axis(Ss_s,argsortS,axis=1)
        dst = np.take_along_axis(Ss_t,argsortT,axis=1)
        Ds  = np.concatenate((dss,dst),axis=1)/alphaopt1
        # turn ds into dataframe with columns s0,s1,s2,s3,t0,t1,t2,t3
        Ds = pd.DataFrame(Ds/alphaopt1,columns=[f'S{i+1}' for i in range(dss.shape[1])]+[f'T{i+1}' for i in range(dst.shape[1])])
        if 'S' in state and state != 'S0':
            deltaS = Singlets - (alphast2/alphaopt1)*Ss_s
            argsortS = np.argsort(deltaS)
            dss = np.take_along_axis(Ss_s,argsortS,axis=1)
            # substitute column named state with dss
            Ds[state] = dss[:,int(state[1:])-1]/alphaopt1
        elif 'T' in state:
            deltaT = Triplets - (alphast2/alphaopt1)*Ss_t
            argsortT = np.argsort(deltaT)
            dst = np.take_along_axis(Ss_t,argsortT,axis=1)
            # substitute column named state with dst
            Ds[state] = dst[:,int(state[1:])-1]/alphaopt1    
        for header in kwargs[file.name]:
            ds = Ds[header].to_numpy()
            line = -np.sort(-np.array(random.choices(ds,k=10000)))
            try:
                heat = np.vstack((heat,line[np.newaxis,:]))
            except:
                heat = line[np.newaxis,:]
            lab = f'{names[file.name]} {header[0]}$_{{{header[1:]}}}$@{{{state[0]}$_{{{state[1:]}}}$}}'    
            yticks.extend([f'{names[file.name]}\n {header[0]}$_{{{header[1:]}}}$@{{{state[0]}$_{{{state[1:]}}}$}}'])
            hist,bins = visualization.spectrum(ds,gran)
            ax[0].plot(bins,hist,label=lab,lw=fontsize/5) #f'{names[file.name]} {state} - {header}')
    heatmap = ax[1].imshow(heat.T, cmap='coolwarm')
    heatmap.set_interpolation('none')
    #keep aspect ratio
    ax[1].set_aspect('auto')      
    # use labels as yticks
    ax[1].set_xticks(np.arange(len(yticks)))
    ax[1].set_xticklabels(yticks) 
    #remove xticks
    ax[1].set_yticks([]) 
    # set heatmap min and max to 0 and 1
    heatmap.set_clim(0,kwargs['maxsusc'])
    ax[2].imshow(np.linspace(kwargs['maxsusc'],0.0,100)[:,np.newaxis], cmap='coolwarm')
    ax[2].set_aspect('auto')
    # set 2 yticks
    ax[2].set_yticks([0,99])
    # set yticklabels to 0 and 1
    ax[2].set_yticklabels([f"{kwargs['maxsusc']:.1f}",'0'])
    ax2.set_ylim(ax[2].get_ylim())
    ax2.set_yticks([99])
    ax2.set_yticklabels(['LE'])
    ax2.set_ylabel('CT character', labelpad=-20)

    
    ax[2].set_xticks([])
    #ax[2].set_ylabel('Susceptibility (eV)', labelpad=-20)
    #remove grid
    ax[2].grid(False)
    ax[0].set_xlim(left=0)
    ax[0].set_ylim(bottom=0)
    ax[0].set_xlabel('Solvent Susceptibility (eV)') 
    ax[0].set_ylim(bottom=0)
    ax[1].set_ylabel('Ensemble Composition')
    eps = kwargs['eps']
    nr  = kwargs['nr']
    title = f'$\epsilon ={eps:.3f}$\n$n={nr:.3f}$'
    ax[0].legend(loc='best',title=title)
    # set title to the left
    ax[0].set_title('a)',loc='left')
    ax[1].set_title('b)',loc='left')
    
def corr_widget(files,names):
    WIDS = []
    kw = {}
    for file in files:
        states = [i.split('_')[1].upper() for i in file.columns if 'd_' in i]
        select = widgets.SelectMultiple(
            options=states,
            value=[states[0]],
            #rows=10,
            description='States',
            tooltip='Select states to plot susceptibility for',
            disabled=False,
        )
        WIDS.append(select)
        kw[file.name] = select
    gran_slider, kw = nemoview.widgets.gran_slider(kw)
    maxsusc, kw = nemoview.widgets.maxsusc(kw)   
    eps, nr, kw = nemoview.widgets.eps_nr(kw)
    tab  = widgets.Tab()
    tab.children = WIDS
    dump = download_button()
    #setting up the titles of the table
    for i in range(len(files)):
        tab.set_title(i,names[files[i].name] +' ' + files[i]['ensemble'][0])
    vbox = widgets.VBox([gran_slider,eps,nr,maxsusc])
    box  = widgets.HBox([tab,vbox,dump])
    display(box)
    fig, ax = plt.subplots(1,3,figsize=(11,4),gridspec_kw={'width_ratios': [1, 1,0.05]})
    ax2 = ax[2].twinx()
    kw['ax2']   = fixed(ax2)
    kw['files'] = fixed(files)
    kw['ax']    = fixed(ax)
    kw['names'] = fixed(names)
    wid = widgets.interactive_output(corr_plot,kw)
    display(wid)
    fig.show()
    dump.on_click(functools.partial(on_button_clicked,fig=fig,name='susceptibilities.png'))

def netw_plot(**kwargs):
    combined = {}
    eps = kwargs['eps']
    nr  = kwargs['nr']
    dielec = (eps,nr)
    if check_dielec(dielec):
        return
    files = kwargs['files']
    names = kwargs['names']
    axs = kwargs['ax']
    axs2 = kwargs['ax2']
    molecs = kwargs['molecs']
    mapa = {'S':'left','T':'right'}
    for a in axs:
        a.clear()
    for a in axs2:
        a.clear()

    for j in range(len(molecs)):
            axs[j].set_title(kwargs['molecs'][j],loc='left')
            initials = [i.split('~>')[0] for i in kwargs[molecs[j]]]
            try:
                for file in files:
                    state = file['ensemble'][0]
                    if names[file.name] == molecs[j] and state in initials:
                        _, _, breakdown = nemo.analysis.rates(state,dielec,data=file,detailed=True)
                        possible = [i for i in kwargs[molecs[j]] if i.split('~>')[0] == state]
                        # normalize each column
                        showdown = breakdown[possible].apply(lambda x: 100*x/x.sum(), axis=0)
                        showdown.index = file['geometry'].astype(int)
                        #reorder rows from high to low
                        showdown = showdown.reindex(showdown.sum(axis=1).sort_values(ascending=False).index)
                        num = min(10,len(showdown))
                        combined[f'{names[file.name]} {state}'] = showdown.iloc[:num,:].style.set_table_attributes("style='display:inline'").set_caption(f'{names[file.name]} {state} Ensemble').background_gradient().format('{:.2f}%')
                        
                        for trans in possible:
                            visualization.plot_network(breakdown,axs[j],mapa[state[0]],trans)
            except:
                pass
    for Axs in [axs,axs2]:
        for ax in Axs:
            #remove upper x axis
            ax.spines['top'].set_visible(False)
            ax.spines['bottom'].set_visible(False)
            #remove grid
            ax.grid(False)
            #remove xticks
            ax.set_xticks([])
            ax.set_xlim(-1,1)
            #make left y axis thicker
            ax.spines['left'].set_linewidth(5)
            #same for right y axis
            ax.spines['right'].set_linewidth(5)
            # change color of y axis
            ax.spines['left'].set_color('#4477AA')
            # change color of y axis
            ax.spines['right'].set_color('#EE6677')
            ax.set_ylabel('Susceptibility (eV)')
            # remove external ticks
            ax.tick_params(axis='both', which='both', length=0)
            ax.relim()    
            top = [ax.get_ylim()[1] for ax in axs]
            bot = [ax.get_ylim()[0] for ax in axs]
    for Axs in [axs,axs2]:
        for ax in Axs:
            ax.set_ylim([min(bot),1.1*max(top)])
    for ax in axs:
        ax.legend()
    display_side_by_side(combined)    

def netw_widget(files,names):
    kw = {}
    eps, nr, kw = nemoview.widgets.eps_nr(kw)
    WIDS = []
    molecs = []
    for i in names.keys():
        mol = names[i]
        if mol not in molecs:
            molecs.append(mol)
    
    for j in range(len(molecs)):
        transitions = []
        for file in files:
            try:
                if names[file.name] == molecs[j]:
                    state = file['ensemble'][0]
                    if state != 'S0':
                        data, _ = nemo.analysis.rates(state,(1,1),data=file)
                        iscs = data[data.Transition.str.contains('~')].Transition.tolist()
                        iscs = [i for i in iscs if '~>S0' not in i]
                        transitions.extend(iscs)
            except:
                pass            
        select = widgets.SelectMultiple(
            options=transitions,
            value=[transitions[0]],
            description='Transitions',
            disabled=False,
        )
        WIDS.append(select)
        kw[molecs[j]] = select
    
    dump = download_button()
    tab  = widgets.Tab()
    tab.children = WIDS
    #setting up the titles of the table
    for i in range(len(molecs)):
        tab.set_title(i,molecs[i])


    
    vbox = widgets.VBox([eps,nr])
    hbox = widgets.HBox([tab,vbox,dump])
    display(hbox)

    fig, ax = plt.subplots(1,len(molecs),figsize=(11,4))

    if len(molecs) == 1:
        Ax = [ax]
    else:
        Ax = ax
    kw['ax']  = fixed(Ax)
    ax2 = [a.twinx() for a in Ax]
    kw['ax2'] = fixed(ax2)
    kw['files']  = fixed(files)
    kw['names']  = fixed(names)
    kw['molecs'] = fixed(molecs)
    wid = widgets.interactive_output(netw_plot,kw)
    fig.show()
    display(wid)
    dump.on_click(functools.partial(on_button_clicked,fig=fig,name='network.png'))    


def body(**kwargs):
    if kwargs['run']:
        files   = kwargs['files']
        datas = []
        names = {}
        for file in files:
            names[file.name] = kwargs[file.name]
            if file['ensemble'][0] != 'S0':
                datas.append(file)

        w_diag = widgets.interactive(diag_widget,files=fixed(datas),names=fixed(names))
        w_spec = widgets.interactive(spec_widget,files=fixed(files),names=fixed(names))
        w_corr = widgets.interactive(corr_widget,files=fixed(files),names=fixed(names))    
        w_netw = widgets.interactive(netw_widget,files=fixed(files),names=fixed(names))

        accordion = widgets.Accordion(children=[w_diag, w_spec,w_corr,w_netw], selected_index=0)
        accordion.set_title(0, 'DIAGRAM')
        accordion.set_title(1, 'SPECTRA')
        accordion.set_title(2, 'SUSCEPTIBILITY')
        accordion.set_title(3, 'NETWORK')
        display(accordion)

#core function
def main(file_name):
    names, kw = [], {}
    file_name = {file_name[i]['name']:file_name[i] for i in range(len(file_name))}
    input_list = list(file_name.keys())
    if len(file_name) > 0:
        datas, norates  = [], []
        for file in input_list:
            data = file_name[file]['content']
            data = io.BytesIO(data)
            data = pd.read_csv(data)
            data.name = file.split('.')[0]
            wid = widgets.Text(
            value=data.name,
            placeholder=file,
            description='Molecule:',
            disabled=False,
            continuous_update=False
            )
            kw[data.name] = wid
            names.append(wid)
            norates.append(data)    
        
        run_but = widgets.ToggleButton(
        value=False,
        description='Read File',
        disabled=False,
        button_style='success', # 'success', 'info', 'warning', 'danger' or ''
        tooltip='Description',
        icon='check'
        )
        # Write a html text to explain the input
        html_text = widgets.HTML(
        value="<b>Ensemble files belonging to the same molecule should have the same name</b>",
        placeholder='naming convention',
        description='',
        )
        #display(html_text)
        h   = widgets.GridBox(names, layout=widgets.Layout(grid_template_columns="repeat(3, 350px)"))
        h1  = widgets.HBox([h,run_but])
        v   = widgets.VBox([html_text, h1])
        tab = widgets.Tab()
        tab.children = (v,)
        tab.set_title(0,'NAMES')
        kw['run'] = run_but
        
        w_body = widgets.interactive_output(body,{'files':fixed(norates), **kw})
        display(tab,w_body)

        
    else:
        pass 

###################################################
#Initializing main function and displaying widgets
###################################################
logo = widgets.Image(value=open(os.path.join(os.getcwd(),'figs','nemoview.png'), 'rb').read(), format='png')
# set the size of the image
ratio = 529/1134
width = 300
logo.layout.width = str(width)+'px'
logo.layout.height = str(int(ratio*width))+'px'
n1 = widgets.HTML(value = f'<b>NEMO: {nemo_version.__version__} NEMOview: {nemoview_version.__version__}<b>')
#n2 =widgets.HTML(value = f'<b>NEMOview: {nemoview_version.__version__}<b>')
#n12 = widgets.VBox([n1,n2])
i = widgets.interactive(main, file_name=dropdown);
v = widgets.VBox([logo,n1,i.children[0],i.children[1]])
display(v)