In [1]:
from os.path import exists
from pathlib import Path
from pyspark.sql import Row, SparkSession, Window, \
    functions as f, types as t
from graphframes import GraphFrame
import shutil
import zstandard as zstd

MAX_BACON_NUMBER = 9  # Previously calculated

In [6]:
def create_sparksession():
    return (SparkSession.builder.master("local[*]")
                                .appName("MLBaconNumber")
                                .config("spark.jars", "/usr/lib/MLBaconNumber/graphframes-0.8.1-spark3.0-s_2.12.jar")
                                .getOrCreate())


def get_roster(spark, compressed_roster_path):
    decompressed_roster_path = "/tmp/decompress_roster.csv"
    decompress_file(compressed_roster_path, decompressed_roster_path)
    return read_roster_csv(spark, decompressed_roster_path)


def get_teams(spark, compressed_teams_path):
    decompressed_teams_path = "/tmp/decompress_teams.csv"
    decompress_file(Path(compressed_teams_path).resolve(), decompressed_teams_path)
    return read_teams_csv(spark, decompressed_teams_path)


def decompress_file(compressed_path, decompressed_path):
    if not exists(decompressed_path):
        decompressor = zstd.ZstdDecompressor()
        with open(compressed_path, "rb") as filein, \
                open(decompressed_path, "wb") as fileout:
            decompression_result = decompressor.copy_stream(
                filein, fileout)
            print("Size (compressed,uncompressed): {}".format(decompression_result))


def read_roster_csv(spark, path):
    schema = t.StructType().add("year", t.IntegerType()) \
        .add("player_id", t.StringType()) \
        .add("last_name", t.StringType()) \
        .add("first_name", t.StringType()) \
        .add("bats", t.StringType()) \
        .add("throws", t.StringType()) \
        .add("team_id", t.StringType()) \
        .add("position", t.StringType())
    return spark.read.csv(path, schema) \
        .drop("bats", "throws", "position") \
        .withColumn("team_id", f.concat("team_id", "year")) \
        .repartition(16)


def read_teams_csv(spark, path):
    schema = t.StructType() \
        .add("year_id", t.IntegerType()) \
        .add("lg_id", t.StringType()) \
        .add("team_id", t.StringType()) \
        .add("franch_id", t.StringType()) \
        .add("div_id", t.StringType()) \
        .add("rank", t.IntegerType()) \
        .add("g", t.IntegerType()) \
        .add("g_home", t.IntegerType()) \
        .add("w", t.IntegerType()) \
        .add("l", t.IntegerType()) \
        .add("div_win", t.StringType()) \
        .add("wc_win", t.StringType()) \
        .add("lg_win", t.StringType()) \
        .add("ws_win", t.StringType()) \
        .add("r", t.IntegerType()) \
        .add("ab", t.IntegerType()) \
        .add("h", t.IntegerType()) \
        .add("_2b", t.IntegerType()) \
        .add("_3b", t.IntegerType()) \
        .add("hr", t.IntegerType()) \
        .add("bb", t.IntegerType()) \
        .add("so", t.IntegerType()) \
        .add("sb", t.IntegerType()) \
        .add("cs", t.IntegerType()) \
        .add("hbp", t.IntegerType()) \
        .add("sf", t.IntegerType()) \
        .add("ra", t.IntegerType()) \
        .add("er", t.IntegerType()) \
        .add("era", t.DoubleType()) \
        .add("cg", t.IntegerType()) \
        .add("sho", t.IntegerType()) \
        .add("sv", t.IntegerType()) \
        .add("ip_outs", t.IntegerType()) \
        .add("h_a", t.IntegerType()) \
        .add("hr_a", t.IntegerType()) \
        .add("bb_a", t.IntegerType()) \
        .add("so_a", t.IntegerType()) \
        .add("e", t.IntegerType()) \
        .add("dp", t.IntegerType()) \
        .add("fp", t.DoubleType()) \
        .add("name", t.StringType()) \
        .add("park", t.StringType()) \
        .add("attendance", t.IntegerType()) \
        .add("bpf", t.IntegerType()) \
        .add("ppf", t.IntegerType()) \
        .add("team_id_br", t.StringType()) \
        .add("team_id_lahman45", t.StringType()) \
        .add("team_id_retro", t.StringType())
    return spark.read.csv(path, schema) \
        .select(f.concat("team_id_retro", "year_id").alias("team_id"),
                "year_id", "name", "lg_id") \
        .repartition(16)


def write_csv(df, filename):
    temp_file = Path("/tmp", filename)
    output_file = Path("../../data", filename)
    df.repartition(1).write.csv(str(temp_file.absolute()),
                                mode="overwrite", header=True)
    csv = list(temp_file.glob("**/*.csv"))[0]
    shutil.copyfile(csv, output_file.resolve())
    shutil.rmtree(temp_file)


def build_mlb_graph(spark, roster):
    team_vertices = get_team_vertices(roster)
    player_vertices = get_player_vertices(roster)
    first_game_vertex = get_first_game_vertex(spark)

    player_edges = roster.select(f.col("player_id").alias(
        "src"), f.col("team_id").alias("dst"))
    team_edges = roster.select(f.col("team_id").alias(
        "src"), f.col("player_id").alias("dst"))
    first_game_edges = get_first_game_edges(spark)

    vertices = (team_vertices
                .unionByName(player_vertices)
                .unionByName(first_game_vertex))
    edges = team_edges.unionByName(player_edges).unionByName(first_game_edges)
    return GraphFrame(vertices, edges)


def calculate_first_game_shortest_paths(mlb_graph):
    return mlb_graph.shortestPaths(landmarks=["first_game"])


# Create a new graph (a tree) where the only edges are the shortest
# paths between players to the first game
def build_first_game_tree(spark, roster, roster_with_shortest_paths):
    teammate_edges = get_shortest_path_edges(roster_with_shortest_paths)
    
    # There's some weird Spark bug where we get different results from teammate_edges
    # Before and after the union. By saving it to storage, we ensure we get the right
    # results.
    temp_file = str(Path("/tmp", "teammate_edges").absolute())
    teammate_edges.write.parquet(temp_file, mode="overwrite")
    edges = (spark.read.parquet(temp_file)
             .unionByName(get_first_game_edges(spark)
                       .withColumn("team_id", f.lit("PLAYED_IN"))))

    vertices = (get_player_vertices(roster)
                .unionByName(get_first_game_vertex(spark)))
    return GraphFrame(vertices, edges)


def calculate_bacon_numbers(roster, shortest_path_tree):
    return (shortest_path_tree.shortestPaths(landmarks=["first_game"])
            .select(f.col("id").alias("player_id"),
                    (f.col("distances").getItem("first_game") - 1).alias("bacon_number"))
            .drop("distances")
            .join(roster, on="player_id", how="inner")
            .groupBy("player_id", "last_name", "first_name", "bacon_number")
            .agg(f.max("year").alias("max_year")))


def add_bacon_path(bacon_numbers, shortest_path_tree):
    df = bacon_numbers.withColumn("v0", f.col("player_id"))
    _ = shortest_path_tree.edges.cache().count()
    for i in range(MAX_BACON_NUMBER + 1):
        df = df.join(shortest_path_tree.edges,
                     on=(f.col("v" + str(i)) == f.col("src")), how="leftouter") \
            .withColumn("e" + str(i+1), f.col("team_id")) \
            .withColumn("v" + str(i+1), f.col("dst")) \
            .drop("src", "dst", "team_id")
    return df


def get_team_vertices(roster):
    return roster.select(f.col("team_id").alias("id"),
                         f.col("year").alias("max_year"),
                         f.lit("").alias("last_name"),
                         f.lit("").alias("first_name")).distinct()


def get_player_vertices(roster):
    return roster.groupBy(f.col("player_id").alias("id"),
                          f.col("last_name"), f.col("first_name")) \
        .agg(f.max("year").alias("max_year"))


def get_first_game_vertex(spark):
    return spark.createDataFrame([
        Row(id="first_game", max_year=1871, last_name="", first_name="")
    ])


def get_first_game_edges(spark):
    return spark.createDataFrame([
        ("whitd102", "first_game"),
        ("kimbg101", "first_game"),
        ("paboc101", "first_game"),
        ("allia101", "first_game"),
        ("white104", "first_game"),
        ("prata101", "first_game"),
        ("sutte101", "first_game"),
        ("carlj102", "first_game"),
        ("bassj101", "first_game"),
        ("selmf101", "first_game"),
        ("mathb101", "first_game"),
        ("foraj101", "first_game"),
        ("goldw101", "first_game"),
        ("lennb101", "first_game"),
        ("caret101", "first_game"),
        ("mince101", "first_game"),
        ("mcdej101", "first_game"),
        ("kellb105", "first_game"),
    ], ["src", "dst"])


def add_shortest_paths_to_roster(roster, shortest_paths):
    shortest_paths = shortest_paths.select("id", "distances")
    return (roster.join(shortest_paths, on=shortest_paths.id == roster.team_id,
                        how="leftouter")
            .withColumn("team_shortest_path", f.col("distances").getItem("first_game"))
            .drop("id", "distances")
            .join(shortest_paths, on=shortest_paths.id == roster.player_id, how="leftouter")
            .withColumn("player_shortest_path", f.col("distances").getItem("first_game"))
            .drop("id", "distances"))


def get_shortest_path_edges(roster_shortest_paths):
    left = roster_shortest_paths.alias("left")
    right = roster_shortest_paths.alias("right")
    return (left.join(right, on="team_id", how="inner")
            .where(f.col("left.player_id") != f.col("right.player_id"))
            .where(f.col("left.player_shortest_path")
                   > f.col("right.player_shortest_path"))
            .groupBy(f.col("left.player_id").alias("src"))
            .agg(f.first("right.player_id").alias("dst"),
                 f.min("team_id").alias("team_id")))

In [3]:
compressed_roster = "../../data/roster.csv.zst"
compressed_teams = "../../data/teams.csv.zst"
shortest_path_filename = "first_game_distances.csv"
bacon_numbers_filename = "bacon_numbers.csv"
teams_filename = "teams.csv"

spark = create_sparksession()

22/12/10 03:46:15 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable


Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).


In [4]:
roster = get_roster(spark, compressed_roster)

mlb_graph = build_mlb_graph(spark, roster)
first_game_paths = calculate_first_game_shortest_paths(mlb_graph)

Size (compressed,uncompressed): (887421, 4019466)




In [7]:
roster_with_shortest_paths = add_shortest_paths_to_roster(
    roster, first_game_paths)
shortest_path_tree = build_first_game_tree(
    spark, roster, roster_with_shortest_paths)
bacon_numbers = calculate_bacon_numbers(roster, shortest_path_tree)
bacon_numbers = add_bacon_path(bacon_numbers, shortest_path_tree)

write_csv(bacon_numbers, bacon_numbers_filename)

                                                                                

In [8]:
teams = get_teams(spark, compressed_teams)
write_csv(teams, teams_filename)

Size (compressed,uncompressed): (194241, 584967)
