# <ins>Dash application</ins>: **Longitudinal modelling of the co-development of depression and cardio-metabolic risk from childhood to young adulthood**

In [22]:
import os, pyreadr, itertools
import pandas as pd
import numpy as np
import json
import re

from dash import Dash, html, dcc, callback, Output, Input, State, ctx, no_update, dash_table, get_asset_url
import dash_bootstrap_components as dbc
import dash_cytoscape as cyto
import textwrap

# import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots

# pd.set_option('display.max_rows', None)

In [1]:
from backend_funcs import get_label, desc_plot, model_structure, read_res1, best_fit1, make_plot1, make_net1, stylenet1, make_table1, read_res3, timemarks3, make_net3, stylenet3

In [3]:
a,b,c,d = read_res1('SDEP','FMI')

In [73]:
modstr = pd.DataFrame(model_structure['full']).T

df = c.loc[modstr.index[0]]

name = 'CL'

df.loc[df.label.str.contains(name)].iloc[::-1][['est','sign']].round(2).iloc[0]


est     0.013674
sign            
Name: full, dtype: object

In [202]:
def make_table1(depname, cmrname, which_model='maCL_dep-maCL_cmr-maAR_dep-maAR_cmr'):
    '''Input: names of the depression report (sDEP = self or mDEP = parental reports) and cardio-metabolic risk (CMR) marker; model structure.
       Extracts the fit measures for the specified model and stores into a table to be displayed next to the graph. 
    '''
    fitm = read_res1(depname, cmrname)[1]
    
    if which_model == 'best':
        which_model = fitm.index[fitm.aic == fitm.aic.min()][0]# Best fitting model (lowest AIC)
    
    dt = pd.DataFrame(fitm.loc[which_model, ['npar', 'df', 'chisq', 'pvalue','cfi', 'tli','rmsea','srmr','aic', 'bic', ]])
    dt = dt.rename(columns={ which_model:' '}).round(3)
    dt.insert(loc=0, column='Fit measures', value=['Number of parameters','Degrees of freedom','\u03C7\u00b2','P-value','CFI','TLI','RMSEA','SRMR','AIC','BIC'])

    def format_row(x, form='{0:.2f}'):
        try: return form.format(x)
        except: return x

    dt.loc[['npar','df']] = dt.loc[['npar','df']].applymap(format_row, form='{0:.0f}')
    dt.loc[['chisq','cfi','tli','rmsea','srmr']] = dt.loc[['chisq','cfi','tli','rmsea','srmr']].applymap(format_row, form='{0:.2f}')
    dt.loc[['pvalue']] = dt.loc[['pvalue']].applymap(format_row, form='{0:.3f}')
    dt.loc[['aic','bic']] = dt.loc[['aic','bic']].applymap(format_row, form='{0:.1f}')

    return(dt)

d = make_table1('sDEP','FMI')

d


Unnamed: 0_level_0,Fit measures,Unnamed: 2_level_0
rownames,Unnamed: 1_level_1,Unnamed: 2_level_1
npar,Number of parameters,47.0
df,Degrees of freedom,43.0
chisq,χ²,449.61
pvalue,P-value,0.0
cfi,CFI,0.99
tli,TLI,0.98
rmsea,RMSEA,0.03
srmr,SRMR,0.02
aic,AIC,308746.4
bic,BIC,309083.9


In [39]:
extr_est(which=1, imp_corr=True)

KeyError: 1

In [33]:
pd.set_option('display.max_rows', None)
c.loc[c['pvalue']==0.050] # .round(3) 

Unnamed: 0_level_0,lhs,op,rhs,label,est,se,z,pvalue,ci.lower,ci.upper,sign
"rep(f, nrow(es))",Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1


In [2]:
PATH = '/Users/Serena/Desktop/panel_network/results/'

## Tab 1: generalized cross-lagged panel model 

Rscript 1 is designed to produce a single .RData file for each dep-cmr marker pair. This contains the following elements: 
- **`summ`**: a summary dataframe (with information about which marker is used and the timepoints included + mean ranges and number of observations)
- **`fit_meas`**: fit measures for every parameter combination, when the model converged.
- **`estimates`**: (unstandardized) estimates (+ bootsrapped SE, pvalues and CIs) for every parameter combination, when the model converged.
- **`failed`**: list of models that did not converge, with corresponding error or warning message.


In [35]:
def read_res1(depname, cmrname, path=PATH+'mod1/'):
    res = pyreadr.read_r(f'{path}{depname}_{cmrname}.RData')
    summ = res['dat_summ']
    fitm = res['fit_meas'].T
    esti = res['estimates'].set_index('rep(f, nrow(es))')
    fail = list(res['failed'].index) # TODO: report warning message ...
    return(summ, fitm, esti, fail)
# use
# summ, fitm, esti, fail = read_res1('sDEP','FMI') 
# summ = read_res1('sDEP','BMI')[0]

### Descriptives plot 
First, plot the **median** and **interquartile ranges** of each measure included in the model, against time. This gives a more complete understanding of the data that is fed into the models.

In [None]:
def make_plot1(depname, cmrname):
    # load summary dataframe
    summ = read_res1(depname,cmrname)[0]
    
    # extract timepoints
    t_dep = [ float(x.split('_')[-1][:-1]) for x in summ.columns[:summ.shape[1]//2] ]
    t_cmr = [ float(x.split('_')[-1][:-1]) for x in summ.columns[summ.shape[1]//2:] ]

    # scatterplot function 
    def scat(t, name, fullname, shortname):
        means = summ.loc['Median', summ.columns.str.contains(name)]
        p = go.Scatter(x = t, y = means, 
                       error_y = dict(type='data', symmetric=False, # visible=True,
                                      array = summ.loc['3rd Qu.', summ.columns.str.contains(name)] - means,
                                      arrayminus = means - summ.loc['1st Qu.', summ.columns.str.contains(name)]),
                       name = fullname, text = [f'{shortname} {n}' for n in range(1,len(t)+1)],
                       marker = dict(size = 10, symbol = 'square',opacity = .8), opacity = .7,
                       hovertemplate = """ <b>%{text}</b> <br> Median: %{y:.2f} <br> Timepoint: %{x} years <br><extra></extra>""")
        return p

    fig = make_subplots(specs=[[{'secondary_y': True}]])
    
    fig.add_trace( scat(t_dep, depname, 'Depression score', 'DEP'), secondary_y=False)
    fig.add_trace( scat(t_cmr, cmrname, cmrname, cmrname), secondary_y=True )

    # Set y-axes
    def yrange(name):
        sub = summ.filter(like=name)
        ymin = sub.min(axis=1)['1st Qu.']; ymax = sub.max(axis=1)['3rd Qu.']
        range = ymax-ymin
        y_max_lower = ymax
        ymin = ymin - (range/10); ymax = ymax + (range/10)
        return [ymin, ymax, y_max_lower]
    
    fig.update_yaxes(title_text='<b>Depression score</b>', secondary_y=False, range=yrange(depname)[:2], 
                     mirror=True, ticks='outside', showline=True, linecolor='black', gridcolor='lightgrey')
    fig.update_yaxes(title_text=f'<b>{cmrname}</b>', secondary_y=True, range=yrange(cmrname)[:2], 
                     mirror=True, ticks='outside', showline=True, linecolor='black', gridcolor='lightgrey')
    # Set x-axis 
    fig.update_xaxes(title_text='Years', mirror=True, ticks='outside', showline=True, linecolor='black', gridcolor='lightgrey')

    # Group "cross-sectional" timepoints using grey background
    # ymin = summ.min(axis=1)['1st Qu.']-1; ymax = summ.max(axis=1)['3rd Qu.']+2
    crosspoints = []
    for i in range(len(t_dep)):
        xmin = min(t_dep[i], t_cmr[i])-.2; xmax = max(t_dep[i], t_cmr[i])+.2
        # rectangles 
        crosspoints.append( dict(x0 = str(xmin), x1 = str(xmax), y0 = 0, y1 = 1, xref='x', yref='paper', 
                                 type='rect', fillcolor='lightgray', opacity=.3, line_width=0, layer='below') )  
        # text 
        fig.add_trace( go.Scatter(x=[xmin + (xmax-xmin)/2], y=[yrange(depname)[2]], mode='text', text=i+1, 
                                  textposition='top center',
                                  textfont_size=13, textfont_color='dimgray', showlegend=False) )
    # Background
    fig.update_layout(# title = dict(text='Included measures\n', font=dict(size=15), automargin=True, yref='paper'),
                      plot_bgcolor='white', shapes=crosspoints, margin=dict(l=20, r=20, t=20, b=20))
    
    return(fig)

### Best fitting model
Find the best fitting model (i.e., lowest AIC) and describe its structure using the `model_structure` matrix.

In [51]:
# Define the matrix of all possible paramter combinations. 
model_structure = pd.read_csv(PATH+'../mats/model_structure.csv').set_index('Unnamed: 0')

# m = pd.DataFrame(itertools.product([0, 1], repeat=4)).T # all possible combinations
# NOTE: this is the same as the mat matrix used to fit the models in Rscript 1.

def best_fit(depname, cmrname, list1=None):
    fitm = read_res1(depname, cmrname)[1]
    # Best fitting model (lowest AIC)
    mod = fitm.index[fitm.aic == fitm.aic.min()][0]
    if list1: # Return list of parameters estimated in the model 
        return list(model_structure.index[(model_structure[mod] > 0) & (model_structure.index.str.contains(list1))])
    else: # Return a dataframe with its name and model structure
        return pd.DataFrame(model_structure[mod])
    

In [56]:
model_structure.index

Index(['maCL_dep', 'maCL_cmr', 'maAR_dep', 'maAR_cmr', 'ltCL_dep', 'ltCL_cmr',
       'ltAR_dep', 'ltAR_cmr'],
      dtype='object', name='Unnamed: 0')

### Graph diagram and model fit table 
Read in the results, select the best fitting model and construct its graph.

In [None]:
def make_net1(depname, cmrname, which_model = 'best', width=1000):
    # read data
    summ,_,esti,fail = read_res1(depname, cmrname)
    
    # Best fitting model
    if which_model == 'best': modstr = best_fit(depname, cmrname).T
    elif which_model in fail: 
        return 'fail' # modstr = best_fit(depname, cmrname).T
    else: modstr = pd.DataFrame(model_structure[which_model]).T
    
    # Extract the estimated paramters from the result files
    def extr_est(name=None, which=None, lamb=False, eta_corr=False, imp_corr=False, model_output=esti):
        df = model_output.loc[modstr.index[0]]
        if   lamb:     l = list(round(df.loc[(df.lhs==f'eta_{name}')&(df.op=='=~'),'est'], 2))[which]
        elif eta_corr: l = df.loc[(df.lhs=='eta_dep') & (df.rhs=='eta_cmr'), 'est'][0]
        elif imp_corr: l = df.loc[ df.label.str.contains('comv'), 'est'][which]
        else:          l = list(round(df.loc[df.label.str.contains(name)].iloc[::-1]['est'], 2))[which]
        return(l)

    # Get timepoints of measurement from summary
    timedic = dict(zip([f'dep{i+1}' if 'DEP' in n else f'cmr{i-summ.shape[1]//2+1}' for i,n in enumerate(summ.columns)],
                       [float(i.split('_')[-1][:-1]) for i in summ.columns]))

    # Ready to draw
    nt = summ.shape[1]//2 # Number of timepoints
    
    pos_top = 30; pos_bot = 550 # Vertical cohordinates (in pixel)
    vs = ['dep','cmr']
    
    e = [] # Initialize
   
    for eta, pos in enumerate([pos_top, pos_bot]):
        # Eta factors nodes
        e.append({'data': {'id':f'eta{eta}', 'label':'Eta'}, 'classes':'latent', 'position':{'x':width/2,'y':pos}})

    #  Eta factors correlations 
    e.append({'data': {'source':'eta0', 'target':'eta1', 'firstname':'eta_corr', 'label':'%.2f' % extr_est(eta_corr=True) }})
    
    for i in range(1,nt+1): 
        
        e.append({'data': {'source':f'imp_dep{i}', 'target':f'imp_cmr{i}', 'firstname':'imp_corr', 
                           'label':'%.2f' % extr_est(which=i-1, imp_corr=True) }})
        
        for eta, v in enumerate(vs):
            
    # ===== Other nodes
            p = [pos_top+90, pos_top+190] if v=='dep' else [pos_bot-90,pos_bot-190] # define position
            e.extend([
                # Observed variables
                {'data': {'id':f'{v}{i}', 'label':summ.columns[(i-1)+(nt*eta)], 'firstname':f'{v.upper()} {i}'}, 
                 'classes':'observed',
                 'position' : {'x':((width/nt)*i)-(width/nt)/2, 'y': p[0]} },
                # Impulses
                {'data': {'id':f'imp_{v}{i}', 'label': f'impulse {i}'}, 
                 'classes':'latent',
                 'position' : {'x':((width/nt)*i)-(width/nt)/2, 'y': p[1]} }
            ])
        
    # ===== Edges: lambdas
            e.append({'data': {'source':f'eta{eta}', 'target':f'{v}{i}', 'firstname':'lambda',
                               'label':'%.2f' % extr_est(f'{v}',i-1, lamb=True) }})
            # impulses link
            e.append({'data': {'source':f'imp_{v}{i}', 'target':f'{v}{i}', 'firstname':'imp_link'}})
            
            if i < nt: 
                otherv = abs(eta-1)
                # maAR and AR terms
                if modstr[f'ltAR_{v}'][0]: e.append({'data': {'source':f'{v}{i}', 'target':f'{v}{i+1}',
                                   'weight':'%.2f' % extr_est(f'^AR_{v}', i-1) * (timedic[f'{v}{i+1}'] - timedic[f'{v}{i}']), 
                                   'label': f'AR{i}', 'firstname':'direct'}})
                if modstr[f'maAR_{v}'][0]: e.append({'data': {'source':f'imp_{v}{i}', 'target':f'{v}{i+1}',
                                   'weight':'%.2f' % extr_est(f'^maAR_{v}',i-1), 'label': f'maAR{i}', 'firstname':'direct'}})
                # maCL and CL terms
                if modstr[f'ltCL_{v}'][0]: e.append({'data': {'source':f'{v}{i}', 'target':f'{vs[otherv]}{i+1}',
                                   'weight':'%.2f' % extr_est(f'^CL_{v}', i-1), 'label': f'CL{i}', 'firstname':'direct'}})
                if modstr[f'maCL_{v}'][0]: e.append({'data': {'source':f'imp_{v}{i}', 'target':f'{vs[otherv]}{i+1}',
                                   'weight':'%.2f' % extr_est(f'^maCL_{v}',i-1), 'label': f'maCL{i}', 'firstname':'direct'}})
                      
    return e

# ===================================================================================================================
# Also define the stile of the graph 
stylenet1=[ 
    # Nodes - shape & color
    {'selector':'.observed',
     'style':{'shape':'rectangle', 'height':25, 'width':60, 'border-width':2,'background-color':'white', 'border-color':'k', 
              'content':'data(firstname)','text-color':'grey','text-halign':'center','text-valign':'center'}},
    
    {'selector':'.latent', 
     'style':{'shape':'round', 'height':20, 'width':20, 'border-width':1,'background-color':'white', 'border-color':'silver'}},
    
    # Edges
    {'selector':'edge[firstname *= "direct"]', # directed paths 
     'style':{'curve-style':'straight', 'target-arrow-shape':'vee', 'width': 3, 'arrow-scale':1.2 }},
    
    {'selector':'edge[firstname *= "imp_link"]', # impulses links
     'style':{'curve-style':'straight', 'target-arrow-shape':'vee', 'width': 1, 'arrow-scale':.8 }},

    # Correlations
    {'selector':'edge[firstname *= "eta_corr"]',
     'style':{'curve-style':'unbundled-bezier','target-arrow-shape':'vee','source-arrow-shape':'vee', 'width': 1, 
              'control-point-distances': [-400,-500,-520,-500,-400],'control-point-weights': [0.01, 0.20, 0.5, 0.80, 0.99],
              'label':'data(label)','font-size':15, 'text-background-color':'silver', 'text-background-opacity':.7 }},
    
    {'selector':'edge[firstname *= "imp_corr"]', 
     'style':{'curve-style':'unbundled-bezier','target-arrow-shape':'vee','source-arrow-shape':'vee','width': 1, 
              'label':'data(label)','font-size':15, 'text-background-color':'silver', 'text-background-opacity':.7 }},
    
    # Lambdas
    {'selector':'edge[firstname *= "lambda"]', 
     'style':{'curve-style':'straight', 'target-arrow-shape':'vee', 'width': 1, 'arrow-scale':.8,
              'label':'data(label)','font-size':15, 'text-background-color':'silver', 'text-background-opacity':.7 }},

    # Dashed lines 
    {'selector':'edge[weight < 0.01][weight > -0.01]', 'style':{'line-style':'dashed'}},
]
# Set the color of each edge type and the distance between sorce and label displaying its weight (to avoid overlapping) 
d = {'AR':['red',40],'maAR':['orange',70],'CL':['green',220],'maCL':['lightblue',40]}

for c in d.keys():
    stylenet1.append({'selector':f'[label *= "{c}"]', 
                      'style': {'line-color':d[c][0], 'target-arrow-color':d[c][0],
                                'source-label':'data(weight)', 'source-text-offset': d[c][1],'font-size':20, 'font-weight':'bold',
                                'text-background-color':d[c][0], 'text-background-opacity':.5 }
                     })


In [None]:
def make_table1(depname, cmrname, model='best'):
    fitm = read_res1(depname, cmrname)[1]
    if model=='best':
        model = fitm.index[fitm.aic == fitm.aic.min()][0]# Best fitting model (lowest AIC)
    subs = pd.DataFrame(fitm.loc[model, ['npar', 'df', 'chisq', 'pvalue','cfi', 'tli','rmsea','srmr','aic', 'bic', ]])
    subs = subs.rename(columns={ model:' '}).round(3)
    subs.insert(loc=0, column='Fit measures', value=['Number of parameters','Degrees of freedom','\u03C7\u00b2','P-value','CFI','TLI','RMSEA','SRMR','AIC','BIC'])
    return(subs)

## Tab 2: cross-lagged panel network model
Rscript 2 fits and returns the .RData file for the longitudinal cross-lagged panel network model.

## Tab 3: cross-sectional network models
Rscript 3 is designed to produce a single .RData file for each single-timepoint, cross-sectional network model. This contains the following elements:
- **`wm`**: dataframe with all edge weights
- **`ci`**: 95% confidence intervals for those weights
- **`fit`**: fit measures number of observations the network is based on
- **`layout`**: the spring graphical disposition computed by `qgraph`


In [14]:
labels = {'DEP01':'Felt miserable or unhappy',
          'DEP02':'Didn\'t enjoy anything at all',
          'DEP03':'Felt so tired they just sat around and did nothing',
          'DEP04':'Was very restless',
          'DEP05':'Felt they were no good any more',
          'DEP06':'Cried a lot',
          'DEP07':'Found it hard to think properly or concentrate',
          'DEP08':'Hated themselves',
          'DEP09':'Felt they were a bad person',
          'DEP10':'Felt lonely',
          'DEP11':'Thought nobody really loved them',
          'DEP12':'Thought they would never be as good as other people',
          'DEP13':'Felt they did everything wrong',
          'DEP_score':'Total depression score',
          'height':'Height',
          'weight':'Weight',
          'BMI':'Body mass index',
          'waist_circ':'Waist circumference',
          'waist_hip_ratio':'Waist/hip ratio',
          'total_fatmass':'Total fat mass',
          'total_leanmass':'Total lean mass',
          'trunk_fatmass':'Trunk fat mass',
          'android_fatmass':'Android fat mass',
          'FMI':'Fat mass index',
          'LMI':'Lean mass index',
          'TFI':'Trunk fat mass index',
          'liver_fat':'Liver fat',
          'SBP':'Systolic blood pressure',
          'DBP':'Diastolic blood pressure',
          'PWV':'Pulse wave velocity',
          'IMT':'Intima-media thickness',
          'heart_rate':'Heart rate',
          'LVM':'Left ventricular mass',
          'RWT':'Relative wall thickness',
          'FS':'Fractional shortening',
          'tot_chol':'Total cholesterol',
          'HDL_chol':'HDL-cholesterol',
          'LDL_chol':'LDL-cholesterol',
          'insulin':'Insulin',
          'triglyc':'Triglycerides',
          'glucose':'Glucose',
          'CRP':'C-reactive protein',
          'IL_6':'Interleaukin-6'}

def get_label(var): 
    # Remove year from var name 
    var_name = '_'.join(var.split('_')[:-1])
    # Remove self or maternal report from depression item 
    if 'DEP' in var_name: var_name = var_name[1:]
    # get label
    lab = labels[var_name]
    return lab

In [None]:
def read_res3(time, path=PATH+'mod3/'):
    '''Input: timepoint of interest. 
       Open the .RData file created by Rscript 2.CLNPM (one for each timepoint). This contains the following elements:
       - wm: dataframe with all edge weights
       - ci: 95% confidence intervals for those weights
       - fit: fit measures + number of observations the network is based on.
       - layout: the spring graphical disposition computed by `qgraph`
       Use: wm, ci, fit, lay = read_res3('9.8y-9.8y') -OR- summ = read_res1('sDEP','BMI')[0]
    '''
    res = pyreadr.read_r(f'{path}crosnet_{time}y.RData')
    # weight matrix
    wm = res['wm']; wm['link'] = wm.index; wm[['a','b']] = wm.link.str.split(' ', expand = True)
    wm = wm.loc[wm.a!=wm.b, ] # remove links to between an edge and itself
    wm = wm.reset_index()[['a','b','V1']].rename(columns={'a':'node1','b':'node2','V1':'weight'})
    wm['dir'] = ['neg' if x<0 else 'pos' for x in wm.weight]
    # centrality indices
    ci = res['ci']; ci['class'] = ['dep' if t else 'cmr' for t in ci.node.str.contains('DEP')]
    # fit measures and number of observations 
    fit = res['fit'].T.round(3)
    # layout computed by qgraph (spring algorithm) 
    lay = res['layout'].rename(columns={0:'x_og',1:'y_og'})
    # rescale to pixels 
    lay['x'] = np.interp(lay['x_og'], (lay['x_og'].min(), lay['x_og'].max()), (10, 510))
    lay['y'] = np.interp(lay['y_og'], (lay['y_og'].min(), lay['y_og'].max()), (10, 510))
    
    return wm, ci, fit, lay


times = [file.split('_')[1][:-7] for file in os.listdir(PATH+'mod3') if file.endswith('.RData')]
times.sort() # just for readability, not necessary 

# create marks dict to use in slider, also provides a map for make_net3
timemarks3 = dict()
for t in times:
    where = round(np.mean([float(n) for n in t.split('-')]),1)
    timemarks3[where] = {'label': f'\n{t} years', 'style': {'transform':'rotate(45deg)', 'whitespace':'nowrap'} } # 'color': '#f50'


def make_net3(timepoint):
    '''Input: timepoint of interest. 
       Creates a network structure and the table with centrality indices.
    '''
    lab = timemarks3[timepoint]['label'][1:-6] # get label and remove '\n' in the beginning and ' years' at the end
    
    wm,ci,_,lay = read_res3(lab) # read in data 
    
    # tim estimates?
    wm_trim = wm.loc[abs(wm.weight)>0.01,].reset_index(drop=True)
    
    nodes = [{'data': {'id':node, 'label': '\n'.join(textwrap.wrap(get_label(node), width=20))}, 'classes':group, 
             'position':{'x':lay.loc[node,'x'], 'y':lay.loc[node,'y'] }} 
       for node,group in ci[['node','class']].itertuples(index=False) ]
    
    edges = [{'data': {'source':a, 'target':b, 'weight':w, 'width':round(abs(w)*20,2)}, 'classes':c} 
       for a,b,w,c in wm_trim.itertuples(index=False)]
    
    network = nodes+edges

    # Centrality indices tab
    ci_tab = ci.round(2).drop(columns='class')
    ci_tab.insert(1, 'Node', [get_label(name) for name in ci_tab['node']]) # Replace node name with its label 
    ci_tab = ci_tab.drop(columns='node')

    return network, ci_tab

stylenet3 = [ {'selector': 'node', 'style': {'label': 'data(label)', 'text-wrap':'wrap'} },
           # Edge opacty and width
           {'selector': 'edge', 'style': {'opacity': 'data(weight)', 'width': 'data(width)'}},
           # Color nodes by group
           {'selector': '.dep', 'style': {'background-color': 'lightblue'} },
           {'selector': '.cmr', 'style': {'background-color': 'pink'} },
           # Color edges by positive/negative weights
           {'selector': '.neg', 'style': {'line-color': 'red'} },
           {'selector': '.pos', 'style': {'line-color': 'blue'} } ]

#### Descriptives plot 

In [None]:
# Opening Descriptives JSON file
with open('../mats/descrip.json') as jf:
    desc = json.load(jf)

In [None]:
def make_plot3(var):
    # Read in descriptives from file 
    d = desc[var]
    sum, count, dens = (pd.DataFrame.from_dict(x) for x in d)

    # Initialize figure 
    fig = go.Figure()

    if bool(re.match(r'.DEP[0-9]+|alcohol|canabis|smoking',var)):
        
        likert = ['Not true','Sometimes','True']
        colors = ['lightblue','salmon','crimson']
        
        x_lab = '<br>'.join(textwrap.wrap(get_label(var), width=20)) # Wrap label text 

        # Stacked histogram 
        for i, c in enumerate(count['count'].iloc[:-1]):
            
            fig.add_trace(go.Bar(x=[0], y=[c], name=likert[i], marker_color=colors[i], opacity=0.6, 
                                 marker_line_width=1.5, marker_line_color='white',
                                 customdata =  [count['prop_noNA'].iloc[i]*100],
                         hovertemplate = """<br> n = <b>%{y}</b> (%{customdata:.1f}%) <br><extra></extra>"""))
            
            fig.update_layout(barmode='stack', autosize=False, width=400, height=500, margin=dict(l=20, r=20, t=20, b=20),
                              yaxis=dict(title_text='Count'),
                              xaxis=dict(ticktext=['',x_lab,''], tickvals=[-0.5, 0, 0.5], tickfont=dict(size=15)) ) 
        return fig

    else: 
        x_lab = get_label(var)
        
        # Distribution plot 
        fig.add_trace(go.Scatter(x=dens['x'], y=dens['dens'], fill='tozeroy', mode='lines', line_color='salmon'))
        # add mean and median 
        fig.add_vline(x=sum['mean'][0], line_width=1.5, line_color='crimson', name='meanvalue')
        fig.add_vline(x=sum['50%'][0], line_width=1, line_dash='dash', line_color='crimson', name='medianvalue')
        # Adjust x-axis 
        fig.update_xaxes(title_text=x_lab, range=[round(sum['min']),round(sum['max'])], 
                        ticks='outside', showline=True, linecolor='black', gridcolor='lightgrey')
        fig.update_layout(autosize=False, width=600, height=300, margin=dict(l=20, r=20, t=20, b=20),
                          yaxis=dict(showticklabels=False, showline=True, linecolor='black'),   
                         )
        
        return fig

## Set-up

#### App layout 
The application is structured into 4 main tabs.

In [None]:
def badge_it(text, color):
    return dbc.Badge(text, color=color, style={'padding':'4px 5px'})

def bold_it(text):
    return html.Span(f' {text} ', style={'font-weight':'bold'})

def undeline_it(text):
    return html.Span(f'{text}', style={'text-decoration':'underline'})

def wrap_it(n_spaces=1):
    return html.Span( [html.Br()] * n_spaces )

def param_checklist(depname, cmrname, p='lt'):
    pref = '' if p == 'lt' else 'ma'
    cols = ['crimson','green'] if p=='lt' else ['orange', 'lightblue']
    position = 'left' if p=='lt' else 'right'
    return  html.Div(style={'width':'50%','height':'65%','float':position},
                  children=[ dcc.Checklist(id =f'{p}-checklist',
                   options=[{'label': html.Span([badge_it(f'{pref}AR', cols[0]),' depression']),   'value': f'{p}AR_dep'},
                            {'label': html.Span([badge_it(f'{pref}AR', cols[0]),' cardio-metabolic risk']),'value': f'{p}AR_cmr'},
                            {'label': html.Span([badge_it(f'{pref}CL', cols[1]),' depression \u290F cardio-metab.']),'value': f'{p}CL_dep'},
                            {'label': html.Span([badge_it(f'{pref}CL', cols[1]),' cardio-metab. \u290F depression']),'value': f'{p}CL_cmr'}],
                    value = best_fit(depname, cmrname, list1=p),                          
                inputStyle={ 'margin-left':'20px','margin-right':'20px'}, labelStyle = {'display': 'block'}) ])


In [None]:
app = Dash(__name__, external_stylesheets=[dbc.themes.LITERA, dbc.icons.BOOTSTRAP], suppress_callback_exceptions=True)

app.layout = dbc.Container([
    # Title
    html.H1('Longitudinal modelling of the co-development of depression and cardio-metabolic risk from childhood to young adulthood',
             style={'textAlign':'center', 'font-weight':'bold'}),
    html.Br(), # space
    # Main body
    dbc.Row([ dbc.Col([ 
        dcc.Tabs(id="tabs", 
                 children=[dcc.Tab(label='Data overview', value='tab-0'),
                           dcc.Tab(label='Cross-lag panel model', value='tab-1'), # style={''}
                           dcc.Tab(label='Cross-lag network analysis', value='tab-2'),
                           dcc.Tab(label='Cross-sectional network analysis', value='tab-3') ], 
                 value='tab-0'),
        html.Div(id='tabs-content') ], # App content
        width={'size': 10, 'offset': 1}), # add left and right margin
    ]) ], fluid=True )

# -----------------------------------------------------------------------------------------------------------------------
@callback( Output('tabs-content', 'children'), Input('tabs', 'value') )

def render_content(tab):
    if tab == 'tab-0': # ================================================================================================
        return html.Div([ html.Br(),
            html.Div(['Overview of the data available from the', html.Span(' ALSPAC ', style={'font-weight':'bold'}), 'cohort.']),
            html.Hr(),
            html.Div( html.Img(src=get_asset_url('timeline.png'), style={'width':'100%'})),
            dbc.Row( dbc.Col( [html.Img(src=get_asset_url('samplesizes.png'))], width={'size': 10, 'offset': 1}))
        ])
        
    elif tab == 'tab-1': # ==============================================================================================
        return  html.Div([ html.Br(),
            html.Div(['Results of the generalized', bold_it('cross-lag panel model'), 
                      'described as model 1 in the paper.',wrap_it(2),'Using the selection pane below, you can decide which depression report (i.e., self or \
                       parental reports) and cardio-metabolic risk (CMR) marker you want to model. You can then inspect the variables included in the \
                       model by clicking on the inspect icon or on the graph nodes directly. Check the table on the right for info on the model fit.',
                       wrap_it(), undeline_it('Note'),': by default, the best fitting model is presented (i.e. lowest AIC), but the parameter conbination \
                       can be constumized using the tickboxes on the right (don\'t forget to hit the ', badge_it('Update model', 'silver'),
                      ' button to see the changes).']),
            html.Hr(),
            # Input 
            dbc.Row([dbc.Col([
                         html.H5(children='Depression score', style={'textAlign':'left'}),
                         dcc.RadioItems(id='dep-selection',
                                        options=[{'label': 'Self-reported',  'value':'sDEP'},
                                                 {'label': 'Maternal report','value':'mDEP'}], value='sDEP', 
                                       inputStyle={'margin-left':'20px','margin-right':'20px'}),
                         html.Br(),
                         html.H5(children='Cardio-metabolic marker', style={'textAlign':'left'}),
                         dcc.Dropdown(id='cmr-selection', 
                                      options=[{'label': 'Fat mass index (FMI)', 'value': 'FMI'},
                                               {'label': 'Body mass index (BMI)', 'value': 'BMI'},
                                               {'label': 'Total fat mass', 'value': 'total_fatmass'},
                                               {'label': 'Waist circumference', 'value': 'waist_circ'}],
                                      value='FMI') ], width={'size': 5, 'offset': 1}), 
                      dbc.Col([
                         html.H5(children='Model estimation', style={'textAlign':'left'}),
                         param_checklist('sDEP', 'FMI', p='lt'),
                         param_checklist('sDEP', 'FMI', p='ma'),
                         html.Div( dbc.Button('Update model', id='update-button', color='secondary', n_clicks=0,
                                               style={'font-weight':'bold', 'background-color':'silver','padding':'4px 10px'}), 
                                  style={'width':'25%','height':'35%','float':'right'}),
                         html.Div( id='failed-model', style={'color':'red', 'width':'60%','height':'25%','float':'left'}),
                     ], width={'size': 5, 'offset': 1}), 
                    ]),
            html.Hr(),
            # Time plot 
            html.Div([ # html.I(className="bi bi-info-circle-fill me-2", style={'color':'black'}),
                dbc.Accordion([ dbc.AccordionItem([ dcc.Graph(id='time-graph', figure = make_plot1('sDEP','FMI'))],
                                                       title='Inspect the variables included in this model')], start_collapsed=True)]),
            dbc.Row([
                # Network
                dbc.Col([cyto.Cytoscape(id='cyto-graph',
                                    layout={'name': 'preset', 'fit':False},
                                     style={'width': '100%', 'height': '1000px'}, minZoom=1, maxZoom=1, # reduce the range of user zooming 
                                  elements=make_net1('sDEP', 'FMI'), 
                                stylesheet=stylenet1)], width=9), 
                 # Table
                dbc.Col([ html.Br(), html.Div(id='fitm-table', 
                                        children=[dbc.Table.from_dataframe(df = make_table1('sDEP', 'FMI'), 
                                                                           color='light', striped=True, bordered=True, hover=True, size='lg')]) ],
                        width=3) ]), 
                          
            # Pop variable descriptives
                dbc.Offcanvas(id='pop1', children=[ dcc.Graph()],
                                title='Lab name', is_open=False, placement='end', style={'width': 700})

        ])
    
    elif tab == 'tab-2': # ==============================================================================================
        return html.Div([ 
            html.Br(),
            html.Div(['Results of the', bold_it('cross-lag network analyis'), 'performed using the variables listed below.']),
            html.Hr(),
            html.Div( [html.Img(src=get_asset_url('tempPlot.png'), style={'width':'33%'}),
                       html.Img(src=get_asset_url('contempPlot.png'), style={'width':'33%'}),
                       html.Img(src=get_asset_url('betweenPlot.png'), style={'width':'33%'})])
            # dbc.Row( dbc.Col( [html.Img(src=get_asset_url('samplesizes.png'))], width={'size': 10, 'offset': 1}))
        ])
     
    elif tab == 'tab-3': # ==============================================================================================
        return html.Div([
            html.Br(), 
            html.Span(['Results of each', bold_it('cross-sectional network model'), 'conducted as follow-up analyses. \
            Please select the timepoint of interest from the slides to visualize the network structure and correspoding centrality indices.']),
            html.Br(), html.Hr(),
            # Slider
            html.Span(children=[ html.Div('Select a timepoint:'), html.Br(style={'line-height':'5'}) ]),
            dcc.Slider(id='cros-net-slider', min=9.7, max=24.2, step=None, value=9.7, marks=timemarks3, included=False ),
            dbc.Row([
                # Network
                dbc.Col([cyto.Cytoscape(id='cros-net', layout={'name':'preset'},
                          style={'width':'50%', 'height':'100%', 'position':'absolute', 'left':150, 'top':370, 'z-index':999},
                          minZoom=1, maxZoom=1, # reduce the range of user zooming 
                          elements = make_net3(9.7)[0], 
                          stylesheet= stylenet3)], width=6), 
                # Table
                dbc.Col([ html.Br(),
                          # dbc.Popover(id='pop', children=dbc.PopoverBody('okok'), target='cros-net', trigger='click'),
                          html.Pre(id='display-labels'),
                          dash_table.DataTable(id='ci-table', columns=[ {'name': i, 'id': i} for i in make_net3(9.7)[1].columns ],
                                               sort_action='custom', sort_mode='single', sort_by=[], 
                                               fixed_columns={'headers': True, 'data': 1}, # Fix node name column 
                                               # style_as_list_view=True, # Remove vertical lines between columns 
                                               style_header={'fontWeight':'bold'},
                                               style_cell={'fontSize':20, 'font-family':'sans-serif'},
                                               style_cell_conditional=[{'if': {'column_id': 'Node'}, 'width': '250px'}],
                                               style_data={'whiteSpace':'normal', 'height': 'auto','lineHeight':'20px', 
                                                           # 'minWidth': '100%',},
                                                           'minWidth': '100px', 'width': '100px', 'maxWidth': '100px'}, 
                                               style_table={'overflowX': 'auto', 'minWidth': '100%'})
                           ], width={'size': 5, 'offset': 1}),
                 # Pop variable descriptives
                 dbc.Offcanvas(id='pop3', children=[ dcc.Graph()],
                                title='Lab name', is_open=False, placement='end', style={'width': 700}), # backdrop
            ]) 
            
        ])

# Control variable selection, updating the scatterplot and the ticks on the model selection pane
@callback(
    Output('time-graph', 'figure'),
    Output('lt-checklist', 'value'),
    Output('ma-checklist', 'value'),
    Input('dep-selection', 'value'),
    Input('cmr-selection', 'value')
)
def update_time_plot(dep_selection, cmr_selection):
    return make_plot1(dep_selection, cmr_selection), best_fit(dep_selection, cmr_selection, 'lt'), best_fit(dep_selection, cmr_selection, 'ma')

# Based on variable and model selection display graph (or notify if model didn't converge
@callback(
    Output('cyto-graph', 'elements'),
    Output('fitm-table', 'children'),
    Output('failed-model','children'),
    Input('dep-selection', 'value'),
    Input('cmr-selection', 'value'),
    Input('update-button', 'n_clicks'),
    State('lt-checklist', 'value'),
    State('ma-checklist', 'value') # prevent_initial_call=True
)

def update_graph(dep_selection, cmr_selection, n_clicks, lt_checklist, ma_checklist):
    
    if ctx.triggered_id == 'update-button':
        
        checked = lt_checklist + ma_checklist
        series = pd.Series([ 1 if x in checked else 0 for x in model_structure.index], index=model_structure.index)
        retrieve_model = model_structure.columns[ model_structure.eq(series, axis=0).all() ]
        
        if len(retrieve_model) == 0: 
            return no_update, no_update, 'Sorry, I did not estimate this model.' 
            
        model_name = retrieve_model[0]
        
        update_graph = make_net1(dep_selection, cmr_selection, which_model=model_name)
        
        if update_graph == 'fail': # prevents any single output updating
             return no_update, no_update, 'Sorry, this model did not converge.' 

        tab_df = make_table1(dep_selection, cmr_selection, model=model_name)
        update_tab = dbc.Table.from_dataframe(df = tab_df, color='light', striped=True, bordered=True, hover=True, size='lg')
        
        return update_graph, update_tab, None

    tab_df = make_table1(dep_selection, cmr_selection)
    update_tab = dbc.Table.from_dataframe(df = tab_df, color='light', striped=True, bordered=True, hover=True, size='lg')

    return make_net1(dep_selection, cmr_selection), update_tab, None

# Dispaly info upon tapping on node
@callback(
    # Output('display-labels', 'children'),
    Output('pop1', 'is_open'),
    Output('pop1', 'title'),
    Output('pop1', 'children'),
    Input('cyto-graph', 'tapNodeData'),
    [State('pop1', 'is_open')],
) 
def displayTapNodeData1(node, is_open):
    if node:
        titl = node['label']
        
        if titl=='Eta':
            return is_open, no_update, no_update
            
        plot = dcc.Graph(figure=make_plot3(titl))
        
        return not is_open, titl, plot
        
    return is_open, no_update, no_update

# TAB 3: display cross-sectional network based on selected timpoint ------------------------------------------------
@callback(
    Output('cros-net', 'elements'),
    Output('ci-table', 'data'),
    Input('cros-net-slider', 'value'),
    Input('ci-table', 'sort_by')
)
def update_crosnet(timepoint, sort_by):
    
    if ctx.triggered_id == 'ci-table':
        
        dff = make_net3(timepoint)[1].sort_values( sort_by[0]['column_id'],
                                                  ascending=sort_by[0]['direction'] == 'desc', inplace=False )
        
        return no_update, dff.to_dict('records')
    
    return make_net3(timepoint)[0], make_net3(timepoint)[1].to_dict('records')

# Dispaly info upon tapping on node
@callback(
    # Output('display-labels', 'children'),
    Output('pop3', 'is_open'),
    Output('pop3', 'title'),
    Output('pop3', 'children'),
    Input('cros-net', 'tapNodeData'),
    [State('pop3', 'is_open')],
) 
def displayTapNodeData3(node, is_open):
    if node:
        titl = node['label']
        plot = dcc.Graph(figure=make_plot3(node['id']))
        return not is_open, titl, plot
        
    return is_open, no_update, no_update

if __name__ == '__main__':
    app.run(debug=True, jupyter_mode="external")