In [None]:
# Install pykeen - package that is used to set up the pipeline
pip install pykeen



In [None]:
# Mount Google Drive to work on relevant files
from google.colab import drive
drive.mount('/content/drive')

# Import necessary packages
import pickle
import numpy as np
from pykeen.pipeline import pipeline
from pykeen.triples import TriplesFactory
from typing import List, Tuple
from pykeen.predict import predict_target

# Load the Graph from pickle
graph_path = "/content/drive/MyDrive/Colabs/data/graph_w_covid_genes_treatments_string_cif_coalesced_3.pkl"
with open(graph_path, "rb") as f:
    graph = pickle.load(f)

# Extract triples from the graph
triples = []
for source, target, data in graph.edges(data=True):
    relation = data.get('relation', 'related_to')  # Default relation if not specified
    triples.append((str(source), str(relation), str(target)))

# Convert triples to numpy array
triples_array = np.array(triples)

# Split the triples into training, validation, and testing sets
training_triples = triples_array[:int(0.8 * len(triples_array))]
validation_triples = triples_array[int(0.8 * len(triples_array)):int(0.9 * len(triples_array))]
testing_triples = triples_array[int(0.9 * len(triples_array)):]

training_factory = TriplesFactory.from_labeled_triples(training_triples)
validation_factory = TriplesFactory.from_labeled_triples(validation_triples)
testing_factory = TriplesFactory.from_labeled_triples(testing_triples)

# Create a RGCN pipeline to train the model
pipeline_result = pipeline(
    training=training_factory,
    testing=testing_factory,
    validation=validation_factory,
    model='RGCN',
    training_loop='sLCWA',
    optimizer='Adam',
    model_kwargs={
        'embedding_dim': 100,
        'num_layers': 2,
    },
    training_kwargs={
          'num_epochs': 50,
          'batch_size': 6000,
    },
    random_seed=42,
    use_tqdm=True,
    device='cuda'
)

INFO:pykeen.pipeline.api:Using device: cuda
INFO:pykeen.nn.message_passing:No num_bases was provided. Using sqrt(num_relations)=11.
INFO:pykeen.nn.message_passing:No num_bases was provided. Using sqrt(num_relations)=11.
INFO:pykeen.nn.message_passing:No num_bases was provided. Using sqrt(num_relations)=11.
INFO:pykeen.nn.message_passing:No num_bases was provided. Using sqrt(num_relations)=11.
  (fwd): BasesDecomposition(
    (relation_representations): LowRankRepresentation(
      (bases): Embedding(
        (_embeddings): Embedding(11, 10000)
      )
    )
  )
  (bwd): BasesDecomposition(
    (relation_representations): LowRankRepresentation(
      (bases): Embedding(
        (_embeddings): Embedding(11, 10000)
      )
    )
  )
  (self_loop): Linear(in_features=100, out_features=100, bias=True)
  (dropout): Dropout(p=0.2, inplace=False)
) has parameters, but no reset_parameters.
  (fwd): BasesDecomposition(
    (relation_representations): LowRankRepresentation(
      (bases): Embeddi

Training epochs on cuda:0:   0%|          | 0/50 [00:00<?, ?epoch/s]

Training batches on cuda:0:   0%|          | 0/674 [00:00<?, ?batch/s]

Training batches on cuda:0:   0%|          | 0/674 [00:00<?, ?batch/s]

Training batches on cuda:0:   0%|          | 0/674 [00:00<?, ?batch/s]

Training batches on cuda:0:   0%|          | 0/674 [00:00<?, ?batch/s]

Training batches on cuda:0:   0%|          | 0/674 [00:00<?, ?batch/s]

Training batches on cuda:0:   0%|          | 0/674 [00:00<?, ?batch/s]

Training batches on cuda:0:   0%|          | 0/674 [00:00<?, ?batch/s]

In [None]:
# Import necessary modules
import os
import pickle

# Retrieve the trained model from the pipeline result
model = pipeline_result.model

# Define the path where the model will be saved
save_path = '/content/drive/MyDrive/my_models/trained_rgcn_new_graph_model.pkl'
directory = os.path.dirname(save_path)

# Check if the directory exists, and if not, create it
if not os.path.exists(directory):
    os.makedirs(directory)

# Save the model to the specified path using pickle
with open(save_path, 'wb') as f:
    pickle.dump(model, f)

In [None]:
# Get evaluation results
results = pipeline_result.metric_results.to_dict()
print("\nModel Evaluation Results:")
print(results)

# Get embeddings correctly
entity_embeddings = result.model.entity_representations[0](indices=None).cpu().detach().numpy()
relation_embeddings = result.model.relation_representations[0](indices=None).cpu().detach().numpy()


Model Evaluation Results:
{'head': {'optimistic': {'adjusted_geometric_mean_rank_index': 0.1387197586492378, 'inverse_harmonic_mean_rank': 0.0006024196282459551, 'variance': 691748385.8013566, 'inverse_median_rank': 2.356989652815424e-05, 'adjusted_inverse_harmonic_mean_rank': 0.000472985407946379, 'inverse_geometric_mean_rank': 3.4012427814031914e-05, 'inverse_arithmetic_mean_rank': 2.3222912661772754e-05, 'adjusted_arithmetic_mean_rank_index': 0.07182688188059483, 'count': 505082.0, 'standard_deviation': 26301.109972800703, 'adjusted_arithmetic_mean_rank': 0.9281746663420413, 'z_inverse_harmonic_mean_rank': 79.86278921712818, 'median_rank': 42427.0, 'arithmetic_mean_rank': 43060.92067624663, 'harmonic_mean_rank': 1659.9724728619256, 'z_arithmetic_mean_rank': 88.41456327743163, 'median_absolute_deviation': 32744.752597914725, 'geometric_mean_rank': 29401.017929906415, 'z_geometric_mean_rank': 98.74585013856847, 'hits_at_1': 0.00010097370327986347, 'hits_at_3': 0.00032271987518858324,

In [None]:
# Define the list of COVID-related diseases
COV_disease_list = [
    "Disease::SARS-CoV2 E",
    "Disease::SARS-CoV2 M",
    "Disease::SARS-CoV2 N",
    "Disease::SARS-CoV2 Spike",
    "Disease::SARS-CoV2 nsp1",
    "Disease::SARS-CoV2 nsp10",
    "Disease::SARS-CoV2 nsp11",
    "Disease::SARS-CoV2 nsp12",
    "Disease::SARS-CoV2 nsp13",
    "Disease::SARS-CoV2 nsp14",
    "Disease::SARS-CoV2 nsp15",
    "Disease::SARS-CoV2 nsp2",
    "Disease::SARS-CoV2 nsp4",
    "Disease::SARS-CoV2 nsp5",
    "Disease::SARS-CoV2 nsp5_C145A",
    "Disease::SARS-CoV2 nsp6",
    "Disease::SARS-CoV2 nsp7",
    "Disease::SARS-CoV2 nsp8",
    "Disease::SARS-CoV2 nsp9",
    "Disease::SARS-CoV2 orf10",
    "Disease::SARS-CoV2 orf3a",
    "Disease::SARS-CoV2 orf3b",
    "Disease::SARS-CoV2 orf6",
    "Disease::SARS-CoV2 orf7a",
    "Disease::SARS-CoV2 orf8",
    "Disease::SARS-CoV2 orf9b",
    "Disease::SARS-CoV2 orf9c",
    "Disease::MESH:D045169",
    "Disease::MESH:D045473",
    "Disease::MESH:D001351",
    "Disease::MESH:D065207",
    "Disease::MESH:D028941",
    "Disease::MESH:D058957",
    "Disease::MESH:D006517",
]

In [None]:
def get_top_predictions(
    model,
    training_factory,
    relation: str,
    tail_entities: List[str],
    k: int = 20
) -> dict:
    """
    Get top 20 compund predictions for given tail (COVID) entities and relation.
    """
    results = {}

    for tail in tail_entities:
        try:
            # Use predict_target to get predictions
            predictions = predict_target(
                model=model,
                tail=tail,
                relation=relation,
                triples_factory=training_factory
            )

            # Get the results as a dataframe and apply the filter
            df = predictions.df

            # Filter for only Compound predictions
            compound_df = df[df['head_label'].str.startswith('Compound::')].sort_values('score', ascending=False)

            if compound_df.empty:
                print(f"Warning: No compound predictions found for {tail}")
                continue

            # Get top k predictions
            top_k_df = compound_df.head(k)

            # Convert to list of (compound, score) tuples
            head_scores = list(zip(top_k_df['head_label'], top_k_df['score']))
            results[tail] = head_scores

        except KeyError as e:
            print(f"Warning: Entity or relation not found in training data: {e}")
            continue

        except Exception as e:
            print(f"Error predicting for {tail}: {str(e)}")
            continue

    return results

# Get predictions for COVID diseases
relation = "DRKG::Treats::Compound:Disease"
predictions = get_top_predictions(
    result.model,
    train_factory,
    relation,
    COV_disease_list
)

# Print results
print("\nTop 20 predicted compounds for each disease:")
for disease, compounds in predictions.items():
    if compounds:  # Only print if we have compound predictions
        print(f"\n{disease}:")
        for i, (compound, score) in enumerate(compounds, 1):
            print(f"{i}. {compound} (score: {score:.4f})")


Top 20 predicted compounds for each disease:

Disease::SARS-CoV2 E:
1. Compound::CHEMBL2108712 (score: 0.6998)
2. Compound::CHEMBL2109592 (score: 0.5862)
3. Compound::brenda:135439 (score: 0.5665)
4. Compound::brenda:169662 (score: 0.4908)
5. Compound::CHEMBL3707379 (score: 0.4628)
6. Compound::DB05296 (score: 0.4618)
7. Compound::brenda:224348 (score: 0.4513)
8. Compound::MESH:D009842 (score: 0.4394)
9. Compound::rhea:58850 (score: 0.4357)
10. Compound::CHEBI:37007 (score: 0.4352)
11. Compound::CHEBI:29449 (score: 0.4168)
12. Compound::CHEBI:15805 (score: 0.4123)
13. Compound::MESH:C400974 (score: 0.4014)
14. Compound::brenda:223261 (score: 0.3865)
15. Compound::MESH:C541773 (score: 0.3809)
16. Compound::pubchem:5287470 (score: 0.3773)
17. Compound::bindingdb:50154292 (score: 0.3736)
18. Compound::MESH:C037178 (score: 0.3616)
19. Compound::DB08431 (score: 0.3614)
20. Compound::MESH:C060090 (score: 0.3565)

Disease::SARS-CoV2 M:
1. Compound::MESH:C555517 (score: 1.3289)
2. Compound::M

In [None]:
def get_compounds_with_treatments(graph):
    """
    Get list of compounds that have treatment data associated with COVID-19.
    Returns list of compounds and count.
    """
    # Initialize an empty list to store compound nodes.
    compounds = []

    # Iterate through all nodes and their attributes.
    for node, attributes in graph.nodes(data=True):
        # Check if the node is a string, starts with 'Compound::', and contains 'treatment_data' in its attributes.
        if (isinstance(node, str) and
            node.startswith('Compound::') and
            'treatment_data' in attributes):
            compounds.append(node)

    return compounds, len(compounds)

# Get the compounds and count
compounds, num_compounds = get_compounds_with_treatments(graph)
print(compounds)
print(num_compounds)

['Compound::DB01087', 'Compound::DB01072', 'Compound::DB06817', 'Compound::DB00773', 'Compound::DB01392', 'Compound::DB00539', 'Compound::DB07795', 'Compound::DB13997', 'Compound::DB09299', 'Compound::DB00608', 'Compound::DB14126', 'Compound::DB00091', 'Compound::DB00198', 'Compound::DB00811', 'Compound::DB09026', 'Compound::DB08934', 'Compound::DB00220', 'Compound::DB11779', 'Compound::DB00328', 'Compound::DB04786', 'Compound::DB00482', 'Compound::DB01015', 'Compound::DB09102', 'Compound::DB01264', 'Compound::DB01645', 'Compound::DB00822', 'Compound::DB09065', 'Compound::DB12610', 'Compound::DB00831', 'Compound::DB00457', 'Compound::DB04115', 'Compound::DB01041', 'Compound::DB00722', 'Compound::DB00213', 'Compound::DB00133', 'Compound::DB06803', 'Compound::DB12270', 'Compound::DB00319', 'Compound::DB11952', 'Compound::DB00302', 'Compound::DB00199', 'Compound::DB13925', 'Compound::DB00143', 'Compound::DB09061', 'Compound::DB14810', 'Compound::DB00673', 'Compound::DB00159', 'Compound::D

In [None]:
def analyze_predictions_vs_treatments(predictions, actual_compounds):
    """
    Analyze how many predicted compounds for each disease are actual treatment compounds.
    """
    analysis_results = {}

    # Loop through all predicted compounds and respective diseases
    for disease, predicted_compounds in predictions.items():
        if predicted_compounds:
            # Get just the compound names from predictions (without scores)
            predicted_set = set(compound for compound, _ in predicted_compounds)

            # Find overlap with actual compounds
            overlap = predicted_set.intersection(set(actual_compounds))

            # Calculate percentage
            percentage = (len(overlap) / len(predicted_compounds)) * 100

            # Store results
            analysis_results[disease] = {
                'overlap_count': len(overlap),
                'overlap_percentage': percentage,
                'overlapping_compounds': list(overlap)
            }

            # Print results
            print(f"\n{disease}:")
            print(f"Number of predicted compounds in treatment database: {len(overlap)}/20")
            print(f"Percentage: {percentage:.2f}%")
            print("Overlapping compounds:")
            for compound in overlap:
                print(f"- {compound}")

    return analysis_results

# First get predictions
relation = "DRKG::Treats::Compound:Disease"
predictions = get_top_predictions(
    result.model,
    train_factory,
    relation,
    COV_disease_list
)

# Then get actual treatment compounds
actual_compounds, num_actual = get_compounds_with_treatments(graph)

# Analyze overlap
print(f"\nTotal number of compounds with actual treatment data: {num_actual}")
analysis = analyze_predictions_vs_treatments(predictions, actual_compounds)


Total number of compounds with actual treatment data: 451

Disease::SARS-CoV2 E:
Number of predicted compounds in treatment database: 0/20
Percentage: 0.00%
Overlapping compounds:

Disease::SARS-CoV2 M:
Number of predicted compounds in treatment database: 0/20
Percentage: 0.00%
Overlapping compounds:

Disease::SARS-CoV2 N:
Number of predicted compounds in treatment database: 0/20
Percentage: 0.00%
Overlapping compounds:

Disease::SARS-CoV2 Spike:
Number of predicted compounds in treatment database: 0/20
Percentage: 0.00%
Overlapping compounds:

Disease::SARS-CoV2 nsp1:
Number of predicted compounds in treatment database: 1/20
Percentage: 5.00%
Overlapping compounds:
- Compound::DB14810

Disease::SARS-CoV2 nsp10:
Number of predicted compounds in treatment database: 0/20
Percentage: 0.00%
Overlapping compounds:

Disease::SARS-CoV2 nsp11:
Number of predicted compounds in treatment database: 0/20
Percentage: 0.00%
Overlapping compounds:

Disease::SARS-CoV2 nsp12:
Number of predicted compo

Note that final metrics were obtained using the metrics mentioned in the HGT notebook.