In [1]:
import findspark
findspark.init()

In [2]:
from pyspark.context import SparkContext
from pyspark.sql import SparkSession

In [3]:
spark = SparkSession.builder.appName('LR_CurzLiner').getOrCreate()

In [4]:
df_curz = spark.read \
    .option('header', True) \
    .option('inferSchema', True) \
    .csv(path='data/cruise_ship_info.csv')

df_curz.printSchema()

root
 |-- Ship_name: string (nullable = true)
 |-- Cruise_line: string (nullable = true)
 |-- Age: integer (nullable = true)
 |-- Tonnage: double (nullable = true)
 |-- passengers: double (nullable = true)
 |-- length: double (nullable = true)
 |-- cabins: double (nullable = true)
 |-- passenger_density: double (nullable = true)
 |-- crew: double (nullable = true)



In [5]:
df_curz.show()

+-----------+-----------+---+------------------+----------+------+------+-----------------+----+
|  Ship_name|Cruise_line|Age|           Tonnage|passengers|length|cabins|passenger_density|crew|
+-----------+-----------+---+------------------+----------+------+------+-----------------+----+
|    Journey|    Azamara|  6|30.276999999999997|      6.94|  5.94|  3.55|            42.64|3.55|
|      Quest|    Azamara|  6|30.276999999999997|      6.94|  5.94|  3.55|            42.64|3.55|
|Celebration|   Carnival| 26|            47.262|     14.86|  7.22|  7.43|             31.8| 6.7|
|   Conquest|   Carnival| 11|             110.0|     29.74|  9.53| 14.88|            36.99|19.1|
|    Destiny|   Carnival| 17|           101.353|     26.42|  8.92| 13.21|            38.36|10.0|
|    Ecstasy|   Carnival| 22|            70.367|     20.52|  8.55|  10.2|            34.29| 9.2|
|    Elation|   Carnival| 15|            70.367|     20.52|  8.55|  10.2|            34.29| 9.2|
|    Fantasy|   Carnival| 23| 

### Categorical Data Analysis

In [6]:
df_curz.groupBy('Cruise_line').count().show()

+-----------------+-----+
|      Cruise_line|count|
+-----------------+-----+
|            Costa|   11|
|              P&O|    6|
|           Cunard|    3|
|Regent_Seven_Seas|    5|
|              MSC|    8|
|         Carnival|   22|
|          Crystal|    2|
|           Orient|    1|
|         Princess|   17|
|        Silversea|    4|
|         Seabourn|    3|
| Holland_American|   14|
|         Windstar|    3|
|           Disney|    2|
|        Norwegian|   13|
|          Oceania|    3|
|          Azamara|    2|
|        Celebrity|   10|
|             Star|    6|
|  Royal_Caribbean|   23|
+-----------------+-----+



In [7]:
from pyspark.ml.feature import StringIndexer

In [8]:
indexer = StringIndexer(inputCol='Cruise_line', outputCol='Cruise_line_indexed')

In [9]:
df_curz_indexed = indexer.fit(df_curz).transform(df_curz)
df_curz_indexed.show()

+-----------+-----------+---+------------------+----------+------+------+-----------------+----+-------------------+
|  Ship_name|Cruise_line|Age|           Tonnage|passengers|length|cabins|passenger_density|crew|Cruise_line_indexed|
+-----------+-----------+---+------------------+----------+------+------+-----------------+----+-------------------+
|    Journey|    Azamara|  6|30.276999999999997|      6.94|  5.94|  3.55|            42.64|3.55|               16.0|
|      Quest|    Azamara|  6|30.276999999999997|      6.94|  5.94|  3.55|            42.64|3.55|               16.0|
|Celebration|   Carnival| 26|            47.262|     14.86|  7.22|  7.43|             31.8| 6.7|                1.0|
|   Conquest|   Carnival| 11|             110.0|     29.74|  9.53| 14.88|            36.99|19.1|                1.0|
|    Destiny|   Carnival| 17|           101.353|     26.42|  8.92| 13.21|            38.36|10.0|                1.0|
|    Ecstasy|   Carnival| 22|            70.367|     20.52|  8.5

### Vectorizing Data

In [10]:
from pyspark.ml.linalg import Vectors
from pyspark.ml.feature import VectorAssembler

In [11]:
df_curz_indexed.columns

['Ship_name',
 'Cruise_line',
 'Age',
 'Tonnage',
 'passengers',
 'length',
 'cabins',
 'passenger_density',
 'crew',
 'Cruise_line_indexed']

In [12]:
assembler = VectorAssembler(
    inputCols=['Age', 'Tonnage', 'passengers', 'length', 'cabins', 'passenger_density', 'Cruise_line_indexed'],
    outputCol='features'
)

In [13]:
df_curz_vec = assembler.transform(df_curz_indexed)
df_curz_vec.show()

+-----------+-----------+---+------------------+----------+------+------+-----------------+----+-------------------+--------------------+
|  Ship_name|Cruise_line|Age|           Tonnage|passengers|length|cabins|passenger_density|crew|Cruise_line_indexed|            features|
+-----------+-----------+---+------------------+----------+------+------+-----------------+----+-------------------+--------------------+
|    Journey|    Azamara|  6|30.276999999999997|      6.94|  5.94|  3.55|            42.64|3.55|               16.0|[6.0,30.276999999...|
|      Quest|    Azamara|  6|30.276999999999997|      6.94|  5.94|  3.55|            42.64|3.55|               16.0|[6.0,30.276999999...|
|Celebration|   Carnival| 26|            47.262|     14.86|  7.22|  7.43|             31.8| 6.7|                1.0|[26.0,47.262,14.8...|
|   Conquest|   Carnival| 11|             110.0|     29.74|  9.53| 14.88|            36.99|19.1|                1.0|[11.0,110.0,29.74...|
|    Destiny|   Carnival| 17|     

In [14]:
df_curz_vec.select('features').show(truncate=False)

+--------------------------------------------------+
|features                                          |
+--------------------------------------------------+
|[6.0,30.276999999999997,6.94,5.94,3.55,42.64,16.0]|
|[6.0,30.276999999999997,6.94,5.94,3.55,42.64,16.0]|
|[26.0,47.262,14.86,7.22,7.43,31.8,1.0]            |
|[11.0,110.0,29.74,9.53,14.88,36.99,1.0]           |
|[17.0,101.353,26.42,8.92,13.21,38.36,1.0]         |
|[22.0,70.367,20.52,8.55,10.2,34.29,1.0]           |
|[15.0,70.367,20.52,8.55,10.2,34.29,1.0]           |
|[23.0,70.367,20.56,8.55,10.22,34.23,1.0]          |
|[19.0,70.367,20.52,8.55,10.2,34.29,1.0]           |
|[6.0,110.23899999999999,37.0,9.51,14.87,29.79,1.0]|
|[10.0,110.0,29.74,9.51,14.87,36.99,1.0]           |
|[28.0,46.052,14.52,7.27,7.26,31.72,1.0]           |
|[18.0,70.367,20.52,8.55,10.2,34.29,1.0]           |
|[17.0,70.367,20.52,8.55,10.2,34.29,1.0]           |
|[11.0,86.0,21.24,9.63,10.62,40.49,1.0]            |
|[8.0,110.0,29.74,9.51,14.87,36.99,1.0]       

In [16]:
df_curz_vec_final = df_curz_vec.select(['features', 'crew'])
df_curz_vec_final.show(truncate=False)

+--------------------------------------------------+----+
|features                                          |crew|
+--------------------------------------------------+----+
|[6.0,30.276999999999997,6.94,5.94,3.55,42.64,16.0]|3.55|
|[6.0,30.276999999999997,6.94,5.94,3.55,42.64,16.0]|3.55|
|[26.0,47.262,14.86,7.22,7.43,31.8,1.0]            |6.7 |
|[11.0,110.0,29.74,9.53,14.88,36.99,1.0]           |19.1|
|[17.0,101.353,26.42,8.92,13.21,38.36,1.0]         |10.0|
|[22.0,70.367,20.52,8.55,10.2,34.29,1.0]           |9.2 |
|[15.0,70.367,20.52,8.55,10.2,34.29,1.0]           |9.2 |
|[23.0,70.367,20.56,8.55,10.22,34.23,1.0]          |9.2 |
|[19.0,70.367,20.52,8.55,10.2,34.29,1.0]           |9.2 |
|[6.0,110.23899999999999,37.0,9.51,14.87,29.79,1.0]|11.5|
|[10.0,110.0,29.74,9.51,14.87,36.99,1.0]           |11.6|
|[28.0,46.052,14.52,7.27,7.26,31.72,1.0]           |6.6 |
|[18.0,70.367,20.52,8.55,10.2,34.29,1.0]           |9.2 |
|[17.0,70.367,20.52,8.55,10.2,34.29,1.0]           |9.2 |
|[11.0,86.0,21

### Model bulding

In [18]:
df_curz_train, df_curz_test = df_curz_vec_final.randomSplit([0.7, 0.3])
df_curz_train.count(), df_curz_test.count(), df_curz_vec_final.count()

(100, 58, 158)

In [19]:
from pyspark.ml.regression import LinearRegression

In [20]:
lr_regressor = LinearRegression(featuresCol='features', labelCol='crew')

In [21]:
mod_lr = lr_regressor.fit(df_curz_train)

In [22]:
mod_lr_results = mod_lr.evaluate(df_curz_test)

In [23]:
mod_lr_results.rootMeanSquaredError

0.7841056116044544

In [24]:
mod_lr_results.r2

0.9437609091202274

In [25]:
mod_lr_results.meanAbsoluteError

0.6256891413603044

In [26]:
from pyspark.sql.functions import corr

In [27]:
df_curz.select(corr('crew', 'passengers')).show()

+----------------------+
|corr(crew, passengers)|
+----------------------+
|    0.9152341306065384|
+----------------------+

