In [0]:
import pyspark
from pyspark.sql import SparkSession
from pyspark.ml.linalg import Vectors
from pyspark.ml.feature import VectorAssembler
from pyspark.ml.regression import LinearRegression

spark = SparkSession.builder.appName('lrex').getOrCreate()

In [0]:
cruise_data = spark.read.csv('/FileStore/tables/cruise_ship_info.csv', inferSchema=True, header=True)

In [0]:
cruise_data.printSchema()

root
 |-- Ship_name: string (nullable = true)
 |-- Cruise_line: string (nullable = true)
 |-- Age: integer (nullable = true)
 |-- Tonnage: double (nullable = true)
 |-- passengers: double (nullable = true)
 |-- length: double (nullable = true)
 |-- cabins: double (nullable = true)
 |-- passenger_density: double (nullable = true)
 |-- crew: double (nullable = true)



In [0]:
cruise_data.show()

+-----------+-----------+---+------------------+----------+------+------+-----------------+----+
|  Ship_name|Cruise_line|Age|           Tonnage|passengers|length|cabins|passenger_density|crew|
+-----------+-----------+---+------------------+----------+------+------+-----------------+----+
|    Journey|    Azamara|  6|30.276999999999997|      6.94|  5.94|  3.55|            42.64|3.55|
|      Quest|    Azamara|  6|30.276999999999997|      6.94|  5.94|  3.55|            42.64|3.55|
|Celebration|   Carnival| 26|            47.262|     14.86|  7.22|  7.43|             31.8| 6.7|
|   Conquest|   Carnival| 11|             110.0|     29.74|  9.53| 14.88|            36.99|19.1|
|    Destiny|   Carnival| 17|           101.353|     26.42|  8.92| 13.21|            38.36|10.0|
|    Ecstasy|   Carnival| 22|            70.367|     20.52|  8.55|  10.2|            34.29| 9.2|
|    Elation|   Carnival| 15|            70.367|     20.52|  8.55|  10.2|            34.29| 9.2|
|    Fantasy|   Carnival| 23| 

In [0]:
cruise_data.columns

Out[18]: ['Ship_name',
 'Cruise_line',
 'Age',
 'Tonnage',
 'passengers',
 'length',
 'cabins',
 'passenger_density',
 'crew']

In [0]:
assembler = VectorAssembler(inputCols=['Age', 'Tonnage', 'passengers', 'length', 'cabins', 'passenger_density', 'crew'],
                            outputCol='features')

In [0]:
output = assembler.transform(cruise_data)

In [0]:
output.printSchema()

root
 |-- Ship_name: string (nullable = true)
 |-- Cruise_line: string (nullable = true)
 |-- Age: integer (nullable = true)
 |-- Tonnage: double (nullable = true)
 |-- passengers: double (nullable = true)
 |-- length: double (nullable = true)
 |-- cabins: double (nullable = true)
 |-- passenger_density: double (nullable = true)
 |-- crew: double (nullable = true)
 |-- features: vector (nullable = true)



In [0]:
output.show()

+-----------+-----------+---+------------------+----------+------+------+-----------------+----+--------------------+
|  Ship_name|Cruise_line|Age|           Tonnage|passengers|length|cabins|passenger_density|crew|            features|
+-----------+-----------+---+------------------+----------+------+------+-----------------+----+--------------------+
|    Journey|    Azamara|  6|30.276999999999997|      6.94|  5.94|  3.55|            42.64|3.55|[6.0,30.276999999...|
|      Quest|    Azamara|  6|30.276999999999997|      6.94|  5.94|  3.55|            42.64|3.55|[6.0,30.276999999...|
|Celebration|   Carnival| 26|            47.262|     14.86|  7.22|  7.43|             31.8| 6.7|[26.0,47.262,14.8...|
|   Conquest|   Carnival| 11|             110.0|     29.74|  9.53| 14.88|            36.99|19.1|[11.0,110.0,29.74...|
|    Destiny|   Carnival| 17|           101.353|     26.42|  8.92| 13.21|            38.36|10.0|[17.0,101.353,26....|
|    Ecstasy|   Carnival| 22|            70.367|     20.

In [0]:
final_output_data = output.select('features', 'crew')

In [0]:
final_output_data.show()

+--------------------+----+
|            features|crew|
+--------------------+----+
|[6.0,30.276999999...|3.55|
|[6.0,30.276999999...|3.55|
|[26.0,47.262,14.8...| 6.7|
|[11.0,110.0,29.74...|19.1|
|[17.0,101.353,26....|10.0|
|[22.0,70.367,20.5...| 9.2|
|[15.0,70.367,20.5...| 9.2|
|[23.0,70.367,20.5...| 9.2|
|[19.0,70.367,20.5...| 9.2|
|[6.0,110.23899999...|11.5|
|[10.0,110.0,29.74...|11.6|
|[28.0,46.052,14.5...| 6.6|
|[18.0,70.367,20.5...| 9.2|
|[17.0,70.367,20.5...| 9.2|
|[11.0,86.0,21.24,...| 9.3|
|[8.0,110.0,29.74,...|11.6|
|[9.0,88.5,21.24,9...|10.3|
|[15.0,70.367,20.5...| 9.2|
|[12.0,88.5,21.24,...| 9.3|
|[20.0,70.367,20.5...| 9.2|
+--------------------+----+
only showing top 20 rows



In [0]:
train_data, test_data = final_output_data.randomSplit([0.7, 0.3])

In [0]:
lr = LinearRegression(labelCol='crew')

In [0]:
lr_model = lr.fit(train_data)

In [0]:
test_results = lr_model.evaluate(test_data)

In [0]:
test_results.residuals.show()

+--------------------+
|           residuals|
+--------------------+
|1.598721155460225...|
|4.218847493575595...|
|4.218847493575595...|
|5.329070518200751...|
|-7.10542735760100...|
|1.776356839400250...|
|1.421085471520200...|
|1.421085471520200...|
|                 0.0|
|-1.77635683940025...|
|1.421085471520200...|
|-1.77635683940025...|
|8.881784197001252...|
|1.065814103640150...|
|-5.32907051820075...|
|3.552713678800501...|
|-3.64153152077051...|
|2.398081733190338...|
|7.993605777301127...|
|9.769962616701378...|
+--------------------+
only showing top 20 rows



In [0]:
test_results.meanSquaredError

Out[31]: 3.6806253634712e-28

In [0]:
test_results.r2

Out[30]: 1.0

In [0]:
final_output_data.describe().show()

+-------+-----------------+
|summary|             crew|
+-------+-----------------+
|  count|              158|
|   mean|7.794177215189873|
| stddev|3.503486564627034|
|    min|             0.59|
|    max|             21.0|
+-------+-----------------+



In [0]:
unlabeled_data = test_data.select('features')

In [0]:
unlabeled_data.show()

+--------------------+
|            features|
+--------------------+
|[5.0,160.0,36.34,...|
|[6.0,30.276999999...|
|[6.0,30.276999999...|
|[6.0,90.0,20.0,9....|
|[6.0,112.0,38.0,9...|
|[7.0,89.6,25.5,9....|
|[9.0,113.0,26.74,...|
|[10.0,77.0,20.16,...|
|[10.0,90.09,25.01...|
|[10.0,91.62700000...|
|[10.0,105.0,27.2,...|
|[11.0,91.0,20.32,...|
|[11.0,108.977,26....|
|[11.0,138.0,31.14...|
|[12.0,88.5,21.24,...|
|[12.0,108.865,27....|
|[13.0,25.0,3.82,5...|
|[13.0,30.27699999...|
|[13.0,61.0,13.8,7...|
|[13.0,63.0,14.4,7...|
+--------------------+
only showing top 20 rows



In [0]:
predictions = lr_model.transform(unlabeled_data)

In [0]:
predictions.show()

+--------------------+------------------+
|            features|        prediction|
+--------------------+------------------+
|[5.0,160.0,36.34,...|13.599999999999984|
|[6.0,30.276999999...|3.5499999999999576|
|[6.0,30.276999999...|3.5499999999999576|
|[6.0,90.0,20.0,9....| 8.999999999999995|
|[6.0,112.0,38.0,9...|10.900000000000007|
|[7.0,89.6,25.5,9....| 9.869999999999997|
|[9.0,113.0,26.74,...|12.379999999999987|
|[10.0,77.0,20.16,...| 8.999999999999986|
|[10.0,90.09,25.01...|              8.58|
|[10.0,91.62700000...| 9.000000000000002|
|[10.0,105.0,27.2,...|10.679999999999986|
|[11.0,91.0,20.32,...| 9.990000000000002|
|[11.0,108.977,26....|11.999999999999991|
|[11.0,138.0,31.14...|11.849999999999989|
|[12.0,88.5,21.24,...|10.290000000000004|
|[12.0,108.865,27....|10.999999999999996|
|[13.0,25.0,3.82,5...|2.9500000000000366|
|[13.0,30.27699999...| 3.999999999999976|
|[13.0,61.0,13.8,7...| 5.999999999999992|
|[13.0,63.0,14.4,7...|  5.30999999999999|
+--------------------+------------