In [1]:
from pyspark.sql import SparkSession

In [2]:
spark = SparkSession.builder.appName("ML_spark_reg2").getOrCreate()

In [3]:
spark

In [4]:
df_cruise = spark.read.csv("cruise_ship_info.csv", header=True, inferSchema=True)

In [5]:
df_cruise.show()

+-----------+-----------+---+------------------+----------+------+------+-----------------+----+
|  Ship_name|Cruise_line|Age|           Tonnage|passengers|length|cabins|passenger_density|crew|
+-----------+-----------+---+------------------+----------+------+------+-----------------+----+
|    Journey|    Azamara|  6|30.276999999999997|      6.94|  5.94|  3.55|            42.64|3.55|
|      Quest|    Azamara|  6|30.276999999999997|      6.94|  5.94|  3.55|            42.64|3.55|
|Celebration|   Carnival| 26|            47.262|     14.86|  7.22|  7.43|             31.8| 6.7|
|   Conquest|   Carnival| 11|             110.0|     29.74|  9.53| 14.88|            36.99|19.1|
|    Destiny|   Carnival| 17|           101.353|     26.42|  8.92| 13.21|            38.36|10.0|
|    Ecstasy|   Carnival| 22|            70.367|     20.52|  8.55|  10.2|            34.29| 9.2|
|    Elation|   Carnival| 15|            70.367|     20.52|  8.55|  10.2|            34.29| 9.2|
|    Fantasy|   Carnival| 23| 

In [6]:
df_cruise.printSchema()

root
 |-- Ship_name: string (nullable = true)
 |-- Cruise_line: string (nullable = true)
 |-- Age: integer (nullable = true)
 |-- Tonnage: double (nullable = true)
 |-- passengers: double (nullable = true)
 |-- length: double (nullable = true)
 |-- cabins: double (nullable = true)
 |-- passenger_density: double (nullable = true)
 |-- crew: double (nullable = true)



In [7]:
from pyspark.ml.feature import StringIndexer

# Création de l'indexeur
indexer = StringIndexer(inputCol="Cruise_line", outputCol="Cruise_line_index")

# Appliquer la transformation
df_indexed = indexer.fit(df_cruise).transform(df_cruise)

# Afficher le résultat
df_indexed.select("Cruise_line", "Cruise_line_index").show()

+-----------+-----------------+
|Cruise_line|Cruise_line_index|
+-----------+-----------------+
|    Azamara|             16.0|
|    Azamara|             16.0|
|   Carnival|              1.0|
|   Carnival|              1.0|
|   Carnival|              1.0|
|   Carnival|              1.0|
|   Carnival|              1.0|
|   Carnival|              1.0|
|   Carnival|              1.0|
|   Carnival|              1.0|
|   Carnival|              1.0|
|   Carnival|              1.0|
|   Carnival|              1.0|
|   Carnival|              1.0|
|   Carnival|              1.0|
|   Carnival|              1.0|
|   Carnival|              1.0|
|   Carnival|              1.0|
|   Carnival|              1.0|
|   Carnival|              1.0|
+-----------+-----------------+
only showing top 20 rows



In [8]:
df_cruise = df_cruise.join(df_indexed.select("Cruise_line", "Cruise_line_index"), on="Cruise_line", how="left")

In [None]:
#df_cruise = indexer.fit(df_cruise).transform(df_cruise)

In [9]:
df_cruise.show()

+-----------+-----------+---+------------------+----------+------+------+-----------------+----+-----------------+
|Cruise_line|  Ship_name|Age|           Tonnage|passengers|length|cabins|passenger_density|crew|Cruise_line_index|
+-----------+-----------+---+------------------+----------+------+------+-----------------+----+-----------------+
|    Azamara|    Journey|  6|30.276999999999997|      6.94|  5.94|  3.55|            42.64|3.55|             16.0|
|    Azamara|    Journey|  6|30.276999999999997|      6.94|  5.94|  3.55|            42.64|3.55|             16.0|
|    Azamara|      Quest|  6|30.276999999999997|      6.94|  5.94|  3.55|            42.64|3.55|             16.0|
|    Azamara|      Quest|  6|30.276999999999997|      6.94|  5.94|  3.55|            42.64|3.55|             16.0|
|   Carnival|Celebration| 26|            47.262|     14.86|  7.22|  7.43|             31.8| 6.7|              1.0|
|   Carnival|Celebration| 26|            47.262|     14.86|  7.22|  7.43|       

In [10]:
from pyspark.ml.regression import LinearRegression

In [11]:
from pyspark.ml.linalg import Vectors
from pyspark.ml.feature import VectorAssembler

In [12]:
df_cruise.columns

['Cruise_line',
 'Ship_name',
 'Age',
 'Tonnage',
 'passengers',
 'length',
 'cabins',
 'passenger_density',
 'crew',
 'Cruise_line_index']

In [13]:
assembler = VectorAssembler(
    inputCols=['Cruise_line_index','Age','Tonnage','passengers','length','cabins','passenger_density'],
    outputCol= "features")

In [14]:
output = assembler.transform(df_cruise)

In [15]:
output.select("features").show()

+--------------------+
|            features|
+--------------------+
|[16.0,6.0,30.2769...|
|[16.0,6.0,30.2769...|
|[16.0,6.0,30.2769...|
|[16.0,6.0,30.2769...|
|[1.0,26.0,47.262,...|
|[1.0,26.0,47.262,...|
|[1.0,26.0,47.262,...|
|[1.0,26.0,47.262,...|
|[1.0,26.0,47.262,...|
|[1.0,26.0,47.262,...|
|[1.0,26.0,47.262,...|
|[1.0,26.0,47.262,...|
|[1.0,26.0,47.262,...|
|[1.0,26.0,47.262,...|
|[1.0,26.0,47.262,...|
|[1.0,26.0,47.262,...|
|[1.0,26.0,47.262,...|
|[1.0,26.0,47.262,...|
|[1.0,26.0,47.262,...|
|[1.0,26.0,47.262,...|
+--------------------+
only showing top 20 rows



In [16]:
final_df_df_cruise = output[['features', "crew"]]

In [17]:
final_df_df_cruise.show()

+--------------------+----+
|            features|crew|
+--------------------+----+
|[16.0,6.0,30.2769...|3.55|
|[16.0,6.0,30.2769...|3.55|
|[16.0,6.0,30.2769...|3.55|
|[16.0,6.0,30.2769...|3.55|
|[1.0,26.0,47.262,...| 6.7|
|[1.0,26.0,47.262,...| 6.7|
|[1.0,26.0,47.262,...| 6.7|
|[1.0,26.0,47.262,...| 6.7|
|[1.0,26.0,47.262,...| 6.7|
|[1.0,26.0,47.262,...| 6.7|
|[1.0,26.0,47.262,...| 6.7|
|[1.0,26.0,47.262,...| 6.7|
|[1.0,26.0,47.262,...| 6.7|
|[1.0,26.0,47.262,...| 6.7|
|[1.0,26.0,47.262,...| 6.7|
|[1.0,26.0,47.262,...| 6.7|
|[1.0,26.0,47.262,...| 6.7|
|[1.0,26.0,47.262,...| 6.7|
|[1.0,26.0,47.262,...| 6.7|
|[1.0,26.0,47.262,...| 6.7|
+--------------------+----+
only showing top 20 rows



In [18]:
train_data, test_data = final_df_df_cruise.randomSplit([0.7,0.3])

In [19]:
train_data.describe().show()

+-------+-----------------+
|summary|             crew|
+-------+-----------------+
|  count|             1473|
|   mean|8.780353021045476|
| stddev| 3.28032205475319|
|    min|             0.59|
|    max|             21.0|
+-------+-----------------+



In [20]:
test_data.describe().show()

+-------+------------------+
|summary|              crew|
+-------+------------------+
|  count|               641|
|   mean| 8.774617784711367|
| stddev|3.0760263864677273|
|    min|              0.59|
|    max|              21.0|
+-------+------------------+



In [21]:
lr = LinearRegression(labelCol='crew')

In [22]:
lrModel = lr.fit(train_data)

In [23]:
# les coefs
print('Coefs: {} Intercept: {}'.format(lrModel.coefficients, lrModel.intercept))

Coefs: [0.08780808215474968,-0.02235925288298329,-0.001245429109476595,-0.1649951511108162,0.4077838984909384,0.9738295604284268,-0.00990801410882119] Intercept: -0.638356055310359


In [24]:
# Evaluation du modele
test_result = lrModel.evaluate(test_data)

In [25]:
test_result.residuals.show()

+--------------------+
|           residuals|
+--------------------+
| 0.20177431397418033|
| 0.20177431397418033|
| 0.20177431397418033|
| 0.20177431397418033|
| -1.2474454343837635|
| -1.2474454343837635|
| -1.2474454343837635|
| -1.2474454343837635|
| -1.2474454343837635|
| -1.2474454343837635|
| -1.2474454343837635|
| -1.2474454343837635|
| -1.2474454343837635|
| -1.2474454343837635|
|0.021350319888464142|
|0.021350319888464142|
|0.021350319888464142|
|0.021350319888464142|
|0.021350319888464142|
| 0.09672147957526889|
+--------------------+
only showing top 20 rows



In [26]:
unlabeled_data = test_data.select('features')

In [27]:
predictions = lrModel.transform(unlabeled_data)

In [28]:
predictions.show()

+--------------------+------------------+
|            features|        prediction|
+--------------------+------------------+
|[0.0,4.0,220.0,54...| 20.79822568602582|
|[0.0,4.0,220.0,54...| 20.79822568602582|
|[0.0,4.0,220.0,54...| 20.79822568602582|
|[0.0,4.0,220.0,54...| 20.79822568602582|
|[0.0,5.0,160.0,36...|14.847445434383763|
|[0.0,5.0,160.0,36...|14.847445434383763|
|[0.0,5.0,160.0,36...|14.847445434383763|
|[0.0,5.0,160.0,36...|14.847445434383763|
|[0.0,5.0,160.0,36...|14.847445434383763|
|[0.0,5.0,160.0,36...|14.847445434383763|
|[0.0,5.0,160.0,36...|14.847445434383763|
|[0.0,5.0,160.0,36...|14.847445434383763|
|[0.0,5.0,160.0,36...|14.847445434383763|
|[0.0,5.0,160.0,36...|14.847445434383763|
|[0.0,6.0,158.0,43...|13.578649680111536|
|[0.0,6.0,158.0,43...|13.578649680111536|
|[0.0,6.0,158.0,43...|13.578649680111536|
|[0.0,6.0,158.0,43...|13.578649680111536|
|[0.0,6.0,158.0,43...|13.578649680111536|
|[0.0,7.0,158.0,43...| 13.50327852042473|
+--------------------+------------

In [29]:
print('RMSE: {}'.format(test_result.rootMeanSquaredError))

RMSE: 1.017708380417418


In [30]:
print('MSE: {}'.format(test_result.meanSquaredError))

MSE: 1.035730347571844


In [31]:
test_result.r2

0.8903661554288541