In [15]:
"""
Compute ROC data for EWS and ML predictions for Bury et al trained model
Export data for plotting
Edited from published codes by Bury et al. (2021), Deep learning for early warning signals of tipping points, PNAS. 
"""

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import sklearn.metrics as metrics
import scipy.stats as stats
import os

# Run for early or late predictions
bool_pred_early = False

#Change your test model
#test_models are: 'SIRwhiteN', 'SIRenvN', 'SIRdemN', 'SEIR'
test_model = 'SIRwhiteN'

#--------
# Import EWS and ML data
#–------------

# Import EWS data
df_ews_forced = pd.read_csv('../data/ews/df_ews_forced_{}.csv'.format(test_model))
df_ews_null = pd.read_csv('../data/ews/df_ews_null_{}.csv'.format(test_model))

# Use EWS data in x
df_ews_forced = df_ews_forced[df_ews_forced['Variable']=='I']
df_ews_null = df_ews_null[df_ews_null['Variable']=='I']

# Import kendall tau data
df_ktau_forced = pd.read_csv('../data/ews/df_ktau_forced_{}.csv'.format(test_model))
df_ktau_null = pd.read_csv('../data/ews/df_ktau_null_{}.csv'.format(test_model))

# Import ML prediction data
df_ml_forced = pd.read_csv('../data/ml_pred_Bury/df_ml_forced_Bury_{}.csv'.format(test_model))
df_ml_null = pd.read_csv('../data/ml_pred_Bury/df_ml_null_Bury_{}.csv'.format(test_model))

# Add column for truth values (1 for forced, 0 for null)
df_ktau_forced['truth value'] = 1
df_ktau_null['truth value'] = 0

df_ml_forced['truth value'] = 1
df_ml_null['truth value'] = 0

if not os.path.exists('../data/roc'):
    os.makedirs('../data/roc')


#---------------------------
# Get predictions from trajectories
#--------------------------

# Time interval relative to transition point for where to make predictions
# as proportion of dataset
if bool_pred_early:
    pred_interval_rel = np.array([0.6,0.8])
else:
    # Late interval for predictions
    pred_interval_rel = np.array([0.8,1])


# Initialise lists
list_df_ktau_preds = []
list_df_ml_preds = []

# Get predictions from forced trajectories
tsid_vals = df_ml_forced['tsid'].unique()
for tsid in tsid_vals:

    # Get EWS data to find start and transition time
    df = df_ews_forced[(df_ews_forced['tsid']==tsid)]
    t_start = df['Time'].iloc[0]
    t_transition = df[['Time','residuals']].dropna()['Time'].iloc[-1] # where the residuals end

    # Get prediction interval in time
    t_pred_start = t_start + (t_transition-t_start)*pred_interval_rel[0]
    t_pred_end = t_start + (t_transition-t_start)*pred_interval_rel[1]


    # Get ktau data specific to this variable and tsid.
    # Get data within prediction interval
    df_ktau_forced_final = df_ktau_forced[
        (df_ktau_forced['tsid']==tsid)&\
        (df_ktau_forced['Time'] >= t_pred_start)&\
        (df_ktau_forced['Time'] <= t_pred_end)
        ]
    df_ml_forced_final = df_ml_forced[
        (df_ml_forced['tsid']==tsid)&\
        (df_ml_forced['Time'] >= t_pred_start)&\
        (df_ml_forced['Time'] <= t_pred_end)
        ]

    # Extract 10 evenly spaced predictions across the prediciton time interval
    # We do this so some transitions don't input more data to the ROC
    # than others.
    n_predictions = 10

    # Ktau forced trajectories
    idx = np.round(np.linspace(0, len(df_ktau_forced_final) - 1, n_predictions)).astype(int)
    list_df_ktau_preds.append(df_ktau_forced_final.iloc[idx])

    # ML forced trajectories
    idx = np.round(np.linspace(0, len(df_ml_forced_final) - 1, n_predictions)).astype(int)
    list_df_ml_preds.append(df_ml_forced_final.iloc[idx])



# Get predictions from null trajectories
tsid_vals = df_ml_null['tsid'].unique()
for tsid in tsid_vals:

    # Get EWS data to find start and transition time (for forced data)
    df = df_ews_null[(df_ews_null['tsid']==tsid)]                                      
    t_start = df['Time'].iloc[0]
    t_transition = df[['Time','residuals']].dropna()['Time'].iloc[-1] # where the residuals end

    # Get prediction interval in time
    t_pred_start = t_start + (t_transition-t_start)*pred_interval_rel[0]
    t_pred_end = t_start + (t_transition-t_start)*pred_interval_rel[1]


    # Get ktau data specific to this variable and tsid.
    # Get data within prediction interval
    df_ktau_null_final = df_ktau_null[
        (df_ktau_null['tsid']==tsid)&\
        (df_ktau_null['Time'] >= t_pred_start)&\
        (df_ktau_null['Time'] <= t_pred_end)
        ]
    df_ml_null_final = df_ml_null[
        (df_ml_null['tsid']==tsid)&\
        (df_ml_null['Time'] >= t_pred_start)&\
        (df_ml_null['Time'] <= t_pred_end)
        ]

    # Extract 10 evenly spaced predictions across the prediciton time interval
    # We do this so some transitions don't input more data to the ROC
    # than others.
    n_predictions = 10

    # Ktau forced trajectories
    idx = np.round(np.linspace(0, len(df_ktau_null_final) - 1, n_predictions)).astype(int)
    list_df_ktau_preds.append(df_ktau_null_final.iloc[idx])

    # ML forced trajectories
    idx = np.round(np.linspace(0, len(df_ml_null_final) - 1, n_predictions)).astype(int)
    list_df_ml_preds.append(df_ml_null_final.iloc[idx])



# Concatenate data
df_ktau_preds = pd.concat(list_df_ktau_preds)
df_ml_preds = pd.concat(list_df_ml_preds)


#-------------------
# Get data on ML favoured bifurcation for each forced trajectory
#-------------------

# For each prediction, select the bifurcation that the ML gives greatest weight to
df_ml_preds['fav_bif'] = df_ml_preds[['fold_prob','hopf_prob','branch_prob','null_prob']].idxmax(axis=1)

# Count each bifurcation choice for forced trajectories
counts_forced = df_ml_preds[df_ml_preds['truth value']==1]['fav_bif'].value_counts()

fold_count_forced = counts_forced['fold_prob'] if 'fold_prob' in counts_forced.index else 0
hopf_count_forced = counts_forced['hopf_prob'] if 'hopf_prob' in counts_forced.index else 0
branch_count_forced = counts_forced['branch_prob'] if 'branch_prob' in counts_forced.index else 0
null_count_forced = counts_forced['null_prob'] if 'null_prob' in counts_forced.index else 0

df_counts_forced = pd.DataFrame({
    'fold': [fold_count_forced],
    'hopf': [hopf_count_forced],
    'branch': [branch_count_forced],
    'null': [null_count_forced],
    }
)

# Export data on bifurcation prediction counts
if not os.path.exists('../data/roc'):
    os.makedirs('../data/roc')

filepath_forced = '../data/roc/df_bif_pred_counts_forced_Bauch_{}_{}.csv'.format(
    'early' if bool_pred_early else 'late', test_model)

df_counts_forced.to_csv(filepath_forced,
                 index=False)


# Count each bifurcation choice for null trajectories
counts_null = df_ml_preds[df_ml_preds['truth value']==0]['fav_bif'].value_counts()

fold_count_null = counts_null['fold_prob'] if 'fold_prob' in counts_null.index else 0
hopf_count_null = counts_null['hopf_prob'] if 'hopf_prob' in counts_null.index else 0
branch_count_null = counts_null['branch_prob'] if 'branch_prob' in counts_null.index else 0
null_count_null = counts_null['null_prob'] if 'null_prob' in counts_null.index else 0

df_counts_null = pd.DataFrame({
    'fold': [fold_count_null],
    'hopf': [hopf_count_null],
    'branch': [branch_count_null],
    'null': [null_count_null],
    }
)

# Export data on bifurcation prediction counts
if not os.path.exists('../data/roc'):
    os.makedirs('../data/roc')

filepath_null = '../data/roc/df_bif_pred_counts_null_Bauch_{}_{}.csv'.format(
    'early' if bool_pred_early else 'late', test_model)

df_counts_null.to_csv(filepath_null,
                 index=False)

##Export ml prediction file
df_ml_preds.to_csv('../data/roc/df_ml_pred_Bauch_{}.csv'.format(test_model),
                   index=False)

print('Exported bifurcation count data to data/roc')

#--------------------
# Functions to compute ROC
#–--------------------

# Function to compute ROC data from truth and indicator vals
# and return a df.
def roc_compute(truth_vals, indicator_vals):

    # Compute ROC curve and threhsolds using sklearn
    fpr, tpr, thresholds = metrics.roc_curve(truth_vals,indicator_vals)

    # Compute AUC (area under curve)
    auc = metrics.auc(fpr, tpr)

    # Put into a DF
    dic_roc = {'fpr':fpr, 'tpr':tpr, 'thresholds':thresholds, 'auc':auc}
    df_roc = pd.DataFrame(dic_roc)

    return df_roc

#---------------------
## Compute ROC data
#–--------------------

# # Initialise list for ROC dataframes for predicting bifurcation
list_roc = []


# Assign indicator and truth values for ML prediction
indicator_vals = df_ml_preds['bif_prob']                    
truth_vals = df_ml_preds['truth value']
df_roc = roc_compute(truth_vals,indicator_vals)
df_roc['ews'] = 'ML bif'
list_roc.append(df_roc)


# Assign indicator and truth values for variance
indicator_vals = df_ktau_preds['ktau_variance']
truth_vals = df_ktau_preds['truth value']
df_roc = roc_compute(truth_vals,indicator_vals)
df_roc['ews'] = 'Variance'
list_roc.append(df_roc)


# Assign indicator and truth values for lag-1 AC
indicator_vals = df_ktau_preds['ktau_ac']
truth_vals = df_ktau_preds['truth value']
df_roc = roc_compute(truth_vals,indicator_vals)
df_roc['ews'] = 'Lag-1 AC'
list_roc.append(df_roc)


# Concatenate roc dataframes
df_roc_full = pd.concat(list_roc, ignore_index=True)


# Export ROC data
filepath = '../data/roc/df_roc_cr_trans_Bauch_{}_{}.csv'.format(
    'early' if bool_pred_early else 'late', test_model)

df_roc_full.to_csv(filepath,
                   index=False,)

print('Exported ROC data to {}'.format(filepath))

#-------------
# Plotly fig
#----------------

AUC_list = []

import plotly.graph_objects as go

fig = go.Figure()
df_roc = df_roc_full

# ML bif plot
df_trace = df_roc[df_roc['ews']=='ML bif']
fig.add_trace(
    go.Scatter(x=df_trace['fpr'],
                y=df_trace['tpr'],
                mode='lines',
                name='ML bif (AUC={})'.format(df_trace.round(2)['auc'].iloc[0])
    )
)

AUC_list.append(df_trace.round(2)['auc'].iloc[0])

# Variance plot
df_trace = df_roc[df_roc['ews']=='Variance']
fig.add_trace(
    go.Scatter(x=df_trace['fpr'],
                y=df_trace['tpr'],
                name='Variance (AUC={})'.format(df_trace.round(2)['auc'].iloc[0])
                )
    )

AUC_list.append(df_trace.round(2)['auc'].iloc[0])

# Lag-1  AC plot
df_trace = df_roc[df_roc['ews']=='Lag-1 AC']
fig.add_trace(
    go.Scatter(x=df_trace['fpr'],
                y=df_trace['tpr'],
                name='Lag-1 AC (AUC={})'.format(df_trace.round(2)['auc'].iloc[0])
                )
    )

AUC_list.append(df_trace.round(2)['auc'].iloc[0])

# Line y=x
fig.add_trace(
    go.Scatter(x=np.linspace(0,1,100),
                y=np.linspace(0,1,100),
                showlegend=False,
                line={'color':'black',
                      'dash':'dash'
                      }
                )
    )


fig.update_xaxes(
    title = 'False positive rate',
    range=[-0.01,1],
    )


fig.update_yaxes(
    title = 'True positive rate',
    )


fig.update_layout(
    legend=dict(
        x=0.5,
        y=0,
        ),
    width=600,
    height=600,
    title='',
    )

AUC = pd.DataFrame(AUC_list, columns=[test_model])

AUC.to_csv('../data/roc/AUC_Bury_{}.csv'.format(test_model),
                   index=False)

# Sample frequency data
data_forced = {'Category': ['F', 'H', 'T', 'N'],
        'Frequency': [df_counts_forced['fold'][0], df_counts_forced['hopf'][0], df_counts_forced['branch'][0], df_counts_forced['null'][0]]}

df_forced = pd.DataFrame(data_forced)

data_null = {'Category': ['F', 'H', 'T', 'N'],
        'Frequency': [df_counts_null['fold'][0], df_counts_null['hopf'][0], df_counts_null['branch'][0], df_counts_null['null'][0]]}

df_null = pd.DataFrame(data_null)


# Create a subplot for the frequency data
fig.add_trace(go.Bar(x=df_forced['Category'], y=df_forced['Frequency'], 
                     name='Frequency among Trans. Sims.',
                     xaxis='x2', yaxis='y2'))

fig.add_trace(go.Bar(x=df_null['Category'], y=df_null['Frequency'], 
                     name='Frequency among Null Sims.',
                     xaxis='x2', yaxis='y2'))

# Update the layout to display the frequency subplot
fig.update_layout(
    xaxis2=dict(domain=[0.8, 1.0], anchor='y2'),
    yaxis2=dict(domain=[0.3, 0.5], anchor='x2')
)

fig.show()


print(df_counts_forced)
print(df_counts_null)
print(AUC)



Exported bifurcation count data to data/roc
Exported ROC data to ../data/roc/df_roc_cr_trans_Bauch_late_SIRwhiteN.csv


   fold  hopf  branch  null
0    51     1      45     3
   fold  hopf  branch  null
0    41     0       8    51
   SIRwhiteN
0       0.93
1       0.83
2       0.74
