In [0]:
from pyspark.sql import SparkSession
spark=SparkSession.builder.appName('Proj LR').getOrCreate()


In [0]:
data=spark.read.table('cruise_ship_info')

In [0]:
data.columns

Out[17]: ['Ship_name',
 'Cruise_line',
 'Age',
 'Tonnage',
 'passengers',
 'length',
 'cabins',
 'passenger_density',
 'crew']

In [0]:
from pyspark.ml.feature import VectorAssembler,StringIndexer
from pyspark.ml.regression import LinearRegression

In [0]:
assembler=VectorAssembler(inputCols=[
 'treated',
 'Age',
 'Tonnage',
 'passengers',
 'length',
 'cabins',
 'passenger_density'],outputCol='features')

In [0]:
stringindexer=StringIndexer(inputCol="Cruise_line",outputCol="treated")
model_string_indexer=stringindexer.fit(data)

In [0]:
treated=model_string_indexer.transform(data)

In [0]:
treated.printSchema()

root
 |-- Ship_name: string (nullable = true)
 |-- Cruise_line: string (nullable = true)
 |-- Age: long (nullable = true)
 |-- Tonnage: double (nullable = true)
 |-- passengers: double (nullable = true)
 |-- length: double (nullable = true)
 |-- cabins: double (nullable = true)
 |-- passenger_density: double (nullable = true)
 |-- crew: double (nullable = true)
 |-- treated: double (nullable = false)



In [0]:
transformed_assembler=assembler.transform(treated)

In [0]:
transformed_assembler.show()

+-----------+-----------+---+------------------+----------+------+------+-----------------+----+-------+--------------------+
|  Ship_name|Cruise_line|Age|           Tonnage|passengers|length|cabins|passenger_density|crew|treated|            features|
+-----------+-----------+---+------------------+----------+------+------+-----------------+----+-------+--------------------+
|    Journey|    Azamara|  6|30.276999999999997|      6.94|  5.94|  3.55|            42.64|3.55|   16.0|[16.0,6.0,30.2769...|
|      Quest|    Azamara|  6|30.276999999999997|      6.94|  5.94|  3.55|            42.64|3.55|   16.0|[16.0,6.0,30.2769...|
|Celebration|   Carnival| 26|            47.262|     14.86|  7.22|  7.43|             31.8| 6.7|    1.0|[1.0,26.0,47.262,...|
|   Conquest|   Carnival| 11|             110.0|     29.74|  9.53| 14.88|            36.99|19.1|    1.0|[1.0,11.0,110.0,2...|
|    Destiny|   Carnival| 17|           101.353|     26.42|  8.92| 13.21|            38.36|10.0|    1.0|[1.0,17.0,101.

In [0]:
data_features=transformed_assembler.select(['features','crew'])

In [0]:
data_features.show()

+--------------------+----+
|            features|crew|
+--------------------+----+
|[16.0,6.0,30.2769...|3.55|
|[16.0,6.0,30.2769...|3.55|
|[1.0,26.0,47.262,...| 6.7|
|[1.0,11.0,110.0,2...|19.1|
|[1.0,17.0,101.353...|10.0|
|[1.0,22.0,70.367,...| 9.2|
|[1.0,15.0,70.367,...| 9.2|
|[1.0,23.0,70.367,...| 9.2|
|[1.0,19.0,70.367,...| 9.2|
|[1.0,6.0,110.2389...|11.5|
|[1.0,10.0,110.0,2...|11.6|
|[1.0,28.0,46.052,...| 6.6|
|[1.0,18.0,70.367,...| 9.2|
|[1.0,17.0,70.367,...| 9.2|
|[1.0,11.0,86.0,21...| 9.3|
|[1.0,8.0,110.0,29...|11.6|
|[1.0,9.0,88.5,21....|10.3|
|[1.0,15.0,70.367,...| 9.2|
|[1.0,12.0,88.5,21...| 9.3|
|[1.0,20.0,70.367,...| 9.2|
+--------------------+----+
only showing top 20 rows



In [0]:
lr_model=LinearRegression(labelCol="crew")

In [0]:
train,test=data_features.randomSplit([0.6,0.4])

In [0]:
fitted_model=lr_model.fit(train)

In [0]:
metrics=fitted_model.evaluate(test)

In [0]:
metrics.r2

Out[43]: 0.8989159375834559

In [0]:
predictions=fitted_model.transform(test)
predictions.show()

+--------------------+-----+------------------+
|            features| crew|        prediction|
+--------------------+-----+------------------+
|[0.0,5.0,160.0,36...| 13.6|14.727758201915593|
|[0.0,6.0,158.0,43...| 13.6|13.617222107512712|
|[0.0,7.0,158.0,43...| 13.6| 13.54220936931763|
|[0.0,9.0,90.09,25...| 8.69| 9.161733052945902|
|[0.0,12.0,90.09,2...| 8.68| 8.707024597524306|
|[0.0,13.0,138.0,3...|11.76|12.575567006035287|
|[0.0,15.0,78.491,...|  6.6|  8.14199355033236|
|[0.0,16.0,78.491,...| 7.65| 8.126452283679711|
|[0.0,18.0,70.0,18...|  7.2| 7.827845607805646|
|[0.0,22.0,73.941,...| 8.22|  9.08031927634848|
|[0.0,25.0,73.192,...| 8.08| 8.572507873587588|
|[1.0,8.0,110.0,29...| 11.6| 12.10890552571268|
|[1.0,9.0,110.0,29...| 11.6| 12.09793898764022|
|[1.0,11.0,110.0,2...| 19.1|12.080705834175648|
|[1.0,12.0,88.5,21...|10.29| 9.346718944996752|
|[1.0,12.0,88.5,21...|  9.3|10.329831978613726|
|[1.0,13.0,101.509...| 11.5|11.100552030428188|
|[1.0,17.0,101.353...| 10.0|10.650368183