# Find training data

In [None]:
from inference_utils.pytorch_data_utils import check_training_data
import pandas as pd
from tqdm.notebook import tqdm

In [None]:
species_groups = ['fish', 'algae', 'invertebrates']
models = ['EC50EC10', 'EC50', 'EC10']
effectordering = {
            'EC50_algae': {'POP':'POP'},
            'EC10_algae': {'POP':'POP'},
            'EC50EC10_algae': {'POP':'POP'}, 
            'EC50_invertebrates': {'MOR':'MOR','ITX':'ITX'},
            'EC10_invertebrates': {'MOR':'MOR','DVP':'DVP','ITX':'ITX', 'REP': 'REP', 'MPH': 'MPH', 'POP': 'POP'} ,
            'EC50EC10_invertebrates': {'MOR':'MOR','DVP':'DVP','ITX':'ITX', 'REP': 'REP', 'MPH': 'MPH', 'POP': 'POP'} ,
            'EC50_fish': {'MOR':'MOR'},
            'EC10_fish': {'MOR':'MOR','DVP':'DVP','ITX':'ITX', 'REP': 'REP', 'MPH': 'MPH', 'POP': 'POP','GRO': 'GRO'} ,
            'EC50EC10_fish': {'MOR':'MOR','DVP':'DVP','ITX':'ITX', 'REP': 'REP', 'MPH': 'MPH', 'POP': 'POP','GRO': 'GRO'} 
            }

endpointordering = {
            'EC50_algae': {'EC50':'EC50'},
            'EC10_algae': {'EC10':'EC10'},
            'EC50EC10_algae': {'EC50':'EC50', 'EC10': 'EC10'}, 
            'EC50_invertebrates': {'EC50':'EC50'},
            'EC10_invertebrates': {'EC10':'EC10'},
            'EC50EC10_invertebrates': {'EC50':'EC50', 'EC10': 'EC10'},
            'EC50_fish': {'EC50':'EC50'},
            'EC10_fish': {'EC10':'EC10'},
            'EC50EC10_fish': {'EC50':'EC50', 'EC10': 'EC10'} 
            }

default_durations = {
    'algae': 72,
    'fish': 96,
    'invertebrates': 48
}

## Get CLS embeddings

In [None]:
raw_data = pd.read_excel('/cephyr/users/skall/Alvis/Ecotoxformer/Inference_2/ecoCAIT/data/development/Preprocessed_complete_data.xlsx', sheet_name='dataset')

In [None]:
SMILES_COLUMN_NAME = 'SMILES_Canonical_RDKit'
for SPECIES in tqdm(species_groups):
    for model in models:
        cls_dict = {}
        MODEL_VERSION = f'{model}_{SPECIES}'
        EXPOSURE_DURATION = default_durations[SPECIES]
        PREDICTION_ENDPOINT = list(endpointordering[MODEL_VERSION].keys())[0]
        PREDICTION_EFFECT = list(effectordering[MODEL_VERSION].keys())[0]
        ecocait = ecoCAIT_for_inference(model_version=MODEL_VERSION, path_to_model_weights='/cephyr/users/skall/Alvis/Ecotoxformer/Inference_2/ecoCAIT/ecoCAIT/')
        ecocait.load_fine_tuned_model()
        data = raw_data.copy()
        
        data = data.drop_duplicates(subset=['SMILES_Canonical_RDKit'])

        results = ecocait.predict_toxicity(SMILES = data[SMILES_COLUMN_NAME].tolist(), exposure_duration=EXPOSURE_DURATION, endpoint=PREDICTION_ENDPOINT, effect=PREDICTION_EFFECT, return_cls_embeddings=True)
        results.reset_index(drop=True, inplace=True)
        results['CLS_embeddings'] = results['CLS_embeddings'].apply(lambda x: np.asarray(x, dtype=np.float32))
        results = results[['SMILES_Canonical_RDKit','CLS_embeddings']]        
        results.to_pickle(f'./data/predictions/{MODEL_VERSION}_CLS_embeddings.pkl.zip', compression='zip')



## Get predictions

In [None]:
SMILES_COLUMN_NAME = 'SMILES_Canonical_RDKit'
ct=0
for SPECIES in tqdm(species_groups):
    
    for model in models:
        MODEL_VERSION = f'{model}_{SPECIES}'
        ecocait = ecoCAIT_for_inference(model_version=MODEL_VERSION, path_to_model_weights='/cephyr/users/skall/Alvis/Ecotoxformer/Inference_2/ecoCAIT/ecoCAIT/')
        ecocait.load_fine_tuned_model()
        data = raw_data.copy()
        
        for PREDICTION_ENDPOINT in endpointordering[MODEL_VERSION]:
            for PREDICTION_EFFECT in effectordering[MODEL_VERSION]:
                print(MODEL_VERSION, PREDICTION_ENDPOINT, PREDICTION_EFFECT)
                try:
                    EXPOSURE_DURATION = data[(data.species_group==SPECIES) & (data.endpoint==PREDICTION_ENDPOINT) & (data.effect==PREDICTION_EFFECT)].Duration_Value.value_counts().index[0]
                except:
                    EXPOSURE_DURATION = default_durations[SPECIES]
                data = data.drop_duplicates(subset=['SMILES_Canonical_RDKit'])

                results = ecocait.predict_toxicity(SMILES = data[SMILES_COLUMN_NAME].tolist(), exposure_duration=EXPOSURE_DURATION, endpoint=PREDICTION_ENDPOINT, effect=PREDICTION_EFFECT, return_cls_embeddings=True)
                results.reset_index(drop=True, inplace=True)
                results['exposure_duration'] = EXPOSURE_DURATION
                
                if ct == 0: # first time model is used:
                    for column in results.columns:
                        results.rename(columns={column: f'{MODEL_VERSION}_{PREDICTION_ENDPOINT}_{PREDICTION_EFFECT} {column}'}, inplace=True)
                    combined_results = results[['SMILES_Canonical_RDKit', 'predictions log10(mg/L)', 'exposure_duration']]
                    ct += 1
                else:
                    results = results[['predictions log10(mg/L)', 'exposure_duration']]
                    ct+=1
                
                for column in results.columns:
                    results.rename(columns={column: f'{MODEL_VERSION}_{PREDICTION_ENDPOINT}_{PREDICTION_EFFECT} {column}'}, inplace=True)
                
                if ct > 1:
                    combined_results = pd.concat([combined_results, results], axis=1)

In [None]:
combined_results

In [None]:
for col in combined_results.columns:
    if 'exposure_duration' in col:
        combined_results[col] = combined_results[col].astype(np.float32)

In [None]:
combined_results.info(memory_usage='deep')

In [None]:
combined_results.to_pickle(f'./data/predictions/combined_predictions.pkl.zip', compression='zip')

## Get training data matches

In [None]:
training_data = pd.read_pickle('./data/Preprocessed_complete_data_fixed_smiles_format.zip', compression='zip')
all_preds = pd.read_pickle(f'./data/predictions/combined_predictions.pkl.zip', compression='zip')

In [None]:
all_preds

Unnamed: 0,SMILES_Canonical_RDKit,EC50EC10_fish_EC50_MOR predictions log10(mg/L),exposure_duration,EC50EC10_fish_EC50_DVP predictions log10(mg/L),EC50EC10_fish_EC50_DVP exposure_duration,EC50EC10_fish_EC50_ITX predictions log10(mg/L),EC50EC10_fish_EC50_ITX exposure_duration,EC50EC10_fish_EC50_REP predictions log10(mg/L),EC50EC10_fish_EC50_REP exposure_duration,EC50EC10_fish_EC50_MPH predictions log10(mg/L),...,EC10_invertebrates_EC10_DVP predictions log10(mg/L),EC10_invertebrates_EC10_DVP exposure_duration,EC10_invertebrates_EC10_ITX predictions log10(mg/L),EC10_invertebrates_EC10_ITX exposure_duration,EC10_invertebrates_EC10_REP predictions log10(mg/L),EC10_invertebrates_EC10_REP exposure_duration,EC10_invertebrates_EC10_MPH predictions log10(mg/L),EC10_invertebrates_EC10_MPH exposure_duration,EC10_invertebrates_EC10_POP predictions log10(mg/L),EC10_invertebrates_EC10_POP exposure_duration
0,O=[N+]([O-])c1ccc(Cl)cc1,1.233817,96.0,1.133595,114.0,1.162595,48.0,1.049389,168.0,1.005820,...,-0.173498,72.0,0.752605,48.0,-0.766412,504.0,0.002427,96.0,-0.128783,48.0
1,Nc1ccc([N+](=O)[O-])cc1,1.760723,96.0,1.640172,114.0,1.667832,48.0,1.540111,168.0,1.474763,...,-1.629592,72.0,1.174659,48.0,-2.579684,504.0,-1.311189,96.0,-1.653334,48.0
2,O=[N+]([O-])c1ccc(O)cc1,1.335025,96.0,1.201835,114.0,1.258662,48.0,1.023190,168.0,0.925750,...,0.662401,72.0,1.140378,48.0,-0.017691,504.0,0.710306,96.0,0.649748,48.0
3,CN(C)c1ccc(C=O)cc1,1.595901,96.0,1.455373,114.0,1.504574,48.0,1.227222,168.0,1.086821,...,-0.658417,72.0,-0.056982,48.0,-0.846574,504.0,-0.586412,96.0,-0.592586,48.0
4,O=[N+]([O-])c1ccc([N+](=O)[O-])cc1,-0.220858,96.0,-0.282735,114.0,-0.216847,48.0,-0.396102,168.0,-0.474600,...,0.646538,72.0,1.130999,48.0,-0.111782,504.0,0.697391,96.0,0.628116,48.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
6503,CCC(C(=O)O)c1ccc(N2C(=O)c3ccccc3C2=O)cc1,1.664032,96.0,1.423662,114.0,1.541090,48.0,0.904201,168.0,0.544423,...,0.225479,72.0,0.788441,48.0,-0.696985,504.0,0.361918,96.0,0.260004,48.0
6504,NC(=O)NC1NC(=O)NC1=O,1.987361,96.0,1.837992,114.0,1.884547,48.0,1.680902,168.0,1.569726,...,1.636588,72.0,1.881917,48.0,1.370613,504.0,1.651322,96.0,1.591306,48.0
6505,S=C(SSSSSSC(=S)N1CCCCC1)N1CCCCC1,0.789767,96.0,0.424589,114.0,0.657936,48.0,-0.345595,168.0,-0.635702,...,-1.106904,72.0,0.207116,48.0,-1.382495,504.0,-0.979049,96.0,-1.016408,48.0
6506,CC1CCC(C(C)C)CC1,0.660052,96.0,0.551487,114.0,0.596186,48.0,0.347313,168.0,0.212513,...,0.295813,72.0,0.835263,48.0,-0.567684,504.0,0.384176,96.0,0.322851,48.0


In [None]:
for model in models:
    for species in tqdm(species_groups):
        MODELTYPE = f'{model}_{species}'
        for endpoint in endpointordering[MODELTYPE]:
            for effect in effectordering[MODELTYPE]:
                all_preds = check_training_data(all_preds, model, species, endpoint, effect)

                for col in all_preds.columns:
                    if ((col == 'species match') | (col == 'endpoint match') | (col == 'effect match')):
                        all_preds.rename(columns={col: f'{model}_{species}_{endpoint}_{effect} {col}'}, inplace=True)


  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

In [None]:
all_preds

Unnamed: 0,SMILES_Canonical_RDKit,EC50EC10_fish_EC50_MOR predictions log10(mg/L),exposure_duration,EC50EC10_fish_EC50_DVP predictions log10(mg/L),EC50EC10_fish_EC50_DVP exposure_duration,EC50EC10_fish_EC50_ITX predictions log10(mg/L),EC50EC10_fish_EC50_ITX exposure_duration,EC50EC10_fish_EC50_REP predictions log10(mg/L),EC50EC10_fish_EC50_REP exposure_duration,EC50EC10_fish_EC50_MPH predictions log10(mg/L),...,EC10_invertebrates_EC10_ITX endpoint match,EC10_invertebrates_EC10_ITX effect match,EC10_invertebrates_EC10_REP endpoint match,EC10_invertebrates_EC10_REP effect match,EC10_invertebrates_EC10_MPH endpoint match,EC10_invertebrates_EC10_MPH effect match,EC10_invertebrates_EC10_POP endpoint match,EC10_invertebrates_EC10_POP effect match,EC10_algae_EC10_POP endpoint match,EC10_algae_EC10_POP effect match
0,O=[N+]([O-])c1ccc(Cl)cc1,1.233817,96.0,1.133595,114.0,1.162595,48.0,1.049389,168.0,1.005820,...,1,1,1,1,1,0,1,0,0,0
1,Nc1ccc([N+](=O)[O-])cc1,1.760723,96.0,1.640172,114.0,1.667832,48.0,1.540111,168.0,1.474763,...,1,1,1,0,1,0,1,1,1,1
2,O=[N+]([O-])c1ccc(O)cc1,1.335025,96.0,1.201835,114.0,1.258662,48.0,1.023190,168.0,0.925750,...,1,1,1,1,1,0,1,0,1,1
3,CN(C)c1ccc(C=O)cc1,1.595901,96.0,1.455373,114.0,1.504574,48.0,1.227222,168.0,1.086821,...,0,0,0,0,0,0,0,0,0,0
4,O=[N+]([O-])c1ccc([N+](=O)[O-])cc1,-0.220858,96.0,-0.282735,114.0,-0.216847,48.0,-0.396102,168.0,-0.474600,...,0,0,0,0,0,0,0,0,0,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
6503,CCC(C(=O)O)c1ccc(N2C(=O)c3ccccc3C2=O)cc1,1.664032,96.0,1.423662,114.0,1.541090,48.0,0.904201,168.0,0.544423,...,0,0,0,0,0,0,0,0,1,1
6504,NC(=O)NC1NC(=O)NC1=O,1.987361,96.0,1.837992,114.0,1.884547,48.0,1.680902,168.0,1.569726,...,0,0,0,0,0,0,0,0,1,1
6505,S=C(SSSSSSC(=S)N1CCCCC1)N1CCCCC1,0.789767,96.0,0.424589,114.0,0.657936,48.0,-0.345595,168.0,-0.635702,...,0,0,0,0,0,0,0,0,1,1
6506,CC1CCC(C(C)C)CC1,0.660052,96.0,0.551487,114.0,0.596186,48.0,0.347313,168.0,0.212513,...,0,0,0,0,0,0,0,0,1,1


In [None]:
all_preds.to_pickle(f'./data/predictions/combined_predictions.pkl.zip', compression='zip')

## Add errors from 10x10 CV

In [None]:
from development.figures.figure_utils.preprocess_data import Preprocess10x10Fold, GroupDataForPerformance
from development.development_utils.preprocessing.Get_data_for_model import PreprocessData
from tqdm.notebook import tqdm

In [None]:
import pandas as pd
training_data = pd.read_pickle('./data/tutorials/combined_predictions.pkl.zip', compression='zip')
training_data['SMILES'] = training_data['SMILES_Canonical_RDKit'].copy()
training_data = PreprocessData(training_data).GetCanonicalSMILES()

In [None]:
def match(x, error_dict):
    try:
        return error_dict[x]
    except:
        return None

In [None]:
for species_group in ['fish', 'invertebrates','algae']:
    for model in tqdm(['EC50','EC10','EC50EC10']):
        if model != 'EC50EC10':
            cvpreds = Preprocess10x10Fold(name=f'{model}_{species_group}', uselogdata=True, full_filepath=f'./data/results/{model}_{species_group}_predictions_100x_CV_RDkit.zip')
        else:
            cvpreds = Preprocess10x10Fold(name=f'{model}_{species_group}', uselogdata=True, full_filepath=f'./data/results/{model}_{species_group}_withoverlap_predictions_100x_CV_RDkit.zip')
        
        wavgcv = GroupDataForPerformance(cvpreds)
        wavgcv['SMILES'] = wavgcv['Canonical_SMILES_figures'].copy()
        wavgcv = PreprocessData(wavgcv).GetCanonicalSMILES()
        error_dict = dict(zip(wavgcv.SMILES_Canonical_RDKit.tolist(), wavgcv.L1error.tolist()))

        training_data[f'{model}_{species_group} L1Error'] = training_data.SMILES_Canonical_RDKit.apply(lambda x: match(x, error_dict))

In [None]:
training_data.to_pickle('./data/tutorials/combined_predictions_and_errors.pkl.zip', compression='zip')