In [None]:
import os
import uuid
from array import array
from pyspark.sql import DataFrame
import pyspark.sql.functions as f
from pyspark.sql.types import StringType,BooleanType,StructType,StructField,IntegerType, DecimalType
from pyspark.sql.functions import lit
from decimal import Decimal

f_uuid = f.udf(lambda: str(uuid.uuid4()), StringType())


In [None]:
cosmosEndpoint = "https://xxxxxx.documents.azure.com:443/"
cosmosMasterKey = "*******"
cosmosDatabaseName = "*******"
cosmosContainerName = "*******"

cfg = {
  "spark.cosmos.accountEndpoint" : cosmosEndpoint,
  "spark.cosmos.accountKey" : cosmosMasterKey,
  "spark.cosmos.database" : cosmosDatabaseName,
  "spark.cosmos.container" : cosmosContainerName,
}
# Configure Catalog Api to be used
spark.conf.set("spark.sql.catalog.cosmosCatalog", "com.azure.cosmos.spark.CosmosCatalog")
spark.conf.set("spark.sql.catalog.cosmosCatalog.spark.cosmos.accountEndpoint", cosmosEndpoint)
spark.conf.set("spark.sql.catalog.cosmosCatalog.spark.cosmos.accountKey", cosmosMasterKey)
spark.conf.set("spark.cosmos.throughputControl.enabled",True)
spark.conf.set("spark.cosmos.throughputControl.targetThroughput",20000)

def write_to_cosmos_graph(df: DataFrame):
        
    df.write\
   .format("cosmos.oltp")\
   .options(**cfg)\
    .mode("Append")\
   .save()

In [None]:
def create_vertex_df(
    df: DataFrame,
    vertex_properties_col_name: list, partition_col: str,
    vertex_label: str,id: str
):
  columns = [id, partition_col,"label"]
  columns.extend(['nvl2({x}, array(named_struct("id", uuid(), "_value", {x})), NULL) AS {x}'.format(x=x) for x in vertex_properties_col_name])
  if "label" in df.columns:
    df=df.withColumn("label",df[vertex_label])
  else:
    df=df.withColumn("label",f.lit(vertex_label))
 
  return df.selectExpr(*columns).withColumnRenamed(id,"id")
  

In [None]:
def create_edge_df(srcdf: DataFrame, destdf: DataFrame, label: str, partition_col: str, 
                   vertexidcol: str, sinkcol: str, sinklabel: str, vertexlabel: str, sinkpartitioncol: str,srcjoincol: str,destjoincol: str,isedgetable: bool):
  if(isedgetable):
      #we have edge table
      if(sinklabel in srcdf.columns):
        srcdf=srcdf.withColumn("_sinkLabel",srcdf[sinklabel])
      else:
        srcdf=srcdf.withColumn("_sinkLabel",f.lit(sinklabel))
      if(vertexlabel in srcdf.columns):
        srcdf=srcdf.withColumn("_vertexLabel",srcdf[vertexlabel])
      else:
        srcdf=srcdf.withColumn("_vertexLabel",f.lit(vertexlabel))
      srcdf=srcdf.selectExpr("_sinkLabel","_vertexLabel",srcjoincol,partition_col)
      destdf=destdf.selectExpr(label,destjoincol,vertexidcol,sinkcol,sinkpartitioncol)
      df=srcdf.join(destdf,srcdf[srcjoincol]==destdf[destjoincol],"inner")
      if("label" in df.columns):
        df=df.withColumn("label",df[label])
      else:
        df=df.withColumn("label",f.lit(label))
      df=df.withColumn("_sink",df[sinkcol]).withColumn("_sinkPartition",df[sinkpartitioncol]).withColumn("_vertexId",df[vertexidcol])\
          .withColumn("id",f_uuid()).withColumn("_isEdge",f.lit(True))
  else:
    destdf=destdf.withColumn("_sink",destdf[sinkcol]).withColumn("_sinkPartition",destdf[sinkpartitioncol]).select(destjoincol,"_sink","_sinkPartition")
    srcdf=srcdf.withColumn("_vertexId",srcdf[vertexidcol]).select(srcjoincol,"_vertexId",partition_col)
    df=srcdf.join(destdf,srcdf[srcjoincol]==destdf[destjoincol],"inner")
    df=df.withColumn("label",f.lit(label)).withColumn("id",f_uuid()).withColumn("_sinkLabel",f.lit(sinklabel))\
        .withColumn("_vertexLabel",f.lit(vertexlabel)).withColumn("_isEdge",f.lit(True))
 
  columns=["label","_sink","_sinkLabel","_vertexId","_vertexLabel","_isEdge","_sinkPartition",partition_col,"id"]
  return df.selectExpr(*columns)
  

In [None]:
#vertex_airroutes
import pandas as pd
df=spark.createDataFrame(pd.read_csv("https://raw.githubusercontent.com/krlawrence/graph/master/sample-data/air-routes-latest-nodes.csv"))

airroutes=df.withColumn("srno",df["~id"]).withColumnRenamed("~id","id").withColumnRenamed("~label","label").withColumnRenamed("code:string","code")\
  .withColumnRenamed("desc:string","desc").withColumnRenamed("country:string","country").withColumnRenamed("city:string","city")\
  .selectExpr("cast(srno as string) srno","cast(id as string) id","label","code","desc","country","city")

airroutes.show()



+----+---+-------+----+--------------------+-------+---------------+
|srno| id|  label|code|                desc|country|           city|
+----+---+-------+----+--------------------+-------+---------------+
|   0|  0|version|0.87|Air Routes Data -...|   null|           null|
|   1|  1|airport| ATL|Hartsfield - Jack...|     US|        Atlanta|
|   2|  2|airport| ANC|Anchorage Ted Ste...|     US|      Anchorage|
|   3|  3|airport| AUS|Austin Bergstrom ...|     US|         Austin|
|   4|  4|airport| BNA|Nashville Interna...|     US|      Nashville|
|   5|  5|airport| BOS|        Boston Logan|     US|         Boston|
|   6|  6|airport| BWI|Baltimore/Washing...|     US|      Baltimore|
|   7|  7|airport| DCA|Ronald Reagan Was...|     US|Washington D.C.|
|   8|  8|airport| DFW|Dallas/Fort Worth...|     US|         Dallas|
|   9|  9|airport| FLL|Fort Lauderdale/H...|     US|Fort Lauderdale|
|  10| 10|airport| IAD|Washington Dulles...|     US|Washington D.C.|
|  11| 11|airport| IAH|George Bush

In [None]:
#edges_airroutes
import pandas as pd
df=spark.createDataFrame(pd.read_csv("https://raw.githubusercontent.com/krlawrence/graph/master/sample-data/air-routes-latest-edges.csv"))

airroutesedges=df.withColumn("srno",df["~id"]).withColumnRenamed("~id","id").withColumnRenamed("~label","label").withColumnRenamed("~from","from")\
  .withColumnRenamed("~to","to").withColumnRenamed("dist:int","dist")\
  .selectExpr("id","cast(from as string) from","cast(to as string) to","label","dist","srno")

airroutesedges.show()


+----+----+---+-----+------+----+
|  id|from| to|label|  dist|srno|
+----+----+---+-----+------+----+
|3748|   1|  3|route| 809.0|3748|
|3749|   1|  4|route| 214.0|3749|
|3750|   1|  5|route| 945.0|3750|
|3751|   1|  6|route| 576.0|3751|
|3752|   1|  7|route| 546.0|3752|
|3753|   1|  8|route| 729.0|3753|
|3754|   1|  9|route| 581.0|3754|
|3755|   1| 10|route| 533.0|3755|
|3756|   1| 11|route| 688.0|3756|
|3757|   1| 12|route| 759.0|3757|
|3758|   1| 13|route|1941.0|3758|
|3759|   1| 14|route| 761.0|3759|
|3760|   1| 15|route| 404.0|3760|
|3761|   1| 16|route| 596.0|3761|
|3762|   1| 17|route| 906.0|3762|
|3763|   1| 18|route| 606.0|3763|
|3764|   1| 19|route| 546.0|3764|
|3765|   1| 20|route|1580.0|3765|
|3766|   1| 21|route| 356.0|3766|
|3767|   1| 22|route|2180.0|3767|
+----+----+---+-----+------+----+
only showing top 20 rows



In [None]:
#Vertex
vertex_airroutes = create_vertex_df(
    df=airroutes,
    vertex_properties_col_name=["code","desc","country","code"],
  vertex_label = "label",id="id",partition_col="srno"
)

vertex_airroutes.display()



id,srno,label,code,desc,country,code.1
0,0,version,"List(List(f6a5d74f-b89a-4354-8f8f-877ad7386455, 0.87))","List(List(e5c58445-55fa-476a-8928-4c57608c8a53, Air Routes Data - Version: 0.87 Generated: 2021-08-31 14:58:59 UTC; Graph created by Kelvin R. Lawrence; Please let me know of any errors you find in the graph or routes that should be added.))",,"List(List(8670683d-baa1-495a-8291-97ddf4c72b16, 0.87))"
1,1,airport,"List(List(1be05c75-0cfa-45cc-ab5f-915c446c88b6, ATL))","List(List(7a922ff4-cf43-4555-a451-88273abbedd7, Hartsfield - Jackson Atlanta International Airport))","List(List(d45d9aa9-f5a9-42ec-8879-37b771fdf221, US))","List(List(e95ddbec-affe-4c7f-9de0-041c2795c424, ATL))"
2,2,airport,"List(List(0274731b-c2f6-41f6-9214-a964f5bef4dd, ANC))","List(List(c8cf1a8e-8a47-41c3-ab6e-6ed632a7fb3c, Anchorage Ted Stevens))","List(List(4b20bc37-ab58-4af2-a3dd-ef3f07ea0d1c, US))","List(List(2ee6d96d-f944-4b81-b206-11a7a9732c6a, ANC))"
3,3,airport,"List(List(074d8055-1a9f-43bd-ab33-991173853d8c, AUS))","List(List(e80a0296-33ce-45cb-886c-0194d4c784ff, Austin Bergstrom International Airport))","List(List(42ede357-89f8-4dd5-8b52-e6d825950232, US))","List(List(7de18efe-da2d-42ae-a313-0618c43fb0a2, AUS))"
4,4,airport,"List(List(d1dee5dd-b3cc-4a87-bc22-25fba53d0bc0, BNA))","List(List(3944b926-5b14-4908-ad65-09d9f0e5e2d3, Nashville International Airport))","List(List(11558344-b2f1-444f-bd8a-eb9de6dd3778, US))","List(List(d8a64b04-7381-4bda-93e6-88ab72950cb0, BNA))"
5,5,airport,"List(List(04d8b36c-ee4d-4b89-9fff-8f36af989d14, BOS))","List(List(5d1ce53e-9f68-4782-90fc-38d49a4cdb0c, Boston Logan))","List(List(6138ae86-ed9d-42d8-9825-7b2944f8c81a, US))","List(List(6bffb56b-dbdd-41bd-b595-047b163248a9, BOS))"
6,6,airport,"List(List(73acadc1-995e-40d2-bdc1-3c9a6e1469ff, BWI))","List(List(80cb2d2c-f959-4100-a2d9-cb61c3dae591, Baltimore/Washington International Airport))","List(List(91ad68bc-f1b9-4c07-9b0c-cf8ecb4a5ae7, US))","List(List(f3153c39-2c40-4a1f-88b1-92fb34bd4447, BWI))"
7,7,airport,"List(List(a1381b27-7c5b-4a82-b9e3-32f1de5ce30a, DCA))","List(List(07374bcf-849a-47e0-9e67-ba1311dab8b5, Ronald Reagan Washington National Airport))","List(List(59227d5f-08ac-4abd-976e-f67648237324, US))","List(List(75fc3fc7-cd9e-4eaa-8e2b-ada6b28ce138, DCA))"
8,8,airport,"List(List(725d2110-1352-4762-8030-89ab0d65102e, DFW))","List(List(2c4efd69-50ab-4612-8883-d44eb462b54f, Dallas/Fort Worth International Airport))","List(List(68726b0d-bdfe-4c77-807f-843a2ccc15f0, US))","List(List(fd4b7072-533a-4db6-afae-d364aaf5e125, DFW))"
9,9,airport,"List(List(54bfedb3-7e59-44ce-894b-97b1b81d5e33, FLL))","List(List(d279a148-a8c3-4319-94fe-0953fbcbb30a, Fort Lauderdale/Hollywood International Airport))","List(List(3bb7ea33-4971-4e85-864e-9fbc7d732fa6, US))","List(List(0d4bc0dc-0316-4109-b061-e54860ea2717, FLL))"


In [None]:
edges_airroutes=create_edge_df(airroutes,airroutesedges,"label","srno","from","to","label","label","to","srno","from",True)

edges_airroutes.schema

#edges_airroutes.show()

Out[62]: StructType(List(StructField(label,StringType,true),StructField(_sink,StringType,true),StructField(_sinkLabel,StringType,true),StructField(_vertexId,StringType,true),StructField(_vertexLabel,StringType,true),StructField(_isEdge,BooleanType,false),StructField(_sinkPartition,StringType,true),StructField(srno,StringType,true),StructField(id,StringType,true)))

In [None]:
#Write Vertex
write_to_cosmos_graph(vertex_airroutes)


In [None]:
#Write Edges
write_to_cosmos_graph(edges_airroutes)