In [1]:
from pyspark.sql import SparkSession, functions as F

In [2]:
spark = SparkSession \
            .builder \
            .appName("advent-of-code-2021") \
            .master("local[*]") \
            .getOrCreate()

Using Spark's default log4j profile: org/apache/spark/log4j-defaults.properties
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
21/12/04 08:57:42 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable


In [3]:
df_test = spark.read.text("test-input.txt")
df = spark.read.text("input.txt")

In [7]:
# Part 1
def part_1(df):
    number_of_bits = len(df.first()["value"])
    number_of_rows = df.count()

    bit_columns = []
    sum_aggregations = []
    for i in range(number_of_bits):
        bit_columns.append(F.substring("value", i + 1, 1).cast("int").alias(f"bit_{i}"))
        sum_aggregations.append(F.sum(f"bit_{i}").alias(f"sum_{i}"))

    df = df \
        .select(*bit_columns) \
        .agg(*sum_aggregations)

    most_common_columns = []
    most_common_columns_names = []
    for i in range(number_of_bits):
        most_common_columns.append(F.round(F.col(f"sum_{i}") / number_of_rows).cast("int").alias(f"most_common_{i}"))
        most_common_columns_names.append(f"most_common_{i}")

    df = df \
        .select(*most_common_columns) \
        .select(most_common_columns_names)

    gamma_bits = list(df.collect()[0])
    epsilon_bits = [abs(1 - gamma_bit) for gamma_bit in gamma_bits]

    gamma = int("".join([str(b) for b in gamma_bits]), 2)
    epsilon = int("".join([str(b) for b in epsilon_bits]), 2)

    return gamma * epsilon

assert part_1(df_test) == 198

print(f"Solution: {part_1(df)}")

Solution: 3549854


In [65]:
# Part 2
def _get_rating(df, rating_type):
    number_of_bits = len(df.first()["value"])

    filtered_df = df
    for i in range(number_of_bits):
        number_of_rows = filtered_df.count()
        if number_of_rows == 1:
            break

        filtered_df = filtered_df.withColumn(f"bit_{i}", F.substring("value", i + 1, 1).cast("int"))
        sum_df = filtered_df.agg(F.sum(f"bit_{i}").alias(f"sum_{i}"))
        sum_i = sum_df.collect()[0][0]

        ones_ratio = sum_i / number_of_rows
        if ones_ratio != 0.5:
            most_common = round(ones_ratio)
            least_common = abs(1 - most_common)
        else:
            most_common = 1
            least_common = 0

        filter_value = most_common if rating_type == "oxygen" else least_common
        filtered_df = filtered_df.filter(F.col(f"bit_{i}") == filter_value)

    if filtered_df.count() == 1:
        return filtered_df.first()[0]
    else:
        raise ValueError(f"No rating found for {rating_type}")

def part_2(df):
    oxygen_rating = int(_get_rating(df, "oxygen"), 2)
    co2_rating = int(_get_rating(df, "co2"), 2)
    return oxygen_rating * co2_rating        

assert part_2(df_test) == 230

print(f"Solution: {part_2(df)}")

Solution: 3765399
