In [1]:
from pyspark.sql import SparkSession

In [2]:
spark = SparkSession.builder.appName("lr_example").getOrCreate()

23/12/13 20:55:46 WARN Utils: Your hostname, MyPySpark resolves to a loopback address: 127.0.1.1; using 10.0.2.15 instead (on interface enp0s3)
23/12/13 20:55:46 WARN Utils: Set SPARK_LOCAL_IP if you need to bind to another address
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
23/12/13 20:55:47 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable
23/12/13 20:55:48 WARN Utils: Service 'SparkUI' could not bind on port 4040. Attempting port 4041.


In [3]:
from pyspark.ml.regression import LinearRegression

In [4]:
data = spark.read.csv("cruise_ship_info(2).csv",inferSchema=True,header=True)

In [5]:
data.printSchema()

root
 |-- Ship_name: string (nullable = true)
 |-- Cruise_line: string (nullable = true)
 |-- Age: integer (nullable = true)
 |-- Tonnage: double (nullable = true)
 |-- passengers: double (nullable = true)
 |-- length: double (nullable = true)
 |-- cabins: double (nullable = true)
 |-- passenger_density: double (nullable = true)
 |-- crew: double (nullable = true)



In [6]:
data.show(2)

+---------+-----------+---+------------------+----------+------+------+-----------------+----+
|Ship_name|Cruise_line|Age|           Tonnage|passengers|length|cabins|passenger_density|crew|
+---------+-----------+---+------------------+----------+------+------+-----------------+----+
|  Journey|    Azamara|  6|30.276999999999997|      6.94|  5.94|  3.55|            42.64|3.55|
|    Quest|    Azamara|  6|30.276999999999997|      6.94|  5.94|  3.55|            42.64|3.55|
+---------+-----------+---+------------------+----------+------+------+-----------------+----+
only showing top 2 rows



In [7]:
from pyspark.ml.linalg import Vectors
from pyspark.ml.feature import StringIndexer, VectorAssembler


In [8]:
index = StringIndexer(
    inputCol = "Cruise_line",
    outputCol = "CruiseLine")

indexer = index.fit(data).transform(data)


In [9]:
indexer.printSchema()

root
 |-- Ship_name: string (nullable = true)
 |-- Cruise_line: string (nullable = true)
 |-- Age: integer (nullable = true)
 |-- Tonnage: double (nullable = true)
 |-- passengers: double (nullable = true)
 |-- length: double (nullable = true)
 |-- cabins: double (nullable = true)
 |-- passenger_density: double (nullable = true)
 |-- crew: double (nullable = true)
 |-- CruiseLine: double (nullable = false)



In [10]:
indexer.columns


['Ship_name',
 'Cruise_line',
 'Age',
 'Tonnage',
 'passengers',
 'length',
 'cabins',
 'passenger_density',
 'crew',
 'CruiseLine']

In [11]:
assembler = VectorAssembler(
    inputCols = ['passengers', 'length',
                 'cabins', 'passenger_density', 'CruiseLine'],
    outputCol = "features")

output = assembler.transform(indexer)
output.printSchema()


root
 |-- Ship_name: string (nullable = true)
 |-- Cruise_line: string (nullable = true)
 |-- Age: integer (nullable = true)
 |-- Tonnage: double (nullable = true)
 |-- passengers: double (nullable = true)
 |-- length: double (nullable = true)
 |-- cabins: double (nullable = true)
 |-- passenger_density: double (nullable = true)
 |-- crew: double (nullable = true)
 |-- CruiseLine: double (nullable = false)
 |-- features: vector (nullable = true)



In [14]:


final_data = output.select('features','crew')
train_data,test_data = final_data.randomSplit([0.7,0.3])
unlabelled = test_data.select('features')



In [15]:
final_data.show()

+--------------------+----+
|            features|crew|
+--------------------+----+
|[6.94,5.94,3.55,4...|3.55|
|[6.94,5.94,3.55,4...|3.55|
|[14.86,7.22,7.43,...| 6.7|
|[29.74,9.53,14.88...|19.1|
|[26.42,8.92,13.21...|10.0|
|[20.52,8.55,10.2,...| 9.2|
|[20.52,8.55,10.2,...| 9.2|
|[20.56,8.55,10.22...| 9.2|
|[20.52,8.55,10.2,...| 9.2|
|[37.0,9.51,14.87,...|11.5|
|[29.74,9.51,14.87...|11.6|
|[14.52,7.27,7.26,...| 6.6|
|[20.52,8.55,10.2,...| 9.2|
|[20.52,8.55,10.2,...| 9.2|
|[21.24,9.63,10.62...| 9.3|
|[29.74,9.51,14.87...|11.6|
|[21.24,9.63,10.62...|10.3|
|[20.52,8.55,10.2,...| 9.2|
|[21.24,9.63,11.62...| 9.3|
|[20.52,8.55,10.2,...| 9.2|
+--------------------+----+
only showing top 20 rows



In [16]:
train_data.describe().show()

+-------+------------------+
|summary|              crew|
+-------+------------------+
|  count|               110|
|   mean| 7.678727272727275|
| stddev|3.4067766004936977|
|    min|              0.59|
|    max|              21.0|
+-------+------------------+



In [17]:
lr = LinearRegression(labelCol='crew', featuresCol='features')
lrModel = lr.fit(train_data)

23/12/13 20:57:17 WARN Instrumentation: [4fad02f3] regParam is zero, which might cause numerical instability and overfitting.
23/12/13 20:57:18 WARN InstanceBuilder: Failed to load implementation from:dev.ludovic.netlib.blas.JNIBLAS
23/12/13 20:57:18 WARN InstanceBuilder: Failed to load implementation from:dev.ludovic.netlib.lapack.JNILAPACK


In [18]:
print("Intercept :{}".format(lrModel.intercept))
print("Coefficients :{}".format(lrModel.coefficients))

Intercept :-2.3804720584197128
Coefficients :[-0.08146748931037798,0.49006169858563,0.7611589615677744,0.01666182329094658,0.040270008997525045]


In [19]:
test_results = lrModel.evaluate(test_data)

In [20]:
print("R-Squared :{}".format(test_results.r2))
print("RMSE :{}".format(test_results.rootMeanSquaredError))

R-Squared :0.8773845952788832
RMSE :1.2956785875671035


In [21]:
deployresult = lrModel.transform(unlabelled)


In [22]:
deployresult.show()

+--------------------+--------------------+
|            features|          prediction|
+--------------------+--------------------+
|[0.94,2.96,0.45,2...|-0.00944730271769...|
|[2.08,4.4,1.04,48...|  1.7628329474160207|
|[2.08,4.4,1.04,48...|  1.7628329474160207|
|[3.2,5.13,1.6,60....|  2.4931123154718433|
|[3.88,5.97,1.94,6...|  3.2222421826621774|
|[3.94,4.36,0.88,3...|  1.0364146687073719|
|[6.84,5.94,3.42,4...|  3.8373828686828535|
|[6.86,5.93,3.44,4...|   3.401106563374456|
|[6.94,5.94,3.55,4...|   4.022004658016868|
|[7.49,6.74,3.96,5...|   4.292606105463383|
|[7.76,6.22,3.86,3...|   3.711791572807906|
|[8.0,5.37,4.0,23....|  3.4042029977058146|
|[13.02,7.18,6.54,...|   5.934096997021838|
|[14.4,7.77,7.2,43...|   6.584284812780654|
|[15.04,7.08,7.52,...|   6.214253879834365|
|[15.66,8.23,7.83,...|   7.242204997945075|
|[17.5,9.64,8.75,4...|   9.096478753802497|
|[18.0,8.67,9.0,38...|   7.900357022625778|
|[18.48,9.51,9.24,...|   8.694858195936197|
|[18.48,9.59,9.24,...|    8.7049