In [4]:
from pyspark.sql import SparkSession
spark = SparkSession.builder.appName("UDFExample").getOrCreate()


In [3]:
data = [("John", 28, "Sales",20000,23),
        ("Jane", 33, "Marketing",3000,24),
        ("Jake", 29, "Finance",4000,25),
        ("Julie", 35, "HR",5000,26)]

columns = ["Name", "Age", "Department","salary","bonus"]

df = spark.createDataFrame(data, columns)
df.show()

+-----+---+----------+------+-----+
| Name|Age|Department|salary|bonus|
+-----+---+----------+------+-----+
| John| 28|     Sales| 20000|   23|
| Jane| 33| Marketing|  3000|   24|
| Jake| 29|   Finance|  4000|   25|
|Julie| 35|        HR|  5000|   26|
+-----+---+----------+------+-----+



In [4]:
def total_pay(s,b):
    return s+b

from pyspark.sql.functions import udf,col
from pyspark.sql.types import IntegerType

total_payment = udf(lambda s,b:total_pay(s,b),IntegerType()) #registering udf

df.withColumn('totpay',total_payment(df.salary,df.bonus)).show()




+-----+---+----------+------+-----+------+
| Name|Age|Department|salary|bonus|totpay|
+-----+---+----------+------+-----+------+
| John| 28|     Sales| 20000|   23| 20023|
| Jane| 33| Marketing|  3000|   24|  3024|
| Jake| 29|   Finance|  4000|   25|  4025|
|Julie| 35|        HR|  5000|   26|  5026|
+-----+---+----------+------+-----+------+



In [6]:
@udf(returnType=IntegerType()) #registering udf

def total_pay(s,b):
    return s+b

df.select('*',total_pay(df.salary,df.bonus).alias('tpay')).show()


+-----+---+----------+------+-----+-----+
| Name|Age|Department|salary|bonus| tpay|
+-----+---+----------+------+-----+-----+
| John| 28|     Sales| 20000|   23|20023|
| Jane| 33| Marketing|  3000|   24| 3024|
| Jake| 29|   Finance|  4000|   25| 4025|
|Julie| 35|        HR|  5000|   26| 5026|
+-----+---+----------+------+-----+-----+



In [None]:
df.createOrReplaceTempView("emps")


spark.udf.register(name='tpay',f='total_pay',returnType=IntegerType())

%sql
select id,name,tpay(salary,bonus) from emps

complex data types

In [5]:
from pyspark.sql.functions import array, col
from pyspark.sql.functions import udf
from pyspark.sql.types import ArrayType,IntegerType

# Define a Python function that operates on arrays
def double_elements(arr):
    return [x * 2 for x in arr]

# Register the UDF
double_elements_udf = udf(double_elements, ArrayType(IntegerType()))

# Create a DataFrame with an array column
data = [([1, 2, 3],), ([4, 5, 6],)]
columns = ["numbers"]
df = spark.createDataFrame(data, columns)

df.show()

# Apply the UDF
df_transformed = df.withColumn("doubled_numbers", double_elements_udf(col("numbers")))

# Show the results
df_transformed.show()


+---------+
|  numbers|
+---------+
|[1, 2, 3]|
|[4, 5, 6]|
+---------+

+---------+---------------+
|  numbers|doubled_numbers|
+---------+---------------+
|[1, 2, 3]|      [2, 4, 6]|
|[4, 5, 6]|    [8, 10, 12]|
+---------+---------------+

