In [1]:
import pandas as pd
import numpy as np
import torch
from torch_geometric.data import Data, Dataset
from sklearn.model_selection import train_test_split

In [2]:
# Load csv files
df_molecules_species = pd.read_csv("./data/lotus_aggregated.csv")
df_molecules_species = df_molecules_species[['organism_taxonomy_09species', 'structure_smiles_2D']]
df_species_features = pd.read_csv("./data/species_features_dummy.csv", index_col=0) # Considering species name as index
df_species_features = df_species_features[df_species_features.columns[::-1]]
df_molecules_features = pd.read_csv("./data/mol_dummy_rdkit.csv", index_col=0) # Considering molecule structure as index

In [3]:
# Convert Boolean values to integers
df_species_features = df_species_features.astype(int)
df_molecules_features = df_molecules_features.astype(int)

In [4]:
# Here, we'll use a simple co-occurrence matrix as our graph, where species and molecules are nodes
species = df_molecules_species["organism_taxonomy_09species"].unique()
molecules = df_molecules_species["structure_smiles_2D"].unique()
num_species = len(species)
num_molecules = len(molecules)

In [5]:
species = pd.DataFrame(data={
    'organism_taxonomy_09species' : species,
    "species_ID": pd.RangeIndex(len(species))
})

In [6]:
molecules = pd.DataFrame(data={
    'structure_smiles_2D' : molecules,
    "molecules_ID": pd.RangeIndex(len(molecules))
})

In [7]:
species_to_species_id = pd.merge(df_molecules_species['organism_taxonomy_09species'],
         species,
         left_on='organism_taxonomy_09species',
        right_on='organism_taxonomy_09species',
         how='left')

In [8]:
mol_to_mol_id = pd.merge(df_molecules_species['structure_smiles_2D'],
         molecules,
         left_on='structure_smiles_2D',
        right_on='structure_smiles_2D',
         how='left')

In [9]:
species_to_species_id = torch.from_numpy(species_to_species_id['species_ID'].values)
mol_to_mol_id = torch.from_numpy(mol_to_mol_id['molecules_ID'].values)

In [10]:
edge_index_species_to_mol = torch.stack([species_to_species_id,mol_to_mol_id], dim=0)

In [26]:
from torch_geometric.data import HeteroData
import torch_geometric.transforms as T

In [27]:
data = HeteroData()

In [28]:
# Save node indices:
data["species"].node_id = torch.arange(len(species))
data["molecule"].node_id = torch.arange(len(molecules))

In [29]:
data["species"].x = torch.from_numpy(df_species_features.values).to(torch.float)
data["molecule"].x = torch.from_numpy(df_molecules_features.values).to(torch.float)
data["species", "has", "molecule"].edge_index = edge_index_species_to_mol
data = T.ToUndirected()(data)

In [30]:
transform = T.RandomLinkSplit(
    num_val=0.1,
    num_test=0.1,
    disjoint_train_ratio=0.3,
    neg_sampling_ratio=2.0,
    add_negative_train_samples=False,
    edge_types=("species", "has", "molecule"),
    rev_edge_types=("molecule", "rev_has", "species"), 
)
train_data, val_data, test_data = transform(data)

In [76]:
from torch_geometric.nn import SAGEConv, to_hetero
import torch.nn.functional as F
from torch import Tensor
from torch_geometric.nn import GATConv

class GNN(torch.nn.Module):
    def __init__(self, hidden_channels):
        super().__init__()
        self.conv1 = GATConv(hidden_channels, hidden_channels, add_self_loops=False)
        self.conv2 = GATConv(hidden_channels, hidden_channels, add_self_loops=False)

    def forward(self, x: Tensor, edge_index: Tensor) -> Tensor:
        x = F.relu(self.conv1(x, edge_index))
        x = self.conv2(x, edge_index)
        return x
# Our final classifier applies the dot-product between source and destination
class Classifier(torch.nn.Module):
    def forward(self, x_species: Tensor, x_molecule: Tensor, edge_label_index: Tensor) -> Tensor:
        # Convert node embeddings to edge-level representations:
        edge_feat_species = x_species[edge_label_index[0]]
        edge_feat_molecule = x_molecule[edge_label_index[1]]
        # Apply dot-product to get a prediction per supervision edge:
        edge_logits = (edge_feat_species * edge_feat_molecule).sum(dim=-1)
        # Apply sigmoid function to get probabilities between 0 and 1
        edge_probs = torch.sigmoid(edge_logits)
        return edge_probs

class Model(torch.nn.Module):
    def __init__(self, hidden_channels):
        super().__init__()
        # Since the dataset does not come with rich features, we also learn two
        # embedding matrices for users and movies:
        self.movie_lin = torch.nn.Linear(1024, hidden_channels)
        self.species_emb = torch.nn.Embedding(data["species"].num_nodes, hidden_channels)
        self.molecule_emb = torch.nn.Embedding(data["molecule"].num_nodes, hidden_channels)
        # Instantiate homogeneous GNN:
        self.gnn = GNN(hidden_channels)
        # Convert GNN model into a heterogeneous variant:
        self.gnn = to_hetero(self.gnn, metadata=data.metadata())
        self.classifier = Classifier()
        
    def forward(self, data: HeteroData) -> Tensor:
        x_dict = {
          "species": self.species_emb(data["species"].node_id),
          "molecule": self.movie_lin(data["molecule"].x) + self.molecule_emb(data["molecule"].node_id),
        } 
        # `x_dict` holds feature matrices of all node types
        # `edge_index_dict` holds all edge indices of all edge types
        x_dict = self.gnn(x_dict, data.edge_index_dict)
        pred_prob = self.classifier(
            x_dict["species"],
            x_dict["molecule"],
            data["species", "has", "molecule"].edge_label_index,
        )
        return pred_prob

model = Model(hidden_channels=128)

In [77]:
import tqdm
import torch.nn.functional as F

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device: '{device}'")
model = model.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-2)

# Convert train_data to device
train_data = train_data.to(device)

for epoch in range(1, 200):
    optimizer.zero_grad()
    pred = model(train_data)
    ground_truth = train_data["species", "has", "molecule"].edge_label
    loss = F.binary_cross_entropy(pred, ground_truth)
    loss.backward()
    optimizer.step()
    print(f"Epoch: {epoch:03d}, Loss: {loss:.4f}")


Device: 'cpu'
Epoch: 001, Loss: 1.1392
Epoch: 002, Loss: 0.4642
Epoch: 003, Loss: 0.3276
Epoch: 004, Loss: 0.2260
Epoch: 005, Loss: 0.1903
Epoch: 006, Loss: 0.1781
Epoch: 007, Loss: 0.1709
Epoch: 008, Loss: 0.1645
Epoch: 009, Loss: 0.1581
Epoch: 010, Loss: 0.1518
Epoch: 011, Loss: 0.1454
Epoch: 012, Loss: 0.1389
Epoch: 013, Loss: 0.1325
Epoch: 014, Loss: 0.1261
Epoch: 015, Loss: 0.1197
Epoch: 016, Loss: 0.1133
Epoch: 017, Loss: 0.1069
Epoch: 018, Loss: 0.1007
Epoch: 019, Loss: 0.0945
Epoch: 020, Loss: 0.0885
Epoch: 021, Loss: 0.0825
Epoch: 022, Loss: 0.0768
Epoch: 023, Loss: 0.0713
Epoch: 024, Loss: 0.0659
Epoch: 025, Loss: 0.0608
Epoch: 026, Loss: 0.0560
Epoch: 027, Loss: 0.0514
Epoch: 028, Loss: 0.0471
Epoch: 029, Loss: 0.0431
Epoch: 030, Loss: 0.0393
Epoch: 031, Loss: 0.0359
Epoch: 032, Loss: 0.0327
Epoch: 033, Loss: 0.0298
Epoch: 034, Loss: 0.0271
Epoch: 035, Loss: 0.0247
Epoch: 036, Loss: 0.0224
Epoch: 037, Loss: 0.0204
Epoch: 038, Loss: 0.0186
Epoch: 039, Loss: 0.0170
Epoch: 040,

In [78]:
from sklearn.metrics import roc_auc_score

val_data = val_data.to(device)

with torch.no_grad():
    pred = model(val_data)
    ground_truth = val_data["species", "has", "molecule"].edge_label

# Move prediction and ground truth to CPU and convert to numpy arrays for AUC computation
pred = pred.cpu().numpy()
ground_truth = ground_truth.cpu().numpy()

# Compute AUC
auc = roc_auc_score(ground_truth, pred)

print()
print(f"Validation AUC: {auc:.4f}")



Validation AUC: 0.4156


In [66]:
val_data["species", "has", "molecule"]

{'edge_index': tensor([[ 4849,  3405,  2580,  ...,  1364,  2265,  8922],
        [ 8290,  3328,  2496,  ..., 11811,   229, 10483]]), 'edge_label': tensor([1., 1., 1.,  ..., 0., 0., 0.]), 'edge_label_index': tensor([[7095, 6301, 5531,  ...,  718, 9926, 5778],
        [ 978, 7084,  369,  ..., 7989,  659, 8027]])}

In [59]:
pred

array([1.        , 0.59894294, 0.9999999 , ..., 0.61078596, 0.57677233,
       0.9976546 ], dtype=float32)

In [75]:
ground_truth

array([1., 1., 1., ..., 0., 0., 0.], dtype=float32)