In [0]:
from pyspark.sql.functions import *
from pyspark.sql import Window
from pyspark.sql.types import *

def getMaxID(table):
    maxID = table.agg({"ID": "max"}).collect()[0][0]
    return 0 if maxID is None else maxID

def unifyWithIncremendedID(table,appendedTable,columns,maxID):
    return (table.union(appendedTable)
                       .dropDuplicates(columns)
                       .withColumn("ID",
                                   coalesce(
                                       col("ID"),
                                       row_number().over(Window.orderBy("ID")) + lit(maxID)
                                   )
                                  )
           )

def readTable(fileName):
    return (spark.read.format("parquet")
            .option("header", "true")
            .option("inferSchema", "true")
            .load("abfss://"+container+"@"+storageAccountName+".dfs.core.windows.net/"+
                  dirs["lvl3"]+fileName+".parquet")
            .cache()
           )
    
def loadTable(table,fileName):
    return (table.write
            .format("parquet")
            .mode("overwrite")
            .option("header", "true")
            .save("abfss://"+container+"@"+storageAccountName+".dfs.core.windows.net/"+
                  dirs["lvl3"]+fileName+".parquet"))
    
def updateDimensionTable(config,data,database):
    outputTable = readTable(database)
    maxID = getMaxID(outputTable)

    selectedColumns = [conf.asDict()["targetName"] for conf in config if conf.asDict()["targetBase"]==database]
    selectedData = (data.withColumn("ID",lit(None).cast(LongType()))
                            .select(["ID"]+selectedColumns)
                            .cache()
                   )
    outputTable = unifyWithIncremendedID(outputTable,selectedData,selectedColumns,maxID)
    loadTable(outputTable,database)
    return outputTable

def updateFactsTable(config,data,databases,dimTables,configFileName):
    outputTable = readTable(configFileName.split(".")[0])
    maxID = getMaxID(outputTable)
    
    factsTable = data
    for database,dimTable in list(zip(databases,dimTables)):
        cond = [factsTable[conf.asDict()["targetName"]] == dimTable[conf.asDict()["targetName"]] for conf in config if conf.asDict()["targetBase"]==database]
        factsTable = factsTable.join(dimTable,cond,"inner").withColumnRenamed("ID",database+"_ID")
    selectedColumns = [database+"_ID" for database in databases]
    factsTable = (factsTable.withColumn("ID",lit(None).cast(LongType()))
                            .select(["ID"]+[database+"_ID" for database in databases])
                            .cache()
                 )
    outputTable = unifyWithIncremendedID(outputTable,factsTable,selectedColumns,maxID)
    loadTable(outputTable,configFileName.split(".")[0])
    return outputTable