In [2]:
from pyspark.sql import SparkSession
from pyspark.sql import Row
from pyspark.sql.functions import col, count, when, expr

# Initialize Spark session
spark = SparkSession.builder \
    .appName("EntityResolutionEvaluation") \
    .getOrCreate()

# Sample data: (true_label, predicted_label)
data = [
    Row(true_label=0, predicted_label=1),
    Row(true_label=0, predicted_label=0),
    Row(true_label=0, predicted_label=1),
    Row(true_label=0, predicted_label=0),
    Row(true_label=1, predicted_label=1),
    Row(true_label=0, predicted_label=0)
]

# Create DataFrame
df = spark.createDataFrame(data)

# Calculate confusion matrix components
conf_matrix = df.groupBy().agg(
    count(when((col("true_label") == 1) & (col("predicted_label") == 1), 1)).alias("TP"),
    count(when((col("true_label") == 1) & (col("predicted_label") == 0), 1)).alias("FN"),
    count(when((col("true_label") == 0) & (col("predicted_label") == 1), 1)).alias("FP"),
    count(when((col("true_label") == 0) & (col("predicted_label") == 0), 1)).alias("TN")
)

conf_matrix.show()

# Compute precision, recall, and F1-score
metrics = conf_matrix.select(
    expr("TP / (TP + FP)").alias("Precision"),
    expr("TP / (TP + FN)").alias("Recall"),
    expr("2 * (Precision * Recall) / (Precision + Recall)").alias("F1_Score")
)

metrics.show()


+---+---+---+---+
| TP| FN| FP| TN|
+---+---+---+---+
|  1|  0|  2|  3|
+---+---+---+---+

+------------------+------+--------+
|         Precision|Recall|F1_Score|
+------------------+------+--------+
|0.3333333333333333|   1.0|     0.5|
+------------------+------+--------+

