In [1]:
from pyspark.sql import SparkSession

In [2]:
spark = SparkSession.builder.appName('test').getOrCreate()

In [3]:
data = spark.read.csv('cruise_ship_info.csv', inferSchema = True, header = True)

In [4]:
data.show()

+-----------+-----------+---+------------------+----------+------+------+-----------------+----+
|  Ship_name|Cruise_line|Age|           Tonnage|passengers|length|cabins|passenger_density|crew|
+-----------+-----------+---+------------------+----------+------+------+-----------------+----+
|    Journey|    Azamara|  6|30.276999999999997|      6.94|  5.94|  3.55|            42.64|3.55|
|      Quest|    Azamara|  6|30.276999999999997|      6.94|  5.94|  3.55|            42.64|3.55|
|Celebration|   Carnival| 26|            47.262|     14.86|  7.22|  7.43|             31.8| 6.7|
|   Conquest|   Carnival| 11|             110.0|     29.74|  9.53| 14.88|            36.99|19.1|
|    Destiny|   Carnival| 17|           101.353|     26.42|  8.92| 13.21|            38.36|10.0|
|    Ecstasy|   Carnival| 22|            70.367|     20.52|  8.55|  10.2|            34.29| 9.2|
|    Elation|   Carnival| 15|            70.367|     20.52|  8.55|  10.2|            34.29| 9.2|
|    Fantasy|   Carnival| 23| 

In [5]:
data.printSchema()

root
 |-- Ship_name: string (nullable = true)
 |-- Cruise_line: string (nullable = true)
 |-- Age: integer (nullable = true)
 |-- Tonnage: double (nullable = true)
 |-- passengers: double (nullable = true)
 |-- length: double (nullable = true)
 |-- cabins: double (nullable = true)
 |-- passenger_density: double (nullable = true)
 |-- crew: double (nullable = true)



In [6]:
data.describe().show()

+-------+---------+-----------+------------------+------------------+-----------------+-----------------+------------------+-----------------+-----------------+
|summary|Ship_name|Cruise_line|               Age|           Tonnage|       passengers|           length|            cabins|passenger_density|             crew|
+-------+---------+-----------+------------------+------------------+-----------------+-----------------+------------------+-----------------+-----------------+
|  count|      158|        158|               158|               158|              158|              158|               158|              158|              158|
|   mean| Infinity|       NULL|15.689873417721518| 71.28467088607599|18.45740506329114|8.130632911392404| 8.830000000000005|39.90094936708861|7.794177215189873|
| stddev|     NULL|       NULL| 7.615691058751413|37.229540025907866|9.677094775143416|1.793473548054825|4.4714172221480615| 8.63921711391542|3.503486564627034|
|    min|Adventure|    Azamara|   

### String to Index

In [8]:
from pyspark.ml.feature import StringIndexer

In [9]:
df = data.select('Cruise_line')

In [10]:
df.show(5)

+-----------+
|Cruise_line|
+-----------+
|    Azamara|
|    Azamara|
|   Carnival|
|   Carnival|
|   Carnival|
+-----------+
only showing top 5 rows



In [11]:
indexer = StringIndexer(inputCol = 'Cruise_line', outputCol = 'Cruise_line_indx')

In [12]:
indexed = indexer.fit(df).transform(data)

In [13]:
indexed.show()

+-----------+-----------+---+------------------+----------+------+------+-----------------+----+----------------+
|  Ship_name|Cruise_line|Age|           Tonnage|passengers|length|cabins|passenger_density|crew|Cruise_line_indx|
+-----------+-----------+---+------------------+----------+------+------+-----------------+----+----------------+
|    Journey|    Azamara|  6|30.276999999999997|      6.94|  5.94|  3.55|            42.64|3.55|            16.0|
|      Quest|    Azamara|  6|30.276999999999997|      6.94|  5.94|  3.55|            42.64|3.55|            16.0|
|Celebration|   Carnival| 26|            47.262|     14.86|  7.22|  7.43|             31.8| 6.7|             1.0|
|   Conquest|   Carnival| 11|             110.0|     29.74|  9.53| 14.88|            36.99|19.1|             1.0|
|    Destiny|   Carnival| 17|           101.353|     26.42|  8.92| 13.21|            38.36|10.0|             1.0|
|    Ecstasy|   Carnival| 22|            70.367|     20.52|  8.55|  10.2|            34.

In [14]:
df = indexed

In [15]:
df.show(2)

+---------+-----------+---+------------------+----------+------+------+-----------------+----+----------------+
|Ship_name|Cruise_line|Age|           Tonnage|passengers|length|cabins|passenger_density|crew|Cruise_line_indx|
+---------+-----------+---+------------------+----------+------+------+-----------------+----+----------------+
|  Journey|    Azamara|  6|30.276999999999997|      6.94|  5.94|  3.55|            42.64|3.55|            16.0|
|    Quest|    Azamara|  6|30.276999999999997|      6.94|  5.94|  3.55|            42.64|3.55|            16.0|
+---------+-----------+---+------------------+----------+------+------+-----------------+----+----------------+
only showing top 2 rows



In [16]:
df.printSchema()

root
 |-- Ship_name: string (nullable = true)
 |-- Cruise_line: string (nullable = true)
 |-- Age: integer (nullable = true)
 |-- Tonnage: double (nullable = true)
 |-- passengers: double (nullable = true)
 |-- length: double (nullable = true)
 |-- cabins: double (nullable = true)
 |-- passenger_density: double (nullable = true)
 |-- crew: double (nullable = true)
 |-- Cruise_line_indx: double (nullable = false)



In [17]:
df.columns

['Ship_name',
 'Cruise_line',
 'Age',
 'Tonnage',
 'passengers',
 'length',
 'cabins',
 'passenger_density',
 'crew',
 'Cruise_line_indx']

In [18]:
from pyspark.ml.regression import LinearRegression
from pyspark.ml.linalg import Vectors
from pyspark.ml.feature import VectorAssembler

In [19]:
assembler = VectorAssembler(inputCols = [ 'Age', 'Tonnage', 'passengers', 'length',
                            'cabins', 'passenger_density', 'Cruise_line_indx'],
                            outputCol = 'features')

In [20]:
output = assembler.transform(df)

In [21]:
output.show()

+-----------+-----------+---+------------------+----------+------+------+-----------------+----+----------------+--------------------+
|  Ship_name|Cruise_line|Age|           Tonnage|passengers|length|cabins|passenger_density|crew|Cruise_line_indx|            features|
+-----------+-----------+---+------------------+----------+------+------+-----------------+----+----------------+--------------------+
|    Journey|    Azamara|  6|30.276999999999997|      6.94|  5.94|  3.55|            42.64|3.55|            16.0|[6.0,30.276999999...|
|      Quest|    Azamara|  6|30.276999999999997|      6.94|  5.94|  3.55|            42.64|3.55|            16.0|[6.0,30.276999999...|
|Celebration|   Carnival| 26|            47.262|     14.86|  7.22|  7.43|             31.8| 6.7|             1.0|[26.0,47.262,14.8...|
|   Conquest|   Carnival| 11|             110.0|     29.74|  9.53| 14.88|            36.99|19.1|             1.0|[11.0,110.0,29.74...|
|    Destiny|   Carnival| 17|           101.353|     26

In [22]:
final_data = output.select('features', 'Crew')

### Model Building

In [24]:
train_data, test_data = final_data.randomSplit([0.7, 0.3])

In [25]:
train_data.describe().show()

+-------+-----------------+
|summary|             Crew|
+-------+-----------------+
|  count|              112|
|   mean| 8.00125000000001|
| stddev|3.639712176770263|
|    min|             0.59|
|    max|             21.0|
+-------+-----------------+



In [26]:
test_data.describe().show()

+-------+------------------+
|summary|              Crew|
+-------+------------------+
|  count|                46|
|   mean| 7.290000000000001|
| stddev|3.1273531584676797|
|    min|              1.46|
|    max|              13.6|
+-------+------------------+



In [27]:
lr = LinearRegression(labelCol = 'Crew')

In [28]:
lr_model = lr.fit(train_data)

### Evaluating the model

In [30]:
test_results = lr_model.evaluate(test_data)

In [31]:
test_results.residuals.show()

+--------------------+
|           residuals|
+--------------------+
|  0.2454646457347227|
|   2.049862020971175|
| -1.3769134427142227|
| 0.38564099948240305|
| -0.8859164439555656|
|-0.32889496143448405|
| 0.35416273180043234|
|  0.7453722973787915|
| -0.5057481472598564|
|-0.07992679041674933|
|   2.201822162460612|
|-0.05423217423086...|
|   0.748526974338743|
|0.057215306138388655|
|  0.9678114675880565|
|  0.1946642154117626|
|  0.3356682345450497|
| -0.8043309291061167|
|   0.045006877887797|
|-0.30533771334581106|
+--------------------+
only showing top 20 rows



### Root Mean Squared Error

In [33]:
test_results.rootMeanSquaredError

0.7077411174355042

### R2 value

In [35]:
test_results.r2

0.9476471512272073

In [36]:
final_data.describe().show()

+-------+-----------------+
|summary|             Crew|
+-------+-----------------+
|  count|              158|
|   mean|7.794177215189873|
| stddev|3.503486564627034|
|    min|             0.59|
|    max|             21.0|
+-------+-----------------+



### Predictions

In [38]:
unlabeled_data = test_data.select('features')

In [39]:
predictions = lr_model.transform(unlabeled_data)

In [40]:
predictions.show()

+--------------------+------------------+
|            features|        prediction|
+--------------------+------------------+
|[5.0,115.0,35.74,...|11.954535354265277|
|[5.0,122.0,28.5,1...| 4.650137979028825|
|[5.0,160.0,36.34,...|14.976913442714222|
|[6.0,110.23899999...|11.114359000517597|
|[7.0,116.0,31.0,9...|12.885916443955566|
|[9.0,59.058,17.0,...| 7.728894961434484|
|[9.0,81.0,21.44,9...| 9.645837268199568|
|[9.0,88.5,21.24,9...|  9.55462770262121|
|[10.0,86.0,21.14,...| 9.705748147259856|
|[10.0,90.09,25.01...|  8.65992679041675|
|[10.0,151.4,26.2,...|10.328177837539387|
|[11.0,91.62700000...| 9.054232174230862|
|[12.0,77.104,20.0...| 8.841473025661257|
|[12.0,90.09,25.01...| 8.622784693861611|
|[12.0,91.0,20.32,...| 9.022188532411944|
|[13.0,101.509,27....|11.305335784588237|
|[14.0,30.27699999...|3.3943317654549503|
|[14.0,77.104,20.0...| 8.804330929106117|
|[15.0,30.27699999...| 3.954993122112203|
|[16.0,19.2,3.2,5....| 2.415337713345811|
+--------------------+------------

### Correlation Check

In [42]:
from pyspark.sql.functions import corr

In [43]:
df.select(corr('crew','passengers')).show()

+----------------------+
|corr(crew, passengers)|
+----------------------+
|    0.9152341306065384|
+----------------------+

