In [1]:
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' 
import networkx as nx
import itertools
import pandas as pd
import numpy as np
from stellargraph import StellarGraph
from tensorflow import keras
import tensorflow as tf
import category_encoders as ce
import stellargraph
import stellargraph as sg
from stellargraph.data import EdgeSplitter
from stellargraph.mapper import HinSAGELinkGenerator

import matplotlib.pyplot as plt
from tensorflow import keras
from tensorflow.keras import mixed_precision

In [2]:
g = nx.read_graphml("./graph/test_graph.gml")
species_features_dummy = pd.read_csv("./data/species_features.csv.gz", index_col=0)
molecule_features_dummy = pd.read_csv("./data/molecule_features.csv.gz", index_col=0).astype("int8")
df_agg = pd.read_csv("./data/lotus_agg_test.csv.gz", index_col=0)

In [3]:
df_agg_train = pd.read_csv("./data/lotus_agg_train.csv.gz", index_col=0)

In [4]:
g_train = nx.read_graphml("./graph/train_graph.gml")
g_merged = nx.compose(g_train, g)

In [5]:
G = StellarGraph.from_networkx(g_merged,
                               node_features={'species': species_features_dummy,
                                              'molecule':molecule_features_dummy})

In [6]:
df= pd.concat([df_agg_train, df_agg]).reset_index(drop=True)

In [7]:
df

Unnamed: 0,organism_name,structure_smiles_2D,reference_wikidata,organism_taxonomy_08genus,organism_taxonomy_07tribe,organism_taxonomy_06family,organism_taxonomy_05order,organism_taxonomy_04class,organism_taxonomy_03phylum,organism_taxonomy_02kingdom,organism_taxonomy_01domain,structure_taxonomy_classyfire_01kingdom,structure_taxonomy_classyfire_02superclass,structure_taxonomy_classyfire_03class,structure_taxonomy_classyfire_04directparent,total_papers_molecule,total_papers_species
0,Aaptos,COc1cc2c3c(ccnc3c1OC)NC=C2,2,Aaptos,,Suberitidae,Suberitida,Demospongiae,Porifera,Metazoa,Eukaryota,Organic compounds,Organoheterocyclic compounds,Diazanaphthalenes,Naphthyridines,16,2
1,Aaptos aaptos,CC(C)CCCC(C)C1CCC2C3=CCC4CC(O)CCC4(C)C3CCC21C,1,Aaptos,,Suberitidae,Suberitida,Demospongiae,Porifera,Metazoa,Eukaryota,Organic compounds,Lipids and lipid-like molecules,Steroids and steroid derivatives,Cholesterols and derivatives,42,38
2,Aaptos aaptos,CC(C)CCCC(C)C1CCC2C3CCC4CC(O)CCC4(C)C3CCC12C,1,Aaptos,,Suberitidae,Suberitida,Demospongiae,Porifera,Metazoa,Eukaryota,Organic compounds,Lipids and lipid-like molecules,Steroids and steroid derivatives,Cholesterols and derivatives,94,38
3,Aaptos aaptos,COC1=Cc2ccnc3c2C1(O)C(CC(C)=O)=NC=C3,1,Aaptos,,Suberitidae,Suberitida,Demospongiae,Porifera,Metazoa,Eukaryota,Organic compounds,Organoheterocyclic compounds,Azepines,Azepines,1,38
4,Aaptos aaptos,COC1=Cc2ccnc3c2C1=C(OC)NC=C3,1,Aaptos,,Suberitidae,Suberitida,Demospongiae,Porifera,Metazoa,Eukaryota,Organic compounds,Organoheterocyclic compounds,Azepines,Azepines,1,38
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
438067,Amycolatopsis,CCC(C)C1C(=O)N2NCCCC2C(=O)N(C)C(C(C)(C)O)C(=O)...,3,Amycolatopsis,,Pseudonocardiaceae,Pseudonocardiales,,Actinobacteria,,Bacteria,Organic compounds,Organic acids and derivatives,Carboxylic acids and derivatives,Oligopeptides,3,147
438068,Centipeda minima,COc1cc(-c2oc3cc(O)cc(O)c3c(=O)c2OC)ccc1O,2,Centipeda,Athroismeae,Asteraceae,Asterales,Magnoliopsida,Streptophyta,Archaeplastida,Eukaryota,Organic compounds,Phenylpropanoids and polyketides,Flavonoids,3-O-methylated flavonoids,48,77
438069,Solidago wrightii,CC(C)C1CC=C2C(O)CCC(C)C2(C)C1OC(=O)C=Cc1ccccc1,1,Solidago,Astereae,Asteraceae,Asterales,Magnoliopsida,Streptophyta,Archaeplastida,Eukaryota,Organic compounds,Lipids and lipid-like molecules,Prenol lipids,"Eremophilane, 8,9-secoeremophilane and furoere...",1,4
438070,Vitis aestivalis,O=C(C=Cc1ccc(O)c(O)c1)OC(C(=O)O)C(O)C(=O)O,3,Vitis,Viteae,Vitaceae,Vitales,Magnoliopsida,Streptophyta,Archaeplastida,Eukaryota,Organic compounds,Phenylpropanoids and polyketides,Cinnamic acids and derivatives,Coumaric acids and derivatives,111,61


In [8]:
# Fetch unique species and molecules and their respective features
unique_species_df = df.drop_duplicates(subset=['organism_name'])
unique_molecules_df = df.drop_duplicates(subset=['structure_smiles_2D'])

In [9]:
# Fetch the corresponding features
species_features_df = unique_species_df[['organism_taxonomy_01domain',
                                         'organism_taxonomy_02kingdom',
                                         'organism_taxonomy_03phylum',
                                         'organism_taxonomy_04class',
                                         'organism_taxonomy_05order',
                                         'organism_taxonomy_06family',
                                         'organism_taxonomy_07tribe',
                                         'organism_taxonomy_08genus',
                                         'organism_name']]

molecule_features_df = unique_molecules_df[['structure_taxonomy_classyfire_01kingdom',
                                            'structure_taxonomy_classyfire_02superclass',
                                            'structure_taxonomy_classyfire_03class',
                                            'structure_taxonomy_classyfire_04directparent']]

In [10]:
molecule_features_df.index = [i for i in unique_molecules_df['structure_smiles_2D']]
species_features_df.index = [i for i in unique_species_df['organism_name']]

In [11]:
new_row = {'structure_taxonomy_classyfire_01kingdom':'Organic compounds',
           'structure_taxonomy_classyfire_02superclass':'Organic acids and derivatives',
           'structure_taxonomy_classyfire_03class':'Carboxylic acids and derivatives',
           'structure_taxonomy_classyfire_04directparent':'Amino acids, peptides, and analogues'}

In [12]:
molecule_features_df=molecule_features_df.append(
    pd.DataFrame([new_row],
                 index=['CC(C)C[C@@H]1NC(=O)C2C[C@]3(OC(C)=O)[C@H](N2C1=O)N(C(C)=O)C1=C3C=CC=C1'],
                 columns=molecule_features_df.columns)
)

In [13]:
g_merged.add_node('CC(C)C[C@@H]1NC(=O)C2C[C@]3(OC(C)=O)[C@H](N2C1=O)N(C(C)=O)C1=C3C=CC=C1',
                 label='molecule')

In [14]:
# create features for species
encoder_species = ce.BinaryEncoder(cols=[col for col in species_features_df.columns])
species_features_dummy = encoder_species.fit_transform(species_features_df)


encoder_molecule = ce.BinaryEncoder(cols=[col for col in molecule_features_df.columns])
molecule_features_dummy = encoder_molecule.fit_transform(molecule_features_df)

species_features_dummy.index = species_features_df.index
molecule_features_dummy.index = molecule_features_df.index

In [15]:
G = StellarGraph.from_networkx(g_merged,
                               node_features={'species': species_features_dummy,
                                              'molecule':molecule_features_dummy})

In [16]:
print(G.info())
G.check_graph_for_ml()

StellarDiGraph: Directed multigraph
 Nodes: 184991, Edges: 876144

 Node types:
  molecule: [148191]
    Features: float32 vector, length 27
    Edge types: molecule-present_in->species
  species: [36800]
    Features: float32 vector, length 80
    Edge types: species-has->molecule

 Edge types:
    species-has->molecule: [438072]
        Weights: all 1 (default)
        Features: none
    molecule-present_in->species: [438072]
        Weights: all 1 (default)
        Features: none


In [17]:
model = tf.keras.models.load_model("./model/batch_128_layer_1024_m_to_s", compile=True)

In [18]:
test_mol= np.array([['CC(C)C[C@@H]1NC(=O)C2C[C@]3(OC(C)=O)[C@H](N2C1=O)N(C(C)=O)C1=C3C=CC=C1', 'Aspergillus ustus']])

In [19]:
test_flow = HinSAGELinkGenerator(G,
                                 batch_size=128,
                                num_samples=[3,1],
                                head_node_types=["molecule", "species"]).flow(test_mol,
                                                                              np.ones(len(test_mol)).reshape(-1,1))

In [20]:
def predict(model, flow, iterations=10):
    predictions = []
    _ = model.predict(flow, workers=-1).flatten()
    for _ in range(iterations):
        predictions.append(model.predict(flow, workers=-1).flatten())

    return np.mean(predictions, axis=0)

In [21]:
predict(model, test_flow)



array([0.65889513], dtype=float32)

In [22]:
new_line_species = {'organism_taxonomy_01domain': 'Eukaryota',
                    'organism_taxonomy_02kingdom': 'Metazoa',
                    'organism_taxonomy_03phylum': 'Chordata',
                    'organism_taxonomy_04class': 'Ascidiacea',
                    'organism_taxonomy_05order': 'Stolidobranchia',
                    'organism_taxonomy_06family' : 'Styelidae',
                    'organism_taxonomy_07tribe': np.nan,
                    'organism_taxonomy_08genus': 'Polyandrocarpa',
                    'organism_name': 'Polyandrocarpa sparsa'}

In [23]:
g_merged.add_node('Polyandrocarpa sparsa', label='species')

In [24]:
species_features_df=species_features_df.append(
    pd.DataFrame([new_line_species],
                 index=['Polyandrocarpa sparsa'],
                 columns=species_features_df.columns)
)

In [25]:
new_row = {'structure_taxonomy_classyfire_01kingdom':'Organic compounds',
           'structure_taxonomy_classyfire_02superclass':'Lignans, neolignans and related compounds',
           'structure_taxonomy_classyfire_03class':np.nan,
           'structure_taxonomy_classyfire_04directparent':'Lignans, neolignans and related compounds'}

In [26]:
molecule_features_df=molecule_features_df.append(
    pd.DataFrame([new_row],
                 index=['CN1C[C@H](Cl)[C@]23Cc4ccc(Br)cc4N(C)C2=NC[C@@H](Cl)[C@]32C1=Nc1ccccc21'],
                 columns=molecule_features_df.columns)
)
g_merged.add_node('CN1C[C@H](Cl)[C@]23Cc4ccc(Br)cc4N(C)C2=NC[C@@H](Cl)[C@]32C1=Nc1ccccc21',
                 label='molecule')

In [27]:
# create features for species
encoder_species = ce.BinaryEncoder(cols=[col for col in species_features_df.columns])
species_features_dummy = encoder_species.fit_transform(species_features_df)


encoder_molecule = ce.BinaryEncoder(cols=[col for col in molecule_features_df.columns])
molecule_features_dummy = encoder_molecule.fit_transform(molecule_features_df)

species_features_dummy.index = species_features_df.index
molecule_features_dummy.index = molecule_features_df.index

In [28]:
test_mol = np.array([['CN1C[C@H](Cl)[C@]23Cc4ccc(Br)cc4N(C)C2=NC[C@@H](Cl)[C@]32C1=Nc1ccccc21',
                     'Polyandrocarpa sparsa']])

In [29]:
G = StellarGraph.from_networkx(g_merged,
                               node_features={'species': species_features_dummy,
                                              'molecule':molecule_features_dummy})

In [30]:
test_flow = HinSAGELinkGenerator(G,
                                 batch_size=128,
                                num_samples=[3,1],
                                head_node_types=["molecule", "species"]).flow(test_mol,
                                                                              np.ones(len(test_mol)).reshape(-1,1))

In [31]:
predict(model, test_flow)



array([0.41193002], dtype=float32)

In [32]:
model_s_to_m = tf.keras.models.load_model("./model/batch_128_layer_1024_s_to_m/", compile=True)

In [33]:
test_sp = np.array([['Polyandrocarpa sparsa',
                     'CN1C[C@H](Cl)[C@]23Cc4ccc(Br)cc4N(C)C2=NC[C@@H](Cl)[C@]32C1=Nc1ccccc21']])

In [34]:
test_flow_sp = HinSAGELinkGenerator(G,
                                 batch_size=128,
                                num_samples=[3,1],
                                head_node_types=["species", "molecule"]).flow(test_sp,
                                                                              np.ones(len(test_sp)).reshape(-1,1))

In [35]:
out_mol = predict(model_s_to_m, test_flow_sp)



In [36]:
out_sp = predict(model, test_flow)



In [37]:
(out_mol+out_sp)/2

array([0.69998956], dtype=float32)

In [38]:
def add_nodes_to_graph(G: nx.DiGraph, data: pd.DataFrame)-> nx.DiGraph :
    '''
    This funnction will add the missing node to the graph before prediction. Input should be a pandas dataframe
    with one column with the species, and the other with the molecules that will be predicted. There must be a
    column name to the input data. 
    '''
    
    assert len(data.columns) == 2, "Input data must have 2 columns !"
    if df.columns.isnull().any():
        raise ValueError("The input DataFrame has unnamed columns ! Columns should be named.")
    
    first_column_set = set(data.iloc[:, 0])
    second_column_set = set(data.iloc[:, 1])
    
    for name in first_column_set:
        if name not in G:
            G.add_node(name, label=data.columns[0])
            
    for name in second_column_set:
        if name not in G:
            G.add_node(name, label=data.columns[1])
    
    return G

In [39]:
def check_which_model_to_use(G: nx.DiGraph, data: pd.DataFrame):
    '''
    This function will split the input data in 3 different typesof list. 
    If the molecule queried is present but not the species, it should use model_s_to_m.
    If molecule is not known but the species is, it should use model_m_to_s.
    If both are unknown, it should run both and average them out. 
    '''
    assert len(data.columns) == 2, "Input data must have 2 columns ! "
    
    if data.columns[0] != 'molecule' and data.columns[0] != 'species':
        raise ValueError("First column must be named either : 'molecule' or 'species' !")
        
    if data.columns[1] != 'molecule' and data.columns[1] != 'species':
        raise ValueError("Second column must be named either : 'molecule' or 'species' ! ")

    
    m_to_s = np.empty((1,2))
    s_to_m = np.empty((1,2))
    both = np.empty((1,2))
    
    for i, row in data.iterrows():
        if (row['molecule'] not in G) and (row['species'] not in G):
            both = np.append(both, row.values.reshape(1,2), axis=0)
        
        elif (row['molecule'] not in G) and (row['species'] in G):
            m_to_s = np.append(m_to_s, row.values.reshape(1,2), axis=0)
            
        elif (row['molecule'] in G) and (row['species'] not in G):
            s_to_m = np.append(s_to_m, row.values.reshape(1,2), axis=0)
            
        else:
            both = np.append(both, row.values.reshape(1,2), axis=0)
    
    return m_to_s[1:], s_to_m[1:], both[1:]

In [40]:
species_features_dummy = pd.read_csv("./data/species_features.csv.gz", index_col=0)
molecule_features_dummy = pd.read_csv("./data/molecule_features.csv.gz", index_col=0).astype("int8")
df_agg = pd.read_csv("./data/lotus_agg_test.csv.gz", index_col=0)

In [41]:
species_feat = species_features_dummy[species_features_dummy.index.isin(df_agg.organism_name)]
molecule_feat = molecule_features_dummy[molecule_features_dummy.index.isin(df_agg.structure_smiles_2D)]

df_agg_train = pd.read_csv("./data/lotus_agg_train.csv.gz", index_col=0)

species_unique_to_test_set = df_agg[~df_agg.organism_name.isin(df_agg_train.organism_name)]
molecules_unique_to_test_set = df_agg[~df_agg.structure_smiles_2D.isin(df_agg_train.structure_smiles_2D)]

In [42]:
species_unique_to_test_set = species_unique_to_test_set[['organism_name', 'structure_smiles_2D']].to_numpy()

In [43]:
species_unique_to_test_set = pd.DataFrame(species_unique_to_test_set, columns=['species', 'molecule'])

In [44]:
m, s, both = check_which_model_to_use(g_merged, species_unique_to_test_set)

In [45]:
def nx_to_stellargraph(g: nx.DiGraph,
                       molecule_features: pd.core.frame.DataFrame,
                       species_features:pd.core.frame.DataFrame) -> stellargraph.core.graph.StellarDiGraph:
    
    if len(g.nodes()) != (len(molecule_features.index) + len(species_features.index)):
        raise Exception("Number of nodes does not match number of features ! Please check your graph or your features.")
        
    G = StellarGraph.from_networkx(g,
                                   node_features={'species': species_features,
                                                  'molecule': molecule_features})
    G.check_graph_for_ml()
    print(G.info())
    
    return G

In [46]:
type(G)

stellargraph.core.graph.StellarDiGraph

# TODO

In [47]:
def create_flows(graph: stellargraph.core.graph.StellarDiGraph,
                 data: pd.DataFrame,
                ):
    
    
    ### TODO 
    if data.columns[0] == 'molecule':
        flow_sp = HinSAGELinkGenerator(
            G,
            batch_size=128,
            num_samples=[3,1],
            head_node_types=["species", "molecule"]).flow(,
                                                          np.ones(len(test_sp)).reshape(-1,1))
    return 0

SyntaxError: invalid syntax (2009748232.py, line 12)

In [None]:
def predict_using_both_models(model_m_to_s,
                              model_s_to_m,
                              flow_m: stellargraph.mapper.sequences.LinkSequence,
                              flow_s: stellargraph.mapper.sequences.LinkSequence) -> np.ndarray:
    
    #do prediction both ways and average them. 
    a = predict(model_m_to_s, flow_m)
    b = predict(model_s_to_m, flow_s)
    
    return (a+b)/2

In [None]:
type(G)