In [49]:
import pandas as pd
import numpy as np

ONLY_ORGANIC = True

predictions = pd.read_csv('only_organic_predictions.csv' if ONLY_ORGANIC else 'predictions.csv')
predictions_embedding = pd.read_csv('only_organic_predictions_embedding.csv' if ONLY_ORGANIC else 'predictions_embedding.csv')
predictions.head()

Unnamed: 0.1,Unnamed: 0,species,chemical,prediction
0,0,https://cfpub.epa.gov/ecotox/taxon/1,https://cfpub.epa.gov/ecotox/cas/100414,0.004222
1,1,https://cfpub.epa.gov/ecotox/taxon/1,https://cfpub.epa.gov/ecotox/cas/100425,0.019484
2,2,https://cfpub.epa.gov/ecotox/taxon/1,https://cfpub.epa.gov/ecotox/cas/100447,0.013535
3,3,https://cfpub.epa.gov/ecotox/taxon/1,https://cfpub.epa.gov/ecotox/cas/100516,0.006229
4,4,https://cfpub.epa.gov/ecotox/taxon/1,https://cfpub.epa.gov/ecotox/cas/100527,0.020521


In [50]:
effect_data = pd.read_csv('effect_data_extra.csv')
effect_data.head()

Unnamed: 0.2,Unnamed: 0,Unnamed: 0.1,species,chemical,conc (mol/L),species_divisions,species_others,subClassOf,smiles,smiles_clusters
0,0,0,https://cfpub.epa.gov/ecotox/taxon/1,https://cfpub.epa.gov/ecotox/cas/10025919,3.051629,https://cfpub.epa.gov/ecotox/group/Fish,https://cfpub.epa.gov/ecotox/group/StandardTes...,,Cl[Sb](Cl)Cl,4
1,1,1,https://cfpub.epa.gov/ecotox/taxon/1,https://cfpub.epa.gov/ecotox/cas/10028156,5.681105,https://cfpub.epa.gov/ecotox/group/Fish,https://cfpub.epa.gov/ecotox/group/StandardTes...,,[O-][O+]=O,4
2,2,2,https://cfpub.epa.gov/ecotox/taxon/1,https://cfpub.epa.gov/ecotox/cas/100414,3.398977,https://cfpub.epa.gov/ecotox/group/Fish,https://cfpub.epa.gov/ecotox/group/StandardTes...,http://purl.obolibrary.org/obo/CHEBI_33832,CCC1=CC=CC=C1,2
3,3,3,https://cfpub.epa.gov/ecotox/taxon/1,https://cfpub.epa.gov/ecotox/cas/100425,3.512146,https://cfpub.epa.gov/ecotox/group/Fish,https://cfpub.epa.gov/ecotox/group/StandardTes...,"http://purl.obolibrary.org/obo/CHEBI_134179,ht...",C=CC1=CC=CC=C1,2
4,4,4,https://cfpub.epa.gov/ecotox/taxon/1,https://cfpub.epa.gov/ecotox/cas/10043013,3.127255,https://cfpub.epa.gov/ecotox/group/Fish,https://cfpub.epa.gov/ecotox/group/StandardTes...,,O=S1(=O)O[Al]2OS(=O)(=O)O[Al](O1)OS(=O)(=O)O2,4


In [51]:
predictions = pd.merge(predictions, effect_data,  how='left', left_on=['species','chemical'], right_on = ['species','chemical'])
predictions_embedding = pd.merge(predictions_embedding, effect_data,  how='left', left_on=['species','chemical'], right_on = ['species','chemical'])

In [52]:
import sys
from SPARQLWrapper import SPARQLWrapper, JSON
namespace = 'https://cfpub.epa.gov/ecotox/'

endpoint_url = "https://query.wikidata.org/sparql"

query = """select ?cas ?mw where {
  ?c wdt:P231 ?castmp ;
     wdt:P2067 ?mw .
  bind(replace(?castmp,'-','') as ?cas)
}"""

def get_results(endpoint_url, query):
    user_agent = "WDQS-example Python/%s.%s" % (sys.version_info[0], sys.version_info[1])
    # TODO adjust user agent; see https://w.wiki/CX6
    sparql = SPARQLWrapper(endpoint_url, agent=user_agent)
    sparql.setQuery(query)
    sparql.setReturnFormat(JSON)
    return sparql.query().convert()

results = get_results(endpoint_url, query)

mw = {}
for result in results["results"]["bindings"]:
    mw['https://cfpub.epa.gov/ecotox/cas/'+result['cas']['value']] = float(result['mw']['value'])


In [53]:
predictions['molecular_weight'] = predictions['chemical'].apply(lambda x: mw[x])
predictions_embedding['molecular_weight'] = predictions_embedding['chemical'].apply(lambda x: mw[x])

In [54]:
f = lambda x,col: 1e3 * 10**(-x[col]) * x['molecular_weight']

predictions['predicted conc (mg/L)'] = f(predictions,'prediction')
predictions_embedding['predicted conc (mg/L)'] = f(predictions_embedding,'prediction')

predictions['true conc (mg/L)'] = f(predictions,'conc (mol/L)')
predictions_embedding['true conc (mg/L)'] = f(predictions_embedding,'conc (mol/L)')

In [55]:
def hazard_function(c):
    if np.isnan(c): 
        return 'NaN'
    if c <= 1: #mg/L
        return 'Category 1' # Very toxic
    if c <= 10:
        return 'Category 2' # Toxic
    if c <= 100:
        return 'Category 3' # Harmful
    return 'Category 4' # Maybe harmful

predictions['predicted hazard'] = list(map(hazard_function,predictions['predicted conc (mg/L)'].values))
predictions_embedding['predicted hazard'] = list(map(hazard_function,predictions_embedding['predicted conc (mg/L)'].values))
predictions['true hazard'] = list(map(hazard_function,predictions['true conc (mg/L)'].values))
predictions_embedding['true hazard'] = list(map(hazard_function,predictions_embedding['true conc (mg/L)'].values))


In [56]:
predictions.groupby('predicted hazard').count()

Unnamed: 0_level_0,Unnamed: 0_x,species,chemical,prediction,Unnamed: 0_y,Unnamed: 0.1,conc (mol/L),species_divisions,species_others,subClassOf,smiles,smiles_clusters,molecular_weight,predicted conc (mg/L),true conc (mg/L),true hazard
predicted hazard,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1
Category 4,3796,3796,3796,3796,3796,3796,3796,3796,1774,2565,3796,3796,3796,3796,3796,3796


In [57]:
from rdflib import Graph, URIRef
import numpy as np
import glob 
graph = Graph()
for filename in glob.glob('only_organic_reduced_kgs/reduced_*' if ONLY_ORGANIC else 'reduced_kgs/reduced_*'):
    graph.load(filename,format=filename.split('.')[-1])
graph.load('only_organic_physical_properties.ttl' if ONLY_ORGANIC else 'physical_properties.ttl',format='ttl')
    
entities = sorted(list(set(graph.subjects()) | set(graph.objects())))
relations = sorted(list(set(graph.predicates())))

entity_mappings = {e:i for i,e in enumerate(entities)}
inverse_entity_mappings = {i:e for i,e in enumerate(entities)}
relation_mappings = {e:i for i,e in enumerate(relations)}
triples = np.asarray(list(map(lambda x: (entity_mappings[x[0]],
                                         relation_mappings[x[1]],
                                         entity_mappings[x[2]]),graph)))

In [58]:
import sys  
sys.path.insert(0, './')
from embedding_model import ComplEx

In [59]:
embedding_model = ComplEx(entities,relations)
embedding_model.load_weights('only_organic_model.tf' if ONLY_ORGANIC else 'model.tf')
entity_matrix = embedding_model.get_layer('entity_embedding').weights[0].numpy()



In [60]:
species = set(predictions.species)
chemicals = set(predictions.chemical)

In [61]:
%%time
import tqdm.notebook as tq
from itertools import product
from scipy.spatial import distance

lf = lambda x,y: np.linalg.norm(x-y,ord=2)

distance_matrix_species = np.asarray([lf(entity_matrix[entity_mappings[URIRef(s1)]],entity_matrix[entity_mappings[URIRef(s2)]]) for s1,s2 in tq.tqdm(product(species,species),total=len(species)**2)]).reshape((len(species),len(species)))
distance_matrix_chemicals = np.asarray([lf(entity_matrix[entity_mappings[URIRef(c1)]],entity_matrix[entity_mappings[URIRef(c2)]]) for c1,c2 in tq.tqdm(product(chemicals,chemicals),total=len(chemicals)**2)]).reshape((len(chemicals),len(chemicals)))

HBox(children=(FloatProgress(value=0.0, max=828100.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=139129.0), HTML(value='')))


CPU times: user 16.1 s, sys: 2.49 s, total: 18.6 s
Wall time: 15.8 s


In [62]:
distance_matrix_species.shape

(910, 910)

In [63]:
species_mappings = {k:i for i,k in enumerate(species)}
chemical_mappings = {k:i for i,k in enumerate(chemicals)}
inverse_species_mappings = {i:k for i,k in enumerate(species)}
inverse_chemical_mappings = {i:k for i,k in enumerate(chemicals)}

In [64]:
taxonomy = Graph()
taxonomy.load('./only_organic_reduced_kgs/reduced_taxonomy.nt' if ONLY_ORGANIC else './reduced_kgs/reduced_taxonomy.nt',format='nt')
chemical_graph = Graph()
chemical_graph = Graph()
for filename in glob.glob('only_organic_reduced_kgs/reduced_*' if ONLY_ORGANIC else 'reduced_kgs/reduced_*'):
    chemical_graph.load(filename,format=filename.split('.')[-1])

In [65]:
from rdflib.namespace import RDFS
def taxon_name(uri):
    try:
        return str(list(taxonomy.objects(subject=URIRef(uri),predicate=URIRef('https://cfpub.epa.gov/ecotox/latinName'))).pop(0))
    except:
        return uri
    
def chemical_name(uri):
    try:
        l = list(chemical_graph.objects(subject=URIRef(uri),predicate=RDFS.label))
        return str(l.pop(0)).split('/')[-1]
    except:
        return uri
    

In [66]:
import sys
from collections import defaultdict
from SPARQLWrapper import SPARQLWrapper, JSON

endpoint_url = "https://query.wikidata.org/sparql"

def get_results(endpoint_url, query):
    user_agent = "WDQS-example Python/%s.%s" % (sys.version_info[0], sys.version_info[1])
    # TODO adjust user agent; see https://w.wiki/CX6
    sparql = SPARQLWrapper(endpoint_url, agent=user_agent)
    sparql.setQuery(query)
    sparql.setReturnFormat(JSON)
    return sparql.query().convert()

def wikidata_explanation(list_of_uris):
    
    query = """select ?cas ?p ?o where {
                  values ?cas {%s}
                  ?chem wdt:P231 ?castmp ;
                          ?p ?o .
                  bind (replace(?castmp,"-","") as ?cas)
                  filter (!isLiteral(?o))
                }""" % ' '.join(map(lambda x: '\"' + x.split('/')[-1] + '\"',list_of_uris))

    d = defaultdict(set)
    
    results = get_results(endpoint_url, query)

    for result in results["results"]["bindings"]:
        d[result['cas']['value']].add((result['p']['value'],result['o']['value']))
    
    return set.intersection(*[d[k] for k in d])


In [67]:
predictions_embedding['categorical error'] = [abs(int(true.split()[-1])-int(pred.split()[-1])) for true,pred in zip(predictions_embedding['true hazard'].values,predictions_embedding['predicted hazard'].values)]
predictions_embedding['log-mae (mol/L)'] = abs(predictions_embedding['conc (mol/L)']- predictions_embedding['prediction'])

In [68]:
predictions_embedding.sort_values(by='log-mae (mol/L)',axis=0,inplace=True,ascending=True)

In [69]:
n = 3
for i,row in enumerate(predictions_embedding[predictions_embedding['true hazard'] != 'Category 4'].iterrows()):
    
    true_haz = row[1]['true hazard']
    pred_haz = row[1]['predicted hazard']
    
    exp_species = np.argsort(distance_matrix_species[species_mappings[row[1]['species']]])[:n+1]
    exp_chemical = np.argsort(distance_matrix_chemicals[chemical_mappings[row[1]['chemical']]])[:n+1]
    
    print(f'True hazard: {true_haz}, predicted: {pred_haz}. log-mae (mol/L):',row[1]['log-mae (mol/L)'])
    
    tn = taxon_name(row[1]['species'])
    cn = chemical_name(row[1]['chemical'])
 
    sp = [URIRef(inverse_species_mappings[i]) for i in exp_species]
    cp = [URIRef(inverse_chemical_mappings[i]) for i in exp_chemical]
    
    explanation_s = set.intersection(*[set(taxonomy.predicate_objects(subject=s)) for s in sp])
    explanation_c = set.intersection(*[set(chemical_graph.predicate_objects(subject=c)) for c in cp])
    
    #explanation_s = wikidata_explanation(map(str,sp))
    #explanation_c = wikidata_explanation(map(str,cp))
    
    print('Explanation')
    print('Close species common facts:\n','\t'+'\n\t'.join(map(lambda x: ','.join(taxon_name(x)),explanation_s)))
    print('Close compound common facts:\n','\t'+'\n\t'.join(map(','.join,explanation_c)))
    
    #print(f'{tn} close to',[taxon_name(inverse_species_mappings[i]) for i in exp_species])
    #print(f'{cn} close to',[chemical_name(inverse_chemical_mappings[i]) for i in exp_chemical])
    #print('')
    
    if i > 5: break

True hazard: Category 3, predicted: Category 4. log-mae (mol/L): 2.5143664806154518
Explanation
Close species common facts:
 	https://www.ncbi.nlm.nih.gov/taxonomy/rank,https://www.ncbi.nlm.nih.gov/taxonomy/rank/species
	http://www.w3.org/1999/02/22-rdf-syntax-ns#type,https://www.ncbi.nlm.nih.gov/taxonomy/division/1
Close compound common facts:
 	http://id.nlm.nih.gov/mesh/vocab#allowableQualifier,http://id.nlm.nih.gov/mesh/Q000037
	http://id.nlm.nih.gov/mesh/vocab#allowableQualifier,http://id.nlm.nih.gov/mesh/Q000032
	http://id.nlm.nih.gov/mesh/vocab#allowableQualifier,http://id.nlm.nih.gov/mesh/Q000652
	http://id.nlm.nih.gov/mesh/vocab#allowableQualifier,http://id.nlm.nih.gov/mesh/Q000266
	http://id.nlm.nih.gov/mesh/vocab#allowableQualifier,http://id.nlm.nih.gov/mesh/Q000145
	http://id.nlm.nih.gov/mesh/vocab#allowableQualifier,http://id.nlm.nih.gov/mesh/Q000008
	http://id.nlm.nih.gov/mesh/vocab#allowableQualifier,http://id.nlm.nih.gov/mesh/Q000138
	http://id.nlm.nih.gov/mesh/vocab#al

In [70]:
from sklearn.metrics import r2_score

r2_score(predictions['conc (mol/L)'], predictions['prediction']),r2_score(predictions_embedding['conc (mol/L)'], predictions_embedding['prediction'])

(-5.180131454790831, -5.1575973244339925)

In [71]:
def hazard_category_metric(true,pred):
    true = np.asarray(list(map(lambda x: int(x.split(' ')[-1]),true)))
    pred = np.asarray(list(map(lambda x: int(x.split(' ')[-1]),pred)))
    
    m = np.where(true-pred>0,abs(true-pred)/3,abs(true-pred))
    
    return np.mean(m)
    
hazard_category_metric(predictions['true hazard'],predictions['predicted hazard']),hazard_category_metric(predictions_embedding['true hazard'],predictions_embedding['predicted hazard'])


(1.3788198103266596, 1.3788198103266596)