In [0]:
from pyspark.sql import SparkSession
from pyspark.sql.functions import col, lit, udf
from pyspark.sql.functions import sum,avg,max,min,mean,count
from pyspark.sql.types import IntegerType, DoubleType
spark = SparkSession.builder.appName("Spark DataFrames").getOrCreate()

In [0]:
df = spark.read.options(inferSchema="True",header="True",delimiter =',').csv("/FileStore/tables/OfficeDataProject.csv")
df.show()
df.printSchema()

In [0]:
# 1. Print the total number of employees in the company
df.count()

In [0]:
# 2. Print the total number os departments in the company
df.select(col("department")).distinct().count()
#df.dropDuplicates(["department"]).count()

In [0]:
# 3. Print the department names of the company
df.select(col("department")).distinct().show()
#df.select("department").dropDuplicates(["department"]).show()

In [0]:
# 4. Print the total number of employees in each department
df.groupBy(df.department).agg(count("*").alias("total_employees")).show()

In [0]:
# 5. Print the total number of employees in each state
df.groupBy(df.state).agg(count("*").alias("total_employees")).show()

In [0]:
# 6. Print the total number of employees in each state in each department
df.groupBy(df.state,df.department).agg(count("*").alias("total_employees")).show()

In [0]:
# 7. Print the minimum and maximum salaries in each department and sort salaries in ascending order
df.groupBy(df.department).agg(min("salary").alias("min_salary"),max("salary").alias("max_salary")).sort(col("min_salary").asc(),col("max_salary").asc()).show()

In [0]:
# 8. Print the names of employees working in NY state under Finance epartment whose bonuses are greater than the average bonuses of employees in NY state
def state_avg_salary(state):
  return df.filter(df.state==state).groupBy(df.state).agg(avg(df.salary).alias("avg_salary")).select(col("avg_salary")).collect()[0][0]

In [0]:
state_avg_salary("NY")

In [0]:
df.filter((df.state=="NY") & (df.department=="Finance") & (df.salary>state_avg_salary("NY"))).show()

In [0]:
# 9. Raise the salaries $500 of all employees whose age is greater than 45
def raise_500(x,y):
  if x >45:
    return y + 500
  return y
  
raise_500UDF = udf(lambda x,y: raise_500(x,y), IntegerType())

In [0]:
df2 = df.withColumn("salary", raise_500UDF(col("age"),col("salary")))
df2.filter(col("age")>45).show()

In [0]:
df.filter(col("age")>45).show()

In [0]:
# 10. Create a DF of all those employees whose age is greater than 45 and save them in a file
df2.filter(col("age")>45).write.mode("overwrite").options(header="True").csv("/FileStore/tables/MiniProject_DF/output")