In [1]:
from pyspark.sql import SparkSession

In [2]:
spark = SparkSession.builder.appName('consult_ship').getOrCreate()

## Import cruise ship data

In [3]:
ship_info = spark.read.csv('cruise_ship_info.csv',inferSchema=True,header=True)

In [4]:
ship_info.count()

158

In [5]:
ship_info.show(5)

+-----------+-----------+---+------------------+----------+------+------+-----------------+----+
|  Ship_name|Cruise_line|Age|           Tonnage|passengers|length|cabins|passenger_density|crew|
+-----------+-----------+---+------------------+----------+------+------+-----------------+----+
|    Journey|    Azamara|  6|30.276999999999997|      6.94|  5.94|  3.55|            42.64|3.55|
|      Quest|    Azamara|  6|30.276999999999997|      6.94|  5.94|  3.55|            42.64|3.55|
|Celebration|   Carnival| 26|            47.262|     14.86|  7.22|  7.43|             31.8| 6.7|
|   Conquest|   Carnival| 11|             110.0|     29.74|  9.53| 14.88|            36.99|19.1|
|    Destiny|   Carnival| 17|           101.353|     26.42|  8.92| 13.21|            38.36|10.0|
+-----------+-----------+---+------------------+----------+------+------+-----------------+----+
only showing top 5 rows



## Create data set with necessary features

In [6]:
from pyspark.ml.feature import StringIndexer

indexer = StringIndexer(inputCol='Cruise_line',outputCol='CruiseLineInd')
ship_info_CruiseLineInd = indexer.fit(ship_info).transform(ship_info)

In [7]:
ship_info_CruiseLineInd.show(5)

+-----------+-----------+---+------------------+----------+------+------+-----------------+----+-------------+
|  Ship_name|Cruise_line|Age|           Tonnage|passengers|length|cabins|passenger_density|crew|CruiseLineInd|
+-----------+-----------+---+------------------+----------+------+------+-----------------+----+-------------+
|    Journey|    Azamara|  6|30.276999999999997|      6.94|  5.94|  3.55|            42.64|3.55|         16.0|
|      Quest|    Azamara|  6|30.276999999999997|      6.94|  5.94|  3.55|            42.64|3.55|         16.0|
|Celebration|   Carnival| 26|            47.262|     14.86|  7.22|  7.43|             31.8| 6.7|          1.0|
|   Conquest|   Carnival| 11|             110.0|     29.74|  9.53| 14.88|            36.99|19.1|          1.0|
|    Destiny|   Carnival| 17|           101.353|     26.42|  8.92| 13.21|            38.36|10.0|          1.0|
+-----------+-----------+---+------------------+----------+------+------+-----------------+----+-------------+
o

In [8]:
ship_info_CruiseLineInd.printSchema()

root
 |-- Ship_name: string (nullable = true)
 |-- Cruise_line: string (nullable = true)
 |-- Age: integer (nullable = true)
 |-- Tonnage: double (nullable = true)
 |-- passengers: double (nullable = true)
 |-- length: double (nullable = true)
 |-- cabins: double (nullable = true)
 |-- passenger_density: double (nullable = true)
 |-- crew: double (nullable = true)
 |-- CruiseLineInd: double (nullable = true)



In [9]:
ship_info.select('Cruise_line').distinct().count()

20

In [10]:
ship_info.groupBy('Cruise_line').count().show()

+-----------------+-----+
|      Cruise_line|count|
+-----------------+-----+
|            Costa|   11|
|              P&O|    6|
|           Cunard|    3|
|Regent_Seven_Seas|    5|
|              MSC|    8|
|         Carnival|   22|
|          Crystal|    2|
|           Orient|    1|
|         Princess|   17|
|        Silversea|    4|
|         Seabourn|    3|
| Holland_American|   14|
|         Windstar|    3|
|           Disney|    2|
|        Norwegian|   13|
|          Oceania|    3|
|          Azamara|    2|
|        Celebrity|   10|
|             Star|    6|
|  Royal_Caribbean|   23|
+-----------------+-----+



## Group necessary features into single column 

In [11]:
from pyspark.ml.feature import VectorAssembler

cols = ['CruiseLineInd','Age','Tonnage','passengers','length','cabins','passenger_density']
assembler = VectorAssembler(inputCols=cols,outputCol='features')

In [12]:
data = assembler.transform(ship_info_CruiseLineInd)
data = data.select('crew','features')

In [13]:
data.dtypes

[('crew', 'double'), ('features', 'vector')]

In [14]:
data.show(5)

+----+--------------------+
|crew|            features|
+----+--------------------+
|3.55|[16.0,6.0,30.2769...|
|3.55|[16.0,6.0,30.2769...|
| 6.7|[1.0,26.0,47.262,...|
|19.1|[1.0,11.0,110.0,2...|
|10.0|[1.0,17.0,101.353...|
+----+--------------------+
only showing top 5 rows



In [15]:
data.take(1)[0][1]

DenseVector([16.0, 6.0, 30.277, 6.94, 5.94, 3.55, 42.64])

In [16]:
# Test StringType
# VectorAssembler doesn't support StringType data type
cols2 = ['CruiseLineInd','Cruise_line']
assembler2 = VectorAssembler(inputCols=cols2,outputCol='features')

In [17]:
data2 = assembler2.transform(ship_info_CruiseLineInd)

IllegalArgumentException: 'Data type StringType is not supported.'

## Split the data

In [18]:
train_data, test_data = data.randomSplit([0.7, 0.3])

## Build a model

In [19]:
from pyspark.ml.regression import LinearRegression

In [20]:
lr = LinearRegression(featuresCol='features',labelCol='crew')
lr_model = lr.fit(train_data)

In [21]:
print('Train RMSE: {}'.format(lr_model.summary.rootMeanSquaredError))
print('Train R2: {}'.format(lr_model.summary.r2))
print('\n')
lr_model.summary.predictions.show(5)
lr_model.summary.predictions.describe().show()

Train RMSE: 1.0762603001585764
Train R2: 0.8958631617670496


+----+--------------------+-------------------+
|crew|            features|         prediction|
+----+--------------------+-------------------+
|0.59|[8.0,22.0,3.341,0...|0.36877132140861457|
| 0.6|[6.0,12.0,2.329,0...| 0.5252581828003182|
|0.88|[14.0,25.0,5.35,1...|  1.612392327371942|
|1.46|[10.0,27.0,12.5,3...| 1.2217495882479186|
| 1.6|[13.0,24.0,10.0,2...|  1.767720084280832|
+----+--------------------+-------------------+
only showing top 5 rows

+-------+-----------------+-------------------+
|summary|             crew|         prediction|
+-------+-----------------+-------------------+
|  count|              104|                104|
|   mean|7.928173076923077| 7.9281730769230725|
| stddev|3.351299056755722| 3.1720061713784484|
|    min|             0.59|0.36877132140861457|
|    max|             19.1|  15.10673559510311|
+-------+-----------------+-------------------+



## Evaluate the model

In [22]:
test_results = lr_model.evaluate(test_data)

In [23]:
print('Test RMSE: {}'.format(test_results.rootMeanSquaredError))
print('Test R2: {}'.format(test_results.r2))
print('\n')
test_results.predictions.show(5)
test_results.predictions.describe().show()

Test RMSE: 0.6215505166427965
Test R2: 0.9727230753966235


+----+--------------------+-------------------+
|crew|            features|         prediction|
+----+--------------------+-------------------+
|0.59|[8.0,22.0,3.341,0...|0.37288304764400193|
|0.88|[14.0,27.0,5.35,1...| 1.5833068389995801|
| 1.6|[13.0,21.0,10.0,2...| 1.7982627023914424|
| 1.6|[13.0,27.0,10.0,2...| 1.7371774661702217|
| 2.1|[11.0,19.0,16.8,2...|  2.308006116550767|
+----+--------------------+-------------------+
only showing top 5 rows

+-------+-----------------+-------------------+
|summary|             crew|         prediction|
+-------+-----------------+-------------------+
|  count|               54|                 54|
|   mean|7.536111111111113|  7.623583826451914|
| stddev|3.798723428407837|  3.779893692648583|
|    min|             0.59|0.37288304764400193|
|    max|             21.0| 20.964017728060593|
+-------+-----------------+-------------------+



## Note: train_data and test_data split did not reserve data stratification

## The model is ready to operate on unlabeled data!

## The model is doing quite well (good RMSE/mean and high R2), let's check for some correlation

In [24]:
# getting 'corr' via DataFrame (retuen a number)
ship_info.corr('crew','passengers')

0.9152341306065384

In [26]:
from pyspark.sql.functions import corr

In [27]:
# getting 'corr' via sql.functions (return a column)
ship_info.select(corr('crew','passengers')).show()

+----------------------+
|corr(crew, passengers)|
+----------------------+
|    0.9152341306065384|
+----------------------+



In [29]:
ship_info.select(corr('crew','cabins')).show()

+------------------+
|corr(crew, cabins)|
+------------------+
|0.9508226063578497|
+------------------+

