In [0]:
# Install required packages (Databricks will run this at the start of the job)
# You can use %pip in a notebook cell or subprocess in a script
try:
    from datasketch import MinHash, MinHashLSH
    from Levenshtein import ratio as levenshtein_ratio
except ImportError:
    import sys
    import subprocess
    subprocess.check_call([sys.executable, "-m", "pip", "install", "datasketch", "python-Levenshtein"])
    from datasketch import MinHash, MinHashLSH
    from Levenshtein import ratio as levenshtein_ratio

from pyspark.sql import SparkSession
from pyspark.sql.functions import col, lit, udf, concat_ws, lower, regexp_replace
from pyspark.sql.types import StringType
import pandas as pd
import re
from collections import defaultdict

# Spark session
spark = SparkSession.builder.appName("DeduplicationJob").getOrCreate()

# Normalization functions
def preprocess_text(text):
    text = str(text).strip().lower()
    text = re.sub(r'\s+', ' ', text)  # Remove extra spaces
    text = re.sub(r'[^a-z0-9\s]', '', text)  # Remove special characters
    return text

COUNTRY_MAP = {
    "ch": "switzerland",
    "suisse": "switzerland",
    "schweiz": "switzerland",
    "li": "liechtenstein",
    "us": "united states",
    "gb": "united kingdom",
    "de": "germany"
}

def preprocess_country_names(country):
    country = preprocess_text(country)
    return COUNTRY_MAP.get(country, country)

# UDFs for Spark
preprocess_text_udf = udf(lambda x: preprocess_text(x), StringType())
preprocess_country_names_udf = udf(lambda x: preprocess_country_names(x), StringType())

# Read input table
input_table = "my_database.my_schema.dedup_input"
df_spark = spark.table(input_table)

# Normalize column names for pandas
columns_map = {
    "FUSION_CUSTOMER_NAME": "fusion_customer_name",
    "ADDRESS_LINE_1": "address_line_1",
    "POSTAL_CODE": "postal_code",
    "CITY": "city",
    "COUNTRY": "country",
    "ID": "id",
    "SOURCE_SYSTEM": "source_system"
}

# Select and rename columns
df_spark = df_spark.select(*columns_map.keys())
for orig, new in columns_map.items():
    df_spark = df_spark.withColumnRenamed(orig, new)

# Create 'full_address' column (normalize address_line_1 only, as there's no address_line_2)
df_spark = df_spark.withColumn("full_address", preprocess_text_udf(col("address_line_1")))

# Normalize other columns
df_spark = df_spark.withColumn("fusion_customer_name", preprocess_text_udf(col("fusion_customer_name")))
df_spark = df_spark.withColumn("city", preprocess_text_udf(col("city")))
df_spark = df_spark.withColumn("country", preprocess_country_names_udf(col("country")))

# Convert Spark DataFrame -> Pandas DataFrame for MinHash deduplication
df = df_spark.toPandas()

# Deduplication algorithm (unchanged except column names)
def create_weighted_minhash(row, columns, weights):
    combined_text = []
    for col, weight in zip(columns, weights):
        if col in row:
            normalized_value = str(row[col])
            combined_text.extend([normalized_value] * weight)  # Repeat based on weight
    combined_text = " ".join(combined_text)
    minhash = MinHash(num_perm=128)
    for word in combined_text.split():
        minhash.update(word.encode("utf8"))
    return minhash

def weighted_clustering(df, columns, weights, threshold):
    lsh = MinHashLSH(threshold=threshold / 100, num_perm=128)
    clusters = defaultdict(list)
    minhashes = {}
    
    for idx, row in df.iterrows():
        minhash = create_weighted_minhash(row, columns, weights)
        minhashes[idx] = minhash
        lsh.insert(idx, minhash)
    
    for idx in minhashes:
        cluster = lsh.query(minhashes[idx])
        clusters[frozenset(cluster)].append(idx)
    
    cluster_sizes = {}
    cluster_id = 1
    cluster_assignments = [None] * len(df)
    driver_ids = [None] * len(df)
    link_scores = [None] * len(df)
    
    for cluster in clusters.values():
        size = len(cluster)
        driver_idx = cluster[0]
        driver_row_id = df.iloc[driver_idx]['id']
        
        for idx in cluster:
            row_id = df.iloc[idx]['id']
            cluster_assignments[idx] = cluster_id
            driver_ids[idx] = driver_row_id
            link_scores[idx] = levenshtein_ratio(df.iloc[driver_idx][columns[0]], df.iloc[idx][columns[0]]) * 100
            df.at[idx, "Row_Id"] = row_id
        
        cluster_sizes[cluster_id] = size
        cluster_id += 1
    
    for idx in range(len(df)):
        if cluster_assignments[idx] is None:
            row_id = df.iloc[idx]['id']
            cluster_assignments[idx] = row_id
            driver_ids[idx] = row_id
            link_scores[idx] = 0.0
            df.at[idx, "Row Id"] = row_id
    
    return cluster_assignments, driver_ids, cluster_sizes, link_scores

# Hardcoded thresholds/columns/weights
thresholds = {
    "combined": {
        "columns": ["fusion_customer_name", "full_address", "city", "country"],
        "weights": [5, 5, 2, 1],
        "threshold": 80
    }
}

# Run clustering
combined_columns = thresholds["combined"]["columns"]
weights = thresholds["combined"]["weights"]
threshold = thresholds["combined"]["threshold"]

cluster_ids, driver_ids, cluster_sizes, link_scores = weighted_clustering(df, combined_columns, weights, threshold)

# Add results to DataFrame
df["Cluster Id"] = cluster_ids
df["Driver Id"] = driver_ids
df["Cluster Size"] = [cluster_sizes[cluster_id] for cluster_id in cluster_ids]
df["LinkScore"] = link_scores

# Convert back to Spark DataFrame
df_spark_out = spark.createDataFrame(df)

# Write to output table (overwrite mode)
output_table = "my_database.my_schema.dedup_output"
df_spark_out.write.mode("overwrite").saveAsTable(output_table)
print(f"Clustered data saved to {output_table}")