In [13]:
from pyspark import SparkContext, SparkConf
from pyspark.sql import SparkSession


from pyspark.sql.types import StringType, IntegerType, ArrayType
from pyspark.sql.functions import udf
from pyspark.sql import functions as F
from pyspark.sql.functions import collect_set, collect_list
import numpy as np



def create_spark_context(master_ip='127.0.0.1'):
    master_ip = 'spark://{}:7077'.format(master_ip)
    spark = SparkSession.builder \
        .master(master_ip)  \
        .enableHiveSupport() \
        .getOrCreate()
    
    sc = spark.sparkContext

    return (spark, sc)

spark, sc = create_spark_context()

In [14]:
input_filename = '../ressources/data/financial_sells_100000.csv'
separator = ','

init_flat_data = spark.read         \
    .option("sep", separator)  \
    .csv(input_filename, header=True)

init_flat_data = init_flat_data.fillna(0)
init_flat_data.printSchema()

root
 |-- _c0: string (nullable = true)
 |-- company_name: string (nullable = true)
 |-- company_id: string (nullable = true)
 |-- country: string (nullable = true)
 |-- sector_name: string (nullable = true)
 |-- order_id: string (nullable = true)
 |-- product_name: string (nullable = true)
 |-- NBI: string (nullable = true)
 |-- order_date: string (nullable = true)



In [15]:
rdd_product_by_group = init_flat_data  \
        .groupby(['sector_name', 'country'])   \
        .agg( collect_set('product_name').alias('all_product_group') )    

grouped_customers = init_flat_data  \
        .groupby(['sector_name', 'country', 'company_name', 'company_id'])    \
        .agg(collect_set('product_name').alias('product_bought')) 


# define two methods that calculate unbought products
calculate_diff = udf(lambda x, y: len(set(x) - set(y)) , IntegerType())
calculate_unbought_products = udf(lambda x, y: list(set(x) - set(y)) , ArrayType(StringType()))

#convert array to string
convert_array2str = udf(lambda x: ','.join(x) , StringType())

detection_result = grouped_customers    \
    .join(rdd_product_by_group, on=['sector_name', 'country'])  \
    .withColumn('diff_count', calculate_diff('all_product_group', 'product_bought')) \
    .withColumn('unbought_products', calculate_unbought_products('all_product_group', 'product_bought')) \
    .filter('diff_count > 0')


# remove array columns
detection_result = detection_result \
    .withColumn('unbought_products', convert_array2str('unbought_products'))  \
    .withColumn('product_bought', convert_array2str('product_bought'))  \
    .withColumn('all_product_group', convert_array2str('all_product_group')) 



print( detection_result.show(10) )

+--------------------+--------------+--------------------+----------+--------------------+--------------------+----------+--------------------+
|         sector_name|       country|        company_name|company_id|      product_bought|   all_product_group|diff_count|   unbought_products|
+--------------------+--------------+--------------------+----------+--------------------+--------------------+----------+--------------------+
|Metallurgical ind...|        France|  UnitedHealth Group|        49|Commodities,Asset...|Commodities,Expor...|         1|      Export Finance|
| Electrical industry|         Italy|              Nestle|        11|Life and Dammage ...|Life and Dammage ...|         1|    Leverage Finance|
|     Energy industry|        Suisse|Ping An Insurance...|        48|ALD Car Renting a...|ALD Car Renting a...|         1|           Financing|
|     Energy industry|        Suisse|          ExxonMobil|        13|Financing,General...|ALD Car Renting a...|         1|ALD Car Rentin

In [7]:
detection_result.coalesce(1) \
    .write   \
    .format("csv") \
    .mode('overwrite') \
    .option("header", "true") \
    .save('../ressources/data/2_cross_selling_output')

In [16]:
# save result in my postgresql database
mode = "overwrite"
table_name = 'algo_cross_selling_analysis'
url = "jdbc:postgresql://127.0.0.1:5432/financial_opportunities"
properties = {
    "user": "zouhairhajji",
    "password": '',
    "driver": "org.postgresql.Driver"
}
detection_result  \
        .write     \
        .jdbc(url=url, table=table_name, mode=mode, properties=properties)

----------------------------------------
Exception happened during processing of request from ('127.0.0.1', 54835)
Traceback (most recent call last):
  File "/usr/local/Cellar/python/3.7.2/Frameworks/Python.framework/Versions/3.7/lib/python3.7/socketserver.py", line 316, in _handle_request_noblock
    self.process_request(request, client_address)
  File "/usr/local/Cellar/python/3.7.2/Frameworks/Python.framework/Versions/3.7/lib/python3.7/socketserver.py", line 347, in process_request
    self.finish_request(request, client_address)
  File "/usr/local/Cellar/python/3.7.2/Frameworks/Python.framework/Versions/3.7/lib/python3.7/socketserver.py", line 360, in finish_request
    self.RequestHandlerClass(request, client_address, self)
  File "/usr/local/Cellar/python/3.7.2/Frameworks/Python.framework/Versions/3.7/lib/python3.7/socketserver.py", line 720, in __init__
    self.handle()
  File "/usr/local/lib/python3.7/site-packages/pyspark/accumulators.py", line 269, in handle
    poll(accum_u