In [23]:
from pyspark.sql import SparkSession

# SparkSession 생성
spark = SparkSession.builder \
    .appName("241211_03_SparkSQL_UDF") \
    .getOrCreate()


In [24]:
datas = [
    ("A", "2022-04-16", 31200),
    ("B", "2022-04-17", 41200),
    ("C", "2022-04-11", 31500),
    ("D", "2022-04-12", 21500),
    ("E", "2022-04-13", 51000)
]
columns = ["product", "date", "price"]

In [3]:
df = spark.createDataFrame(data = datas, schema = columns)
df.show()

[Stage 0:>                                                          (0 + 1) / 1]

+-------+----------+-----+
|product|      date|price|
+-------+----------+-----+
|      A|2022-04-16|31200|
|      B|2022-04-17|41200|
|      C|2022-04-11|31500|
|      D|2022-04-12|21500|
|      E|2022-04-13|51000|
+-------+----------+-----+



                                                                                

### UDF 

user defined function 사용자 정의 함수 

1. 파이썬의 함수로 정의
2. spark.udf.register()로 등록
3. session이 지나고 나면 사라짐
4. sql 문 안에서 func 처럼 사용한다.

In [4]:
from pyspark.sql.types import LongType

def squared(n):
    return n*n

In [7]:
spark.sql(
    '''
    select * from product
    '''
).show()

+-------+----------+-----+
|product|      date|price|
+-------+----------+-----+
|      A|2022-04-16|31200|
|      B|2022-04-17|41200|
|      C|2022-04-11|31500|
|      D|2022-04-12|21500|
|      E|2022-04-13|51000|
+-------+----------+-----+



In [10]:
spark.udf.register('squared', squared, LongType())

24/12/11 16:17:06 WARN SimpleFunctionRegistry: The function squared replaced a previously registered function.


<function __main__.squared(n)>

In [11]:
df.createOrReplaceTempView('product')

In [13]:
spark.sql('select price, squared(price) from product').show()

+-----+--------------+
|price|squared(price)|
+-----+--------------+
|31200|     973440000|
|41200|    1697440000|
|31500|     992250000|
|21500|     462250000|
|51000|    2601000000|
+-----+--------------+



In [18]:
def read_number(n):
    units = ["", "십", "백", "천", "만"]
    nums = '일이삼사오육칠팔구'
    result = []
    i=0
    while n > 0 :
        n,r = divmod(n,10)
        if r >0 :
            result.append(nums[r-1]+units[i])
        i += 1

    return "".join(reversed(result))

In [19]:
read_number(123)

'일백이십삼'

In [20]:
spark.udf.register('read_number', read_number)

<function __main__.read_number(n)>

In [21]:
spark.sql('select price, read_number(price) from product').show()

+-----+------------------+
|price|read_number(price)|
+-----+------------------+
|31200|      삼만일천이백|
|41200|      사만일천이백|
|31500|      삼만일천오백|
|21500|      이만일천오백|
|51000|          오만일천|
+-----+------------------+



In [25]:
spark.stop()