In [15]:
"""
Compute ROC data for EWS and ML predictions for SIR trained model
Export data for plotting
Edited from published codes by Bury et al. (2021), Deep learning for early warning signals of tipping points, PNAS. 
"""

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import sklearn.metrics as metrics
import scipy.stats as stats
import os

# Run for early or late predictions
bool_pred_early = False

#Change your test model
#test_models are: 'SIRwhiteN', 'SIRenvN', 'SIRdemN', 'SEIR'
test_model = 'SEIR'
classifier_length = 500                             #classifier_length are 500, 100
#--------
# Import EWS and ML data
#–------------

# Import EWS data
df_ews_forced = pd.read_csv('../data/ews/df_ews_forced_{}.csv'.format(test_model))
df_ews_null = pd.read_csv('../data/ews/df_ews_null_{}.csv'.format(test_model))

# Use EWS data in x
df_ews_forced = df_ews_forced[df_ews_forced['Variable']=='I']
df_ews_null = df_ews_null[df_ews_null['Variable']=='I']

# Import kendall tau data
df_ktau_forced = pd.read_csv('../data/ews/df_ktau_forced_{}.csv'.format(test_model))
df_ktau_null = pd.read_csv('../data/ews/df_ktau_null_{}.csv'.format(test_model))

# Import ML prediction data
df_ml_forced = pd.read_csv('../data/ml_pred_SIR_CL{}/df_ml_forced_{}_len{}.csv'.format(classifier_length, test_model, classifier_length))
df_ml_null = pd.read_csv('../data/ml_pred_SIR_CL{}/df_ml_null_{}_len{}.csv'.format(classifier_length, test_model, classifier_length))

# Add column for truth values (1 for forced, 0 for null)
df_ktau_forced['truth value'] = 1
df_ktau_null['truth value'] = 0

df_ml_forced['truth value'] = 1
df_ml_null['truth value'] = 0

if not os.path.exists('../data/roc'):
    os.makedirs('../data/roc')

#---------------------------
# Get predictions from trajectories
#--------------------------

# Time interval relative to transition point for where to make predictions
# as proportion of dataset
if bool_pred_early:
    pred_interval_rel = np.array([0.6,0.8])
else:
    # Late interval for predictions
    pred_interval_rel = np.array([0.8, 1])


# Initialise lists
list_df_ktau_preds = []
list_df_ml_preds = []

# Get predictions from forced trajectories
tsid_vals = df_ml_forced['tsid'].unique()
for tsid in tsid_vals:

    # Get EWS data to find start and transition time
    df = df_ews_forced[(df_ews_forced['tsid']==tsid)]
    t_start = df['Time'].iloc[0]
    t_transition = df[['Time','residuals']].dropna()['Time'].iloc[-1]             

    # Get prediction interval in time
    t_pred_start = t_start + (t_transition-t_start)*pred_interval_rel[0]
    t_pred_end = t_start + (t_transition-t_start)*pred_interval_rel[1]

    n_predictions = 5
    
    df_ml_forced_final = df_ml_forced[(df_ml_forced['tsid']==tsid)].tail(n_predictions)
    
    df_ktau_forced_final = df_ktau_forced[
        (df_ktau_forced['tsid']==tsid)&\
        (df_ktau_forced['Time'] >= t_pred_start)&\
        (df_ktau_forced['Time'] <= t_pred_end)
        ].tail(n_predictions)
    
    list_df_ktau_preds.append(df_ktau_forced_final)
    list_df_ml_preds.append(df_ml_forced_final)



# Get predictions from null trajectories
tsid_vals = df_ml_null['tsid'].unique()
for tsid in tsid_vals:

    # Get EWS data to find start and transition time (for forced data)
    df = df_ews_null[(df_ews_null['tsid']==tsid)]
    t_start = df['Time'].iloc[0]
    t_transition = df[['Time','residuals']].dropna()['Time'].iloc[-1]

    # Get prediction interval in time
    t_pred_start = t_start + (t_transition-t_start)*pred_interval_rel[0]
    t_pred_end = t_start + (t_transition-t_start)*pred_interval_rel[1]

    n_predictions = 5
    
    df_ml_null_final = df_ml_null[(df_ml_null['tsid']==tsid)].tail(n_predictions)
    
    df_ktau_null_final = df_ktau_null[
        (df_ktau_null['tsid']==tsid)&\
        (df_ktau_null['Time'] >= t_pred_start)&\
        (df_ktau_null['Time'] <= t_pred_end)
        ].tail(n_predictions)
    
    list_df_ktau_preds.append(df_ktau_null_final)
    list_df_ml_preds.append(df_ml_null_final)


# Concatenate data
df_ktau_preds = pd.concat(list_df_ktau_preds)
df_ml_preds = pd.concat(list_df_ml_preds)


#-------------------
# Get data on ML favoured bifurcation for each forced trajectory
#-------------------

# For each prediction, select the bifurcation that the ML gives greatest weight to
df_ml_preds['fav_bif'] = df_ml_preds[['null_prob','branch_prob']].idxmax(axis=1)

# Count each bifurcation choice for forced trajectories
counts_forced = df_ml_preds[df_ml_preds['truth value']==1]['fav_bif'].value_counts()

null_count_forced = counts_forced['null_prob'] if 'null_prob' in counts_forced.index else 0
branch_count_forced = counts_forced['branch_prob'] if 'branch_prob' in counts_forced.index else 0

df_counts_forced = pd.DataFrame({
    'null': [null_count_forced],
    'branch': [branch_count_forced]
    }
)

# Export data on bifurcation prediction counts
filepath_forced_counts = '../data/roc/df_bif_pred_counts_forced_{}_SIR_trained_{}_CL{}.csv'.format(
    'early' if bool_pred_early else 'late', test_model, classifier_length)

df_counts_forced.to_csv(filepath_forced_counts,
                 index=False)

################### Count each bifurcation choice for null trajectories
counts_null = df_ml_preds[df_ml_preds['truth value']==0]['fav_bif'].value_counts()

null_count_null = counts_null['null_prob'] if 'null_prob' in counts_null.index else 0
branch_count_null = counts_null['branch_prob'] if 'branch_prob' in counts_null.index else 0

df_counts_null = pd.DataFrame({
    'null': [null_count_null],
    'branch': [branch_count_null]
    }
)

# Export data on bifurcation prediction counts on null
filepath_null_counts = '../data/roc/df_bif_pred_counts_null_{}_SIR_trained_{}_CL{}.csv'.format(
    'early' if bool_pred_early else 'late', test_model, classifier_length)

df_counts_null.to_csv(filepath_null_counts,
                 index=False)

##Export ml prediction file
df_ml_preds.to_csv('../data/roc/df_ml_pred_SIR_trained_{}_CL{}.csv'.format(test_model, classifier_length),
                   index=False)

print('Exported bifurcation count data to data/roc')


#--------------------
# Functions to compute ROC
#–--------------------

# Function to compute ROC data from truth and indicator vals
# and return a df.
def roc_compute(truth_vals, indicator_vals):

    # Compute ROC curve and threhsolds using sklearn
    fpr, tpr, thresholds = metrics.roc_curve(truth_vals,indicator_vals)

    # Compute AUC (area under curve)
    auc = metrics.auc(fpr, tpr)

    # Put into a DF
    dic_roc = {'fpr':fpr, 'tpr':tpr, 'thresholds':thresholds, 'auc':auc}
    df_roc = pd.DataFrame(dic_roc)

    return df_roc


#---------------------
## Compute ROC data
#–--------------------

# # Initialise list for ROC dataframes for predicting bifurcation
list_roc = []


# Assign indicator and truth values for ML prediction
indicator_vals = df_ml_preds['bif_prob']
truth_vals = df_ml_preds['truth value']
df_roc = roc_compute(truth_vals,indicator_vals)
df_roc['ews'] = 'ML bif'
list_roc.append(df_roc)


# Assign indicator and truth values for variance
indicator_vals = df_ktau_preds['ktau_variance']
truth_vals = df_ktau_preds['truth value']
df_roc = roc_compute(truth_vals,indicator_vals)
df_roc['ews'] = 'Variance'
list_roc.append(df_roc)


# Assign indicator and truth values for lag-1 AC
indicator_vals = df_ktau_preds['ktau_ac']
truth_vals = df_ktau_preds['truth value']
df_roc = roc_compute(truth_vals,indicator_vals)
df_roc['ews'] = 'Lag-1 AC'
list_roc.append(df_roc)


# Concatenate roc dataframes
df_roc_full = pd.concat(list_roc, ignore_index=True)


# Export ROC data
filepath = '../data/roc/df_roc_cr_trans_{}_SIR_trained_{}_CL{}.csv'.format(
    'early' if bool_pred_early else 'late', test_model, classifier_length)

df_roc_full.to_csv(filepath,
                   index=False,)

print('Exported ROC data to {}'.format(filepath))

#-------------
# Plotly fig
#----------------

AUC_list = []

import plotly.graph_objects as go

fig = go.Figure()
df_roc = df_roc_full

# ML bif plot
df_trace = df_roc[df_roc['ews']=='ML bif']
fig.add_trace(
    go.Scatter(x=df_trace['fpr'],
                y=df_trace['tpr'],
                mode='lines',
                name='ML bif (AUC={})'.format(df_trace.round(2)['auc'].iloc[0])
    )
)

AUC_list.append(df_trace.round(2)['auc'].iloc[0])

# Variance plot
df_trace = df_roc[df_roc['ews']=='Variance']
fig.add_trace(
    go.Scatter(x=df_trace['fpr'],
                y=df_trace['tpr'],
                name='Variance (AUC={})'.format(df_trace.round(2)['auc'].iloc[0])
                )
    )

AUC_list.append(df_trace.round(2)['auc'].iloc[0])

# Lag-1  AC plot
df_trace = df_roc[df_roc['ews']=='Lag-1 AC']
fig.add_trace(
    go.Scatter(x=df_trace['fpr'],
                y=df_trace['tpr'],
                name='Lag-1 AC (AUC={})'.format(df_trace.round(2)['auc'].iloc[0])
                )
    )


AUC_list.append(df_trace.round(2)['auc'].iloc[0])

# Line y=x
fig.add_trace(
    go.Scatter(x=np.linspace(0,1,100),
                y=np.linspace(0,1,100),
                showlegend=False,
                line={'color':'black',
                      'dash':'dash'
                      }
                )
    )



fig.update_xaxes(
    title = 'False positive rate',
    range=[-0.01,1],
    )



fig.update_yaxes(
    title = 'True positive rate',
    )



fig.update_layout(
    legend=dict(
        x=0.5,
        y=0.0,
        ),
    width=600,
    height=600,
    title='',
    )

AUC = pd.DataFrame(AUC_list, columns=[test_model])

AUC.to_csv('../data/roc/AUC_SIR_{}_CL{}.csv'.format(test_model, classifier_length),
                   index=False)

# Sample frequency data
data_forced = {'Category': ['T', 'N'],
        'Frequency': [df_counts_forced['branch'][0], df_counts_forced['null'][0]]}

df_forced = pd.DataFrame(data_forced)

data_null = {'Category': ['T', 'N'],
        'Frequency': [df_counts_null['branch'][0], df_counts_null['null'][0]]}

df_null = pd.DataFrame(data_null)


# Create a subplot for the frequency data
fig.add_trace(go.Bar(x=df_forced['Category'], y=df_forced['Frequency'], 
                     name='Frequency among Trans. Sims.',
                     xaxis='x2', yaxis='y2'))

fig.add_trace(go.Bar(x=df_null['Category'], y=df_null['Frequency'], 
                     name='Frequency among Null Sims.',
                     xaxis='x2', yaxis='y2'))

# Update the layout to display the frequency subplot
fig.update_layout(
    xaxis2=dict(domain=[0.85, 0.95], anchor='y2'),
    yaxis2=dict(domain=[0.3, 0.5], anchor='x2')
)

fig.show()

print(df_counts_forced)
print(df_counts_null)
print(AUC)
print(df_ml_preds)
print(df_ktau_preds)

Exported bifurcation count data to data/roc
Exported ROC data to ../data/roc/df_roc_cr_trans_late_SIR_trained_SEIR_CL500.csv


   null  branch
0    14      36
   null  branch
0    50       0
   SEIR
0  1.00
1  0.88
2  0.94
     null_prob  branch_prob  bif_prob  tsid    Time  truth value      fav_bif
45    0.797941     0.202059  0.202059     1   663.0            1    null_prob
46    0.863858     0.136142  0.136142     1   673.0            1    null_prob
47    0.872486     0.127514  0.127514     1   683.0            1    null_prob
48    0.645075     0.354925  0.354925     1   693.0            1    null_prob
49    0.327371     0.672629  0.672629     1   703.0            1  branch_prob
..         ...          ...       ...   ...     ...          ...          ...
495   0.995240     0.004760  0.004760    10  1450.0            0    null_prob
496   0.995424     0.004576  0.004576    10  1460.0            0    null_prob
497   0.998184     0.001816  0.001816    10  1470.0            0    null_prob
498   0.998899     0.001101  0.001101    10  1480.0            0    null_prob
499   0.999159     0.000841  0.000841    10  1

In [16]:
# print(df_ml_preds[0:50])
# print(df_ktau_preds[50:100])