# <ins>Dash application</ins>: **Longitudinal modelling of the co-development of depression and cardio-metabolic risk from childhood to young adulthood**

In [1]:
import pyreadr
import pandas as pd
import numpy as np

from dash import Dash, html, dcc, callback, Output, Input, dash_table
import dash_bootstrap_components as dbc
import dash_cytoscape as cyto

import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots

### Tab 1: generalized cross-lagged panel model 

Rscript 1 is designed to produce a single .RData file for each dep-cmr marker pair. This contains the following elements: 
- **`summ`**: a summary dataframe (with information about which marker is used and the timepoints included + mean ranges and number of observations)
- **`fit_meas`**: fit measures for every parameter combination, when the model converged.
- **`estimates`**: (unstandardized) estimates (+ bootsrapped SE, pvalues and CIs) for every parameter combination, when the model converged.
- **`failed`**: list of models that did not converge, with corresponding error or warning message.


In [10]:
def read_res1(depname, cmrname, path='/Users/Serena/Desktop/panel_network/results/'):
    res = pyreadr.read_r(f'{path}{depname}_{cmrname}.RData')
    summ = res['dat_summ']
    fitm = res['fit_meas'].T
    esti = res['estimates'].set_index('rep(f, nrow(es))')
    fail = res['failed']['V1'] #.rename(columns={'V1':'problem'})
    return(summ, fitm, esti, fail)

# summ, fitm, esti, fail = read_res1('sDEP','FMI') # use

Find the best fitting model (i.e., lowest AIC) and describe its structure using the `model_structure` matrix.

In [21]:
# Define the matrix of all possible paramter combinations. 
model_structure = pd.DataFrame(data= np.array([[1,0,0,1,0,1,0,1,1,1,1,1,1,1,1,1,1,1,1], # 1 = estimated, 0 = set to 0.
                                               [1,0,0,1,1,0,1,0,1,1,1,1,1,1,1,1,1,1,1],
                                               [1,0,1,0,0,1,1,1,0,1,1,1,1,1,1,1,1,1,1],
                                               [1,0,1,0,1,0,1,1,1,0,1,1,1,1,1,1,1,1,1],
                                               [1,1,1,1,1,1,1,1,1,1,0,0,1,0,1,0,1,1,1],
                                               [1,1,1,1,1,1,1,1,1,1,0,0,1,1,0,1,0,1,1],
                                               [1,1,1,1,1,1,1,1,1,1,0,1,0,0,1,1,1,0,1],
                                               [1,1,1,1,1,1,1,1,1,1,0,1,0,1,0,1,1,1,0]]), 
                               columns=['full_st','no_ma','no_maCL','no_maAR','no_ma_dep','no_ma_cmr',
                                        'no_maCL_dep','no_maCL_cmr','no_maAR_dep','no_maAR_cmr',
                                        'no_lt','no_ltCL','no_ltAR' ,'no_lt_dep','no_lt_cmr',
                                        'no_ltCL_dep','no_ltCL_cmr','no_ltAR_dep','no_ltAR_cmr'], 
                               index=['maCL_dep','maCL_cmr','maAR_dep','maAR_cmr',
                                      'ltCL_dep','ltCL_cmr','ltAR_dep','ltAR_cmr'])
# NOTE: this is the same as the mat matrix used to fit the models in Rscript 1.

def best_fit(depname, cmrname):
    fitm = read_res1(depname, cmrname)[1]
    # Best fitting model (lowest AIC)
    mod = fitm.index[fitm.aic == fitm.aic.min()][0]
    # Return a dataframe with its name and model structure
    return pd.DataFrame(model_structure[mod])
    

First, I plot the median and interquartile ranges of each measure included in the model, against time. This gives a more complete understandin of the data that is fed to the models.

In [4]:
def make_plot1(depname, cmrname):
    # load summary dataframe
    summ = read_res1(depname,cmrname)[0]
    
    # extract timepoints
    t_dep = [ float(x.split('_')[-1][:-1]) for x in summ.columns[:summ.shape[1]//2] ]
    t_cmr = [ float(x.split('_')[-1][:-1]) for x in summ.columns[summ.shape[1]//2:] ]

    # scatterplot function 
    def scat(t, name, fullname, shortname):
        means = summ.loc['Median', summ.columns.str.contains(name)]
        p = go.Scatter(x = t, y = means, 
                       error_y = dict(type='data', symmetric=False, # visible=True,
                                      array = summ.loc['3rd Qu.', summ.columns.str.contains(name)] - means,
                                      arrayminus = means - summ.loc['1st Qu.', summ.columns.str.contains(name)]),
                       name = fullname, text = [f'{shortname} {n}' for n in range(1,len(t)+1)],
                       marker = dict(size = 10, symbol = 'square',opacity = .8), opacity = .7,
                       hovertemplate = """ <b>%{text}</b> <br> Median: %{y:.2f} <br> Timepoint: %{x} years <br><extra></extra>""")
        return p

    fig = make_subplots(specs=[[{'secondary_y': True}]])
    
    fig.add_trace( scat(t_dep, depname, 'Depression score', 'DEP'), secondary_y=False)
    fig.add_trace( scat(t_cmr, cmrname, cmrname, cmrname), secondary_y=True )
    
    # Set y-axes
    fig.update_yaxes(title_text='<b>Depression score</b>', secondary_y=False, 
                     mirror=True, ticks='outside', showline=True, linecolor='black', gridcolor='lightgrey')
    fig.update_yaxes(title_text=f'<b>{cmrname}</b>', secondary_y=True,
                     mirror=True, ticks='outside', showline=True, linecolor='black', gridcolor='lightgrey')
    # Set x-axis 
    fig.update_xaxes(title_text='Years', mirror=True, ticks='outside', showline=True, linecolor='black', gridcolor='lightgrey')

    # Group timepoints in the background
    ymin = summ.min(axis=1)['1st Qu.']-1; ymax = summ.max(axis=1)['3rd Qu.']+2
    crosspoints = []
    for i in range(len(t_dep)):
        xmin = min(t_dep[i], t_cmr[i])-.2; xmax = max(t_dep[i], t_cmr[i])+.2
        # rectangles 
        crosspoints.append( dict(x0 = str(xmin), x1 = str(xmax), y0 = str(ymin), y1 = str(ymax),
                                 type='rect', xref='x', yref='y', fillcolor='lightgray', opacity=.3, line_width=0, layer='below') )  
        # text 
        fig.add_trace( go.Scatter(x=[xmin+.2], y=[ymax-1], mode='text', text=i+1, textposition='middle center',
                                  textfont_size=13, textfont_color='dimgray', showlegend=False) )
    # Background
    fig.update_layout(# title = dict(text='Included measures\n', font=dict(size=15), automargin=True, yref='paper'),
                      plot_bgcolor='white', shapes=crosspoints, margin=dict(l=20, r=20, t=20, b=20))
    
    return(fig)

Read in the results, select the best fitting model and construct its graph.
> `TODO`: make interactive parameter choice work and add fit measures to display 

In [5]:
def make_net1(depname, cmrname, width=1000, ar_terms=True, maar_terms=True, cl_terms=True, macl_terms=True):
    # read data
    summ, esti = read_res1(depname, cmrname)[0,2]
    
    # Best fitting model
    modstr = best_fit(depname, cmrname)
    
    # Extract the estimated paramters from the result files
    def extr_est(name, which, eta=False, moodel_output=esti):
        df = moodel_output.loc[modstr.columns[0]]
        if eta: l = list(round(df.loc[(df.lhs==f'eta_{name}')&(df.op=='=~'),'est'], 2))
        else:   l = list(round(df.loc[df.label.str.contains(name)].iloc[::-1]['est'], 2))
        return(l[which])

    # Ready to draw
    nt = summ.shape[1]//2 # Number of timepoints
    pos_top = 30; pos_bot = 550 # Vertical cohordinates (in pixel)
    vs = ['dep','cmr']
    
    e = [] # Initialize
   
    # Eta factors nodes
    for eta, pos in enumerate([pos_top, pos_bot]):
        e.append({'data': {'id':f'eta{eta}', 'label':'Eta'}, 'classes':'latent', 'position':{'x':width/2,'y':pos}})
    
    for i in range(1,nt+1): 
        for eta, v in enumerate(vs):
            
    # ===== Other nodes
            p = [pos_top+90, pos_top+190] if v=='dep' else [pos_bot-90,pos_bot-190] # define position
            e.extend([
                # Observed variables
                {'data': {'id':f'{v}{i}', 'label':f'{v.capitalize()} {i}'}, 'classes':'observed',
                 'position' : {'x':((width/nt)*i)-(width/nt)/2, 'y': p[0]} },
                # Impulses
                {'data': {'id':f'imp_{v}{i}', 'label':f'impulse {i}'}, 'classes':'latent',
                 'position' : {'x':((width/nt)*i)-(width/nt)/2, 'y': p[1]} }
            ])
        
    # ===== Edges: lambdas
            e.append({'data': {'source':f'eta{eta}', 'target':f'{v}{i}',
                               'label':'%.2f' % extr_est(f'{v}',i-1, eta=True) }})
            # impulses link
            e.append({'data': {'source':f'imp_{v}{i}', 'target':f'{v}{i}', 'label':'imp_link'}})
            
            if i < nt: 
                otherv = abs(eta-1)
                # maAR and AR terms
                if modstr[f'ltAR_{v}']: e.append({'data': {'source':f'{v}{i}', 'target':f'{v}{i+1}',
                                   'weight': extr_est(f'^AR_{v}', i-1), 'label': f'AR{i}'}})
                if modstr[f'maAR_{v}']: e.append({'data': {'source':f'imp_{v}{i}', 'target':f'{v}{i+1}',
                                   'weight': extr_est(f'^maAR_{v}',i-1), 'label': f'maAR{i}'}})
                # maCL and CL terms
                if modstr[f'ltCL_{v}']: e.append({'data': {'source':f'{v}{i}', 'target':f'{vs[otherv]}{i+1}',
                                   'weight': extr_est(f'^CL_{v}', i-1), 'label': f'CL{i}'}})
                if modstr[f'maCL_{v}']: e.append({'data': {'source':f'imp_{v}{i}', 'target':f'{vs[otherv]}{i+1}',
                                   'weight': extr_est(f'^maCL_{v}',i-1), 'label': f'maCL{i}'}})
                      
    return(e)

# ===================================================================================================================
# Also define the stile of the graph 
stylenet1=[ 
    # Nodes - shape & color
    {'selector':'.observed', 'style':{'shape':'rectangle', 'height':25, 'border-width':2,
                                      'background-color':'white', 'border-color':'k'}},
    {'selector':'.latent', 'style':{'shape':'round', 'height':20,'width':20, 'border-width':1,
                                    'background-color':'white', 'border-color':'silver'}},
    # Edges
    {'selector':'edge', 'style':{'target-arrow-shape':'vee', 'curve-style':'straight', 
                                 'width': 3, 'arrow-scale':1.2 }},
    
    {'selector':'[label *= "imp_link"]', 'style':{'width': 1, 'arrow-scale':.8 }},
    
    {'selector':'[source *= "eta"]', 'style':{'width': 1, 'arrow-scale':.8,
                                              'label':'data(label)','font-size':15,
                                              'text-background-color':'silver', 'text-background-opacity':.7 }},
]
# Set the color of each edge type and the distance between sorce and label displaying its weight (to avoid overlapping) 
d = {'AR':['red',60],'maAR':['orange',70],'CL':['green',230],'maCL':['lightblue',40]}

for c in d.keys():
    stylenet1.append({'selector':f'[label *= "{c}"]', 'style': {'line-color':d[c][0], 'target-arrow-color':d[c][0],
                                                    'text-background-color':d[c][0], 'text-background-opacity':.5,
                                                    'source-label':'data(weight)', 'source-text-offset': d[c][1],
                                                    'font-size':20, 'font-weight':'bold'}
                     }) # 'source-label':'data(weight)','source-text-offset': 50,'source-endpoint':['-10%','0%']


In [6]:
# def make_table(depname, cmrname):
#     summ,_,_,_ = read_res1(depname,cmrname)
#     times = pd.DataFrame([['Depression']+[x.split('_')[-1] for x in summ.columns[:summ.shape[1]//2]], 
#                ['Cardio-metabolic risk']+[x.split('_')[-1] for x in summ.columns[summ.shape[1]//2:]]])
#     # .rename(columns = {0:'Marker'}, index={0:'dep',1:'cmr'})
#     return(times)

### Tab 2: cross-lagged panel network model
Rscript 2 fits and returns the .RData file for the longitudinal cross-lagged panel network model.

### Tab 3: cross-sectional network models
Rscript 3 is designed to produce a single .RData file for each single-timepoint, cross-sectional network model. This contains the following elements:
- **`wm`**: dataframe with all edge weights
- **`ci`**: 95% confidence intervals for those weights
- **`n_obs`**: number of observations the network is based on


In [7]:
def read_res3(time, path='/Users/Serena/Desktop/panel_network/results/'):
    res = pyreadr.read_r(f'{path}crosnet_{time}.RData')
    # weight matrix
    wm = res['wm']; wm['link'] = wm.index; wm[['a','b']] = wm.link.str.split(' ', expand = True)
    wm = wm.loc[wm.a!=wm.b, ] # remove links to between an edge and itself
    wm = wm.reset_index()[['a','b','V1']].rename(columns={'a':'node1','b':'node2','V1':'weight'})
    wm['dir'] = ['neg' if x<0 else 'pos' for x in wm.weight]
    # centrality indices
    ci = res['ci']; ci['class'] = ['dep' if t else 'cmr' for t in ci.node.str.contains('DEP')]
    # number of observations 
    nobs = int(res['n_obs'].iloc[0])
    return(wm, ci, nobs)

wm,ci,n = read_res3('9.6y-9.8y')

wm_trim = wm.loc[abs(wm.weight)>0.01,].reset_index(drop=True)

nodes = [{'data': {'id':node, 'label':node}, 'classes':group } 
       for node,group in ci[['node','class']].itertuples(index=False) ]
edges = [{'data': {'source':a, 'target':b, 'weight':w, 'width':round(abs(w)*20,2)}, 'classes':c} 
       for a,b,w,c in wm_trim.itertuples(index=False)]
net = nodes+edges

  nobs = int(res['n_obs'].iloc[0])


## Set-up

#### App layout 
The application is structured into 3 main tabs. Add radio buttons to the app layout. 

#### Interactive graphs
I use the `plotly.express` library to build the interactive graphs. These are then assigned to the figure property of `dcc.Graph`, the compontent of the "Dash Core Components" module used to render interactive graphs.

#### Callback
Then, build the **callback** to create the interaction between the buttons and the chart. To work with the callback, import the callback module and the two arguments commonly used within the callback: Output and Input.

Both the RadioItems and the Graph components were given id names: used by the callback to identify the components.

The inputs and outputs of our app are the properties of a particular component. For example, input is the value property of the component that has the ID "controls-and-radio-item". Output is the figure property of the component with the ID "controls-and-graph", which is currently an empty dictionary (empty graph).

The callback function's argument col_chosen refers to the component property of the input. We build the chart inside the callback function. Every time the user selects a new radio item, the figure is rebuilt / updated. Return the graph at the end of the function. This assigns it to the figure property of the dcc.Graph, thus displaying the figure in the app.

In [9]:
app = Dash(__name__, external_stylesheets=[dbc.themes.LITERA])

app.layout = html.Div([
    # Title
    html.H1('Longitudinal modelling of the co-development of depression and cardio-metabolic risk from childhood to young adulthood',
             style={'textAlign':'center', 'font-weight':'bold'}),
    html.Br(), # space
    # Main body
    dbc.Row([
        dbc.Col(width=1), # add left margin
        dbc.Col([ 
            dcc.Tabs(id="tabs", value='tab-1', children=[
                dcc.Tab(label='Cross-lag panel model', value='tab-1'), # style={''}
                dcc.Tab(label='Cross-lag network analysis', value='tab-2'),
                dcc.Tab(label='Cross-sectional network analysis', value='tab-3') ]),
            html.Div(id='tabs-content') ]),
        dbc.Col(width=1), # add right margin
    ])
])

@callback(Output('tabs-content', 'children'), Input('tabs', 'value'))

def render_content(tab):
    if tab == 'tab-1':
        return  html.Div([ 
            html.Br(),
            html.Span(['This are the results of the generalized ', dbc.Badge('cross-lag panel model', color='royalblue'),
                       ' described as model 1 in the paper. You can select the input for the models below.']),
            html.Hr(),
            # Input 
            dbc.Row([dbc.Col(width=1), # add left margin
                     dbc.Col([
                         html.H5(children='Depression score', style={'textAlign':'left'}),
                         dcc.RadioItems(id='dep-selection',
                                        options=[{'label': 'Self-reported', 'value': 'sDEP'},
                                                 {'label': 'Maternal report', 'value': 'mDEP'}],
                                        value='sDEP',
                                        inputStyle={'margin-left':'20px','margin-right':'20px'}),
                         html.Br(),
                         html.H5(children='Cardio-metabolic marker', style={'textAlign':'left'}),
                         dcc.Dropdown(options=[{'label': 'Fat mass index (FMI)', 'value': 'FMI'},
                                               {'label': 'Body mass index (BMI)', 'value': 'BMI'},
                                               {'label': 'Total fat mass', 'value': 'total_fatmass'},
                                               {'label': 'Waist circumference', 'value': 'waist_circ'}],
                                      value='FMI', id='cmr-selection')
                     ]), 
                     dbc.Col(width=1),
                     dbc.Col([
                         html.H5(children='Model estimation', style={'textAlign':'left'}),
                         dbc.Accordion([
                             dbc.AccordionItem([
                                 dcc.RadioItems(id='ar-model',
                                        options=[{'label': 'AR only', 'value': 'ar'},
                                                 {'label': 'AR + maAR', 'value': 'ar+maar'},
                                                 {'label': 'maAR only', 'value': 'maar'}],
                                        value='ar', inline=True,
                                        inputStyle={'margin-left':'20px','margin-right':'20px'})
                             ], title='Autoregressive parameters'),
                             dbc.AccordionItem([
                                 dcc.RadioItems(id='ar-model',
                                        options=[{'label': 'CL only', 'value': 'cl'},
                                                 {'label': 'CL + maCL', 'value': 'cl+macl'},
                                                 {'label': 'maCL only', 'value': 'macl'}],
                                        value='cl', inline=True,
                                        inputStyle={'margin-left':'20px','margin-right':'20px'})
                             ], title='Cross-lag parameters'),
                             dbc.AccordionItem([
                                 html.P("TBD")
                             ], title='Other parameters')],  start_collapsed=True) ]), 
                     dbc.Col(width=1) # add right margin
                    ]),
            html.Hr(),
            # Table
            dcc.Graph(id='time-graph', figure = make_plot1('sDEP','FMI')),
            # dash_table.DataTable(id='time-table', data = make_table('sDEP', 'FMI').to_dict('records'), page_size=10),
            # Network
            cyto.Cytoscape(id='cyto-graph',
                layout={'name': 'preset', 'fit':False},
                style={'width': '100%', 'height': '1000px'},
                minZoom=0.8, maxZoom=1, # reduce the range of user zooming 
                elements = make_net1('sDEP', 'FMI'), 
                stylesheet=stylenet1)
        ])
    
    elif tab == 'tab-3':
        return html.Div([
            html.Br(), 
            html.Span(['This are the results of the cross-sectional ', dbc.Badge('network model', color='crimson'),
                       ' described as a follow-up analysis in the paper. You can select the timpoint of interest below.']),
            html.Br(), 
            html.Hr(),
            html.Span('Select a timepoint:\n '),
            dcc.Slider(9.7, 25, step=None, value=9.7,
                       marks={ 9.7: {'label': '\n9.6-9.8 years', 
                                         'style': {'transform':'rotate(45deg)','whitespace':'nowrap'}},
                              10.5: {'label': '\n10.5 years', 
                                     'style': {'transform':'rotate(45deg)','whitespace':'nowrap'}},
                              15.5: {'label': ' 15.5 years'},
                              23.8: {'label': ' 23.8 years', 'style': {'color': '#f50'}} }, included=False ),
            
            cyto.Cytoscape(id='cros-net', 
                           layout={'name': 'cose', 'fit':True, 'padding':1, 
                                   'tilingPaddingVertical': 100,'tilingPaddingHorizontal': 100,
                                   'animation':False,
                                   'nodeRepulsion': 1000000, # 'gravity':0,'gravityRange':100,
                                   'nodeDimensionsIncludeLabels':True,
                                   # 'idealEdgeLength':0.0001,
                                   'minNodeSpacing':50},
                style={'width': '80%','height': '100%','position': 'absolute',
                       'left': 0,'top': 390,'z-index': 999},
                           # style={'width': '80%', 'height': '700px'},
                minZoom=1.1, maxZoom=1.1, # reduce the range of user zooming 
                elements = net, 
                stylesheet=[
                    {'selector': 'node', 'style': {'label': 'data(label)'} },
                     # Edge opacty and width
                    {'selector': 'edge', 'style': {'opacity': 'data(weight)',
                                                   'width': 'data(width)'}},
                    # Color nodes by group
                    {'selector': '.dep', 'style': {'background-color': 'lightblue'} },
                    {'selector': '.cmr', 'style': {'background-color': 'pink'} },
                    # Color edges by positive/negative weights
                    {'selector': '.neg', 'style': {'line-color': 'red'} },
                    {'selector': '.pos', 'style': {'line-color': 'blue'} },
                   
                ])
        ])


# Add controls to build the interaction
@callback(
    Output('time-graph', 'figure'),
    Input('dep-selection', 'value'),
    Input('cmr-selection', 'value')
)
def update_plot(dep_selection, cmr_selection):
    return make_plot1(dep_selection, cmr_selection)

@callback(
    Output('cyto-graph', 'elements'),
    Input('dep-selection', 'value'),
    Input('cmr-selection', 'value')
)
def update_graph(dep_selection, cmr_selection):
    return make_net1(dep_selection, cmr_selection)

# @callback(
#     Output('time-table', 'data'),
#     Input('dep-selection', 'value'),
#     Input('cmr-selection', 'value')
# )
# def update_table(dep_selection, cmr_selection):
#     return make_table(dep_selection, cmr_selection).to_dict('records')


if __name__ == '__main__':
    app.run(debug=True, jupyter_mode="external")

Dash app running on http://127.0.0.1:8050/
