In [3]:
import org.apache.spark.sql.functions._
import org.apache.spark.storage.StorageLevel

spark.conf.set("spark.sql.autoBroadcastJoinThreshold", "-1")
spark.conf.set("spark.sql.analyzer.failAmbiguousSelfJoin", "false")
spark.conf.set("spark.sql.shuffle.partitions", "4")

val users = spark.read.option("header", "true")
                        .option("inferSchema", "true")
                        .csv("/home/iceberg/data/events.csv")
                        .where($"user_id".isNotNull)

users.createOrReplaceTempView("events")

val devices = spark.read.option("header", "true")
                        .option("inferSchema", "true")
                        .csv("/home/iceberg/data/devices.csv")

devices.createOrReplaceTempView("devices")

val executionDate = "2023-01-01"

//Caching here should be < 5 GBs or used for broadcast join
//You need to tune executor memory otherwise it'll spill to disk and be slow
//Don't really try using any of the other StorageLevel besides MEMORY_ONLY

val eventsAggregated = spark.sql(f"""
  SELECT user_id, 
          device_id, 
        COUNT(1) as event_counts, 
        COLLECT_LIST(DISTINCT host) as host_array
  FROM events
  GROUP BY 1,2
""").cache()

// eventsAggregated.write.mode("overwrite").saveAsTable("bootcamp.events_aggregated_staging")

spark.sql(f"""
    CREATE TABLE IF NOT EXISTS bootcamp.events_aggregated_staging (
        user_id BIGINT,
        device_id BIGINT,
        event_counts BIGINT,
        host_array ARRAY<STRING>
    )
    PARTITIONED BY (ds STRING)
""")


val usersAndDevices = users
  .join(eventsAggregated, eventsAggregated("user_id") === users("user_id"))
  .groupBy(users("user_id"))
  .agg(
    users("user_id"),
    max(eventsAggregated("event_counts")).as("total_hits"),
    collect_list(eventsAggregated("device_id")).as("devices")
  )

val devicesOnEvents = devices
      .join(eventsAggregated, devices("device_id") === eventsAggregated("device_id"))
      .groupBy(devices("device_id"), devices("device_type"))
      .agg(
        devices("device_id"),
        devices("device_type"),
         collect_list(eventsAggregated("user_id")).as("users")
      )

devicesOnEvents.explain()
usersAndDevices.explain()

devicesOnEvents.take(1)
usersAndDevices.take(1)

== Physical Plan ==
AdaptiveSparkPlan isFinalPlan=false
+- ObjectHashAggregate(keys=[device_id#598, device_type#601], functions=[collect_list(user_id#568, 0, 0)])
   +- ObjectHashAggregate(keys=[device_id#598, device_type#601], functions=[partial_collect_list(user_id#568, 0, 0)])
      +- Project [device_id#598, device_type#601, user_id#568]
         +- SortMergeJoin [device_id#598], [device_id#569], Inner
            :- Sort [device_id#598 ASC NULLS FIRST], false, 0
            :  +- Exchange hashpartitioning(device_id#598, 4), ENSURE_REQUIREMENTS, [plan_id=735]
            :     +- Filter isnotnull(device_id#598)
            :        +- FileScan csv [device_id#598,device_type#601] Batched: false, DataFilters: [isnotnull(device_id#598)], Format: CSV, Location: InMemoryFileIndex(1 paths)[file:/home/iceberg/data/devices.csv], PartitionFilters: [], PushedFilters: [IsNotNull(device_id)], ReadSchema: struct<device_id:int,device_type:string>
            +- Sort [device_id#569 ASC NULLS FIRS

import org.apache.spark.sql.functions._
import org.apache.spark.storage.StorageLevel
users: org.apache.spark.sql.Dataset[org.apache.spark.sql.Row] = [user_id: int, device_id: int ... 4 more fields]
devices: org.apache.spark.sql.DataFrame = [device_id: int, browser_type: string ... 2 more fields]
executionDate: String = 2023-01-01
eventsAggregated: org.apache.spark.sql.Dataset[org.apache.spark.sql.Row] = [user_id: int, device_id: int ... 2 more fields]
usersAndDevices: org.apache.spark.sql.DataFrame = [user_id: int, user_id: int ... 2 more fields]
devicesOnEvents: org.apache.spark.sql.DataFrame = [device_id: int, device_type: string ... 3 more fields]
res1: Array[org.apache.spark.sql.Row] = Array([-2147470439,-2147470439,3,WrappedArray(378988111, 378988111, 378988111)])


In [19]:
eventsAggregated.unpersist()

res18: eventsAggregated.type = [user_id: int, device_id: int ... 2 more fields]
