In [0]:
# Define the test tables
test_tables = [
    "workspace.default.mimic_cxr_test_set_label_explanation_extract_databricks_qwen3_next_80b_a3b_instruct_v3",
    "workspace.default.mimic_cxr_test_set_label_explanation_extract_gpt_5_1_v2",
    "workspace.default.mimic_cxr_test_set_label_explanation_extract_llama_4_maverick_v1",
]

# Define the train tables
train_tables = [
    "workspace.default.mimic_cxr_train_set_label_explanation_extract_databricks_qwen3_next_80b_a3b_instruct_v3",
    "workspace.default.mimic_cxr_train_set_label_explanation_extract_gpt_5_1_v2",
    "workspace.default.mimic_cxr_train_set_label_explanation_extract_llama4_maverick_v1",
]

In [0]:
from pyspark.sql import functions as F
from pyspark.sql.window import Window

def build_consensus_dataset(table_names):
    model_names = [
        "databricks_qwen3_next_80b_a3b_instruct_v3",
        "gpt_5_1_v2",
        "llama4_maverick_v1",
    ]

    # 1) Load and union 3 model tables (include confidence)
    dfs = []
    for tbl, model in zip(table_names, model_names):
        df = (
            spark.table(tbl)
                 .select(
                     "subject_id",
                     "study_id",
                     "findings",
                     "impression",
                     "label",
                     "explanation",
                     "confidence",
                 )
                 .withColumn("model", F.lit(model))
        )
        dfs.append(df)

    union_df = dfs[0].unionByName(dfs[1]).unionByName(dfs[2])

    # 2) Per (subject_id, study_id, label): count + avg(confidence)
    label_stats = (
        union_df
        .groupBy("subject_id", "study_id", "label")
        .agg(
            F.count("*").alias("label_count"),
            F.avg("confidence").alias("avg_conf")
        )
    )

    # 3) Per (subject_id, study_id): how many distinct labels
    study_stats = (
        label_stats
        .groupBy("subject_id", "study_id")
        .agg(
            F.count("*").alias("num_labels")
        )
    )

    joined = label_stats.join(study_stats, on=["subject_id", "study_id"], how="inner")

    # ---------- Unanimous case (all 3 labels same) ----------
    unanimous = (
        joined
        .filter(F.col("num_labels") == 1)
        .select(
            "subject_id",
            "study_id",
            F.col("label").alias("consensus_label"),
            F.col("avg_conf").alias("consensus_confidence"),
        )
    )

    # ---------- Two-label cases: majority + higher avg confidence ----------
    # Rank labels by count within each study
    w = Window.partitionBy("subject_id", "study_id").orderBy(F.col("label_count").desc())

    ranked = (
        joined
        .filter(F.col("num_labels") == 2)
        .withColumn("rank", F.row_number().over(w))
    )

    # Majority (count 2) and minority (count 1)
    majority = (
        ranked
        .filter(F.col("rank") == 1)
        .select(
            "subject_id",
            "study_id",
            F.col("label").alias("consensus_label"),
            F.col("avg_conf").alias("majority_conf"),
        )
    )

    minority = (
        ranked
        .filter(F.col("rank") == 2)
        .select(
            "subject_id",
            "study_id",
            F.col("avg_conf").alias("minority_conf"),
        )
    )

    # Keep only if majority has higher avg confidence than minority
    majority_kept = (
        majority
        .join(minority, on=["subject_id", "study_id"], how="inner")
        .filter(F.col("majority_conf") > F.col("minority_conf"))
        .select(
            "subject_id",
            "study_id",
            "consensus_label",
            F.col("majority_conf").alias("consensus_confidence"),
        )
    )

    # ---------- Combine unanimous + 2-vs-1 high-confidence ----------
    consensus_labels = unanimous.unionByName(majority_kept)

    # 4) Join back to get findings/impression/explanation
    #    For explanation, just pick the first matching row per (subject, study)
    winners = (
        consensus_labels
        .join(union_df, on=["subject_id", "study_id"], how="inner")
        .where(F.col("label") == F.col("consensus_label"))
    )

    w2 = Window.partitionBy("subject_id", "study_id").orderBy("model")

    final_df = (
        winners
        .withColumn("rn", F.row_number().over(w2))
        .filter(F.col("rn") == 1)
        .select(
            "subject_id",
            "study_id",
            "findings",
            "impression",
            F.col("consensus_label").alias("label"),
            "explanation",
            F.col("consensus_confidence").alias("confidence"),
        )
        .distinct()
    )

    return final_df


In [0]:
consensus_train_df = build_consensus_dataset(train_tables)
consensus_test_df  = build_consensus_dataset(test_tables)

display(consensus_train_df.limit(10))
display(consensus_test_df.limit(10))

consensus_train_df.write \
    .mode("overwrite") \
    .option("overwriteSchema", "true") \
    .saveAsTable("workspace.default.mimic_cxr_train_set_label_explanation_consensus_v1")

consensus_test_df.write \
    .mode("overwrite") \
    .option("overwriteSchema", "true") \
    .saveAsTable("workspace.default.mimic_cxr_test_set_label_explanation_consensus_v1")

In [0]:
consensus_train_df.printSchema()
consensus_test_df.printSchema()