In [None]:
appname = "crew_members assignment"

# Look into https://spark.apache.org/downloads.html for the latest version
spark_mirror = "https://mirrors.sonic.net/apache/spark"
spark_version = "3.3.1"
hadoop_version = "3"

# Install Java 8 (Spark does not work with newer Java versions)
#! apt-get update if openjdk fails
! apt-get install openjdk-8-jdk-headless -qq > /dev/null

# Download and extract Spark binary distribution
! rm -rf spark-{spark_version}-bin-hadoop{hadoop_version}.tgz spark-{spark_version}-bin-hadoop{hadoop_version}
! wget -q {spark_mirror}/spark-{spark_version}/spark-{spark_version}-bin-hadoop{hadoop_version}.tgz
! tar xzf spark-{spark_version}-bin-hadoop{hadoop_version}.tgz

# The only 2 environment variables needed to set up Java and Spark
import os
os.environ["JAVA_HOME"] = "/usr/lib/jvm/java-8-openjdk-amd64"
os.environ["SPARK_HOME"] = f"/content/spark-{spark_version}-bin-hadoop{hadoop_version}"

# Set up the Spark environment based on the environment variable SPARK_HOME 
! pip install -q findspark
import findspark
findspark.init()

# Get the Spark session object (basic entry point for every operation)
from pyspark.sql import SparkSession
spark = SparkSession.builder.appName(appname).master("local[*]").getOrCreate()

In [None]:
from pyspark import SparkFiles
spark.sparkContext.addFile('/content/drive/MyDrive/Colab Notebooks/cruise_ship_info.csv')
crew = spark.read.option("header","true").csv('/content/drive/MyDrive/Colab Notebooks/cruise_ship_info.csv')

crew.show(10)

+-----------+-----------+---+------------------+----------+------+------+-----------------+----+
|  Ship_name|Cruise_line|Age|           Tonnage|passengers|length|cabins|passenger_density|crew|
+-----------+-----------+---+------------------+----------+------+------+-----------------+----+
|    Journey|    Azamara|  6|30.276999999999997|      6.94|  5.94|  3.55|            42.64|3.55|
|      Quest|    Azamara|  6|30.276999999999997|      6.94|  5.94|  3.55|            42.64|3.55|
|Celebration|   Carnival| 26|            47.262|     14.86|  7.22|  7.43|             31.8| 6.7|
|   Conquest|   Carnival| 11|             110.0|     29.74|  9.53| 14.88|            36.99|19.1|
|    Destiny|   Carnival| 17|           101.353|     26.42|  8.92| 13.21|            38.36|10.0|
|    Ecstasy|   Carnival| 22|            70.367|     20.52|  8.55|  10.2|            34.29| 9.2|
|    Elation|   Carnival| 15|            70.367|     20.52|  8.55|  10.2|            34.29| 9.2|
|    Fantasy|   Carnival| 23| 

In [None]:
crew.printSchema()

root
 |-- Ship_name: string (nullable = true)
 |-- Cruise_line: string (nullable = true)
 |-- Age: string (nullable = true)
 |-- Tonnage: string (nullable = true)
 |-- passengers: string (nullable = true)
 |-- length: string (nullable = true)
 |-- cabins: string (nullable = true)
 |-- passenger_density: string (nullable = true)
 |-- crew: string (nullable = true)



In [None]:
crew.describe().toPandas()

Unnamed: 0,summary,Ship_name,Cruise_line,Age,Tonnage,passengers,length,cabins,passenger_density,crew
0,count,158,158,158.0,158.0,158.0,158.0,158.0,158.0,158.0
1,mean,Infinity,,15.689873417721518,71.28467088607599,18.45740506329114,8.130632911392404,8.830000000000005,39.90094936708861,7.794177215189873
2,stddev,,,7.615691058751413,37.229540025907866,9.677094775143416,1.793473548054825,4.4714172221480615,8.63921711391542,3.503486564627034
3,min,Adventure,Azamara,10.0,10.0,0.66,10.2,0.33,17.7,0.59
4,max,Zuiderdam,Windstar,9.0,93.0,9.52,9.65,9.87,71.43,9.99


In [None]:
#searching for Null or NaN values
from pyspark.sql.functions import *

crew.select([count(when(isnan(x) | col(x).isNull(), x)).alias(x) for x in crew.columns]
   ).show()

+---------+-----------+---+-------+----------+------+------+-----------------+----+
|Ship_name|Cruise_line|Age|Tonnage|passengers|length|cabins|passenger_density|crew|
+---------+-----------+---+-------+----------+------+------+-----------------+----+
|        0|          0|  0|      0|         0|     0|     0|                0|   0|
+---------+-----------+---+-------+----------+------+------+-----------------+----+



In [None]:
data2 = crew.drop('Ship_name')

for y in data2.columns[1:len(data2.columns)]:
  data2=data2.withColumn(y,col(y).cast("float"))

In [None]:
data2.printSchema()

train_data, test_data = data2.randomSplit([0.67, 0.33])

root
 |-- Cruise_line: string (nullable = true)
 |-- Age: float (nullable = true)
 |-- Tonnage: float (nullable = true)
 |-- passengers: float (nullable = true)
 |-- length: float (nullable = true)
 |-- cabins: float (nullable = true)
 |-- passenger_density: float (nullable = true)
 |-- crew: float (nullable = true)



In [None]:
train_data.show(5)


+-----------+---+-------+----------+------+------+-----------------+----+
|Cruise_line|Age|Tonnage|passengers|length|cabins|passenger_density|crew|
+-----------+---+-------+----------+------+------+-----------------+----+
|    Azamara|  6| 30.277|      6.94|  5.94|  3.55|            42.64|3.55|
|   Carnival| 10|  110.0|     29.74|  9.51| 14.87|            36.99|11.6|
|   Carnival| 11|  110.0|     29.74|  9.53| 14.88|            36.99|19.1|
|   Carnival| 12|   88.5|     21.24|  9.63| 11.62|            41.67| 9.3|
|   Carnival| 15| 70.367|     20.52|  8.55|  10.2|            34.29| 9.2|
+-----------+---+-------+----------+------+------+-----------------+----+
only showing top 5 rows



In [None]:
from pyspark.ml.feature import (OneHotEncoder, StringIndexer)

cruise_indexer = StringIndexer(inputCol='Cruise_line', outputCol='Cruise_index')
cruise_encoder = OneHotEncoder(inputCol='Cruise_index', outputCol='Cruise')

In [None]:
from pyspark.ml.feature import VectorAssembler
assembler = VectorAssembler(inputCols=['Cruise','Age','Tonnage','passengers','length','cabins','passenger_density'],
                            outputCol='features')


In [None]:
from pyspark.ml.regression import LinearRegression

lr = LinearRegression(featuresCol='features', labelCol='crew')

In [None]:
from pyspark.ml import Pipeline

pipeline = Pipeline(stages=[cruise_indexer, cruise_encoder, assembler, lr])

model = pipeline.fit(train_data)

In [None]:
output = model.transform(test_data)
output.show(50)

+-----------------+----+-------+----------+------+------+-----------------+-----+------------+---------------+--------------------+------------------+
|      Cruise_line| Age|Tonnage|passengers|length|cabins|passenger_density| crew|Cruise_index|         Cruise|            features|        prediction|
+-----------------+----+-------+----------+------+------+-----------------+-----+------------+---------------+--------------------+------------------+
|         Carnival| 6.0|110.239|      37.0|  9.51| 14.87|            29.79| 11.5|         1.0| (19,[1],[1.0])|(25,[1,19,20,21,2...|11.503542605606622|
|         Carnival| 9.0|  110.0|     29.74|  9.52| 14.87|            36.99| 11.6|         1.0| (19,[1],[1.0])|(25,[1,19,20,21,2...| 12.24390930882047|
|         Carnival|11.0|   86.0|     21.24|  9.63| 10.62|            40.49|  9.3|         1.0| (19,[1],[1.0])|(25,[1,19,20,21,2...| 9.793082181844884|
|         Carnival|11.0|  110.0|     29.74|  9.53| 14.88|            36.99| 19.1|         1.0|