Init spark session and needed functions to make a similarity checks using ML model

In [1]:
import findspark
findspark.init()

from pyspark import SparkContext
from pyspark.sql import SparkSession
from pyspark.sql import functions as F


sc = SparkContext(appName="MyApp2")
spark = (
    SparkSession.builder 
    .master('local[*]') 
    .config('spark.executor.memory','5g')
    .config('spark.driver.memory', '3g')
    .config("spark.driver.maxResultSize", "2g")
    .getOrCreate()
)


import warnings
warnings.filterwarnings("ignore") 

from sentence_transformers import SentenceTransformer, util
model = SentenceTransformer('sentence-transformers/all-mpnet-base-v2')

def get_vector(a, b):
    if not a or not b:
        return 0.0

    emb1 = model.encode(a, convert_to_tensor=True)
    emb2 = model.encode(b, convert_to_tensor=True)
    torch = util.pytorch_cos_sim(emb1, emb2)
    return float(torch[0][0])



Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
23/10/30 11:52:26 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable
23/10/30 11:52:27 WARN Utils: Service 'SparkUI' could not bind on port 4040. Attempting port 4041.


In [2]:

from sentence_transformers import SentenceTransformer, util
model = SentenceTransformer('sentence-transformers/all-mpnet-base-v2')

broad_model = sc.broadcast(model)

def get_vector(partitionData):
    updatedData = []
    for row in partitionData:
        emb1 = broad_model.value.encode(row["fb_name"], convert_to_tensor=True)
        emb2 = broad_model.value.encode(row["gg_name"], convert_to_tensor=True)
        torch = util.pytorch_cos_sim(emb1, emb2)
        updatedData.append([*row, float(torch[0][0])])
    return iter(updatedData)




In [3]:
fb_df_clean = spark.read.parquet("fb_df")
gg_df_clean = spark.read.parquet("gg_df")
wb_df_clean = spark.read.parquet("wb_df")

                                                                                

In [4]:

fb_gg_cross = fb_df_clean.crossJoin(gg_df_clean)

cross_cat = fb_gg_cross.filter(F.col("fb_category") == F.col("gg_category"))
cross_cat_country = cross_cat.filter(F.col("fb_country_name") == F.col("gg_country_name"))
cross_cat_country_city = cross_cat_country.filter(F.col("fb_city") == F.col("gg_city"))
cross_cat_country_city_dom = cross_cat_country_city.filter(F.col("fb_domain") == F.col("gg_domain"))
cross_cat_country_city_dom_phone = cross_cat_country_city_dom.filter(F.col("fb_phone") == F.col("gg_phone"))

# cross_cat_country_city_dom_phone_vec = cross_cat_country_city_dom_phone.withColumn("vector", F.lit(get_vector_udf(F.col("fb_name"), F.col("gg_name"))))

cross_cat_country_city_dom_phone_vec = cross_cat_country_city_dom_phone.rdd.mapPartitions(get_vector).toDF(cross_cat_country_city_dom_phone.columns + ["vector"])

print(cross_cat_country_city_dom_phone_vec.count())



13785


                                                                                



In [37]:
cross_cat_country_city_dom_phone_vec.filter(F.col("vector") > 0.9).count()

                                                                                

9099



In [None]:
cross_cat_country_city_dom_phone_vec.filter(F.col("vector") > 0.9).show()

Read the datasets

In [None]:
import random
from pyspark.sql.types import StructField, DoubleType, StructType

import warnings
import pandas as pd

# Suppress FutureWarning about iteritems
warnings.simplefilter(action='ignore', category=FutureWarning)

i = 1

while True:
    fraction = 0.002

    fb_df = fb_df_clean.sample(withReplacement=True, fraction=fraction*2, seed=random.randint(1, 100000))
    gg_df = gg_df_clean.sample(withReplacement=True, fraction=fraction, seed=random.randint(1, 100000))

    fb_gg_cross = fb_df.crossJoin(gg_df)

    schema = StructType(fb_gg_cross.schema.fields + [StructField("vector", DoubleType())])

    fb_gg_cross_vec = fb_gg_cross.toPandas()

    fb_gg_cross_vec["vector"] = fb_gg_cross_vec.apply(lambda x: get_vector(x["fb_name"], x["gg_name"]), axis=1)

    df = spark.createDataFrame(fb_gg_cross_vec, schema)

    if df.filter(F.col("vector") > 0.8).count() > 0:
        print(fb_df.count())
        print(gg_df.count())
        df.filter(F.col("vector") > 0.8).show(200, truncate=False)
        i += 1

        if i == 10:
            break

    fraction += 0.0001


In [None]:
cross = fb_df_clean.crossJoin(gg_df_clean)

# print(cross.count())

# cross.show(5, truncate=False)

In [None]:
cross_cat = cross.filter(F.col("fb_category") == F.col("gg_category"))
cross_cat_dom = cross_cat.filter(F.col("fb_domain") == F.col("gg_domain"))
print(cross_cat_dom.count())



111251


                                                                                

In [None]:
cross_cat = cross.filter(F.col("fb_category") == F.col("gg_category"))
cross_cat_country = cross_cat.filter(F.col("fb_country_name") == F.col("gg_country_name"))
print(cross_cat_country.count())



42536979


                                                                                

In [None]:
cross_cat = cross.filter(F.col("fb_category") == F.col("gg_category"))
cross_cat_city = cross_cat.filter(F.col("fb_city") == F.col("gg_city"))
print(cross_cat_city.count())

677073


In [None]:
cross_cat = cross.filter(F.col("fb_category") == F.col("gg_category"))
cross_cat_country = cross_cat.filter(F.col("fb_country_name") == F.col("gg_country_name"))
cross_cat_country_city = cross_cat_country.filter(F.col("fb_city") == F.col("gg_city"))
cross_cat_country_city_dom = cross_cat_country_city.filter(F.col("fb_domain") == F.col("gg_domain"))
cross_cat_country_city_dom_phone = cross_cat_country_city_dom.filter(F.col("fb_phone") == F.col("gg_phone"))
print(cross_cat_country_city_dom_phone.count())

                                                                                

13785


In [None]:
cross_cat = cross.filter(F.col("fb_category") == F.col("gg_category"))
cross_cat_country = cross_cat.filter(F.col("fb_country_name") == F.col("gg_country_name"))
cross_cat_country_city = cross_cat_country.filter(F.col("fb_city") == F.col("gg_city"))
cross_cat_country_city_dom = cross_cat_country_city.filter(F.col("fb_domain") == F.col("gg_domain"))
print(cross_cat_country_city_dom.count())

23647


In [None]:
cross_cat = cross.filter(F.col("fb_category") == F.col("gg_category"))
cross_cat_country = cross_cat.filter(F.col("fb_country_code") == F.col("gg_country_code"))
print(cross_cat_country.count())



58281985


                                                                                

In [None]:
cross_cat = cross.filter(F.col("fb_category") == F.col("gg_category"))
cross_cat_country = cross_cat.filter(F.col("fb_region_code") == F.col("gg_region_code"))
print(cross_cat_country.count())

10799200


In [None]:
cross_cat = cross.filter(F.col("fb_category") == F.col("gg_category"))
cross_cat_region = cross_cat.filter(F.col("fb_region_name") == F.col("gg_region_name"))
print(cross_cat_region.count())

10755307


In [None]:
cross_cat_region = cross.filter(F.col("fb_region_name") == F.col("gg_region_name"))
cross_cat_region = cross_cat_region.filter(F.col("fb_region_name").isNotNull())
print(cross_cat_region.count())



3133129396


                                                                                

In [None]:
cross_cat = cross.filter(F.col("fb_category") == F.col("gg_category"))
cross_cat_phone = cross_cat.filter(F.col("fb_phone") == F.col("gg_phone"))
print(cross_cat_phone.count())

22274


In [None]:
cross_cat_phone = cross.filter(F.col("fb_phone") == F.col("gg_phone"))
print(cross_cat_phone.count())

147943


In [None]:
import random
from pyspark.sql.types import StructField, DoubleType, StructType

import warnings

cross_cat = cross.filter(F.col("fb_category") == F.col("gg_category"))
cross_cat_phone = cross_cat.filter(F.col("fb_phone") == F.col("gg_phone"))
print(cross_cat_phone.count())

# Suppress FutureWarning about iteritems
warnings.simplefilter(action='ignore', category=FutureWarning)

schema = StructType(cross_cat_phone.schema.fields + [StructField("vector", DoubleType())])

cross_cat_phone_vec = cross_cat_phone.toPandas()

cross_cat_phone_vec["vector"] = cross_cat_phone_vec.apply(lambda x: get_vector(x["fb_name"], x["gg_name"]), axis=1)

df_vec_cat_phone_orig = spark.createDataFrame(cross_cat_phone_vec, schema)

df_vec_cat_phone = df_vec_cat_phone_orig.filter(F.col("vector") > 0.8)

if df_vec_cat_phone.count() > 0:
    print(df_vec_cat_phone.count())

    df_vec_cat_phone.show(20, truncate=False)




22274


                                                                                

23/10/29 08:26:46 WARN TaskSetManager: Stage 53 contains a task of very large size (1159 KiB). The maximum recommended task size is 1000 KiB.


                                                                                

23/10/29 08:26:47 WARN TaskSetManager: Stage 56 contains a task of very large size (1159 KiB). The maximum recommended task size is 1000 KiB.
17170
+----------------------------+----------------------------------------------------------------------------------+------------+---------------+---------------+--------------------------------------------+-----------+--------------+----------------+-----------+----------------------+----------------------------+---------------------------------------------------------------+----------------------+-------------+---------------+---------------+-----------------------------------------+-----------+--------------+----------------+-----------+-------------------------------------------------------------------+------------------+
|fb_domain                   |fb_address                                                                        |fb_city     |fb_country_code|fb_country_name|fb_name                                     |fb_phone   |fb_regi

In [None]:
cross_cat = cross.filter(F.col("fb_category") == F.col("gg_category"))
cross_cat_country = cross_cat.filter(F.col("fb_country_name") == F.col("gg_country_name"))
cross_cat_country_city = cross_cat_country.filter(F.col("fb_city") == F.col("gg_city"))
cross_cat_country_city_dom = cross_cat_country_city.filter(F.col("fb_domain") == F.col("gg_domain"))

import random
from pyspark.sql.types import StructField, DoubleType, StructType

import warnings
import pandas as pd

# Suppress FutureWarning about iteritems
warnings.simplefilter(action='ignore', category=FutureWarning)

schema = StructType(cross_cat_country_city_dom.schema.fields + [StructField("vector", DoubleType())])

cross_cat_country_city_dom_vec = cross_cat_country_city_dom.toPandas()

cross_cat_country_city_dom_vec["vector"] = cross_cat_country_city_dom_vec.apply(lambda x: get_vector(x["fb_name"], x["gg_name"]), axis=1)

df_vec_orig = spark.createDataFrame(cross_cat_country_city_dom_vec, schema)

df_vec = df_vec_orig.filter(F.col("vector") > 0.8)

if df_vec.count() > 0:
    print(df_vec.count())

    df_vec.show(20, truncate=False)




                                                                                

23/10/29 09:06:07 WARN TaskSetManager: Stage 62 contains a task of very large size (1242 KiB). The maximum recommended task size is 1000 KiB.


                                                                                

23/10/29 09:06:08 WARN TaskSetManager: Stage 65 contains a task of very large size (1242 KiB). The maximum recommended task size is 1000 KiB.
16055
23/10/29 09:06:09 WARN TaskSetManager: Stage 68 contains a task of very large size (1242 KiB). The maximum recommended task size is 1000 KiB.
+----------------------------+---------------------------------------------------------------------+------------+---------------+---------------+-----------------------------------------------------------------+-----------+--------------+----------------+-----------+--------------------+----------------------------+---------------------------------------------------------------+--------------------+------------+---------------+---------------+--------------------------------------------+-----------+--------------+----------------+-----------+-------------------------------------------------------------------------------+------------------+
|fb_domain                   |fb_address                      

In [None]:


cross_cat = cross.filter(F.col("fb_category") == F.col("gg_category"))
cross_cat_dom = cross_cat.filter(F.col("fb_domain") == F.col("gg_domain"))
print(cross_cat_dom.count())



import random
from pyspark.sql.types import StructField, DoubleType, StructType

import warnings
import pandas as pd

# Suppress FutureWarning about iteritems
warnings.simplefilter(action='ignore', category=FutureWarning)

schema = StructType(cross_cat_dom.schema.fields + [StructField("vector", DoubleType())])

cross_cat_dom_vec = cross_cat_dom.toPandas()

cross_cat_dom_vec["vector"] = cross_cat_dom_vec.apply(lambda x: get_vector(x["fb_name"], x["gg_name"]), axis=1)

df_vec_cat_dom_orig = spark.createDataFrame(cross_cat_dom_vec, schema)

df_vec_cat_dom = df_vec_cat_dom_orig.filter(F.col("vector") > 0.8)

if df_vec_cat_dom.count() > 0:
    print(df_vec_cat_dom.count())

    df_vec_cat_dom.show(20, truncate=False)




111251


                                                                                

23/10/29 12:30:38 WARN TaskSetManager: Stage 75 contains a task of very large size (4631 KiB). The maximum recommended task size is 1000 KiB.


                                                                                

23/10/29 12:30:40 WARN TaskSetManager: Stage 78 contains a task of very large size (4631 KiB). The maximum recommended task size is 1000 KiB.


                                                                                

34076
23/10/29 12:30:41 WARN TaskSetManager: Stage 81 contains a task of very large size (4631 KiB). The maximum recommended task size is 1000 KiB.


[Stage 81:>                                                         (0 + 1) / 1]

23/10/29 12:30:45 WARN PythonRunner: Detected deadlock while completing task 0.0 in stage 81 (TID 318): Attempting to kill Python Worker
+----------------------------+----------------------------------------------------------------------------------+-----------+---------------+---------------+-----------------------------------------------------------------+-----------+--------------+--------------+-----------+----------------------+----------------------------+--------------------------------------------------------------------------------------------------+----------------------+-------------+---------------+---------------+--------------------------------------------+------------+--------------+----------------+-----------+--------------------------------------------------------------------------------------------------+------------------+
|fb_domain                   |fb_address                                                                        |fb_city    |fb_country_code|fb_c

                                                                                

In [None]:
orig = df_vec_cat_dom_orig.unionByName(df_vec_orig).unionByName(df_vec_cat_phone_orig).distinct()

In [None]:
# orig.show(truncate=False)
orig.filter(F.col("vector") > 0.75).sort(F.col("vector")).show(5, truncate=False)

23/10/29 13:08:14 WARN TaskSetManager: Stage 121 contains a task of very large size (4631 KiB). The maximum recommended task size is 1000 KiB.




+-------------------+-----------------------------------------------------+-------+---------------+--------------------+--------------------------------+-----------+--------------+--------------+-----------+-----------------------------+-------------------+------------------------------------------------------------------------------------------------------------------------------------------+-----------------------------+-------+---------------+--------------------+------------------------------------+-----------+--------------+--------------+-----------+------------------------------------------------------------------------------------+------------------+
|fb_domain          |fb_address                                           |fb_city|fb_country_code|fb_country_name     |fb_name                         |fb_phone   |fb_region_code|fb_region_name|fb_zip_code|fb_category                  |gg_domain          |gg_address                                                                  

                                                                                

In [None]:
# Group by the necessary columns
grouped_df = orig.groupBy(
    "fb_domain", "fb_address", "fb_city", "fb_country_code", "fb_country_name",
    "fb_name", "fb_phone", "fb_region_code", "fb_region_name", "fb_zip_code",
    "gg_domain", "gg_address", "gg_city", "gg_country_code", "gg_country_name",
    "gg_name", "gg_phone", "gg_region_code", "gg_region_name", "gg_zip_code", 
    "gg_raw_address", "vector"
).agg(
    # Concatenate unique categories from both fb_category and gg_category with "&" separator
    F.concat_ws(" & ", F.collect_set("fb_category")).alias("categories"),
)

# Show the result
grouped_df.show(truncate=False)



23/10/29 13:39:09 WARN TaskSetManager: Stage 157 contains a task of very large size (4631 KiB). The maximum recommended task size is 1000 KiB.




+-----------------------+---------------------------------------------------------------------------+-------------------+---------------+---------------+-------------------------------+------------+--------------+--------------+-----------+-----------------------+--------------------------------------------------------------------------------+----------------+---------------+---------------+-----------------------------------------------------------+------------+--------------+--------------+-----------+------------------------------------------------+-------------------+--------------------------------------+
|fb_domain              |fb_address                                                                 |fb_city            |fb_country_code|fb_country_name|fb_name                        |fb_phone    |fb_region_code|fb_region_name|fb_zip_code|gg_domain              |gg_address                                                                      |gg_city         |gg_country_code|gg_c

                                                                                

In [None]:
grouped_df.filter(F.col("vector") > 0.90).sort(F.col("vector")).show(100, truncate=False)

23/10/29 13:42:05 WARN TaskSetManager: Stage 193 contains a task of very large size (4631 KiB). The maximum recommended task size is 1000 KiB.




+-------------------------------+------------------------------------------------------------------------------------------------+------------------------+---------------+---------------+------------------------------------------------------------+-----------+--------------+-----------------------+-----------+-------------------------------+-------------------------------------------------------------------------------+------------------------+---------------+---------------+-------------------------------------------------------------+------------+--------------+-----------------------+-----------+-------------------------------------------------------------------------------+------------------+-------------------------------------------+
|fb_domain                      |fb_address                                                                                      |fb_city                 |fb_country_code|fb_country_name|fb_name                                                     |fb

                                                                                

In [None]:
df_vec_cat_phone.unionByName(df_vec_cat_dom).unionByName(df_vec).distinct().show()


23/10/29 07:45:56 WARN TaskSetManager: Stage 44 contains a task of very large size (1159 KiB). The maximum recommended task size is 1000 KiB.




+--------------------+--------------------+--------------------+---------------+---------------+--------------------+-----------+--------------+--------------+-----------+--------------------+--------------------+--------------------+--------------------+--------------------+---------------+---------------+--------------------+-----------+--------------+---------------+-----------+--------------------+------------------+
|           fb_domain|          fb_address|             fb_city|fb_country_code|fb_country_name|             fb_name|   fb_phone|fb_region_code|fb_region_name|fb_zip_code|         fb_category|           gg_domain|          gg_address|         gg_category|             gg_city|gg_country_code|gg_country_name|             gg_name|   gg_phone|gg_region_code| gg_region_name|gg_zip_code|      gg_raw_address|            vector|
+--------------------+--------------------+--------------------+---------------+---------------+--------------------+-----------+--------------+------

                                                                                

In [None]:
# df_vec.filter(F.col("vector") > 0.95).count()

In [None]:
print("fb_phone", cross.filter(F.col("fb_phone").isNotNull() & F.col("gg_phone").isNotNull() & (F.col("fb_phone") != F.col("gg_phone"))).count())

print("fb_domain", cross.filter(F.col("fb_domain").isNotNull() & F.col("gg_domain").isNotNull() & (F.col("fb_domain") != F.col("gg_domain"))).count())

print("fb_country_name", cross.filter(F.col("fb_country_name").isNotNull() & F.col("gg_country_name").isNotNull() & (F.col("fb_country_name") != F.col("gg_country_name"))).count())

print("fb_region_name", cross.filter(F.col("fb_region_name").isNotNull() & F.col("gg_region_name").isNotNull() & (F.col("fb_region_name") != F.col("gg_region_name"))).count())

print("fb_country_code", cross.filter(F.col("fb_country_code").isNotNull() & F.col("gg_country_code").isNotNull() & (F.col("fb_country_code") != F.col("gg_country_code"))).count())

print("fb_region_code", cross.filter(F.col("fb_region_code").isNotNull() & F.col("gg_region_code").isNotNull() & (F.col("fb_region_code") != F.col("gg_region_code"))).count())

print("fb_category", cross.filter(F.col("fb_category").isNotNull() & F.col("gg_category").isNotNull() & (F.col("fb_category") != F.col("gg_category"))).count())

print("fb_city", cross.filter(F.col("fb_city").isNotNull() & F.col("gg_city").isNotNull() & (F.col("fb_city") != F.col("gg_city"))).count())

print("fb_zip_code", cross.filter(F.col("fb_zip_code").isNotNull() & F.col("gg_zip_code").isNotNull() & (F.col("fb_zip_code") != F.col("gg_zip_code"))).count())

                                                                                

fb_phone 49451460947


                                                                                

fb_domain 84346289212


                                                                                

fb_country_name 33229118656


                                                                                

fb_region_name 41575495484


                                                                                

fb_country_code 42078913365


                                                                                

fb_region_code 41567425988


                                                                                

fb_category 76588187075


                                                                                

fb_city 44559004931




fb_zip_code 31156369283


                                                                                