In [1]:
import pandas as pd
import numpy as np
predictions = pd.read_csv('predictions.csv')
predictions_embedding = pd.read_csv('predictions_embedding.csv')
predictions.head()

Unnamed: 0.1,Unnamed: 0,species,chemical,prediction
0,0,https://cfpub.epa.gov/ecotox/taxon/1,https://cfpub.epa.gov/ecotox/cas/10025919,3.409423
1,1,https://cfpub.epa.gov/ecotox/taxon/1,https://cfpub.epa.gov/ecotox/cas/10028156,4.745417
2,2,https://cfpub.epa.gov/ecotox/taxon/1,https://cfpub.epa.gov/ecotox/cas/100414,3.118691
3,3,https://cfpub.epa.gov/ecotox/taxon/1,https://cfpub.epa.gov/ecotox/cas/100425,3.371569
4,4,https://cfpub.epa.gov/ecotox/taxon/1,https://cfpub.epa.gov/ecotox/cas/10043013,3.91503


In [2]:
import sys
from SPARQLWrapper import SPARQLWrapper, JSON
namespace = 'https://cfpub.epa.gov/ecotox/'

endpoint_url = "https://query.wikidata.org/sparql"

query = """select ?cas ?mw where {
  ?c wdt:P231 ?castmp ;
     wdt:P2067 ?mw .
  bind(replace(?castmp,'-','') as ?cas)
}"""

def get_results(endpoint_url, query):
    user_agent = "WDQS-example Python/%s.%s" % (sys.version_info[0], sys.version_info[1])
    # TODO adjust user agent; see https://w.wiki/CX6
    sparql = SPARQLWrapper(endpoint_url, agent=user_agent)
    sparql.setQuery(query)
    sparql.setReturnFormat(JSON)
    return sparql.query().convert()

results = get_results(endpoint_url, query)

mw = {}
for result in results["results"]["bindings"]:
    mw['https://cfpub.epa.gov/ecotox/cas/'+result['cas']['value']] = float(result['mw']['value'])


In [3]:
predictions['molecular_weight'] = predictions['chemical'].apply(lambda x: mw[x])

In [4]:
predictions['conc (mg/L)'] = 1e3 * 10**(-predictions['prediction']) * predictions['molecular_weight']
predictions_embedding['conc (mg/L)'] = 1e3 * 10**(-predictions_embedding['prediction']) * predictions['molecular_weight']

In [5]:
def hazard_function(c):
    if np.isnan(c): 
        return 'NaN'
    if c <= 1: #mg/L
        return 'Category 1' # Very toxic
    if c <= 10:
        return 'Category 2' # Toxic
    if c <= 100:
        return 'Category 3' # Harmful
    return 'Category 4' # Maybe harmful

predictions['Hazard'] = list(map(hazard_function,predictions['conc (mg/L)'].values))
predictions_embedding['Hazard'] = list(map(hazard_function,predictions_embedding['conc (mg/L)'].values))

In [6]:
predictions.groupby('Hazard').count()

Unnamed: 0_level_0,Unnamed: 0,species,chemical,prediction,molecular_weight,conc (mg/L)
Hazard,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1
Category 1,1246,1246,1246,1246,1246,1246
Category 2,2039,2039,2039,2039,2039,2039
Category 3,1800,1800,1800,1800,1800,1800
Category 4,2036,2036,2036,2036,2036,2036


In [7]:
from rdflib import Graph, URIRef
import numpy as np
import glob 
graph = Graph()
for filename in glob.glob('reduced_kgs/reduced_*'):
    graph.load(filename,format=filename.split('.')[-1])
entities = sorted(list(set(graph.subjects()) | set(graph.objects())))
relations = sorted(list(set(graph.predicates())))

entity_mappings = {e:i for i,e in enumerate(entities)}
relation_mappings = {e:i for i,e in enumerate(relations)}
triples = np.asarray(list(map(lambda x: (entity_mappings[x[0]],
                                         relation_mappings[x[1]],
                                         entity_mappings[x[2]]),graph)))

In [8]:
import sys  
sys.path.insert(0, './')
from embedding_model import ComplEx

In [9]:
embedding_model = ComplEx(entities,relations)
embedding_model.load_weights('model.tf')
entity_matrix = embedding_model.get_layer('entity_embedding').weights[0].numpy()

In [65]:
species = set(predictions.species)
chemicals = set(predictions.chemical)

In [66]:
%%time
import tqdm.notebook as tq

distance_matrix_species = np.asarray([np.linalg.norm(entity_matrix[entity_mappings[URIRef(s)]]-entity_matrix,ord=2,axis=-1) for s in tq.tqdm(species)]) / len(entity_matrix[0])
distance_matrix_chemicals = np.asarray([np.linalg.norm(entity_matrix[entity_mappings[URIRef(c)]]-entity_matrix,ord=2,axis=-1) for c in tq.tqdm(chemicals)]) / len(entity_matrix[0])

HBox(children=(FloatProgress(value=0.0, max=1449.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=520.0), HTML(value='')))


CPU times: user 24.7 s, sys: 18.2 s, total: 42.9 s
Wall time: 42.9 s


In [67]:
distance_matrix_species.shape

(1449, 59953)

In [68]:
inverse_species_mappings = {k:i for i,k in enumerate(species)}
inverse_chemical_mappings = {k:i for i,k in enumerate(chemicals)}

In [69]:
n = 3
for i,row in enumerate(predictions_embedding.iterrows()):
    haz = row[1]['Hazard']
    
    exp_species = np.argsort(distance_matrix_species[inverse_species_mappings[row[1]['species']]])[1:n+1]
    exp_chemical = np.argsort(distance_matrix_chemicals[inverse_chemical_mappings[row[1]['chemical']]])[1:n+1]
    
    print(haz)
    print('Species close to',[inverse_entity_mappings[i] for i in exp_species])
    print('Chemical close to',[inverse_entity_mappings[i] for i in exp_chemical])
    
    
    if i > 10: break

Category 4
Species close to ['https://cfpub.epa.gov/ecotox/taxon/1695', 'https://cfpub.epa.gov/ecotox/taxon/10501', 'https://cfpub.epa.gov/ecotox/taxon/1716']
Chemical close to ['https://cfpub.epa.gov/ecotox/cas/7803523', 'http://purl.obolibrary.org/obo/CHEBI_36897', 'http://purl.obolibrary.org/obo/CHEBI_38307']
Category 1
Species close to ['https://cfpub.epa.gov/ecotox/taxon/1695', 'https://cfpub.epa.gov/ecotox/taxon/10501', 'https://cfpub.epa.gov/ecotox/taxon/1716']
Chemical close to ['https://cfpub.epa.gov/ecotox/cas/122394', 'https://cfpub.epa.gov/ecotox/cas/532321', 'https://cfpub.epa.gov/ecotox/cas/78115']
Category 3
Species close to ['https://cfpub.epa.gov/ecotox/taxon/1695', 'https://cfpub.epa.gov/ecotox/taxon/10501', 'https://cfpub.epa.gov/ecotox/taxon/1716']
Chemical close to ['https://cfpub.epa.gov/ecotox/cas/103651', 'https://cfpub.epa.gov/ecotox/cas/111693', 'https://cfpub.epa.gov/ecotox/cas/97881']
Category 3
Species close to ['https://cfpub.epa.gov/ecotox/taxon/1695', 'h