In [1]:
from pyspark.sql import SparkSession

In [2]:
spark = SparkSession. \
builder. \
appName("SparkTransformations"). \
config("spark.sql.warehouse.dir","/user/itv012857/warehourse"). \
enableHiveSupport(). \
master("yarn"). \
getOrCreate()

In [3]:
spark

<h3> 1. Create a dataframe with the following data </h3>

[("Spring",12.3),("Summer",10.5),("Autumn",8.2),("Winter",15.1)]

In [4]:
windspeed_list = [("Spring",12.3),
        ("Summer",10.5),
        ("Autumn",8.2),
        ("Winter",15.1)]

In [5]:
windspeed_list

[('Spring', 12.3), ('Summer', 10.5), ('Autumn', 8.2), ('Winter', 15.1)]

In [6]:
windspeed_df = spark.createDataFrame(windspeed_list).toDF("season","windspeed")

In [7]:
windspeed_df.show()

+------+---------+
|season|windspeed|
+------+---------+
|Spring|     12.3|
|Summer|     10.5|
|Autumn|      8.2|
|Winter|     15.1|
+------+---------+



In [8]:
windspeed_df.printSchema()

root
 |-- season: string (nullable = true)
 |-- windspeed: double (nullable = true)



<H4> Alternate way is to define the schema and then pass the schema while creating the DataFrame

In [9]:
windspeed_schema = "season string, windspeed float"

In [10]:
windspeed_df2 = spark.createDataFrame(windspeed_list,schema = windspeed_schema)

In [11]:
windspeed_df2.show()

+------+---------+
|season|windspeed|
+------+---------+
|Spring|     12.3|
|Summer|     10.5|
|Autumn|      8.2|
|Winter|     15.1|
+------+---------+



In [12]:
windspeed_df2.printSchema()

root
 |-- season: string (nullable = true)
 |-- windspeed: float (nullable = true)



<h4> 2. Consider the library management dataset located at the following path: /public/trendytech/datasets/library_data.json. Using PySpark, load the data into a DataFrame and enforce schema using StructType.

In [13]:
! hadoop fs -head /public/trendytech/datasets/library_data.json

{"library_name": "Central Library","location": "City Center","books": [{"book_id": "B001","book_name": "The Great Gatsby","author": "F. Scott Fitzgerald","copies_available": 5},{"book_id": "B002","book_name": "To Kill a Mockingbird","author": "Harper Lee","copies_available": 3}],"members": [{"member_id": "M001","member_name": "John Smith","age": 28,"books_borrowed": ["B001"]},{"member_id": "M002","member_name": "Emma Johnson","age": 35,"books_borrowed": []}]},
{"library_name": "Community Library","location": "Suburb","books": [{"book_id": "B003","book_name": "1984","author": "George Orwell","copies_available": 2},{"book_id": "B004","book_name": "Pride and Prejudice","author": "Jane Austen","copies_available": 4}],"members": [{"member_id": "M003","member_name": "Michael Brown","age": 42,"books_borrowed": ["B003","B004"]},{"member_id": "M004","member_name": "Sophia Davis","age": 31,"books_borrowed": ["B004"]}]}


In [14]:
from pyspark.sql.types import StructType, StructField, IntegerType, StringType, ArrayType

In [15]:
library_schema = StructType(
                            [  StructField("library_name",StringType(),nullable=False),
                               StructField("location",StringType(),nullable=False),
                               StructField("books",ArrayType(StructType(
                                                                [StructField("book_id",StringType(),nullable=False),
                                                                StructField("book_name",StringType(),nullable=False),
                                                                StructField("author",StringType(),nullable=False),
                                                                StructField("copies_available",IntegerType(),nullable=False)]
                                                              ))
                                               , nullable=False
                                          ),
                             StructField("members",ArrayType(StructType(
                                                                [StructField("member_id",StringType(),nullable=False),
                                                                StructField("member_name",StringType(),nullable=False),
                                                                StructField("age",IntegerType(),nullable=False),
                                                                StructField("books_borrowed",ArrayType(StringType()),nullable=False)]
                                                              ))
                                               , nullable=False)
                            ]
                            )

In [16]:
library_df = spark.read.json("/public/trendytech/datasets/library_data.json", schema=library_schema)

In [17]:
library_df.show(truncate = False)

+-----------------+-----------+------------------------------------------------------------------------------------------------+---------------------------------------------------------------------------+
|library_name     |location   |books                                                                                           |members                                                                    |
+-----------------+-----------+------------------------------------------------------------------------------------------------+---------------------------------------------------------------------------+
|Central Library  |City Center|[{B001, The Great Gatsby, F. Scott Fitzgerald, 5}, {B002, To Kill a Mockingbird, Harper Lee, 3}]|[{M001, John Smith, 28, [B001]}, {M002, Emma Johnson, 35, []}]             |
|Community Library|Suburb     |[{B003, 1984, George Orwell, 2}, {B004, Pride and Prejudice, Jane Austen, 4}]                   |[{M003, Michael Brown, 42, [B003, B004]}, {M004, Sop

In [18]:
library_df.printSchema()

root
 |-- library_name: string (nullable = true)
 |-- location: string (nullable = true)
 |-- books: array (nullable = true)
 |    |-- element: struct (containsNull = true)
 |    |    |-- book_id: string (nullable = true)
 |    |    |-- book_name: string (nullable = true)
 |    |    |-- author: string (nullable = true)
 |    |    |-- copies_available: integer (nullable = true)
 |-- members: array (nullable = true)
 |    |-- element: struct (containsNull = true)
 |    |    |-- member_id: string (nullable = true)
 |    |    |-- member_name: string (nullable = true)
 |    |    |-- age: integer (nullable = true)
 |    |    |-- books_borrowed: array (nullable = true)
 |    |    |    |-- element: string (containsNull = true)



<H3> 3. Given the dataset (/public/trendytech/datasets/train.csv), create a DataFrame using PySpark and perform the following operations: </H3>

a) Drop the columns passenger_name and age from the dataset.

b) Count the number of rows after removing duplicates of columns train_number and ticket_number.

c) Count the number of unique train names.

In [19]:
! hadoop fs -head /public/trendytech/datasets/train.csv

train_number,train_name,seats_available,passenger_name,age,ticket_number,seat_number
123,Express,100,John,25,T123,A1
123,Express,100,Emma,30,T124,B2
456,Superfast,150,Michael,35,T125,C3
456,Superfast,150,Sophia,40,T126,D4
789,Local,50,William,28,T127,E5
789,Local,50,Sophia,32,T128,F6
789,Local,50,Oliver,45,T129,G7


In [20]:
train_schema = "train_number int, train_name string, seats_available int,passenger_name string, age int,ticket_number string,seat_number string"

In [21]:
train_df = spark.read.csv("/public/trendytech/datasets/train.csv",header = True, schema= train_schema)

In [22]:
train_df.show()

+------------+----------+---------------+--------------+---+-------------+-----------+
|train_number|train_name|seats_available|passenger_name|age|ticket_number|seat_number|
+------------+----------+---------------+--------------+---+-------------+-----------+
|         123|   Express|            100|          John| 25|         T123|         A1|
|         123|   Express|            100|          Emma| 30|         T124|         B2|
|         456| Superfast|            150|       Michael| 35|         T125|         C3|
|         456| Superfast|            150|        Sophia| 40|         T126|         D4|
|         789|     Local|             50|       William| 28|         T127|         E5|
|         789|     Local|             50|        Sophia| 32|         T128|         F6|
|         789|     Local|             50|        Oliver| 45|         T129|         G7|
+------------+----------+---------------+--------------+---+-------------+-----------+



In [23]:
train_df.printSchema()

root
 |-- train_number: integer (nullable = true)
 |-- train_name: string (nullable = true)
 |-- seats_available: integer (nullable = true)
 |-- passenger_name: string (nullable = true)
 |-- age: integer (nullable = true)
 |-- ticket_number: string (nullable = true)
 |-- seat_number: string (nullable = true)



In [24]:
train_df2 = train_df.drop("passenger_name","age")

In [25]:
train_df2.show()

+------------+----------+---------------+-------------+-----------+
|train_number|train_name|seats_available|ticket_number|seat_number|
+------------+----------+---------------+-------------+-----------+
|         123|   Express|            100|         T123|         A1|
|         123|   Express|            100|         T124|         B2|
|         456| Superfast|            150|         T125|         C3|
|         456| Superfast|            150|         T126|         D4|
|         789|     Local|             50|         T127|         E5|
|         789|     Local|             50|         T128|         F6|
|         789|     Local|             50|         T129|         G7|
+------------+----------+---------------+-------------+-----------+



In [26]:
train_df.count()

7

In [27]:
train_df.dropDuplicates(["train_number","ticket_number"]).count()

7

In [28]:
train_df.select("train_name").distinct().count()

3

<H4> 4. You are working as a Data Engineer in a large retail company. </H4>The company has a dataset named "sales_data.json" that contains sales records from various stores. The dataset is stored in JSON format and may have some corrupt or malformed records due to occasional data quality issues. Your task is to read the "sales_data.json" dataset (/public/trendytech/datasets/sales_data.json) using PySpark, utilizing different read modes to handle corrupt records. You need to create a DataFrame using PySpark and perform the following operations:

    Read the dataset using the "permissive" mode and count the number of records read.

    Read the dataset using the "dropmalformed" mode and display the number of malformed records.

    Read the dataset using the "failfast" mode.

In [29]:
! hadoop fs -cat /public/trendytech/datasets/sales_data.json

{"store_id": 1, "product": "Apple", "quantity": 10, "revenue": 100.0}
{"store_id": 2, "product": "Banana", "quantity": 15, "revenue": 75.0}
{"store_id": 3, "product": "Orange", "quantity": 12, "revenue": 90.0}
{"store_id": 4, "product": "Mango", "quantity": 8, "revenue": 120.0}
{"store_id": 5, "product": "Grape", "quantity": 20, "revenue": 150.0}
{"store_id": 6, "product": "Watermelon", "quantity": 5, "revenue": 50.0}
{"store_id": 7, "product": "Strawberry", "quantity": 18, "revenue": 108.0}
{"store_id": 8, "product": "Pineapple", "quantity": 14, "revenue": 140.0}
{"store_id": 9, "product": "Cherry", "quantity": 7, "revenue": 105.0}
{"store_id": 10, "product": "Pear", "quantity": 9, "revenue": 81.0}
{"store_id": 11, "product": "Blueberry", "quantity": 11, "revenue": 88.0}
{"store_id": 12, "product": "Kiwi", "quantity": 16, "revenue": 128.0}
{"store_id": 13, "product": "Peach", "quantity": 13, "revenue": 91.0}
{"store_id": 14, "product": "Plum", "quantity": 6, "revenue": 54.0}
{"store_i

In [30]:
sales_schema = "store_id int, product string, quantity int, revenue float"

In [31]:
sales_permissive = spark.read.json("/public/trendytech/datasets/sales_data.json", \
                                   schema= sales_schema, \
                                   mode = "permissive"
                                  )

In [32]:
sales_permissive.count()

22

In [33]:
sales_permissive.show(22)

+--------+----------+--------+-------+
|store_id|   product|quantity|revenue|
+--------+----------+--------+-------+
|       1|     Apple|      10|  100.0|
|       2|    Banana|      15|   75.0|
|       3|    Orange|      12|   90.0|
|       4|     Mango|       8|  120.0|
|       5|     Grape|      20|  150.0|
|       6|Watermelon|       5|   50.0|
|       7|Strawberry|      18|  108.0|
|       8| Pineapple|      14|  140.0|
|       9|    Cherry|       7|  105.0|
|      10|      Pear|       9|   81.0|
|      11| Blueberry|      11|   88.0|
|      12|      Kiwi|      16|  128.0|
|      13|     Peach|      13|   91.0|
|      14|      Plum|       6|   54.0|
|      15|     Lemon|      10|   70.0|
|      16| Raspberry|      17|  136.0|
|      17|   Coconut|       4|   80.0|
|      18|   Avocado|      11|   99.0|
|      19|Blackberry|       8|   64.0|
|      20|         G|    null|    NaN|
|    null|      null|    null|   null|
|      22|Watermelon|       5|   null|
+--------+----------+----

In [34]:
sales_dropmalformed =  spark.read.json("/public/trendytech/datasets/sales_data.json", \
                                   schema= sales_schema, \
                                   mode = "dropmalformed"
                                  )

In [35]:
sales_dropmalformed.count()

21

In [36]:
sales_dropmalformed.show(50)

+--------+----------+--------+-------+
|store_id|   product|quantity|revenue|
+--------+----------+--------+-------+
|       1|     Apple|      10|  100.0|
|       2|    Banana|      15|   75.0|
|       3|    Orange|      12|   90.0|
|       4|     Mango|       8|  120.0|
|       5|     Grape|      20|  150.0|
|       6|Watermelon|       5|   50.0|
|       7|Strawberry|      18|  108.0|
|       8| Pineapple|      14|  140.0|
|       9|    Cherry|       7|  105.0|
|      10|      Pear|       9|   81.0|
|      11| Blueberry|      11|   88.0|
|      12|      Kiwi|      16|  128.0|
|      13|     Peach|      13|   91.0|
|      14|      Plum|       6|   54.0|
|      15|     Lemon|      10|   70.0|
|      16| Raspberry|      17|  136.0|
|      17|   Coconut|       4|   80.0|
|      18|   Avocado|      11|   99.0|
|      19|Blackberry|       8|   64.0|
+--------+----------+--------+-------+



In [37]:
sales_failfast =  spark.read.json("/public/trendytech/datasets/sales_data.json", \
                                   schema= sales_schema, \
                                   mode = "failfast"
                                  )

In [38]:
sales_failfast.show()

Py4JJavaError: An error occurred while calling o173.showString.
: org.apache.spark.SparkException: Job aborted due to stage failure: Task 0 in stage 21.0 failed 4 times, most recent failure: Lost task 0.3 in stage 21.0 (TID 422) (w01.itversity.com executor 2): org.apache.spark.SparkException: Malformed records are detected in record parsing. Parse Mode: FAILFAST. To process malformed records as null result, try setting the option 'mode' as 'PERMISSIVE'.
	at org.apache.spark.sql.catalyst.util.FailureSafeParser.parse(FailureSafeParser.scala:70)
	at org.apache.spark.sql.execution.datasources.json.TextInputJsonDataSource$.$anonfun$readFile$9(JsonDataSource.scala:144)
	at scala.collection.Iterator$$anon$11.nextCur(Iterator.scala:484)
	at scala.collection.Iterator$$anon$11.hasNext(Iterator.scala:490)
	at scala.collection.Iterator$$anon$10.hasNext(Iterator.scala:458)
	at org.apache.spark.sql.execution.datasources.FileScanRDD$$anon$1.hasNext(FileScanRDD.scala:93)
	at scala.collection.Iterator$$anon$10.hasNext(Iterator.scala:458)
	at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage1.processNext(Unknown Source)
	at org.apache.spark.sql.execution.BufferedRowIterator.hasNext(BufferedRowIterator.java:43)
	at org.apache.spark.sql.execution.WholeStageCodegenExec$$anon$1.hasNext(WholeStageCodegenExec.scala:755)
	at org.apache.spark.sql.execution.SparkPlan.$anonfun$getByteArrayRdd$1(SparkPlan.scala:345)
	at org.apache.spark.rdd.RDD.$anonfun$mapPartitionsInternal$2(RDD.scala:898)
	at org.apache.spark.rdd.RDD.$anonfun$mapPartitionsInternal$2$adapted(RDD.scala:898)
	at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)
	at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:373)
	at org.apache.spark.rdd.RDD.iterator(RDD.scala:337)
	at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:90)
	at org.apache.spark.scheduler.Task.run(Task.scala:131)
	at org.apache.spark.executor.Executor$TaskRunner.$anonfun$run$3(Executor.scala:497)
	at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:1439)
	at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:500)
	at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)
	at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)
	at java.lang.Thread.run(Thread.java:748)
Caused by: org.apache.spark.sql.catalyst.util.BadRecordException: java.lang.RuntimeException: Failed to parse a value for data type int (current token: VALUE_STRING).
	at org.apache.spark.sql.catalyst.json.JacksonParser.parse(JacksonParser.scala:492)
	at org.apache.spark.sql.execution.datasources.json.TextInputJsonDataSource$.$anonfun$readFile$7(JsonDataSource.scala:140)
	at org.apache.spark.sql.catalyst.util.FailureSafeParser.parse(FailureSafeParser.scala:60)
	... 23 more
Caused by: java.lang.RuntimeException: Failed to parse a value for data type int (current token: VALUE_STRING).
	at org.apache.spark.sql.catalyst.json.JacksonParser$$anonfun$failedConversion$1.applyOrElse(JacksonParser.scala:375)
	at org.apache.spark.sql.catalyst.json.JacksonParser$$anonfun$failedConversion$1.applyOrElse(JacksonParser.scala:355)
	at scala.runtime.AbstractPartialFunction.apply(AbstractPartialFunction.scala:38)
	at org.apache.spark.sql.catalyst.json.JacksonParser$$anonfun$$nestedInanonfun$makeConverter$4$1.applyOrElse(JacksonParser.scala:184)
	at org.apache.spark.sql.catalyst.json.JacksonParser$$anonfun$$nestedInanonfun$makeConverter$4$1.applyOrElse(JacksonParser.scala:184)
	at org.apache.spark.sql.catalyst.json.JacksonParser.parseJsonToken(JacksonParser.scala:343)
	at org.apache.spark.sql.catalyst.json.JacksonParser.$anonfun$makeConverter$4(JacksonParser.scala:184)
	at org.apache.spark.sql.catalyst.json.JacksonParser.org$apache$spark$sql$catalyst$json$JacksonParser$$convertObject(JacksonParser.scala:397)
	at org.apache.spark.sql.catalyst.json.JacksonParser$$anonfun$$nestedInanonfun$makeStructRootConverter$3$1.applyOrElse(JacksonParser.scala:96)
	at org.apache.spark.sql.catalyst.json.JacksonParser$$anonfun$$nestedInanonfun$makeStructRootConverter$3$1.applyOrElse(JacksonParser.scala:95)
	at org.apache.spark.sql.catalyst.json.JacksonParser.parseJsonToken(JacksonParser.scala:343)
	at org.apache.spark.sql.catalyst.json.JacksonParser.$anonfun$makeStructRootConverter$3(JacksonParser.scala:95)
	at org.apache.spark.sql.catalyst.json.JacksonParser.$anonfun$parse$2(JacksonParser.scala:467)
	at org.apache.spark.util.Utils$.tryWithResource(Utils.scala:2622)
	at org.apache.spark.sql.catalyst.json.JacksonParser.parse(JacksonParser.scala:462)
	... 25 more

Driver stacktrace:
	at org.apache.spark.scheduler.DAGScheduler.failJobAndIndependentStages(DAGScheduler.scala:2258)
	at org.apache.spark.scheduler.DAGScheduler.$anonfun$abortStage$2(DAGScheduler.scala:2207)
	at org.apache.spark.scheduler.DAGScheduler.$anonfun$abortStage$2$adapted(DAGScheduler.scala:2206)
	at scala.collection.mutable.ResizableArray.foreach(ResizableArray.scala:62)
	at scala.collection.mutable.ResizableArray.foreach$(ResizableArray.scala:55)
	at scala.collection.mutable.ArrayBuffer.foreach(ArrayBuffer.scala:49)
	at org.apache.spark.scheduler.DAGScheduler.abortStage(DAGScheduler.scala:2206)
	at org.apache.spark.scheduler.DAGScheduler.$anonfun$handleTaskSetFailed$1(DAGScheduler.scala:1079)
	at org.apache.spark.scheduler.DAGScheduler.$anonfun$handleTaskSetFailed$1$adapted(DAGScheduler.scala:1079)
	at scala.Option.foreach(Option.scala:407)
	at org.apache.spark.scheduler.DAGScheduler.handleTaskSetFailed(DAGScheduler.scala:1079)
	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.doOnReceive(DAGScheduler.scala:2445)
	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:2387)
	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:2376)
	at org.apache.spark.util.EventLoop$$anon$1.run(EventLoop.scala:49)
	at org.apache.spark.scheduler.DAGScheduler.runJob(DAGScheduler.scala:868)
	at org.apache.spark.SparkContext.runJob(SparkContext.scala:2196)
	at org.apache.spark.SparkContext.runJob(SparkContext.scala:2217)
	at org.apache.spark.SparkContext.runJob(SparkContext.scala:2236)
	at org.apache.spark.sql.execution.SparkPlan.executeTake(SparkPlan.scala:472)
	at org.apache.spark.sql.execution.SparkPlan.executeTake(SparkPlan.scala:425)
	at org.apache.spark.sql.execution.CollectLimitExec.executeCollect(limit.scala:47)
	at org.apache.spark.sql.Dataset.collectFromPlan(Dataset.scala:3696)
	at org.apache.spark.sql.Dataset.$anonfun$head$1(Dataset.scala:2722)
	at org.apache.spark.sql.Dataset.$anonfun$withAction$1(Dataset.scala:3687)
	at org.apache.spark.sql.execution.SQLExecution$.$anonfun$withNewExecutionId$5(SQLExecution.scala:103)
	at org.apache.spark.sql.execution.SQLExecution$.withSQLConfPropagated(SQLExecution.scala:163)
	at org.apache.spark.sql.execution.SQLExecution$.$anonfun$withNewExecutionId$1(SQLExecution.scala:90)
	at org.apache.spark.sql.SparkSession.withActive(SparkSession.scala:775)
	at org.apache.spark.sql.execution.SQLExecution$.withNewExecutionId(SQLExecution.scala:64)
	at org.apache.spark.sql.Dataset.withAction(Dataset.scala:3685)
	at org.apache.spark.sql.Dataset.head(Dataset.scala:2722)
	at org.apache.spark.sql.Dataset.take(Dataset.scala:2929)
	at org.apache.spark.sql.Dataset.getRows(Dataset.scala:301)
	at org.apache.spark.sql.Dataset.showString(Dataset.scala:338)
	at sun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
	at sun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)
	at sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
	at java.lang.reflect.Method.invoke(Method.java:498)
	at py4j.reflection.MethodInvoker.invoke(MethodInvoker.java:244)
	at py4j.reflection.ReflectionEngine.invoke(ReflectionEngine.java:357)
	at py4j.Gateway.invoke(Gateway.java:282)
	at py4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:132)
	at py4j.commands.CallCommand.execute(CallCommand.java:79)
	at py4j.GatewayConnection.run(GatewayConnection.java:238)
	at java.lang.Thread.run(Thread.java:748)


<H3> 5. Questions regarding hospital dataset

In [39]:
! hadoop fs -head /public/trendytech/datasets/hospital.csv

patient_id,admission_date,discharge_date,diagnosis,doctor_id,total_cost
1,01-01-2022,2022-01-10,Pneumonia,101,5000.00
2,02-05-2022,2022-02-09,Appendicitis,102,7000.00
3,03-12-2022,2022-03-18,Fractured Arm,103,3500.00
4,04-02-2022,2022-04-08,Heart Attack,104,15000.00
5,05-05-2022,2022-05-07,Influenza,105,2500.00
6,06-10-2022,2022-06-15,Appendicitis,106,8000.00
7,07-20-2022,2022-07-25,Pneumonia,107,5500.00
8,08-25-2022,2022-09-01,Heart Attack,108,20000.00
9,09-15-2022,2022-09-22,Fractured Leg,109,6000.00
10,10-05-2022,2022-10-10,Appendicitis,110,7500.00
11,11-02-2022,2022-11-05,Influenza,111,2800.00
12,12-10-2022,2022-12-18,Pneumonia,112,6000.00
13,01-02-2023,2023-01-09,Heart Attack,113,18000.00
14,02-14-2023,2023-02-18,Appendicitis,114,7200.00
15,03-20-2023,2023-03-28,Fractured Arm,115,3800.00
16,04-05-2023,2023-04-11,Influenza,116,2700.00
17,05-08-2023,2023-05-11,Heart Attack,117,16000.00
18,06-15-2023,2023-06-20,Pneumonia,118,4800.00
19,07-22-2023,2023-07-27,Fractured Leg,119,6500.00


In [48]:
hospital_df = spark.read. \
format("csv"). \
option("header","true"). \
option("inferSchema","true"). \
load("/public/trendytech/datasets/hospital.csv")

In [49]:
hospital_df.show()

+----------+--------------+--------------+-------------+---------+----------+
|patient_id|admission_date|discharge_date|    diagnosis|doctor_id|total_cost|
+----------+--------------+--------------+-------------+---------+----------+
|         1|    01-01-2022|    2022-01-10|    Pneumonia|      101|    5000.0|
|         2|    02-05-2022|    2022-02-09| Appendicitis|      102|    7000.0|
|         3|    03-12-2022|    2022-03-18|Fractured Arm|      103|    3500.0|
|         4|    04-02-2022|    2022-04-08| Heart Attack|      104|   15000.0|
|         5|    05-05-2022|    2022-05-07|    Influenza|      105|    2500.0|
|         6|    06-10-2022|    2022-06-15| Appendicitis|      106|    8000.0|
|         7|    07-20-2022|    2022-07-25|    Pneumonia|      107|    5500.0|
|         8|    08-25-2022|    2022-09-01| Heart Attack|      108|   20000.0|
|         9|    09-15-2022|    2022-09-22|Fractured Leg|      109|    6000.0|
|        10|    10-05-2022|    2022-10-10| Appendicitis|      11

In [50]:
hospital_df.printSchema()

root
 |-- patient_id: integer (nullable = true)
 |-- admission_date: string (nullable = true)
 |-- discharge_date: string (nullable = true)
 |-- diagnosis: string (nullable = true)
 |-- doctor_id: integer (nullable = true)
 |-- total_cost: double (nullable = true)



In [51]:
hospital_df2 = hospital_df.drop("doctor_id")

In [53]:
hospital_df2.show()

+----------+--------------+--------------+-------------+----------+
|patient_id|admission_date|discharge_date|    diagnosis|total_cost|
+----------+--------------+--------------+-------------+----------+
|         1|    01-01-2022|    2022-01-10|    Pneumonia|    5000.0|
|         2|    02-05-2022|    2022-02-09| Appendicitis|    7000.0|
|         3|    03-12-2022|    2022-03-18|Fractured Arm|    3500.0|
|         4|    04-02-2022|    2022-04-08| Heart Attack|   15000.0|
|         5|    05-05-2022|    2022-05-07|    Influenza|    2500.0|
|         6|    06-10-2022|    2022-06-15| Appendicitis|    8000.0|
|         7|    07-20-2022|    2022-07-25|    Pneumonia|    5500.0|
|         8|    08-25-2022|    2022-09-01| Heart Attack|   20000.0|
|         9|    09-15-2022|    2022-09-22|Fractured Leg|    6000.0|
|        10|    10-05-2022|    2022-10-10| Appendicitis|    7500.0|
|        11|    11-02-2022|    2022-11-05|    Influenza|    2800.0|
|        12|    12-10-2022|    2022-12-18|    Pn

In [54]:
hospital_df3 = hospital_df2.withColumnRenamed("total_cost","hospital_bill")

In [55]:
hospital_df3.show()

+----------+--------------+--------------+-------------+-------------+
|patient_id|admission_date|discharge_date|    diagnosis|hospital_bill|
+----------+--------------+--------------+-------------+-------------+
|         1|    01-01-2022|    2022-01-10|    Pneumonia|       5000.0|
|         2|    02-05-2022|    2022-02-09| Appendicitis|       7000.0|
|         3|    03-12-2022|    2022-03-18|Fractured Arm|       3500.0|
|         4|    04-02-2022|    2022-04-08| Heart Attack|      15000.0|
|         5|    05-05-2022|    2022-05-07|    Influenza|       2500.0|
|         6|    06-10-2022|    2022-06-15| Appendicitis|       8000.0|
|         7|    07-20-2022|    2022-07-25|    Pneumonia|       5500.0|
|         8|    08-25-2022|    2022-09-01| Heart Attack|      20000.0|
|         9|    09-15-2022|    2022-09-22|Fractured Leg|       6000.0|
|        10|    10-05-2022|    2022-10-10| Appendicitis|       7500.0|
|        11|    11-02-2022|    2022-11-05|    Influenza|       2800.0|
|     

#### Calculate the duration of stay

In [84]:
from pyspark.sql.functions import to_date, datediff, expr

In [72]:
hospital_df4 = hospital_df3. \
                withColumn("admission_date",to_date("admission_date","MM-dd-yyyy")). \
                withColumn("discharge_date",to_date("discharge_date","yyyy-MM-dd"))

In [73]:
hospital_df4.show()

+----------+--------------+--------------+-------------+-------------+
|patient_id|admission_date|discharge_date|    diagnosis|hospital_bill|
+----------+--------------+--------------+-------------+-------------+
|         1|    2022-01-01|    2022-01-10|    Pneumonia|       5000.0|
|         2|    2022-02-05|    2022-02-09| Appendicitis|       7000.0|
|         3|    2022-03-12|    2022-03-18|Fractured Arm|       3500.0|
|         4|    2022-04-02|    2022-04-08| Heart Attack|      15000.0|
|         5|    2022-05-05|    2022-05-07|    Influenza|       2500.0|
|         6|    2022-06-10|    2022-06-15| Appendicitis|       8000.0|
|         7|    2022-07-20|    2022-07-25|    Pneumonia|       5500.0|
|         8|    2022-08-25|    2022-09-01| Heart Attack|      20000.0|
|         9|    2022-09-15|    2022-09-22|Fractured Leg|       6000.0|
|        10|    2022-10-05|    2022-10-10| Appendicitis|       7500.0|
|        11|    2022-11-02|    2022-11-05|    Influenza|       2800.0|
|     

In [65]:
hospital_df4.printSchema()

root
 |-- patient_id: integer (nullable = true)
 |-- admission_date: date (nullable = true)
 |-- discharge_date: date (nullable = true)
 |-- diagnosis: string (nullable = true)
 |-- doctor_id: integer (nullable = true)
 |-- total_cost: double (nullable = true)



In [77]:
hospital_df5 = hospital_df4. \
    withColumn("duration_of_stay",datediff("discharge_date","admission_date"))

In [78]:
hospital_df5.show()

+----------+--------------+--------------+-------------+-------------+----------------+
|patient_id|admission_date|discharge_date|    diagnosis|hospital_bill|duration_of_stay|
+----------+--------------+--------------+-------------+-------------+----------------+
|         1|    2022-01-01|    2022-01-10|    Pneumonia|       5000.0|               9|
|         2|    2022-02-05|    2022-02-09| Appendicitis|       7000.0|               4|
|         3|    2022-03-12|    2022-03-18|Fractured Arm|       3500.0|               6|
|         4|    2022-04-02|    2022-04-08| Heart Attack|      15000.0|               6|
|         5|    2022-05-05|    2022-05-07|    Influenza|       2500.0|               2|
|         6|    2022-06-10|    2022-06-15| Appendicitis|       8000.0|               5|
|         7|    2022-07-20|    2022-07-25|    Pneumonia|       5500.0|               5|
|         8|    2022-08-25|    2022-09-01| Heart Attack|      20000.0|               7|
|         9|    2022-09-15|    2

In [86]:
hospital_df6 = hospital_df5.withColumn("adjusted_total_cost", 
                                        expr("CASE \
                                                 WHEN diagnosis='Heart Attack' THEN hospital_bill * 1.5 \
                                                 WHEN diagnosis='Appendicitis' THEN hospital_bill * 1.2 \
                                                 ELSE hospital_bill \
                                                 END \
                                             ")
                                      )

In [87]:
hospital_df6.show()

+----------+--------------+--------------+-------------+-------------+----------------+-------------------+
|patient_id|admission_date|discharge_date|    diagnosis|hospital_bill|duration_of_stay|adjusted_total_cost|
+----------+--------------+--------------+-------------+-------------+----------------+-------------------+
|         1|    2022-01-01|    2022-01-10|    Pneumonia|       5000.0|               9|             5000.0|
|         2|    2022-02-05|    2022-02-09| Appendicitis|       7000.0|               4|             8400.0|
|         3|    2022-03-12|    2022-03-18|Fractured Arm|       3500.0|               6|             3500.0|
|         4|    2022-04-02|    2022-04-08| Heart Attack|      15000.0|               6|            22500.0|
|         5|    2022-05-05|    2022-05-07|    Influenza|       2500.0|               2|             2500.0|
|         6|    2022-06-10|    2022-06-15| Appendicitis|       8000.0|               5|             9600.0|
|         7|    2022-07-20| 

In [89]:
hospital_df6.select("patient_id", "diagnosis", "hospital_bill", "adjusted_total_cost").show()

+----------+-------------+-------------+-------------------+
|patient_id|    diagnosis|hospital_bill|adjusted_total_cost|
+----------+-------------+-------------+-------------------+
|         1|    Pneumonia|       5000.0|             5000.0|
|         2| Appendicitis|       7000.0|             8400.0|
|         3|Fractured Arm|       3500.0|             3500.0|
|         4| Heart Attack|      15000.0|            22500.0|
|         5|    Influenza|       2500.0|             2500.0|
|         6| Appendicitis|       8000.0|             9600.0|
|         7|    Pneumonia|       5500.0|             5500.0|
|         8| Heart Attack|      20000.0|            30000.0|
|         9|Fractured Leg|       6000.0|             6000.0|
|        10| Appendicitis|       7500.0|             9000.0|
|        11|    Influenza|       2800.0|             2800.0|
|        12|    Pneumonia|       6000.0|             6000.0|
|        13| Heart Attack|      18000.0|            27000.0|
|        14| Appendiciti