In [1]:
from pyspark.sql.types import * 
from pyspark.sql import SparkSession
from pyspark import SparkConf, SparkContext
from graphframes import *

In [2]:
# 启动 Spark （如果你已经启动就不需要）
spark = SparkSession.builder.master("local[*]") \
   .appName("test") \
   .enableHiveSupport() \
   .getOrCreate()

In [3]:
def create_transport_graph(): 
    node_fields = [
        StructField("id", StringType(), True),
        StructField("latitude", FloatType(), True),
        StructField("longitude", FloatType(), True),
        StructField("population", IntegerType(), True)
    ]
    nodes = spark.read.csv("transport-nodes.csv", header=True,
                           schema=StructType(node_fields))
    rels = spark.read.csv("transport-relationships.csv", header=True)
    reversed_rels = (rels.withColumn("newSrc", rels.dst)
                     .withColumn("newDst", rels.src)
                     .drop("dst", "src")
                     .withColumnRenamed("newSrc", "src")
                     .withColumnRenamed("newDst", "dst")
                     .select("src", "dst", "relationship", "cost"))
    relationships = rels.union(reversed_rels) 
    return GraphFrame(nodes, relationships)

In [4]:
# 生成图
g = create_transport_graph()

In [15]:
g.vertices.show()

+----------------+---------+---------+----------+
|              id| latitude|longitude|population|
+----------------+---------+---------+----------+
|       Amsterdam| 52.37919| 4.899431|    821752|
|         Utrecht|52.092876|  5.10448|    334176|
|        Den Haag|52.078663| 4.288788|    514861|
|       Immingham| 53.61239| -0.22219|      9642|
|       Doncaster| 53.52285| -1.13116|    302400|
|Hoek van Holland|  51.9775|  4.13333|      9382|
|      Felixstowe| 51.96375|   1.3511|     23689|
|         Ipswich| 52.05917|  1.15545|    133384|
|      Colchester| 51.88921|  0.90421|    104390|
|          London|51.509865|-0.118092|   8787892|
|       Rotterdam|  51.9225|  4.47917|    623652|
|           Gouda| 52.01667|  4.70833|     70939|
+----------------+---------+---------+----------+



In [12]:
g.edges.show()

+----------------+----------------+------------+----+
|             src|             dst|relationship|cost|
+----------------+----------------+------------+----+
|       Amsterdam|         Utrecht|       EROAD|  46|
|       Amsterdam|        Den Haag|       EROAD|  59|
|        Den Haag|       Rotterdam|       EROAD|  26|
|       Amsterdam|       Immingham|       EROAD| 369|
|       Immingham|       Doncaster|       EROAD|  74|
|       Doncaster|          London|       EROAD| 277|
|Hoek van Holland|        Den Haag|       EROAD|  27|
|      Felixstowe|Hoek van Holland|       EROAD| 207|
|         Ipswich|      Felixstowe|       EROAD|  22|
|      Colchester|         Ipswich|       EROAD|  32|
|          London|      Colchester|       EROAD| 106|
|           Gouda|       Rotterdam|       EROAD|  25|
|           Gouda|         Utrecht|       EROAD|  35|
|        Den Haag|           Gouda|       EROAD|  32|
|Hoek van Holland|       Rotterdam|       EROAD|  33|
|         Utrecht|       Ams

In [5]:
# 顶点筛选
g.vertices.filter("population > 100000 and population < 300000").sort("population").show()

+----------+--------+---------+----------+
|        id|latitude|longitude|population|
+----------+--------+---------+----------+
|Colchester|51.88921|  0.90421|    104390|
|   Ipswich|52.05917|  1.15545|    133384|
+----------+--------+---------+----------+



In [20]:
from_expr = "id='Den Haag'"
to_expr = "population > 100000 and population < 300000 and id <> 'Den Haag'"
result = g.bfs(from_expr, to_expr)
 
columns = [column for column in result.columns if not column.startswith("e")]
result.select(columns).take(1)

[Row(from=Row(id='Den Haag', latitude=52.07866287231445, longitude=4.288787841796875, population=514861), v1=Row(id='Hoek van Holland', latitude=51.977500915527344, longitude=4.13332986831665, population=9382), v2=Row(id='Felixstowe', latitude=51.963748931884766, longitude=1.351099967956543, population=23689), to=Row(id='Ipswich', latitude=52.05916976928711, longitude=1.1554499864578247, population=133384))]

In [22]:
result.collect()

[Row(from=Row(id='Den Haag', latitude=52.07866287231445, longitude=4.288787841796875, population=514861), e0=Row(src='Den Haag', dst='Hoek van Holland', relationship='EROAD', cost='27'), v1=Row(id='Hoek van Holland', latitude=51.977500915527344, longitude=4.13332986831665, population=9382), e1=Row(src='Hoek van Holland', dst='Felixstowe', relationship='EROAD', cost='207'), v2=Row(id='Felixstowe', latitude=51.963748931884766, longitude=1.351099967956543, population=23689), e2=Row(src='Felixstowe', dst='Ipswich', relationship='EROAD', cost='22'), to=Row(id='Ipswich', latitude=52.05916976928711, longitude=1.1554499864578247, population=133384))]