In [1]:
import findspark
findspark.init()

In [3]:
from pyspark.ml.linalg import Vectors
from pyspark.ml.feature import VectorAssembler
from pyspark.ml.regression import LinearRegression
from pyspark.sql.functions import corr
from pyspark.ml.feature import StringIndexer
from pyspark.ml.feature import OneHotEncoder
from pyspark.ml import Pipeline
from pyspark.ml.evaluation import BinaryClassificationEvaluator
from pyspark import SparkContext
from pyspark.sql import SparkSession

In [4]:
sc = SparkContext()
spark = SparkSession(sc)

In [5]:
#Đọc dữ liệu:
data = spark.read.csv('./cruise_ship_info.csv',inferSchema=True,header=True)

In [6]:
data.printSchema()

root
 |-- Ship_name: string (nullable = true)
 |-- Cruise_line: string (nullable = true)
 |-- Age: integer (nullable = true)
 |-- Tonnage: double (nullable = true)
 |-- passengers: double (nullable = true)
 |-- length: double (nullable = true)
 |-- cabins: double (nullable = true)
 |-- passenger_density: double (nullable = true)
 |-- crew: double (nullable = true)



In [8]:
data.show(3)

+-----------+-----------+---+------------------+----------+------+------+-----------------+----+
|  Ship_name|Cruise_line|Age|           Tonnage|passengers|length|cabins|passenger_density|crew|
+-----------+-----------+---+------------------+----------+------+------+-----------------+----+
|    Journey|    Azamara|  6|30.276999999999997|      6.94|  5.94|  3.55|            42.64|3.55|
|      Quest|    Azamara|  6|30.276999999999997|      6.94|  5.94|  3.55|            42.64|3.55|
|Celebration|   Carnival| 26|            47.262|     14.86|  7.22|  7.43|             31.8| 6.7|
+-----------+-----------+---+------------------+----------+------+------+-----------------+----+
only showing top 3 rows



In [9]:
#Pre-Processing

In [10]:
from pyspark.ml.feature import StringIndexer
indexer = StringIndexer(inputCol='Cruise_line',outputCol='Cruise_Cate')
indexed = indexer.fit(data).transform(data)
indexed.show(3)

+-----------+-----------+---+------------------+----------+------+------+-----------------+----+-----------+
|  Ship_name|Cruise_line|Age|           Tonnage|passengers|length|cabins|passenger_density|crew|Cruise_Cate|
+-----------+-----------+---+------------------+----------+------+------+-----------------+----+-----------+
|    Journey|    Azamara|  6|30.276999999999997|      6.94|  5.94|  3.55|            42.64|3.55|       16.0|
|      Quest|    Azamara|  6|30.276999999999997|      6.94|  5.94|  3.55|            42.64|3.55|       16.0|
|Celebration|   Carnival| 26|            47.262|     14.86|  7.22|  7.43|             31.8| 6.7|        1.0|
+-----------+-----------+---+------------------+----------+------+------+-----------------+----+-----------+
only showing top 3 rows



In [11]:
#Assembler data

In [12]:
from pyspark.ml.linalg import Vectors
from pyspark.ml.feature import VectorAssembler

In [14]:
indexed.columns

['Ship_name',
 'Cruise_line',
 'Age',
 'Tonnage',
 'passengers',
 'length',
 'cabins',
 'passenger_density',
 'crew',
 'Cruise_Cate']

In [18]:
assembler = VectorAssembler(inputCols=['Age',
 'Tonnage',
 'passengers',
 'length',
 'cabins',
 'passenger_density',
 'Cruise_Cate'],outputCol='feature')

In [19]:
output = assembler.transform(indexed)

In [21]:
final_data = output.select('feature','crew')

In [22]:
#train-test split
train_data, test_data = final_data.randomSplit([0.7,0.3])

In [26]:
#Linear Regression Algorithm

In [24]:
from pyspark.ml.regression import LinearRegression

In [29]:
model = LinearRegression(featuresCol='feature',labelCol='crew',predictionCol='prediction')

In [32]:
Lr_model = model.fit(train_data) #Train model

In [34]:
print('Coefficients:',Lr_model.coefficients)
print('Slope: ',Lr_model.intercept)

Coefficients: [-0.007702230658693383,0.012013313041087059,-0.14868448587361707,0.5115285631344361,0.7998195766529851,-0.0035456397742450628,0.04948571734318666]
Slope:  -1.603432428580548


In [41]:
#Check result

In [35]:
test_result = Lr_model.evaluate(test_data)

In [39]:
test_result.meanSquaredError #MSE

1.347326782135822

In [38]:
test_result.meanAbsoluteError #MSA

0.6861809977397163

In [40]:
test_result.r2 #R2

0.887771223840246

##### Nhận xét:
R2 ~ 0.89 là khá tốt