# Figure generation

## Imports

In [1]:
from figure_utils.fishbait_figures import PlotKFoldSingleBarUsingWAvgPreds, PlotKFoldComboBarUsingWAvgPreds, PlotKFoldResidualHistUsingWAvgPreds
from figure_utils.fishbait_figures import PlotBaseModelLossfunResults, PlotPCA_CLSProjection
from figure_utils.fishbait_figures import PlotQSARresidualScatter, PlotQSARcompBarUsingWAvgPredsInterersect, PlotQSARcompScatter, PlotQSARCoverageComboBar
from figure_utils.preprocess_data import Preprocess10x10Fold, GroupDataForPerformance
from figure_utils.preprocess_qsar import LoadQSAR, MatchQSAR, PrepareQSARData

In [2]:
import numpy as np
import pandas as pd

## Make figures

In [3]:
SPECIES_GROUP = 'fish'
ENDPOINT = 'EC10'
NAME = f'EC10_{SPECIES_GROUP}'
EFFECT_FOR_QSAR_COMP = ['all']
INSIDE_AD = True
DURATIONS = ['short exposure','medium exposure', 'long exposure'] # <= 96h, <= 1week, >1week

### Non endpoint specific

In [4]:
def PlotNonEndpointSpecific(combomodel, species_group, save_figures: bool=False):

    if save_figures:
        if species_group == 'fish':
            PlotKFoldComboBarUsingWAvgPreds(f"./figures_for_publication/combo_model_{combomodel}_performance_10x10_CV", combomodel, species_group)
        PlotKFoldSingleBarUsingWAvgPreds(f"./figures_for_publication/{species_group}_single_model_performance_10x10_CV", ec50_name=f'EC50_{species_group}', ec10_name=f'EC10_{species_group}')
        for inside_AD in [True, False]:
            if inside_AD:
                AD='AD'
            else:
                AD='allpreds'

            PlotQSARCoverageComboBar(f"./figures_for_publication/{species_group}_QSAR_coverage_comparison_{AD}", inside_AD=inside_AD, species_group=species_group)
        if species_group == 'fish':        
            PlotBaseModelLossfunResults(f"./figures_for_publication/base_model_and_lossfun")
    else:
        if species_group == 'fish':
            PlotKFoldComboBarUsingWAvgPreds(None, combomodel, species_group)
        PlotKFoldSingleBarUsingWAvgPreds(None, ec50_name=f'EC50_{species_group}', ec10_name=f'EC10_{species_group}')
        for inside_AD in [True, False]:
            if inside_AD:
                AD='AD'
            else:
                AD='allpreds'

            PlotQSARCoverageComboBar(None, inside_AD=inside_AD, species_group=species_group)
        if species_group == 'fish':        
            PlotBaseModelLossfunResults(None)
        
    

In [None]:
PlotNonEndpointSpecific(f'EC50EC10_{SPECIES_GROUP}_withoverlap', species_group = SPECIES_GROUP, save_figures=False)

### Endpoint-specific

In [12]:
def PlotEndpointSpecific(endpoint, effect_for_qsar_comp, name, inside_AD, durations, save_figures: bool=False):
    if inside_AD:
        AD='AD'
    else:
        AD='allpreds'
    if effect_for_qsar_comp == ['all']:
        effect_for_qsar_comp = ['MOR','DVP','ITX','MPH','REP','POP','GRO']

    # ----------- fishbAIT ------------------------
    avg_predictions = Preprocess10x10Fold(name=name)
    avg_predictions = avg_predictions[avg_predictions.effect.isin(effect_for_qsar_comp)]
    weighted_avg_predictions = GroupDataForPerformance(avg_predictions=avg_predictions)
    _, mean_L1, median_L1, RMSE, se = (weighted_avg_predictions.residuals, 
                                    weighted_avg_predictions.L1error.mean(), 
                                    weighted_avg_predictions.L1error.median(),
                                    np.sqrt((weighted_avg_predictions.residuals**2).mean()),
                                    np.std(weighted_avg_predictions.L1error)/np.sqrt(len(weighted_avg_predictions))
                                    )
    print(f'''
    (weighted avg) Avg prediction error: {mean_L1}
    (weighted avg) Median prediction error: {median_L1}
    (weighted avg) RMSE: {RMSE}
    (weighted avg) standard error: {se}
    ''')

    if save_figures:
        PlotKFoldResidualHistUsingWAvgPreds(f"./figures_for_publication/{endpoint}/{name}_weightedavgpreds_10x10fold_hist", name=name, endpoint=endpoint)
    else:
        PlotKFoldResidualHistUsingWAvgPreds(savepath=None, name=name, endpoint=endpoint)

    # ----------- QSAR ------------------------
    ECOSAR, VEGA, TEST = LoadQSAR(endpoint=endpoint, species_group=species_group)
    ECOSAR, TEST, VEGA = PrepareQSARData(ECOSAR, TEST, VEGA, inside_AD=inside_AD, remove_experimental=True)
    TEST_dict = dict(zip(TEST.Canonical_SMILES_figures, TEST.value))
    VEGA_dict = dict(zip(VEGA.Canonical_SMILES_figures, VEGA.value))
    ECOSAR_dict = dict(zip(ECOSAR.Canonical_SMILES_figures, ECOSAR.value))

    qsar_preds = MatchQSAR(avg_predictions, ECOSAR_dict, TEST_dict, VEGA_dict, endpoint=endpoint, duration=durations, species_group=species_group)
    qsar_preds[['ECOSAR_residuals','VEGA_residuals','TEST_residuals']] = qsar_preds[['ECOSAR','VEGA','TEST']] - qsar_preds[['labels']].to_numpy()
    weighted_avg_qsar_preds = GroupDataForPerformance(qsar_preds)
    if (endpoint == 'EC10') | (species_group=='algae'):
        weighted_avg_qsar_preds[['TEST', 'TEST_residuals']] = None
    
    if save_figures:
        PlotQSARcompBarUsingWAvgPredsInterersect(f"./figures_for_publication/{endpoint}/{name}_weightedavgpreds_10x10CV_QSAR_comp_{AD}_SMILES_{''.join(effect_for_qsar_comp)}_effect", predictions=weighted_avg_qsar_preds, endpoint=endpoint, species_group=species_group)
        PlotQSARcompScatter(f"./figures_for_publication/{endpoint}/{name}_weightedavgpreds_10x10CV_QSAR_scatter_{AD}_SMILES_{''.join(effect_for_qsar_comp)}_effect", predictions=weighted_avg_qsar_preds, endpoint=endpoint)
        PlotQSARresidualScatter(f"./figures_for_publication/{endpoint}/{name}_10x10CV_QSAR_residualscatter_{AD}_SMILES_{''.join(effect_for_qsar_comp)}_effect", predictions=qsar_preds, endpoint=endpoint)
        PlotQSARresidualScatter(f"./figures_for_publication/{endpoint}/{name}_weightedavgpreds_10x10CV_QSAR_residualscatter_{AD}_SMILES_{''.join(effect_for_qsar_comp)}_effect", predictions=weighted_avg_qsar_preds, endpoint=endpoint)
    else:
        PlotQSARcompBarUsingWAvgPredsInterersect(None, predictions=weighted_avg_qsar_preds, endpoint=endpoint, species_group=species_group)
        PlotQSARcompScatter(None, predictions=weighted_avg_qsar_preds, endpoint=endpoint)
        PlotQSARresidualScatter(None, predictions=qsar_preds, endpoint=endpoint)
        PlotQSARresidualScatter(None, predictions=weighted_avg_qsar_preds, endpoint=endpoint)

In [None]:
PlotEndpointSpecific(endpoint=ENDPOINT, effect_for_qsar_comp=EFFECT_FOR_QSAR_COMP, name=NAME, inside_AD=INSIDE_AD, durations=DURATIONS, save_figures=False)

In [None]:
PlotPCA_CLSProjection(savepath=None, endpoint=ENDPOINT, flipaxis=False)

## Manual figure generation

In [3]:
species_group = 'fish'
name = f'EC10_{species_group}'
endpoint = 'EC10'

In [None]:
PlotKFoldResidualHistUsingWAvgPreds(None, name, endpoint)

In [None]:
PlotKFoldSingleBarUsingWAvgPreds(None, ec50_name=f'EC50_{species_group}', ec10_name=f'EC10_{species_group}')

### QSAR

In [4]:
avg_predictions = Preprocess10x10Fold(name=name)
avg_predictions = avg_predictions[avg_predictions.effect.isin(['MOR','DVP','ITX','MPH','REP','POP','GRO'])]
weighted_avg_predictions = GroupDataForPerformance(avg_predictions=avg_predictions)

In [5]:
INSIDE_AD = True
DURATIONS = ['short exposure','medium exposure', 'long exposure'] # <= 96h, <= 1week, >1week

In [6]:
# ----------- QSAR ------------------------
ECOSAR, VEGA, TEST = LoadQSAR(endpoint=endpoint, species_group=species_group)

In [None]:
ECOSAR, TEST, VEGA = PrepareQSARData(ECOSAR, TEST, VEGA, inside_AD=INSIDE_AD, remove_experimental=True)

In [8]:
TEST_dict = dict(zip(TEST.Canonical_SMILES_figures, TEST.value))
VEGA_dict = dict(zip(VEGA.Canonical_SMILES_figures, VEGA.value))
ECOSAR_dict = dict(zip(ECOSAR.Canonical_SMILES_figures, ECOSAR.value))

qsar_preds = MatchQSAR(avg_predictions, ECOSAR_dict, TEST_dict, VEGA_dict, endpoint=endpoint, duration=DURATIONS, species_group=species_group)
qsar_preds[['ECOSAR_residuals','VEGA_residuals','TEST_residuals']] = qsar_preds[['ECOSAR','VEGA','TEST']] - qsar_preds[['labels']].to_numpy()
weighted_avg_qsar_preds = GroupDataForPerformance(qsar_preds)
if (endpoint == 'EC10') | (species_group=='algae'):
    weighted_avg_qsar_preds[['TEST', 'TEST_residuals']] = None

In [None]:
PlotQSARcompBarUsingWAvgPredsInterersect(None, predictions=weighted_avg_qsar_preds, endpoint=endpoint, species_group=species_group)
PlotQSARcompScatter(None, predictions=weighted_avg_qsar_preds, endpoint=endpoint)
PlotQSARresidualScatter(None, predictions=qsar_preds, endpoint=endpoint)
PlotQSARresidualScatter(None, predictions=weighted_avg_qsar_preds, endpoint=endpoint)