In [3]:
from pyspark.sql import SparkSession
from pyspark.sql.functions import when, count

# Create a Spark session
spark = SparkSession.builder.appName("drug_associations").getOrCreate()

# # Read the CSV file into a DataFrame
nodes = spark.read.csv('nodes_test_small.csv', header=True) # store nodes csv
edges = spark.read.csv('edges_test_small.csv', header=True) # store nodes csv

# # Show the DataFrame
nodes.take(4)

[Row(id='Anatomy::UBERON:0000002', name='uterine cervix', kind='Anatomy'),
 Row(id='Disease::DOID:0050156', name='idiopathic pulmonary fibrosis', kind='Disease'),
 Row(id='Gene::1', name='A1BG', kind='Gene'),
 Row(id='Compound::DB00014', name='Goserelin', kind='Compound')]

In [5]:
edges.take(4)

[Row(source='Disease::DOID:0050156', metaedge='DdG', target='Gene::1'),
 Row(source='Compound::DB00035', metaedge='CuG', target='Gene::1'),
 Row(source='Compound::DB00035', metaedge='CrC', target='Compound::DB00014'),
 Row(source='Compound::DB00014', metaedge='CtD', target='Disease::DOID:0050156')]

In [6]:
# Filter rows where source is a Compound and target is either Gene or Disease
filtered_edges = edges.filter(
    (edges.source.startswith("Compound::")) &
    ((edges.target.startswith("Gene::")) | (edges.target.startswith("Disease::")))
)

# Compute number of genes and diseases associated with each drug
result = filtered_edges.groupBy("source").agg(
    count(when(filtered_edges.target.startswith("Gene::"), 1)).alias("Number_of_Genes"),
    count(when(filtered_edges.target.startswith("Disease::"), 1)).alias("Number_of_Diseases")
)

# Display the result
result.show()

+-----------------+---------------+------------------+
|           source|Number_of_Genes|Number_of_Diseases|
+-----------------+---------------+------------------+
|Compound::DB00035|              1|                 0|
|Compound::DB00014|              1|                 1|
+-----------------+---------------+------------------+

