In [1]:
from pyspark.sql import DataFrame, SparkSession, Window
from pyspark.sql import functions as F
from pyspark.sql.types import IntegerType

In [2]:
spark = SparkSession.builder.appName("advent-of-code-2024").getOrCreate()

In [3]:
def calculate_sum_of_multiplications(df: DataFrame) -> int:
    global_regex = r"(mul\(\d{1,3},\d{1,3}\))"
    local_regex = r"^mul\((\d{1,3}),(\d{1,3})\)$"
    return (
        df.withColumn(
            "multiplication_codes",
            F.explode(F.regexp_extract_all(F.col("value"), F.lit(global_regex))),
        )
        .withColumns(
            {
                "a": F.regexp_extract_all(
                    "multiplication_codes", F.lit(local_regex), F.lit(1)
                )
                .getItem(0)
                .cast(IntegerType()),
                "b": F.regexp_extract_all(
                    "multiplication_codes", F.lit(local_regex), F.lit(2)
                )
                .getItem(0)
                .cast(IntegerType()),
            }
        )
        .withColumn(
            "multiplication_results",
            F.col("a") * F.col("b"),
        )
        .agg(F.sum("multiplication_results"))
        .collect()[0][0]
    )

In [4]:
def part_1(input_file_name: str) -> int:
    df = spark.read.text(input_file_name)
    return calculate_sum_of_multiplications(df)


assert part_1("test-input-1.txt") == 161

print(f'Solution: {part_1("input.txt")}')

Solution: 153469856


In [5]:
def part_2(input_file_name: str) -> int:
    window = Window.orderBy("line_id", "sub_line_id")
    df = (
        spark.read.text(input_file_name)
        .withColumn("line_id", F.monotonically_increasing_id())
        .select(
            "*",
            F.posexplode(F.split(F.col("value"), F.lit(r"do\(\)"))).alias(
                "sub_line_id", "do_sub_line"
            ),
        )
        .withColumn("next_line_id", F.lead(F.col("line_id")).over(window))
        .withColumn("next_do_sub_line", F.lead(F.col("do_sub_line")).over(window))
        .withColumn(
            "relevant_do_sub_line",
            F.when(
                F.col("line_id") + 1 == F.col("next_line_id"),
                F.concat(F.col("do_sub_line"), F.col("next_do_sub_line")),
            ).otherwise(F.col("do_sub_line")),
        )
        .filter((F.col("sub_line_id") > 0) | (F.col("line_id") == 0))
        .withColumn(
            "relevant_sub_sub_lines",
            F.split(F.col("relevant_do_sub_line"), F.lit(r"don't\(\)")).getItem(0),
        )
        .select(F.col("relevant_sub_sub_lines").alias("value"))
    )
    return calculate_sum_of_multiplications(df)


assert part_2("test-input-2.txt") == 48

print(f'Solution: {part_2("input.txt")}')

Solution: 77055967
