In [199]:
# Configurations
JSON_FILE = "./assets/sample.json"

In [205]:
from pyspark.sql import SparkSession
import pyspark.sql.functions as F

# Create spark session
spark = SparkSession \
    .builder \
    .appName("PySpark SQL App") \
    .config("", "") \
    .getOrCreate()

In [227]:
# Load data files
web_log_data = spark.read.option("multiline","true").json(JSON_FILE)
web_log_data.show()

+-------------+---------+-------+
|       domain|timestamp|user_id|
+-------------+---------+-------+
|    apple.com|111111110| 123456|
|   google.com|010111110| 123456|
| facebook.com|010101111| 123456|
|       amazon|010101111| 123456|
|microsoft.com|010111111| 123456|
|    apple.com|110101010| 234567|
|   google.com|011101010| 234567|
|  netflix.com|010111110| 234567|
| telegram.com|010101111|    431|
|  pokemon.com|010101110|    431|
|  digimon.com|111101010|    431|
+-------------+---------+-------+



In [229]:
# Get Unique domain names
unique_domains = web_log_data.select(web_log_data.domain).distinct()

In [207]:
# Group data by user ID
grouped_web_log_data = web_log_data \
    .groupby('user_id') \
    .agg(
        F.struct(
            F.collect_list(
                web_log_data.domain,
            ).alias('domains'),
            F.collect_list(
                web_log_data.timestamp,
            ).alias('timestamps')
        ).alias('web_log')
    )

In [225]:
import pyspark.sql.types as T
from functools import cmp_to_key, reduce
from operator import mul
from typing import List

# Calculate norm of vector
def get_normalized(vector):
    norm = reduce(lambda acc, cur: acc + cur ** 2, vector) ** 0.5
    return [x / norm for x in vector]

# Calculate correlation between two stamp
def get_timestamp_cosine_similarity(
        timestamp_a,
        timestamp_b
    ):
    a = [float(i) for i in timestamp_a]
    b = [float(i) for i in timestamp_b]
    
    a = get_normalized(a)
    b = get_normalized(b)

    dot_product = reduce(
        lambda acc, curr: acc + mul(*curr),
        zip(a, b),
        0
    )

    return dot_product

# Create correlation matrix
@F.udf(returnType=T.MapType(
        T.StringType(),
        T.MapType(
            T.StringType(),
            T.FloatType(),
            False
            ),
        False
    ))
def get_user_correlation_matrix(
        domains: T.ArrayType,
        timestamps: T.ArrayType
    ):
    
    n_domains = len(domains)
    
    combined: List[List[str]] = [[a, b] for a, b in zip(domains, timestamps)]
    
    combined = sorted(
        combined
    )

    # Correlatio matrix per user
    correlation_matrix = {}
        
    for i in range(n_domains):
        current_domain = {}
        for j in range(i + 1, n_domains):
            current_domain[combined[j][0]] = get_timestamp_cosine_similarity(
                combined[i][1],
                combined[j][1]
            )
        correlation_matrix[combined[i][0]] = current_domain
    
    return correlation_matrix
    
correlated_web_log_data = grouped_web_log_data \
    .withColumn(
        "correlation_matrix",
        get_user_correlation_matrix(
            grouped_web_log_data.web_log.domains,
            grouped_web_log_data.web_log.timestamps
        )
    )

correlated_web_log_data.select(
    correlated_web_log_data.correlation_matrix
).show(truncate=False)

+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
|correlation_matrix                                                                                                                                                                                                                                                                                                                                      |
+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

In [197]:
# Terminate the spark session
spark.stop()