In [58]:
from pyspark.sql.types import StructType, StructField, StringType, ArrayType, Row
from functools import reduce

In [59]:
filename = './test_data/events.json'

df = spark.read.option("multiline","true").json(filename)
df.printSchema()
df.show()

root
 |-- id: string (nullable = true)
 |-- results: array (nullable = true)
 |    |-- element: string (containsNull = true)
 |-- term: string (nullable = true)
 |-- transaction_id: string (nullable = true)
 |-- type: string (nullable = true)

+------+--------------------+-----------+--------------------+--------------+
|    id|             results|       term|      transaction_id|          type|
+------+--------------------+-----------+--------------------+--------------+
|  null|                null|      omnis|9e4bdbef-86fe-4ac...|        SEARCH|
|  null|[X00034, X00084, ...|       null|9e4bdbef-86fe-4ac...|SEARCH_RESULTS|
|X00022|                null|       null|9e4bdbef-86fe-4ac...|   REPORT_VIEW|
|X00084|                null|       null|9e4bdbef-86fe-4ac...|   REPORT_VIEW|
|  null|                null|      quasi|067bf59a-2f53-459...|        SEARCH|
|  null|[X00042, X00031, ...|       null|067bf59a-2f53-459...|SEARCH_RESULTS|
|  null|                null|   sapiente|e0bf0ee2-3b0a

In [60]:
event_schema = StructType(
        [
            StructField("id", StringType(), True),
            StructField("results", ArrayType(StringType(), True), True),
            StructField("term", StringType(), True),
            StructField("transaction_id", StringType(), True),
            StructField("type", StringType(), True),
        ]
    )

group_schema = StructType(
        [
            StructField("t_id", StringType(), True),
            StructField("events", ArrayType(event_schema, True), True),
        ]
    )

In [61]:
grouped = df.rdd.groupBy(lambda row: row['transaction_id']) \
    .map(lambda group: group[1]) \
    .map(list)

print('grouped', type(grouped))

print('grouped.first()', type(grouped.first()))

def intersection(lst1, lst2):
    return list(set(lst1) & set(lst2))

def mapper(group):
    transaction_id = group[0]['transaction_id']
    reports_viewed = [row['id'] for row in group if row['id'] != None]
    search_results = [result for row in group if row['results'] != None for result in row['results']]
    success = len(intersection(reports_viewed, search_results)) > 0
    return Row(
        transaction_id=transaction_id,
        reports_viewed=reports_viewed,
        search_results=search_results,
        success=success
    )

mapped = grouped.map(mapper)

res = mapped.collect()

for r in res:
    print(r)

grouped <class 'pyspark.rdd.PipelinedRDD'>
grouped.first() <class 'list'>
Row(transaction_id='9e4bdbef-86fe-4ace-aa61-d29a64bbb5f4', reports_viewed=['X00022', 'X00084'], search_results=['X00034', 'X00084', 'X00009', 'X00071'], success=True)
Row(transaction_id='067bf59a-2f53-459b-8ad0-2cb4a9639519', reports_viewed=[], search_results=['X00042', 'X00031', 'X00095'], success=False)
Row(transaction_id='e0bf0ee2-3b0a-4095-9483-28aecd5e1437', reports_viewed=[], search_results=[], success=False)
Row(transaction_id='507e386e-d29a-4063-9621-d51dc91063cf', reports_viewed=[], search_results=['X00097'], success=False)
Row(transaction_id='029c1fe7-722a-45f3-86f8-a88d57057567', reports_viewed=[], search_results=[], success=False)
Row(transaction_id='f92a19c9-7918-40b1-9cbd-0d066b115326', reports_viewed=[], search_results=['X00087', 'X00030', 'X00045'], success=False)
Row(transaction_id='6d030326-b720-433c-8e96-03b7a3f3cf15', reports_viewed=[], search_results=[], success=False)
Row(transaction_id='418