In [1]:
from pyspark.sql import SparkSession

In [2]:
spark = SparkSession.builder.appName("ExampleRegression").getOrCreate()

In [3]:
raw_data = spark.read.csv('cruise_ship_info.csv', inferSchema=True, header=True)

In [4]:
raw_data.printSchema()

root
 |-- Ship_name: string (nullable = true)
 |-- Cruise_line: string (nullable = true)
 |-- Age: integer (nullable = true)
 |-- Tonnage: double (nullable = true)
 |-- passengers: double (nullable = true)
 |-- length: double (nullable = true)
 |-- cabins: double (nullable = true)
 |-- passenger_density: double (nullable = true)
 |-- crew: double (nullable = true)



In [6]:
raw_data.describe().show(vertical=True)

-RECORD 0-------------------------------
 summary           | count              
 Ship_name         | 158                
 Cruise_line       | 158                
 Age               | 158                
 Tonnage           | 158                
 passengers        | 158                
 length            | 158                
 cabins            | 158                
 passenger_density | 158                
 crew              | 158                
-RECORD 1-------------------------------
 summary           | mean               
 Ship_name         | Infinity           
 Cruise_line       | null               
 Age               | 15.689873417721518 
 Tonnage           | 71.28467088607599  
 passengers        | 18.45740506329114  
 length            | 8.130632911392404  
 cabins            | 8.830000000000005  
 passenger_density | 39.90094936708861  
 crew              | 7.794177215189873  
-RECORD 2-------------------------------
 summary           | stddev             
 Ship_name      

In [7]:
raw_data.show()

+-----------+-----------+---+------------------+----------+------+------+-----------------+----+
|  Ship_name|Cruise_line|Age|           Tonnage|passengers|length|cabins|passenger_density|crew|
+-----------+-----------+---+------------------+----------+------+------+-----------------+----+
|    Journey|    Azamara|  6|30.276999999999997|      6.94|  5.94|  3.55|            42.64|3.55|
|      Quest|    Azamara|  6|30.276999999999997|      6.94|  5.94|  3.55|            42.64|3.55|
|Celebration|   Carnival| 26|            47.262|     14.86|  7.22|  7.43|             31.8| 6.7|
|   Conquest|   Carnival| 11|             110.0|     29.74|  9.53| 14.88|            36.99|19.1|
|    Destiny|   Carnival| 17|           101.353|     26.42|  8.92| 13.21|            38.36|10.0|
|    Ecstasy|   Carnival| 22|            70.367|     20.52|  8.55|  10.2|            34.29| 9.2|
|    Elation|   Carnival| 15|            70.367|     20.52|  8.55|  10.2|            34.29| 9.2|
|    Fantasy|   Carnival| 23| 

In [8]:
from pyspark.ml.feature import StringIndexer

  return f(*args, **kwds)


In [9]:
str_indexer = StringIndexer(inputCol='Cruise_line', outputCol='cline_index')























In [10]:
cline_indexed_data = str_indexer.fit(raw_data).transform(raw_data)


























In [13]:
cline_indexed_data.show()

+-----------+-----------+---+------------------+----------+------+------+-----------------+----+-----------+
|  Ship_name|Cruise_line|Age|           Tonnage|passengers|length|cabins|passenger_density|crew|cline_index|
+-----------+-----------+---+------------------+----------+------+------+-----------------+----+-----------+
|    Journey|    Azamara|  6|30.276999999999997|      6.94|  5.94|  3.55|            42.64|3.55|       16.0|
|      Quest|    Azamara|  6|30.276999999999997|      6.94|  5.94|  3.55|            42.64|3.55|       16.0|
|Celebration|   Carnival| 26|            47.262|     14.86|  7.22|  7.43|             31.8| 6.7|        1.0|
|   Conquest|   Carnival| 11|             110.0|     29.74|  9.53| 14.88|            36.99|19.1|        1.0|
|    Destiny|   Carnival| 17|           101.353|     26.42|  8.92| 13.21|            38.36|10.0|        1.0|
|    Ecstasy|   Carnival| 22|            70.367|     20.52|  8.55|  10.2|            34.29| 9.2|        1.0|
|    Elation|   Car

In [14]:
from pyspark.sql.functions import corr

In [21]:
cline_indexed_data.select(corr('crew','passenger_density')).show()










+-----------------------------+
|corr(crew, passenger_density)|
+-----------------------------+
|         -0.15550928421699717|
+-----------------------------+



In [23]:
cline_indexed_data.columns


['Ship_name',
 'Cruise_line',
 'Age',
 'Tonnage',
 'passengers',
 'length',
 'cabins',
 'passenger_density',
 'crew',
 'cline_index']

In [24]:
required_features = cline_indexed_data.select(['Ship_name',
 'Tonnage',
 'passengers',
 'length',
 'cabins',
 'crew',
 'cline_index'])

In [25]:
required_features.printSchema()

root
 |-- Ship_name: string (nullable = true)
 |-- Tonnage: double (nullable = true)
 |-- passengers: double (nullable = true)
 |-- length: double (nullable = true)
 |-- cabins: double (nullable = true)
 |-- crew: double (nullable = true)
 |-- cline_index: double (nullable = false)



In [26]:
from pyspark.ml.feature import VectorAssembler

























In [27]:
vect_assembler = VectorAssembler(inputCols=['Tonnage',
                                           'passengers',
                                           'length',
                                           'cabins',
                                           'cline_index'],
                                outputCol='features')


























In [28]:
feature_assembled_data = vect_assembler.transform(required_features)


























In [31]:
feature_assembled_data.show(vertical=True, truncate=False)

























-RECORD 0-----------------------------------------------
 Ship_name   | Journey                                  
 Tonnage     | 30.276999999999997                       
 passengers  | 6.94                                     
 length      | 5.94                                     
 cabins      | 3.55                                     
 crew        | 3.55                                     
 cline_index | 16.0                                     
 features    | [30.276999999999997,6.94,5.94,3.55,16.0] 
-RECORD 1-----------------------------------------------
 Ship_name   | Quest                                    
 Tonnage     | 30.276999999999997                       
 passengers  | 6.94                                     
 length      | 5.94                                     
 cabins      | 3.55                                     
 crew        | 3.55                                     
 cline_index | 16.0                                     
 features    | [30.276999999999

In [32]:
train_set, test_set = feature_assembled_data.randomSplit([0.7, 0.3])
































In [34]:
test_set.count()





























44

In [35]:
from pyspark.ml.regression import LinearRegression



























In [36]:
lin_regressor = LinearRegression(featuresCol='features', labelCol='crew')


























In [37]:
lin_reg_model = lin_regressor.fit(train_set)



























In [38]:
eval_results = lin_reg_model.evaluate(test_set)


























In [41]:
eval_results.rootMeanSquaredError





























1.2459361682438879

In [42]:
unlabelled_data = test_set.select('Ship_name','features','crew').withColumnRenamed('crew','actualCrew')























In [43]:
unlabelled_data.show()
























+-------------+--------------------+----------+
|    Ship_name|            features|actualCrew|
+-------------+--------------------+----------+
|        Aries|[3.341,0.66,2.8,0...|      0.59|
|      Armonia|[58.6,15.66,8.24,...|       7.0|
|    Atlantica|[85.619,21.14,9.5...|       9.2|
|     Conquest|[110.0,29.74,9.53...|      19.1|
|Constellation|[91.0,20.32,9.65,...|      9.99|
|        Crown|[116.0,31.0,9.51,...|      12.0|
|      Destiny|[101.353,26.42,8....|      10.0|
|      Emerald|[113.0,37.82,9.51...|      12.0|
|       Europa|[53.872,14.94,7.9...|      6.36|
|     Fantasia|[133.5,39.59,10.9...|     13.13|
|       Galaxy|[77.7130000000000...|      9.09|
|      Holiday|[46.052,14.52,7.2...|       6.6|
|  Imagination|[70.367,20.52,8.5...|       9.2|
|  Inspiration|[70.367,20.52,8.5...|       9.2|
|       Island|[91.6270000000000...|       9.0|
|      Liberty|[158.0,43.7,11.25...|      13.6|
|       Lirica|[58.825,15.6,8.23...|       7.0|
|      Majesty|[73.941,27.44,8.8...|    

In [44]:
predictions = lin_reg_model.transform(unlabelled_data)


























In [45]:
predictions.show(n=100)


























+--------------+--------------------+----------+------------------+
|     Ship_name|            features|actualCrew|        prediction|
+--------------+--------------------+----------+------------------+
|         Aries|[3.341,0.66,2.8,0...|      0.59|0.3287301893185721|
|       Armonia|[58.6,15.66,8.24,...|       7.0|7.3003920156484305|
|     Atlantica|[85.619,21.14,9.5...|       9.2|  9.55814150436188|
|      Conquest|[110.0,29.74,9.53...|      19.1|11.929220462089726|
| Constellation|[91.0,20.32,9.65,...|      9.99|  9.17181711200157|
|         Crown|[116.0,31.0,9.51,...|      12.0|12.419151770127177|
|       Destiny|[101.353,26.42,8....|      10.0| 10.67423320508528|
|       Emerald|[113.0,37.82,9.51...|      12.0|11.451887241225315|
|        Europa|[53.872,14.94,7.9...|      6.36| 7.005513364578729|
|      Fantasia|[133.5,39.59,10.9...|     13.13|12.959010739756893|
|        Galaxy|[77.7130000000000...|      9.09| 8.445769546385995|
|       Holiday|[46.052,14.52,7.2...|       6.6|