In [1]:
import os, sys
os.environ['HADOOP_CONF_DIR'] = '/etc/hadoop/conf'
os.environ['YARN_CONF_DIR'] = '/etc/hadoop/conf'
os.environ['PYSPARK_PYTHON'] =  'python3.9'
os.environ['PYSPARK_DRIVER_PYTHON'] = 'python3.9'
os.environ['HADOOP_USER_NAME']='ssenigov'

from pyspark.sql import SparkSession
from pyspark import SparkContext, SparkConf

In [2]:
conf = SparkConf().setAppName('DistributeSum').setMaster('yarn')
spark = SparkSession.builder.config(conf=conf).getOrCreate()
print("app_id".ljust(40), spark.sparkContext.applicationId)

Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
25/01/07 21:33:10 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable
25/01/07 21:33:13 WARN DomainSocketFactory: The short-circuit local reads feature cannot be used because libhadoop cannot be loaded.
25/01/07 21:33:13 WARN Client: Neither spark.yarn.jars nor spark.yarn.archive is set, falling back to uploading libraries under SPARK_HOME.


app_id                                   application_1727681258360_0100


In [3]:
df_managers = spark.range(1, 8).withColumnRenamed('id', 'manager_id')
df_total_plan = spark.createDataFrame([{'total_plan': 100}])

df_managers.createOrReplaceTempView("managers")
df_total_plan.createOrReplaceTempView("total_plan")
df_managers.show()
df_total_plan.show()

                                                                                

+----------+
|manager_id|
+----------+
|         1|
|         2|
|         3|
|         4|
|         5|
|         6|
|         7|
+----------+



[Stage 1:>                                                          (0 + 1) / 1]

+----------+
|total_plan|
+----------+
|       100|
+----------+



                                                                                

In [4]:
sql = """
 with cte_man_cnt as (
   select count(1) manager_cnt 
    from managers),
 cte_lines_plan as (
   select explode(sequence(1, total_plan)) line_plan 
    from total_plan ),
 cte_plan_to_manager as (
   select line_plan, manager_cnt, (line_plan-1)%manager_cnt + 1 as manager_id
    from cte_lines_plan join cte_man_cnt)
 select manager_id, count(1) manager_plan
  from cte_plan_to_manager
 group by manager_id
 order by manager_id
"""    
spark.sql(sql).show()



+----------+------------+
|manager_id|manager_plan|
+----------+------------+
|         1|          15|
|         2|          15|
|         3|          14|
|         4|          14|
|         5|          14|
|         6|          14|
|         7|          14|
+----------+------------+



                                                                                

In [5]:
sql = """ 
 with managers (
  select manager_id, count(1) over(partition by 1 ) manager_cnt 
   from managers ),
  plan (
  select manager_id, total_plan, floor(total_plan/manager_cnt) pre_manager_plan,
   total_plan - floor(total_plan/manager_cnt)*manager_cnt diff,
   row_number() over (partition by 1 order by manager_id) rn
  from total_plan join managers )
 select plan.manager_id, plan.total_plan, plan.diff, plan.rn, 
   pre_manager_plan + (case when rn<=diff then 1 else 0 end) manager_plan
 from plan
 order by manager_id
 """
spark.sql(sql).show()

+----------+----------+----+---+------------+
|manager_id|total_plan|diff| rn|manager_plan|
+----------+----------+----+---+------------+
|         1|       100|   2|  1|          15|
|         2|       100|   2|  2|          15|
|         3|       100|   2|  3|          14|
|         4|       100|   2|  4|          14|
|         5|       100|   2|  5|          14|
|         6|       100|   2|  6|          14|
|         7|       100|   2|  7|          14|
+----------+----------+----+---+------------+



In [6]:
manager_cnt = spark.table('managers').count()

sql = """
 with cte_lines_plan as (
  select explode(sequence(1, total_plan)) line_plan 
   from total_plan),
 cte_man_plan as (
  select line_plan, 
    NTILE({manager_cnt}) over(partition by 1 order by line_plan) manager_id
    /* NTILE needs evaluable expression, not value from any column */
   from cte_lines_plan)
 select manager_id, count(1)
   from cte_man_plan
 group by manager_id 
 order by manager_id 
""".format(manager_cnt=manager_cnt)
spark.sql(sql).show()

+----------+--------+
|manager_id|count(1)|
+----------+--------+
|         1|      15|
|         2|      15|
|         3|      14|
|         4|      14|
|         5|      14|
|         6|      14|
|         7|      14|
+----------+--------+



In [7]:
# spark.stop()