## Notebook for extracting graph from tsv

This first section uses a dataset we extracted with no node features.
We have extracted data in /data/drug_interactions.tsv with the following fields:

- drug_interaction_id: id of drug A
- name: name of drug A
- description: interaction info of drug A with drug B
- drugbank_id: id of drug B

Now we want to extract a graph with nodes as the drugs and edges between each drug_interaction_id-drugbank_id pair

In [None]:
import org.apache.spark.graphx._
import org.apache.spark.rdd.RDD

In [76]:
// read in data 
val lines = sc.textFile("/home/jovyan/work/data/drug_interactions.tsv")
// skip header
val header = lines.first() // extract header
val data = lines.filter(row => row != header) // filter out header

lines: org.apache.spark.rdd.RDD[String] = /home/jovyan/work/data/drug_interactions.tsv MapPartitionsRDD[283] at textFile at <console>:38
header: String = drug_interaction_id	name	description	drugbank_id
data: org.apache.spark.rdd.RDD[String] = MapPartitionsRDD[284] at filter at <console>:41


In [2]:
data.take(2)

res0: Array[String] = Array(DB06605	Apixaban	Apixaban may increase the anticoagulant activities of Lepirudin.	DB00001, DB06695	Dabigatran etexilate	Dabigatran etexilate may increase the anticoagulant activities of Lepirudin.	DB00001)


In [3]:
val lines = data.map(line => line.split("\t"))
lines.take(4)

lines: org.apache.spark.rdd.RDD[Array[String]] = MapPartitionsRDD[3] at map at <console>:27
res1: Array[Array[String]] = Array(Array(DB06605, Apixaban, Apixaban may increase the anticoagulant activities of Lepirudin., DB00001), Array(DB06695, Dabigatran etexilate, Dabigatran etexilate may increase the anticoagulant activities of Lepirudin., DB00001), Array(DB01254, Dasatinib, The risk or severity of bleeding and hemorrhage can be increased when Dasatinib is combined with Lepirudin., DB00001), Array(DB01609, Deferasirox, The risk or severity of gastrointestinal bleeding can be increased when Lepirudin is combined with Deferasirox., DB00001))


In [4]:
// read in all drugs (drug_interactions.tsv doesn't contain all drugs, only those with interactions) to get all drugs
val linesDrugs = sc.textFile("/home/jovyan/work/data/drug_features.csv")
// skip header
val header = lines.first() // extract header
val data = lines.filter(row => row != header) // filter out header
linesDrugs.take(4)

linesDrugs: org.apache.spark.rdd.RDD[String] = /home/jovyan/work/data/drug_features.csv MapPartitionsRDD[5] at textFile at <console>:30
header: Array[String] = Array(DB06605, Apixaban, Apixaban may increase the anticoagulant activities of Lepirudin., DB00001)
data: org.apache.spark.rdd.RDD[Array[String]] = MapPartitionsRDD[6] at filter at <console>:33
res2: Array[String] = Array(drug_id	type	group	interactions, DB00001	biotech	[approved]	[DB06605, DB06695, DB01254, DB01609, DB01586, DB02123, DB02659, DB02691, DB03619, DB04348, DB05990, DB06777, DB08833, DB08834, DB08857, DB11622, DB11789, DB09075, DB09053, DB08935, DB06228, DB06206, DB09070, DB00932, DB00013, DB00163, DB09030, DB01381, DB01181, DB00468, DB00908, DB00675, DB00539, DB00806, DB00686, DB00583, DB00255, DB00269, DB00286, DB0...


In [2]:
var all_drugs = spark.read.options(Map("inferSchema"->"true","delimiter"->"\t", "header"->"true"))
  .csv("/home/jovyan/work/data/drug_features.csv")
all_drugs.show(3)

+-------+-------+----------+--------------------+
|drug_id|   type|     group|        interactions|
+-------+-------+----------+--------------------+
|DB00001|biotech|[approved]|[DB06605, DB06695...|
|DB00002|biotech|[approved]|[DB00012, DB00016...|
|DB00003|biotech|[approved]|                null|
+-------+-------+----------+--------------------+
only showing top 3 rows



all_drugs: org.apache.spark.sql.DataFrame = [drug_id: string, type: string ... 2 more fields]


In [3]:
val distinct_drugs = all_drugs.select("drug_id").distinct()
distinct_drugs.count()

distinct_drugs: org.apache.spark.sql.Dataset[org.apache.spark.sql.Row] = [drug_id: string]
res2: Long = 13580


In [4]:
var drugs = distinct_drugs.select("drug_id").rdd
                    .map(x => x(0).toString) // prevent Array type
drugs.take(4)

drugs: org.apache.spark.rdd.RDD[String] = MapPartitionsRDD[32] at map at <console>:27
res3: Array[String] = Array(DB00194, DB00741, DB00846, DB00912)


In [5]:
drugs.count()

res4: Long = 13580


We produce a drug -> node map containing the drug id and node id (which will be used as input to our graph)

In [6]:
val drug2NodeMap = drugs.zipWithIndex()

drug2NodeMap: org.apache.spark.rdd.RDD[(String, Long)] = ZippedWithIndexRDD[33] at zipWithIndex at <console>:26


In [7]:
drug2NodeMap.take(5)

res5: Array[(String, Long)] = Array((DB00194,0), (DB00741,1), (DB00846,2), (DB00912,3), (DB01357,4))


In [8]:
drug2NodeMap.count()

res6: Long = 13580


In [9]:
val df_drug_node_map = spark.createDataFrame(drug2NodeMap).toDF("drug_id", "node_id")

df_drug_node_map: org.apache.spark.sql.DataFrame = [drug_id: string, node_id: bigint]


In [10]:
// save map for later use
df_drug_node_map
   .repartition(1)
   .write.format("com.databricks.spark.csv")
   .option("header", "true")
   .save("/home/jovyan/work/data/drug2NodeMap.csv")

Now we want to map drugs in out drug interactions dataset to these node IDs

In [11]:
var interactions_df = spark.read.options(Map("inferSchema"->"true","delimiter"->"\t", "header"->"true"))
  .csv("/home/jovyan/work/data/drug_interactions.tsv")
interactions_df.show(5)

+-------------------+--------------------+--------------------+-----------+
|drug_interaction_id|                name|         description|drugbank_id|
+-------------------+--------------------+--------------------+-----------+
|            DB06605|            Apixaban|Apixaban may incr...|    DB00001|
|            DB06695|Dabigatran etexilate|Dabigatran etexil...|    DB00001|
|            DB01254|           Dasatinib|The risk or sever...|    DB00001|
|            DB01609|         Deferasirox|The risk or sever...|    DB00001|
|            DB01586|Ursodeoxycholic acid|The risk or sever...|    DB00001|
+-------------------+--------------------+--------------------+-----------+
only showing top 5 rows



interactions_df: org.apache.spark.sql.DataFrame = [drug_interaction_id: string, name: string ... 2 more fields]


In [12]:
var interactions_rdd = interactions_df.rdd.zipWithIndex()
interactions_rdd.take(1)

interactions_rdd: org.apache.spark.rdd.RDD[(org.apache.spark.sql.Row, Long)] = ZippedWithIndexRDD[58] at zipWithIndex at <console>:26
res9: Array[(org.apache.spark.sql.Row, Long)] = Array(([DB06605,Apixaban,Apixaban may increase the anticoagulant activities of Lepirudin.,DB00001],0))


In [13]:
val interactions_rdd2 = interactions_rdd.map(x => (x._1(0).toString, x._1(1).toString, x._1(2).toString, x._1(3).toString, x._2))
interactions_rdd2.take(1)

interactions_rdd2: org.apache.spark.rdd.RDD[(String, String, String, String, Long)] = MapPartitionsRDD[59] at map at <console>:26
res10: Array[(String, String, String, String, Long)] = Array((DB06605,Apixaban,Apixaban may increase the anticoagulant activities of Lepirudin.,DB00001,0))


In [14]:
interactions_df = spark.createDataFrame(interactions_rdd2).toDF("drug_interaction_id", "name", "description", "drugbank_id", "row_number")
interactions_df.show(3)

+-------------------+--------------------+--------------------+-----------+----------+
|drug_interaction_id|                name|         description|drugbank_id|row_number|
+-------------------+--------------------+--------------------+-----------+----------+
|            DB06605|            Apixaban|Apixaban may incr...|    DB00001|         0|
|            DB06695|Dabigatran etexilate|Dabigatran etexil...|    DB00001|         1|
|            DB01254|           Dasatinib|The risk or sever...|    DB00001|         2|
+-------------------+--------------------+--------------------+-----------+----------+
only showing top 3 rows



interactions_df: org.apache.spark.sql.DataFrame = [drug_interaction_id: string, name: string ... 3 more fields]


In [15]:
val drugAs = interactions_df.select("drug_interaction_id", "row_number")
drugAs.show(2)

+-------------------+----------+
|drug_interaction_id|row_number|
+-------------------+----------+
|            DB06605|         0|
|            DB06695|         1|
+-------------------+----------+
only showing top 2 rows



drugAs: org.apache.spark.sql.DataFrame = [drug_interaction_id: string, row_number: bigint]


In [16]:
drugAs.count()

res13: Long = 2668185


In [17]:
var drugAWithNodeIDs = drugAs.join(df_drug_node_map, $"drug_interaction_id" === $"drug_id", "left")
drugAWithNodeIDs.count()

drugAWithNodeIDs: org.apache.spark.sql.DataFrame = [drug_interaction_id: string, row_number: bigint ... 2 more fields]
res14: Long = 2668185


In [18]:
drugAWithNodeIDs = drugAWithNodeIDs.withColumnRenamed("drug_interaction_id","drug_A_id")
           .withColumnRenamed("node_id","drug_A_node_id").drop("drug_id")
drugAWithNodeIDs.show(3)

+---------+----------+--------------+
|drug_A_id|row_number|drug_A_node_id|
+---------+----------+--------------+
|  DB00194|    315697|             0|
|  DB00741|      1855|             1|
|  DB00741|     13266|             1|
+---------+----------+--------------+
only showing top 3 rows



drugAWithNodeIDs: org.apache.spark.sql.DataFrame = [drug_A_id: string, row_number: bigint ... 1 more field]


In [19]:
val drugBs = interactions_df.select("drugbank_id", "row_number")
var drugBWithNodeIDs = drugBs.join(df_drug_node_map, $"drugbank_id" === $"drug_id", "left")
                        .withColumnRenamed("node_id","drug_B_node_id").drop("drug_id")
drugBWithNodeIDs.count()

drugBs: org.apache.spark.sql.DataFrame = [drugbank_id: string, row_number: bigint]
drugBWithNodeIDs: org.apache.spark.sql.DataFrame = [drugbank_id: string, row_number: bigint ... 1 more field]
res16: Long = 2668185


In [20]:
drugBWithNodeIDs.show(1)

+-----------+----------+--------------+
|drugbank_id|row_number|drug_B_node_id|
+-----------+----------+--------------+
|    DB00194|     77879|             0|
+-----------+----------+--------------+
only showing top 1 row



In [21]:
// join drugAWithNodeIDs and drugBWithNodeIDs to get the edges
var edgesData = drugAWithNodeIDs.join(drugBWithNodeIDs, Seq("row_number"), "left")
edgesData.count()

edgesData: org.apache.spark.sql.DataFrame = [row_number: bigint, drug_A_id: string ... 3 more fields]
res18: Long = 2668185


In [25]:
edgesData = edgesData.drop("row_number")
edgesData.show(3)

+---------+--------------+-----------+--------------+
|drug_A_id|drug_A_node_id|drugbank_id|drug_B_node_id|
+---------+--------------+-----------+--------------+
|  DB09030|         10315|    DB00001|          9529|
|  DB00468|          4259|    DB00001|          9529|
|  DB00056|            61|    DB00001|          9529|
+---------+--------------+-----------+--------------+
only showing top 3 rows



In [26]:
// save edges data for later use
edgesData
   .repartition(1)
   .write.format("com.databricks.spark.csv")
   .option("header", "true")
   .save("/home/jovyan/work/data/edges.csv")

## Graph extraction with Features

To use GCNs for link prediction, we need to extract some features for each node (drug). 
For this part, we use a dataset /data/drug_features.csv with the following fields:

    - drug_id: id of drug 
    - type: name of drug A
    - group: interaction info of drug A with drug B
    - target_info: list of info extracted directly from xml, needed for extracting target gene name(s)
    - enzyme_info: list of info extracted directly from xml, needed for extracting enzyme gene name(s)
    - interactions: list of all drug ids this drug interacts with

Now we want to extract a graph with nodes as the drugs containing features, and edges between each drug_interaction_id-drugbank_id pair

In [27]:
import org.apache.spark.graphx._
import org.apache.spark.rdd.RDD

import org.apache.spark.graphx._
import org.apache.spark.rdd.RDD


In [28]:
var df = spark.read.options(Map("inferSchema"->"true","delimiter"->"\t", "header"->"true"))
  .csv("/home/jovyan/work/data/drug_features.csv")
df.show(8)

+-------+--------------+--------------------+--------------------+
|drug_id|          type|               group|        interactions|
+-------+--------------+--------------------+--------------------+
|DB00001|       biotech|          [approved]|[DB06605, DB06695...|
|DB00002|       biotech|          [approved]|[DB00012, DB00016...|
|DB00003|       biotech|          [approved]|                null|
|DB00004|       biotech|[approved, invest...|[DB00012, DB00016...|
|DB00005|       biotech|[approved, invest...|[DB01281, DB00026...|
|DB00006|small molecule|[approved, invest...|[DB06605, DB06695...|
|DB00007|small molecule|[approved, invest...|[DB09066, DB09083...|
|DB00008|       biotech|[approved, invest...|[DB06643, DB00005...|
+-------+--------------+--------------------+--------------------+
only showing top 8 rows



df: org.apache.spark.sql.DataFrame = [drug_id: string, type: string ... 2 more fields]


In [29]:
import org.apache.spark.ml.feature.OneHotEncoder
import org.apache.spark.ml.feature.StringIndexer
import org.apache.spark.ml.feature.VectorIndexer
import org.apache.spark.sql.SparkSession

import org.apache.spark.ml.feature.OneHotEncoder
import org.apache.spark.ml.feature.StringIndexer
import org.apache.spark.ml.feature.VectorIndexer
import org.apache.spark.sql.SparkSession


#### 1. type feature

In [40]:
// let's see what values we have here
df.select("type").distinct().show()

+--------------+
|          type|
+--------------+
|          null|
|       biotech|
|small molecule|
+--------------+



In [30]:
// add is_biotech (0/1) and is_small_molecule columns
df = df.withColumn("is_biotech", col("type") === "biotech")
df = df.withColumn("is_small_molecule", col("type") === "small molecule")
df.show(8)

+-------+--------------+--------------------+--------------------+----------+-----------------+
|drug_id|          type|               group|        interactions|is_biotech|is_small_molecule|
+-------+--------------+--------------------+--------------------+----------+-----------------+
|DB00001|       biotech|          [approved]|[DB06605, DB06695...|      true|            false|
|DB00002|       biotech|          [approved]|[DB00012, DB00016...|      true|            false|
|DB00003|       biotech|          [approved]|                null|      true|            false|
|DB00004|       biotech|[approved, invest...|[DB00012, DB00016...|      true|            false|
|DB00005|       biotech|[approved, invest...|[DB01281, DB00026...|      true|            false|
|DB00006|small molecule|[approved, invest...|[DB06605, DB06695...|     false|             true|
|DB00007|small molecule|[approved, invest...|[DB09066, DB09083...|     false|             true|
|DB00008|       biotech|[approved, inves

df: org.apache.spark.sql.DataFrame = [drug_id: string, type: string ... 4 more fields]
df: org.apache.spark.sql.DataFrame = [drug_id: string, type: string ... 4 more fields]


#### 2. group feature
This feature is a list of values from: ['withdrawn', 'illicit', 'vet_approved', 'investigational', 'approved', 'experimental', 'nutraceutical'] (extracted using pandas)

In [31]:
df = df.withColumn("withdrawn", col("group").contains("withdrawn"))
df = df.withColumn("illicit", col("group").contains("illicit"))
df = df.withColumn("vet_approved", col("group").contains("vet_approved"))
df = df.withColumn("investigational", col("group").contains("investigational'"))
df = df.withColumn("approved", col("group").contains("approved"))
df = df.withColumn("experimental", col("group").contains("experimental"))
df = df.withColumn("nutraceutical", col("group").contains("nutraceutical"))
df.show(8)

+-------+--------------+--------------------+--------------------+----------+-----------------+---------+-------+------------+---------------+--------+------------+-------------+
|drug_id|          type|               group|        interactions|is_biotech|is_small_molecule|withdrawn|illicit|vet_approved|investigational|approved|experimental|nutraceutical|
+-------+--------------+--------------------+--------------------+----------+-----------------+---------+-------+------------+---------------+--------+------------+-------------+
|DB00001|       biotech|          [approved]|[DB06605, DB06695...|      true|            false|    false|  false|       false|          false|    true|       false|        false|
|DB00002|       biotech|          [approved]|[DB00012, DB00016...|      true|            false|    false|  false|       false|          false|    true|       false|        false|
|DB00003|       biotech|          [approved]|                null|      true|            false|    false|

df: org.apache.spark.sql.DataFrame = [drug_id: string, type: string ... 11 more fields]
df: org.apache.spark.sql.DataFrame = [drug_id: string, type: string ... 11 more fields]
df: org.apache.spark.sql.DataFrame = [drug_id: string, type: string ... 11 more fields]
df: org.apache.spark.sql.DataFrame = [drug_id: string, type: string ... 11 more fields]
df: org.apache.spark.sql.DataFrame = [drug_id: string, type: string ... 11 more fields]
df: org.apache.spark.sql.DataFrame = [drug_id: string, type: string ... 11 more fields]
df: org.apache.spark.sql.DataFrame = [drug_id: string, type: string ... 11 more fields]


#### Create Nodes Dataset

In [32]:
// val df2 = df.withColumn("interactions", explode(array(col("interactions"))))
// df2.show(8)

+-------+--------------+--------------------+--------------------+----------+-----------------+---------+-------+------------+---------------+--------+------------+-------------+
|drug_id|          type|               group|        interactions|is_biotech|is_small_molecule|withdrawn|illicit|vet_approved|investigational|approved|experimental|nutraceutical|
+-------+--------------+--------------------+--------------------+----------+-----------------+---------+-------+------------+---------------+--------+------------+-------------+
|DB00001|       biotech|          [approved]|[DB06605, DB06695...|      true|            false|    false|  false|       false|          false|    true|       false|        false|
|DB00002|       biotech|          [approved]|[DB00012, DB00016...|      true|            false|    false|  false|       false|          false|    true|       false|        false|
|DB00003|       biotech|          [approved]|                null|      true|            false|    false|

df2: org.apache.spark.sql.DataFrame = [drug_id: string, type: string ... 11 more fields]


In [33]:
val df_drug_node_map = spark.read.options(Map("delimiter"->",", "header"->"true"))
  .csv("/home/jovyan/work/data/drug2NodeMap.csv")
df_drug_node_map.show(7)

+-------+-------+
|drug_id|node_id|
+-------+-------+
|DB00194|      0|
|DB00741|      1|
|DB00846|      2|
|DB00912|      3|
|DB01357|      4|
|DB01460|      5|
|DB01979|      6|
+-------+-------+
only showing top 7 rows



df_drug_node_map: org.apache.spark.sql.DataFrame = [drug_id: string, node_id: string]


In [34]:
df.count()

res28: Long = 13608


In [35]:
val node_features = df.join(df_drug_node_map, Seq("drug_id"), "left")
node_features.show(5)

+-------+-------+--------------------+--------------------+----------+-----------------+---------+-------+------------+---------------+--------+------------+-------------+-------+
|drug_id|   type|               group|        interactions|is_biotech|is_small_molecule|withdrawn|illicit|vet_approved|investigational|approved|experimental|nutraceutical|node_id|
+-------+-------+--------------------+--------------------+----------+-----------------+---------+-------+------------+---------------+--------+------------+-------------+-------+
|DB00001|biotech|          [approved]|[DB06605, DB06695...|      true|            false|    false|  false|       false|          false|    true|       false|        false|   9529|
|DB00002|biotech|          [approved]|[DB00012, DB00016...|      true|            false|    false|  false|       false|          false|    true|       false|        false|   3685|
|DB00003|biotech|          [approved]|                null|      true|            false|    false|  

node_features: org.apache.spark.sql.DataFrame = [drug_id: string, type: string ... 12 more fields]


In [36]:
node_features.count()

res30: Long = 13608


In [37]:
node_features
   .repartition(1)
   .write.format("com.databricks.spark.csv")
   .option("header", "true")
   .save("/home/jovyan/work/data/node_features.csv")

### Target and Enzyme gene names
Now we will join our node features with our dataset containing gene names for target and enzyme

In [4]:
var node_features = spark.read.options(Map("inferSchema"->"true","delimiter"->",", "header"->"true"))
  .csv("/home/jovyan/work/data/node_features.csv")
node_features.show(8)

+-------+--------------+--------------------+--------------------+----------+-----------------+---------+-------+------------+---------------+--------+------------+-------------+-------+
|drug_id|          type|               group|        interactions|is_biotech|is_small_molecule|withdrawn|illicit|vet_approved|investigational|approved|experimental|nutraceutical|node_id|
+-------+--------------+--------------------+--------------------+----------+-----------------+---------+-------+------------+---------------+--------+------------+-------------+-------+
|DB00001|       biotech|          [approved]|[DB06605, DB06695...|      true|            false|    false|  false|       false|          false|    true|       false|        false|   3022|
|DB00002|       biotech|          [approved]|[DB00012, DB00016...|      true|            false|    false|  false|       false|          false|    true|       false|        false|   1436|
|DB00003|       biotech|          [approved]|                null

node_features: org.apache.spark.sql.DataFrame = [drug_id: string, type: string ... 12 more fields]


In [14]:
var df_gname = spark.read.options(Map("inferSchema"->"true","delimiter"->"\t", "header"->"true"))
  .csv("/home/jovyan/work/data/drug_gname.tsv")
df_gname.show(8)

+-----------+--------------------+------------+
|drugbank_id|        target-gname|enzyme-gname|
+-----------+--------------------+------------+
|    DB00001|                  F2|        null|
|    DB00002|EGFR,FCGR3B,C1QA,...|        null|
|    DB00003|                null|        null|
|    DB00004|   IL2RA,IL2RB,IL2RG|        null|
|    DB00005|TNF,FCGR1A,FCGR2A...|        null|
|    DB00006|                  F2|         MPO|
|    DB00007|               GNRHR|        null|
|    DB00008|       IFNAR2,IFNAR1|      CYP1A2|
+-----------+--------------------+------------+
only showing top 8 rows



df_gname: org.apache.spark.sql.DataFrame = [drugbank_id: string, target-gname: string ... 1 more field]


target-gname and enzyme-gname are a (possible) list of gene names, we will create one field per gene name which contains a 0 or 1 value per drug row depending if it appears for target, and then same for enzymes.

In [15]:
df_gname = df_gname.withColumn("target-gname", split($"target-gname", ","))
df_gname = df_gname.withColumn("enzyme-gname", split($"enzyme-gname", ","))
df_gname.show(10)

+-----------+--------------------+------------+
|drugbank_id|        target-gname|enzyme-gname|
+-----------+--------------------+------------+
|    DB00001|                [F2]|        null|
|    DB00002|[EGFR, FCGR3B, C1...|        null|
|    DB00003|                null|        null|
|    DB00004|[IL2RA, IL2RB, IL...|        null|
|    DB00005|[TNF, FCGR1A, FCG...|        null|
|    DB00006|                [F2]|       [MPO]|
|    DB00007|             [GNRHR]|        null|
|    DB00008|    [IFNAR2, IFNAR1]|    [CYP1A2]|
|    DB00009|[PLG, FGA, PLAUR,...|        null|
|    DB00010|             [GHRHR]|        null|
+-----------+--------------------+------------+
only showing top 10 rows



df_gname: org.apache.spark.sql.DataFrame = [drugbank_id: string, target-gname: array<string> ... 1 more field]
df_gname: org.apache.spark.sql.DataFrame = [drugbank_id: string, target-gname: array<string> ... 1 more field]


In [None]:
#########

In [233]:
var rdd  = df.select("interactions").rdd.map(r => r(0))
// var df3 = spark.createDataFrame(rdd).toDF()
// df3.show(5)

rdd: org.apache.spark.rdd.RDD[Any] = MapPartitionsRDD[416] at map at <console>:46


In [211]:
df.show(5)

+-------+-------+--------------------+--------------------+
|drug_id|   type|               group|        interactions|
+-------+-------+--------------------+--------------------+
|DB00001|biotech|          [approved]|[DB06605, DB06695...|
|DB00002|biotech|          [approved]|[DB00012, DB00016...|
|DB00003|biotech|          [approved]|                null|
|DB00004|biotech|[approved, invest...|[DB00012, DB00016...|
|DB00005|biotech|[approved, invest...|[DB01281, DB00026...|
+-------+-------+--------------------+--------------------+
only showing top 5 rows



In [256]:


val r = udf((input: StructField) => input.split(",").map(_.toString))


<console>: 44: error: value split is not a member of org.apache.spark.sql.types.StructField

In [252]:
df.withColumn("test", r("interactions")).show(5)

<console>: 50: error: type mismatch;

In [263]:
df.withColumn(
        "test",
        split(col("interactions"), ",\\*").cast("Array<String>").alias("ev")
 ).show(5)

+-------+-------+--------------------+--------------------+--------------------+
|drug_id|   type|               group|        interactions|                test|
+-------+-------+--------------------+--------------------+--------------------+
|DB00001|biotech|          [approved]|[DB06605, DB06695...|[[DB06605, DB0669...|
|DB00002|biotech|          [approved]|[DB00012, DB00016...|[[DB00012, DB0001...|
|DB00003|biotech|          [approved]|                null|                null|
|DB00004|biotech|[approved, invest...|[DB00012, DB00016...|[[DB00012, DB0001...|
|DB00005|biotech|[approved, invest...|[DB01281, DB00026...|[[DB01281, DB0002...|
+-------+-------+--------------------+--------------------+--------------------+
only showing top 5 rows



In [257]:
df.printSchema

root
 |-- drug_id: string (nullable = true)
 |-- type: string (nullable = true)
 |-- group: string (nullable = true)
 |-- interactions: string (nullable = true)



In [None]:
df = df.withColumn("interactions", array_to_string_udf(df["interactions_as_arr"]))


In [214]:
df.select("interactions").rdd.take(5)

res165: Array[org.apache.spark.sql.Row] = Array([[DB06605, DB06695, DB01254, DB01609, DB01586, DB02123, DB02659, DB02691, DB03619, DB04348, DB05990, DB06777, DB08833, DB08834, DB08857, DB11622, DB11789, DB09075, DB09053, DB08935, DB06228, DB06206, DB09070, DB00932, DB00013, DB00163, DB09030, DB01381, DB01181, DB00468, DB00908, DB00675, DB00539, DB00806, DB00686, DB00583, DB00255, DB00269, DB00286, DB00783, DB00977, DB01357, DB04573, DB04574, DB04575, DB07931, DB09317, DB09318, DB09369, DB09381, DB11478, DB11674, DB12487, DB13143, DB13386, DB13418, DB13952, DB13953, DB13954, DB13956, DB15334, DB15335, DB09211, DB00159, DB00244, DB00328, DB00461, DB00465, DB00469, DB00482, DB00500, DB00533, DB00554, DB00573, DB00580, DB00586, DB00605, DB00712, DB00749, DB00784, DB00788, DB00795, DB00812, ...


In [183]:
df.withColumn("list_interactions_id", List(col("interactions"))).show(2)

<console>: 45: error: type mismatch;

In [231]:
df.printSchema()

root
 |-- drug_id: string (nullable = true)
 |-- type: string (nullable = true)
 |-- group: string (nullable = true)
 |-- interactions: string (nullable = true)



In [266]:
var rdd = df.select("interactions").rdd

rdd: org.apache.spark.rdd.RDD[org.apache.spark.sql.Row] = MapPartitionsRDD[448] at rdd at <console>:46


In [273]:
var t = rdd.map(x => List(x))

t: org.apache.spark.rdd.RDD[List[org.apache.spark.sql.Row]] = MapPartitionsRDD[454] at map at <console>:46


In [280]:
var index = t.zipWithIndex()

index: org.apache.spark.rdd.RDD[(List[org.apache.spark.sql.Row], Long)] = ZippedWithIndexRDD[463] at zipWithIndex at <console>:46


In [286]:
rdd.count()

res204: Long = 13608


In [277]:
var myIndex = 5
var values = (t.zipWithIndex()
            .filter(_._2==myIndex))

myIndex: Int = 5
values: org.apache.spark.rdd.RDD[(List[org.apache.spark.sql.Row], Long)] = MapPartitionsRDD[462] at filter at <console>:50


In [279]:
values.take(1)

res198: Array[(List[org.apache.spark.sql.Row], Long)] = Array((List([[DB06605, DB06695, DB01254, DB01609, DB01586, DB02123, DB02659, DB02691, DB03619, DB04348, DB05990, DB06777, DB08833, DB08834, DB08857, DB11622, DB11789, DB09075, DB09053, DB08935, DB06228, DB06206, DB09070, DB00932, DB00013, DB00163, DB09030, DB01381, DB01181, DB00468, DB00908, DB00675, DB00539, DB00806, DB00686, DB00583, DB00255, DB00269, DB00286, DB00783, DB00977, DB01357, DB04573, DB04574, DB04575, DB07931, DB09317, DB09318, DB09369, DB09381, DB11478, DB11674, DB12487, DB13143, DB13386, DB13418, DB13952, DB13953, DB13954, DB13956, DB15334, DB15335, DB09211, DB00159, DB00244, DB00328, DB00461, DB00465, DB00469, DB00482, DB00500, DB00533, DB00554, DB00573, DB00580, DB00586, DB00605, DB00712, DB00749, DB00784, DB00788...


In [None]:
// experiment code for one hot encoding/categorising features 

In [41]:
val df = spark.createDataFrame(Seq(
  ("a", 1.0),
  ("b", 0.0),
  ("a", 1.0),
  ("a", 2.0),
  ("c", 1.0),
  ("d", 0.0)
)).toDF("category1", "categoryIndex2")

df.show()

+---------+--------------+
|category1|categoryIndex2|
+---------+--------------+
|        a|           1.0|
|        b|           0.0|
|        a|           1.0|
|        a|           2.0|
|        c|           1.0|
|        d|           0.0|
+---------+--------------+



df: org.apache.spark.sql.DataFrame = [category1: string, categoryIndex2: double]


In [44]:
// need to convert string fields to integer before can one hot encode
val indexedDf = new StringIndexer().setInputCol("category1").setOutputCol("categoryIndex1").fit(df).transform(df);
indexedDf.show();


+---------+--------------+--------------+
|category1|categoryIndex2|categoryIndex1|
+---------+--------------+--------------+
|        a|           1.0|           0.0|
|        b|           0.0|           1.0|
|        a|           1.0|           0.0|
|        a|           2.0|           0.0|
|        c|           1.0|           2.0|
|        d|           0.0|           3.0|
+---------+--------------+--------------+



indexedDf: org.apache.spark.sql.DataFrame = [category1: string, categoryIndex2: double ... 1 more field]


In [45]:
val encoder = new OneHotEncoder()
    .setDropLast(false) // don't dtop last column (spark does this for sparseness)
     .setInputCols(Array("categoryIndex1"))
    .setOutputCols(Array("categoryVec1"))

val model = encoder.fit(indexedDf)
val encoded = model.transform(indexedDf)
encoded.show()


+---------+--------------+--------------+-------------+
|category1|categoryIndex2|categoryIndex1| categoryVec1|
+---------+--------------+--------------+-------------+
|        a|           1.0|           0.0|(4,[0],[1.0])|
|        b|           0.0|           1.0|(4,[1],[1.0])|
|        a|           1.0|           0.0|(4,[0],[1.0])|
|        a|           2.0|           0.0|(4,[0],[1.0])|
|        c|           1.0|           2.0|(4,[2],[1.0])|
|        d|           0.0|           3.0|(4,[3],[1.0])|
+---------+--------------+--------------+-------------+



encoder: org.apache.spark.ml.feature.OneHotEncoder = oneHotEncoder_b4f946501dc4
model: org.apache.spark.ml.feature.OneHotEncoderModel = OneHotEncoderModel: uid=oneHotEncoder_b4f946501dc4, dropLast=false, handleInvalid=error, numInputCols=1, numOutputCols=1
encoded: org.apache.spark.sql.DataFrame = [category1: string, categoryIndex2: double ... 2 more fields]


In [None]:
// now we want to figure out how to 

In [104]:
// how to categorise strings
var sampleData = Seq(
("India", 30, 2, 1991),
("Us", 30, 1, 1987),
("India", 30, 2, 1992),
("China", 30, 3, 1993),
("India", 30, 2, 1980),
("Us", 30, 1, 1990),
("India", 30, 2, 1982),
("China", 30, 3, 1994)
);

val spark = SparkSession
.builder()
.appName("test")
.config("spark.master", "local")
.getOrCreate();

val sampleDataDf = spark.createDataFrame(sampleData).toDF("country", "points", "contestant", "year");
val sampleIndexedDf = new StringIndexer().setInputCol("country").setOutputCol("country_index").fit(sampleDataDf).transform(sampleDataDf);
sampleDataDf.show();
sampleIndexedDf.show();

+-------+------+----------+----+
|country|points|contestant|year|
+-------+------+----------+----+
|  India|    30|         2|1991|
|     Us|    30|         1|1987|
|  India|    30|         2|1992|
|  China|    30|         3|1993|
|  India|    30|         2|1980|
|     Us|    30|         1|1990|
|  India|    30|         2|1982|
|  China|    30|         3|1994|
+-------+------+----------+----+

+-------+------+----------+----+-------------+
|country|points|contestant|year|country_index|
+-------+------+----------+----+-------------+
|  India|    30|         2|1991|          0.0|
|     Us|    30|         1|1987|          2.0|
|  India|    30|         2|1992|          0.0|
|  China|    30|         3|1993|          1.0|
|  India|    30|         2|1980|          0.0|
|     Us|    30|         1|1990|          2.0|
|  India|    30|         2|1982|          0.0|
|  China|    30|         3|1994|          1.0|
+-------+------+----------+----+-------------+



sampleData: Seq[(String, Int, Int, Int)] = List((India,30,2,1991), (Us,30,1,1987), (India,30,2,1992), (China,30,3,1993), (India,30,2,1980), (Us,30,1,1990), (India,30,2,1982), (China,30,3,1994))
spark: org.apache.spark.sql.SparkSession = org.apache.spark.sql.SparkSession@2e2996fe
sampleDataDf: org.apache.spark.sql.DataFrame = [country: string, points: int ... 2 more fields]
sampleIndexedDf: org.apache.spark.sql.DataFrame = [country: string, points: int ... 3 more fields]


In [105]:
var sampleData2 = Seq(
("Us", 15),
("India", 3),
("China", 7)
);
val sampleData2Df = spark.createDataFrame(sampleData2).toDF("country_name", "node_id");

sampleData2: Seq[(String, Int)] = List((Us,15), (India,3), (China,7))
sampleData2Df: org.apache.spark.sql.DataFrame = [country_name: string, node_id: int]


In [106]:
sampleData2Df.show()

+------------+-------+
|country_name|node_id|
+------------+-------+
|          Us|     15|
|       India|      3|
|       China|      7|
+------------+-------+



In [108]:
val join_df = sampleDataDf.join(sampleData2Df, $"country" === $"country_name", "right")
join_df.show()

+-------+------+----------+----+------------+-------+
|country|points|contestant|year|country_name|node_id|
+-------+------+----------+----+------------+-------+
|     Us|    30|         1|1990|          Us|     15|
|     Us|    30|         1|1987|          Us|     15|
|  India|    30|         2|1982|       India|      3|
|  India|    30|         2|1980|       India|      3|
|  India|    30|         2|1992|       India|      3|
|  India|    30|         2|1991|       India|      3|
|  China|    30|         3|1994|       China|      7|
|  China|    30|         3|1993|       China|      7|
+-------+------+----------+----+------------+-------+



join_df: org.apache.spark.sql.DataFrame = [country: string, points: int ... 4 more fields]


In [18]:
// define sample data for one hot encode
var sampleData = Seq(
("Scandinavia", List("Norway", "Denmark", "Sweden")),
("South Asia", List("India", "Pakistan", "Bangladesh")),
);

val spark = SparkSession
.builder()
.appName("test")
.config("spark.master", "local")
.getOrCreate();

val sampleDataDf = spark.createDataFrame(sampleData).toDF("region", "countries");
sampleDataDf.show()

+-----------+--------------------+
|     region|           countries|
+-----------+--------------------+
|Scandinavia|[Norway, Denmark,...|
| South Asia|[India, Pakistan,...|
+-----------+--------------------+



sampleData: Seq[(String, List[String])] = List((Scandinavia,List(Norway, Denmark, Sweden)), (South Asia,List(India, Pakistan, Bangladesh)))
spark: org.apache.spark.sql.SparkSession = org.apache.spark.sql.SparkSession@2e089964
sampleDataDf: org.apache.spark.sql.DataFrame = [region: string, countries: array<string>]


In [19]:
// convert fields to categorical
val sampleIndexedDf = new VectorIndexer().setInputCol("countries").setOutputCol("countries_categorical").fit(sampleDataDf).transform(sampleDataDf);
sampleDataDf.show();
sampleIndexedDf.show();

java.lang.IllegalArgumentException:  requirement failed: Column countries must be of type struct<type:tinyint,size:int,indices:array<int>,values:array<double>> but was actually array<string>.