In [2]:
#Import Libraries
import pandas as pd
from pyspark.sql import SparkSession 
from pyspark.ml.regression import LinearRegression
from pyspark.ml.linalg import Vectors
from pyspark.ml.feature import VectorAssembler, StringIndexer

In [3]:
spark=SparkSession.builder.appName('cruise_ship').getOrCreate()

21/10/11 23:50:28 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable
Using Spark's default log4j profile: org/apache/spark/log4j-defaults.properties
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
21/10/11 23:50:28 WARN Utils: Service 'SparkUI' could not bind on port 4040. Attempting port 4041.
21/10/11 23:50:28 WARN Utils: Service 'SparkUI' could not bind on port 4041. Attempting port 4042.
21/10/11 23:50:28 WARN Utils: Service 'SparkUI' could not bind on port 4042. Attempting port 4043.
21/10/11 23:50:28 WARN Utils: Service 'SparkUI' could not bind on port 4043. Attempting port 4044.


In [13]:
data = spark.read.csv('cruise_ship_info.csv', inferSchema=True,
                     header=True)

In [15]:
data.head(1)[0]

Row(Ship_name='Journey', Cruise_line='Azamara', Age=6, Tonnage=30.276999999999997, passengers=6.94, length=5.94, cabins=3.55, passenger_density=42.64, crew=3.55)

In [16]:
data.printSchema()

root
 |-- Ship_name: string (nullable = true)
 |-- Cruise_line: string (nullable = true)
 |-- Age: integer (nullable = true)
 |-- Tonnage: double (nullable = true)
 |-- passengers: double (nullable = true)
 |-- length: double (nullable = true)
 |-- cabins: double (nullable = true)
 |-- passenger_density: double (nullable = true)
 |-- crew: double (nullable = true)



In [55]:
#EDA
data.groupBy('Cruise_line').count().show()

+-----------------+-----+
|      Cruise_line|count|
+-----------------+-----+
|            Costa|   11|
|              P&O|    6|
|           Cunard|    3|
|Regent_Seven_Seas|    5|
|              MSC|    8|
|         Carnival|   22|
|          Crystal|    2|
|           Orient|    1|
|         Princess|   17|
|        Silversea|    4|
|         Seabourn|    3|
| Holland_American|   14|
|         Windstar|    3|
|           Disney|    2|
|        Norwegian|   13|
|          Oceania|    3|
|          Azamara|    2|
|        Celebrity|   10|
|             Star|    6|
|  Royal_Caribbean|   23|
+-----------------+-----+



In [17]:
#PREPROCESS THE DATA
data.columns

['Ship_name',
 'Cruise_line',
 'Age',
 'Tonnage',
 'passengers',
 'length',
 'cabins',
 'passenger_density',
 'crew']

In [18]:
string_indexer = StringIndexer(inputCol='Cruise_line', outputCol='Cruise_line_digital')

In [21]:
output_data=string_indexer.fit(data)

In [22]:
output = output_data.transform(data)

In [None]:
#output = string_indexer.fit(data).transform(data)

In [24]:
output.columns

['Ship_name',
 'Cruise_line',
 'Age',
 'Tonnage',
 'passengers',
 'length',
 'cabins',
 'passenger_density',
 'crew',
 'Cruise_line_digital']

In [25]:
assembeler = VectorAssembler(inputCols=['Age',
 'Tonnage',
 'passengers',
 'length',
 'cabins',
 'passenger_density',
 'Cruise_line_digital'
], outputCol='features')

In [26]:
final_output = assembeler.transform(output)

In [32]:
final_output.printSchema()

root
 |-- Ship_name: string (nullable = true)
 |-- Cruise_line: string (nullable = true)
 |-- Age: integer (nullable = true)
 |-- Tonnage: double (nullable = true)
 |-- passengers: double (nullable = true)
 |-- length: double (nullable = true)
 |-- cabins: double (nullable = true)
 |-- passenger_density: double (nullable = true)
 |-- crew: double (nullable = true)
 |-- Cruise_line_digital: double (nullable = false)
 |-- features: vector (nullable = true)



In [33]:
final_output.head(1)

[Row(Ship_name='Journey', Cruise_line='Azamara', Age=6, Tonnage=30.276999999999997, passengers=6.94, length=5.94, cabins=3.55, passenger_density=42.64, crew=3.55, Cruise_line_digital=16.0, features=DenseVector([6.0, 30.277, 6.94, 5.94, 3.55, 42.64, 16.0]))]

In [34]:
final_data = final_output.select('features','crew')

In [35]:
final_data.show()

+--------------------+----+
|            features|crew|
+--------------------+----+
|[6.0,30.276999999...|3.55|
|[6.0,30.276999999...|3.55|
|[26.0,47.262,14.8...| 6.7|
|[11.0,110.0,29.74...|19.1|
|[17.0,101.353,26....|10.0|
|[22.0,70.367,20.5...| 9.2|
|[15.0,70.367,20.5...| 9.2|
|[23.0,70.367,20.5...| 9.2|
|[19.0,70.367,20.5...| 9.2|
|[6.0,110.23899999...|11.5|
|[10.0,110.0,29.74...|11.6|
|[28.0,46.052,14.5...| 6.6|
|[18.0,70.367,20.5...| 9.2|
|[17.0,70.367,20.5...| 9.2|
|[11.0,86.0,21.24,...| 9.3|
|[8.0,110.0,29.74,...|11.6|
|[9.0,88.5,21.24,9...|10.3|
|[15.0,70.367,20.5...| 9.2|
|[12.0,88.5,21.24,...| 9.3|
|[20.0,70.367,20.5...| 9.2|
+--------------------+----+
only showing top 20 rows



In [37]:
#TRAIN_TEST SPLIT
train_data, test_data = final_data.randomSplit([0.7, 0.3])

In [38]:
train_data.describe().show(), test_data.describe().show()

+-------+-----------------+
|summary|             crew|
+-------+-----------------+
|  count|              104|
|   mean|8.007211538461549|
| stddev|3.406581424736638|
|    min|             0.59|
|    max|             19.1|
+-------+-----------------+

+-------+------------------+
|summary|              crew|
+-------+------------------+
|  count|                54|
|   mean|7.3838888888888885|
| stddev|3.6805954279712023|
|    min|              0.59|
|    max|              21.0|
+-------+------------------+



(None, None)

In [39]:
# INITIATE A MODEL
lr = LinearRegression(labelCol='crew')

In [40]:
lr_model = lr.fit(train_data)

21/10/12 00:10:00 WARN Instrumentation: [a3ba9dde] regParam is zero, which might cause numerical instability and overfitting.
21/10/12 00:10:00 WARN BLAS: Failed to load implementation from: com.github.fommil.netlib.NativeSystemBLAS
21/10/12 00:10:00 WARN BLAS: Failed to load implementation from: com.github.fommil.netlib.NativeRefBLAS
21/10/12 00:10:00 WARN LAPACK: Failed to load implementation from: com.github.fommil.netlib.NativeSystemLAPACK
21/10/12 00:10:00 WARN LAPACK: Failed to load implementation from: com.github.fommil.netlib.NativeRefLAPACK


In [41]:
test_results = lr_model.evaluate(test_data)

In [44]:
test_results.residuals.show()

+--------------------+
|           residuals|
+--------------------+
| 0.45215536983291216|
|-0.01541977003099...|
| -0.9770321501001469|
| -1.2457232181584565|
|  0.4576849161170209|
|-0.08507351417493858|
|  0.3617487528821073|
|-0.29360256071328195|
|  0.6762148069799778|
|  1.1007048459480444|
|   1.608258269989578|
|-0.25092253103310824|
| -0.4677553346500609|
|-0.33764932902613687|
|-0.27485323856289945|
|-0.41710145254640274|
|  0.7785634085441018|
| -1.1072244323296303|
|-0.19655357606666968|
|0.025775992425545624|
+--------------------+
only showing top 20 rows



In [45]:
test_results.rootMeanSquaredError, test_results.r2

(0.8655372595977698, 0.9436552752304381)

In [48]:
unlabeled_data = test_data.select('features')

In [50]:
unlabeled_data.show()

+--------------------+
|            features|
+--------------------+
|[4.0,220.0,54.0,1...|
|[5.0,133.5,39.59,...|
|[6.0,30.276999999...|
|[6.0,90.0,20.0,9....|
|[6.0,110.23899999...|
|[7.0,158.0,43.7,1...|
|[8.0,77.499,19.5,...|
|[9.0,59.058,17.0,...|
|[9.0,88.5,21.24,9...|
|[9.0,113.0,26.74,...|
|[10.0,46.0,7.0,6....|
|[10.0,68.0,10.8,7...|
|[10.0,81.76899999...|
|[10.0,90.09,25.01...|
|[11.0,86.0,21.24,...|
|[11.0,90.09,25.01...|
|[12.0,88.5,21.24,...|
|[12.0,88.5,21.24,...|
|[12.0,90.09,25.01...|
|[13.0,25.0,3.82,5...|
+--------------------+
only showing top 20 rows



In [52]:
predictions = lr_model.transform(unlabeled_data)

In [54]:
predictions.show()

+--------------------+------------------+
|            features|        prediction|
+--------------------+------------------+
|[4.0,220.0,54.0,1...|20.547844630167088|
|[5.0,133.5,39.59,...|   13.145419770031|
|[6.0,30.276999999...| 4.527032150100147|
|[6.0,90.0,20.0,9....|10.245723218158457|
|[6.0,110.23899999...|11.042315083882979|
|[7.0,158.0,43.7,1...|13.685073514174938|
|[8.0,77.499,19.5,...| 8.638251247117893|
|[9.0,59.058,17.0,...| 7.693602560713282|
|[9.0,88.5,21.24,9...| 9.623785193020023|
|[9.0,113.0,26.74,...|11.279295154051956|
|[10.0,46.0,7.0,6....|2.8617417300104218|
|[10.0,68.0,10.8,7...| 6.610922531033109|
|[10.0,81.76899999...|  8.88775533465006|
|[10.0,90.09,25.01...| 8.917649329026137|
|[11.0,86.0,21.24,...|   9.5748532385629|
|[11.0,90.09,25.01...| 8.897101452546403|
|[12.0,88.5,21.24,...| 9.511436591455897|
|[12.0,88.5,21.24,...|10.407224432329631|
|[12.0,90.09,25.01...|  8.87655357606667|
|[13.0,25.0,3.82,5...|2.9242240075744546|
+--------------------+------------