# Spark-native Similarity Scoring with Parameter Sampling

This notebook merges the optimized Spark solution with parameter sampling so that **each row in the metrics table corresponds to one sampled parameter set**.

Key properties:
- No `toPandas()` in the loop (avoids Java heap space errors)
- Similarities computed **once** and cached
- Weights/penalties applied per configuration using Spark expressions
- Metrics aggregated in Spark


In [None]:
from pyspark.sql import SparkSession
from pyspark.sql import functions as F
from pyspark.sql import types as T
from pyspark.storagelevel import StorageLevel
import jellyfish
import numpy as np


## Base Similarity UDFs (raw, unweighted)

In [None]:
def jw_raw(col1, col2):
    if col1 is None or col2 is None or col1 == '' or col2 == '':
        return None
    return float(jellyfish.jaro_winkler_similarity(str(col1), str(col2)))

def hamming_raw(col1, col2):
    if col1 is None or col2 is None or col1 == '' or col2 == '':
        return None
    s1, s2 = str(col1), str(col2)
    max_size = max(len(s1), len(s2))
    if max_size == 0:
        return None
    return 1.0 - (float(jellyfish.hamming_distance(s1, s2)) / max_size)

def overlap_raw(col1, col2):
    if col1 is None or col2 is None or col1 == '' or col2 == '':
        return None
    return 1.0 if str(col1) == str(col2) else 0.0

udf_jw_raw = F.udf(jw_raw, T.DoubleType())
udf_hamming_raw = F.udf(hamming_raw, T.DoubleType())
udf_overlap_raw = F.udf(overlap_raw, T.DoubleType())


## Compute and Cache Base Similarities (run once)

In [None]:
base_df = (
    link_df
    .withColumn('jw_nome_raw', udf_jw_raw(F.col('nome_a'), F.col('nome_b')))
    .withColumn('jw_nome_mae_raw', udf_jw_raw(F.col('nome_mae_a'), F.col('nome_mae_b')))
    .withColumn('ham_dt_nasc_raw', udf_hamming_raw(F.col('dt_nasc_a'), F.col('dt_nasc_b')))
    .withColumn('ov_sexo_raw', udf_overlap_raw(F.col('sexo_a'), F.col('sexo_b')))
    .persist(StorageLevel.MEMORY_AND_DISK)
)
base_df.count()  # materialize cache


## Parameter Sampling

In [None]:
def sample_param_sets(cfg, n, seed=42):
    rng = np.random.default_rng(seed)
    fields = cfg['dataset']['fields']
    rows = []
    for _ in range(n):
        r = {}
        for k, spec in fields.items():
            w = spec['weight']
            p = spec['penalty']
            w_vals = np.arange(w['low'], w['high'] + 1e-9, w['step'])
            p_vals = np.arange(p['low'], p['high'] + 1e-9, p['step'])
            r[f'w_{k}'] = float(rng.choice(w_vals))
            r[f'p_{k}'] = float(rng.choice(p_vals))
        rows.append(r)
    return rows


## Apply Parameters and Compute Total Score

In [None]:
def apply_params(df, params):
    return (
        df
        .withColumn('sim_nome', F.when(F.col('jw_nome_raw').isNull(), F.lit(params['p_nome']))
                               .otherwise(F.col('jw_nome_raw') * F.lit(params['w_nome'])))
        .withColumn('sim_nome_mae', F.when(F.col('jw_nome_mae_raw').isNull(), F.lit(params['p_nome_mae']))
                                   .otherwise(F.col('jw_nome_mae_raw') * F.lit(params['w_nome_mae'])))
        .withColumn('sim_dt_nasc', F.when(F.col('ham_dt_nasc_raw').isNull(), F.lit(params['p_dt_nasc']))
                                 .otherwise(F.col('ham_dt_nasc_raw') * F.lit(params['w_dt_nasc'])))
        .withColumn('sim_sexo', F.when(F.col('ov_sexo_raw').isNull(), F.lit(params['p_sexo']))
                              .otherwise(F.col('ov_sexo_raw') * F.lit(params['w_sexo'])))
    )

def with_total_score(df, params):
    score_max = params['w_nome'] + params['w_nome_mae'] + params['w_dt_nasc'] + params['w_sexo']
    return df.withColumn('total_score',
                         (F.col('sim_nome') + F.col('sim_nome_mae') + F.col('sim_dt_nasc') + F.col('sim_sexo')) / F.lit(score_max))


## Metrics per Parameter Set

In [None]:
def compute_metrics(df, threshold):
    agg = df.agg(
        F.sum(F.when((F.col('match_status') == 1) & (F.col('total_score') >= threshold), 1).otherwise(0)).alias('VP'),
        F.sum(F.when((F.col('match_status') == 0) & (F.col('total_score') >= threshold), 1).otherwise(0)).alias('FP'),
        F.sum(F.when((F.col('match_status') == 1) & (F.col('total_score') < threshold), 1).otherwise(0)).alias('FN'),
        F.sum(F.when((F.col('match_status') == 0) & (F.col('total_score') < threshold), 1).otherwise(0)).alias('VN'),
    )
    return (
        agg
        .withColumn('precision', F.col('VP') / (F.col('VP') + F.col('FP')))
        .withColumn('recall', F.col('VP') / (F.col('VP') + F.col('FN')))
        .withColumn('specificity', F.col('VN') / (F.col('VN') + F.col('FP')))
        .withColumn('accuracy', (F.col('VP') + F.col('VN')) /
                              (F.col('VP') + F.col('FP') + F.col('FN') + F.col('VN')))
    )


## Experiment Loop (each row = one parameter set)

In [None]:
rows = sample_param_sets(cfg, n=1000, seed=7)
metrics_rows = []
for params in rows:
    scored = apply_params(base_df, params)
    scored = with_total_score(scored, params)
    m = compute_metrics(scored, threshold=0.8)
    for k, v in params.items():
        m = m.withColumn(k, F.lit(v))
    metrics_rows.append(m)

metrics_df = metrics_rows[0]
for m in metrics_rows[1:]:
    metrics_df = metrics_df.unionByName(m)

metrics_df.show(5)
