In [29]:
from pyspark.sql import *
from pyspark.ml.feature import *
from pyspark.ml.regression import *

In [5]:
spark = SparkSession.builder.appName("Amith").getOrCreate()
df = spark.read.csv("tips.csv",header=True,inferSchema=True)
df.show(5)

+----------+----+------+------+---+------+----+
|total_bill| tip|   sex|smoker|day|  time|size|
+----------+----+------+------+---+------+----+
|     16.99|1.01|Female|    No|Sun|Dinner|   2|
|     10.34|1.66|  Male|    No|Sun|Dinner|   3|
|     21.01| 3.5|  Male|    No|Sun|Dinner|   3|
|     23.68|3.31|  Male|    No|Sun|Dinner|   2|
|     24.59|3.61|Female|    No|Sun|Dinner|   4|
+----------+----+------+------+---+------+----+
only showing top 5 rows



In [18]:
indexer = StringIndexer(inputCols=["sex","smoker","day","time"],outputCols=["sex_indexed","smoker_indexed","day_indexed","time_indexed"])
df_transformed = indexer.fit(df).transform(df)
df_transformed = df_transformed.drop(*["sex","smoker","day","time"])
df_transformed.show(5)

+----------+----+----+-----------+--------------+-----------+------------+
|total_bill| tip|size|sex_indexed|smoker_indexed|day_indexed|time_indexed|
+----------+----+----+-----------+--------------+-----------+------------+
|     16.99|1.01|   2|        1.0|           0.0|        1.0|         0.0|
|     10.34|1.66|   3|        0.0|           0.0|        1.0|         0.0|
|     21.01| 3.5|   3|        0.0|           0.0|        1.0|         0.0|
|     23.68|3.31|   2|        0.0|           0.0|        1.0|         0.0|
|     24.59|3.61|   4|        1.0|           0.0|        1.0|         0.0|
+----------+----+----+-----------+--------------+-----------+------------+
only showing top 5 rows



In [24]:
agg = VectorAssembler(inputCols=["tip","size","sex_indexed","smoker_indexed","day_indexed","time_indexed"],outputCol="inputs")
df_transformed_2 = agg.transform(df_transformed)
df_transformed_2 = df_transformed_2.select(*["inputs","total_bill"])
df_transformed_2.show(5)

+--------------------+----------+
|              inputs|total_bill|
+--------------------+----------+
|[1.01,2.0,1.0,0.0...|     16.99|
|[1.66,3.0,0.0,0.0...|     10.34|
|[3.5,3.0,0.0,0.0,...|     21.01|
|[3.31,2.0,0.0,0.0...|     23.68|
|[3.61,4.0,1.0,0.0...|     24.59|
+--------------------+----------+
only showing top 5 rows



In [27]:
train_data,test_data = df_transformed_2.randomSplit([0.75,0.25])

In [36]:
regr = LinearRegression(featuresCol="inputs",labelCol="total_bill")
model = regr.fit(train_data)
eval = model.evaluate(test_data)

In [38]:
pred_data = eval.predictions
pred_data.show(5)

+--------------------+----------+------------------+
|              inputs|total_bill|        prediction|
+--------------------+----------+------------------+
|(6,[0,1],[1.25,2.0])|     10.07|13.115960420484729|
|(6,[0,1],[1.97,2.0])|     12.02|15.242995704191618|
| (6,[0,1],[2.0,2.0])|     13.37| 15.33162217434607|
|(6,[0,1],[3.27,2.0])|     17.78| 19.08347607755128|
|(6,[0,1],[3.39,2.0])|     11.61|19.437981958169093|
+--------------------+----------+------------------+
only showing top 5 rows



In [40]:
print(eval.r2,eval.meanAbsoluteError,eval.meanSquaredError)

0.7060107701718896 3.829080341051187 27.613362768838822


In [46]:
pred_data.columns

['inputs', 'total_bill', 'prediction']

In [41]:
pred_data.printSchema()

root
 |-- inputs: vector (nullable = true)
 |-- total_bill: double (nullable = true)
 |-- prediction: double (nullable = false)



In [45]:
df.columns

['total_bill', 'tip', 'sex', 'smoker', 'day', 'time', 'size']

In [42]:
df.printSchema()

root
 |-- total_bill: double (nullable = true)
 |-- tip: double (nullable = true)
 |-- sex: string (nullable = true)
 |-- smoker: string (nullable = true)
 |-- day: string (nullable = true)
 |-- time: string (nullable = true)
 |-- size: integer (nullable = true)



In [43]:
df_transformed.columns

['total_bill',
 'tip',
 'size',
 'sex_indexed',
 'smoker_indexed',
 'day_indexed',
 'time_indexed']

In [44]:
df_transformed.printSchema()

root
 |-- total_bill: double (nullable = true)
 |-- tip: double (nullable = true)
 |-- size: integer (nullable = true)
 |-- sex_indexed: double (nullable = false)
 |-- smoker_indexed: double (nullable = false)
 |-- day_indexed: double (nullable = false)
 |-- time_indexed: double (nullable = false)

