In [85]:
import pyspark

In [86]:
from pyspark.ml.feature import  StringIndexer

In [87]:
from pyspark import SparkContext
from pyspark.sql import SparkSession
from pyspark import SparkConf
from pyspark.sql.functions import desc
from pyspark.sql.functions import *
from pyspark.sql.functions import max as sparkMax
from pyspark.sql.functions import col, lit
from pyspark.sql.types import StringType
import pyspark.sql.functions as F

In [88]:
mySpark = SparkSession.builder.getOrCreate()
spark = SparkSession(mySpark)

In [89]:
SparkSession \
  .builder \
  .master("local[*]")\
  .appName("Pyspark") \
  .config("spark.memory.fraction", 0.8) \
  .config("spark.executor.memory", "8g") \
  .config("spark.driver.memory", "8g")\
  .config("spark.sql.shuffle.partitions" , "800") \
  .config("spark.memory.offHeap.enabled",'true')\
  .config("spark.memory.offHeap.size","8g")\
  .getOrCreate()

In [90]:
file = "C:/Users/pavel/data/hw_22/titanic_train.csv"

In [91]:
data = spark.read.csv(file,header=True,inferSchema=True)

In [92]:
data.printSchema()

root
 |-- PassengerId: integer (nullable = true)
 |-- Survived: integer (nullable = true)
 |-- Pclass: integer (nullable = true)
 |-- Name: string (nullable = true)
 |-- Sex: string (nullable = true)
 |-- Age: double (nullable = true)
 |-- SibSp: integer (nullable = true)
 |-- Parch: integer (nullable = true)
 |-- Ticket: string (nullable = true)
 |-- Fare: double (nullable = true)
 |-- Cabin: string (nullable = true)
 |-- Embarked: string (nullable = true)



In [93]:
data.show()

+-----------+--------+------+--------------------+------+----+-----+-----+----------------+-------+-----+--------+
|PassengerId|Survived|Pclass|                Name|   Sex| Age|SibSp|Parch|          Ticket|   Fare|Cabin|Embarked|
+-----------+--------+------+--------------------+------+----+-----+-----+----------------+-------+-----+--------+
|          1|       0|     3|Braund, Mr. Owen ...|  male|22.0|    1|    0|       A/5 21171|   7.25| NULL|       S|
|          2|       1|     1|Cumings, Mrs. Joh...|female|38.0|    1|    0|        PC 17599|71.2833|  C85|       C|
|          3|       1|     3|Heikkinen, Miss. ...|female|26.0|    0|    0|STON/O2. 3101282|  7.925| NULL|       S|
|          4|       1|     1|Futrelle, Mrs. Ja...|female|35.0|    1|    0|          113803|   53.1| C123|       S|
|          5|       0|     3|Allen, Mr. Willia...|  male|35.0|    0|    0|          373450|   8.05| NULL|       S|
|          6|       0|     3|    Moran, Mr. James|  male|NULL|    0|    0|      

In [94]:
data = data.drop('Name','Ticket')

In [95]:
data = data.withColumn('Fare',round('Fare'))

In [96]:
data.show(5)

+-----------+--------+------+------+----+-----+-----+----+-----+--------+
|PassengerId|Survived|Pclass|   Sex| Age|SibSp|Parch|Fare|Cabin|Embarked|
+-----------+--------+------+------+----+-----+-----+----+-----+--------+
|          1|       0|     3|  male|22.0|    1|    0| 7.0| NULL|       S|
|          2|       1|     1|female|38.0|    1|    0|71.0|  C85|       C|
|          3|       1|     3|female|26.0|    0|    0| 8.0| NULL|       S|
|          4|       1|     1|female|35.0|    1|    0|53.0| C123|       S|
|          5|       0|     3|  male|35.0|    0|    0| 8.0| NULL|       S|
+-----------+--------+------+------+----+-----+-----+----+-----+--------+
only showing top 5 rows



In [97]:
null_counts = data.agg(*[sum(col(c).isNull().cast("int")).alias(c) for c in data.columns])

null_counts.show()

+-----------+--------+------+---+---+-----+-----+----+-----+--------+
|PassengerId|Survived|Pclass|Sex|Age|SibSp|Parch|Fare|Cabin|Embarked|
+-----------+--------+------+---+---+-----+-----+----+-----+--------+
|          0|       0|     0|  0|177|    0|    0|   0|  687|       2|
+-----------+--------+------+---+---+-----+-----+----+-----+--------+



In [98]:
data.select('Age').show()

+----+
| Age|
+----+
|22.0|
|38.0|
|26.0|
|35.0|
|35.0|
|NULL|
|54.0|
| 2.0|
|27.0|
|14.0|
| 4.0|
|58.0|
|20.0|
|39.0|
|14.0|
|55.0|
| 2.0|
|NULL|
|31.0|
|NULL|
+----+
only showing top 20 rows



In [99]:
data.select('Age').describe().show()

+-------+------------------+
|summary|               Age|
+-------+------------------+
|  count|               714|
|   mean| 29.69911764705882|
| stddev|14.526497332334035|
|    min|              0.42|
|    max|              80.0|
+-------+------------------+



In [100]:
data = data.fillna(29,subset=['Age'])

In [101]:
count_embark = data.groupby('Embarked').count()

In [102]:
count_embark.show()

+--------+-----+
|Embarked|count|
+--------+-----+
|    NULL|    2|
|       S|  644|
|       C|  168|
|       Q|   77|
+--------+-----+



In [103]:
data = data.fillna("S",subset=['Embarked'])

In [104]:
count_embark = data.groupby('Embarked').count()
count_embark.show()

+--------+-----+
|Embarked|count|
+--------+-----+
|       S|  646|
|       C|  168|
|       Q|   77|
+--------+-----+



In [105]:
count_fare = data.groupby('Fare').count()
count_fare.show()

+-----+-----+
| Fare|count|
+-----+-----+
|147.0|    2|
| 67.0|    2|
| 69.0|    2|
|  7.0|   65|
| 29.0|    9|
| 18.0|    5|
|120.0|    4|
| 25.0|    4|
| 77.0|    5|
| 50.0|    5|
| 83.0|    5|
| 11.0|   30|
| 58.0|    2|
| 21.0|   15|
| 63.0|    1|
|111.0|    4|
| 22.0|    4|
| 82.0|    3|
| 74.0|    5|
| 19.0|    9|
+-----+-----+
only showing top 20 rows



In [106]:
def age_group(age):
    if age < 18:
        return '<18'
    elif 18 <= age <= 60:
        return '18-60'
    else:
        return '60+'

In [107]:
def fare_group(fare):
    if fare < 8:
        return '<8'
    elif 8 <= fare <= 14:
        return '8-14'
    elif 14 <= fare <= 31:
        return '14-31'
    else:
        return '31+'

In [108]:
age_group_udf = udf(age_group, StringType())

In [109]:
fare_group_udf = udf(fare_group, StringType())

In [110]:
data = data.withColumn("AgeGroup", age_group_udf(col("Age")))

In [111]:
data = data.withColumn("FareGroup", fare_group_udf(col("Fare")))

In [112]:
data.show(20)

+-----------+--------+------+------+----+-----+-----+----+-----+--------+--------+---------+
|PassengerId|Survived|Pclass|   Sex| Age|SibSp|Parch|Fare|Cabin|Embarked|AgeGroup|FareGroup|
+-----------+--------+------+------+----+-----+-----+----+-----+--------+--------+---------+
|          1|       0|     3|  male|22.0|    1|    0| 7.0| NULL|       S|   18-60|       <8|
|          2|       1|     1|female|38.0|    1|    0|71.0|  C85|       C|   18-60|      31+|
|          3|       1|     3|female|26.0|    0|    0| 8.0| NULL|       S|   18-60|     8-14|
|          4|       1|     1|female|35.0|    1|    0|53.0| C123|       S|   18-60|      31+|
|          5|       0|     3|  male|35.0|    0|    0| 8.0| NULL|       S|   18-60|     8-14|
|          6|       0|     3|  male|29.0|    0|    0| 8.0| NULL|       Q|   18-60|     8-14|
|          7|       0|     1|  male|54.0|    0|    0|52.0|  E46|       S|   18-60|      31+|
|          8|       0|     3|  male| 2.0|    3|    1|21.0| NULL|      

In [113]:
data = data.withColumn("Cabin", split(col("Cabin"), "").getItem(0))

In [114]:
count_Cabin = data.groupby('Cabin').count()
count_Cabin.show()

+-----+-----+
|Cabin|count|
+-----+-----+
|    F|   13|
| NULL|  687|
|    T|    1|
|    A|   15|
|    B|   47|
|    C|   59|
|    E|   32|
|    D|   33|
|    G|    4|
+-----+-----+



In [115]:
cabin_mot_null = data.dropna()

In [116]:
cabin_mot_null.show()

+-----------+--------+------+------+----+-----+-----+-----+-----+--------+--------+---------+
|PassengerId|Survived|Pclass|   Sex| Age|SibSp|Parch| Fare|Cabin|Embarked|AgeGroup|FareGroup|
+-----------+--------+------+------+----+-----+-----+-----+-----+--------+--------+---------+
|          2|       1|     1|female|38.0|    1|    0| 71.0|    C|       C|   18-60|      31+|
|          4|       1|     1|female|35.0|    1|    0| 53.0|    C|       S|   18-60|      31+|
|          7|       0|     1|  male|54.0|    0|    0| 52.0|    E|       S|   18-60|      31+|
|         11|       1|     3|female| 4.0|    1|    1| 17.0|    G|       S|     <18|    14-31|
|         12|       1|     1|female|58.0|    0|    0| 27.0|    C|       S|   18-60|    14-31|
|         22|       1|     2|  male|34.0|    0|    0| 13.0|    D|       S|   18-60|     8-14|
|         24|       1|     1|  male|28.0|    0|    0| 36.0|    A|       S|   18-60|      31+|
|         28|       0|     1|  male|19.0|    3|    2|263.0| 

In [117]:
count_Cabin_ = cabin_mot_null.groupby('Cabin').count()
count_Cabin_.show()

+-----+-----+
|Cabin|count|
+-----+-----+
|    F|   13|
|    T|    1|
|    A|   15|
|    B|   47|
|    C|   59|
|    E|   32|
|    D|   33|
|    G|    4|
+-----+-----+



In [118]:
count_Cabin_ = cabin_mot_null.groupby('FareGroup').count()
count_Cabin_.show()

+---------+-----+
|FareGroup|count|
+---------+-----+
|       <8|    4|
|      31+|  139|
|     8-14|   20|
|    14-31|   41|
+---------+-----+



In [119]:
cabin_mot_null.groupby('Cabin').mean('Fare').show()

+-----+------------------+
|Cabin|         avg(Fare)|
+-----+------------------+
|    F|18.846153846153847|
|    T|              36.0|
|    A|39.733333333333334|
|    B|113.57446808510639|
|    C|100.27118644067797|
|    E|          46.09375|
|    D| 57.21212121212121|
|    G|              13.5|
+-----+------------------+



In [120]:
grouped_df = cabin_mot_null.groupBy('Cabin', 'FareGroup').count()

In [121]:
grouped_df.orderBy(['FareGroup','Cabin']).show()

+-----+---------+-----+
|Cabin|FareGroup|count|
+-----+---------+-----+
|    A|    14-31|    4|
|    B|    14-31|    5|
|    C|    14-31|   13|
|    D|    14-31|    5|
|    E|    14-31|    8|
|    F|    14-31|    4|
|    G|    14-31|    2|
|    A|      31+|   10|
|    B|      31+|   39|
|    C|      31+|   46|
|    D|      31+|   24|
|    E|      31+|   17|
|    F|      31+|    2|
|    T|      31+|    1|
|    D|     8-14|    4|
|    E|     8-14|    7|
|    F|     8-14|    7|
|    G|     8-14|    2|
|    A|       <8|    1|
|    B|       <8|    3|
+-----+---------+-----+



In [122]:
data = data.fillna(value='C', subset=['Cabin']).\
    withColumn('Cabin', when(col('FareGroup') == '<8', 'D')
    .when(col('FareGroup') == '8-14', 'F')
    .when(col('FareGroup') == '14-31', 'E')
    .otherwise(col('Cabin')))

In [123]:
data.groupBy('Cabin','FareGroup').count().show()

+-----+---------+-----+
|Cabin|FareGroup|count|
+-----+---------+-----+
|    A|      31+|   10|
|    E|      31+|   17|
|    F|      31+|    2|
|    D|      31+|   24|
|    B|      31+|   39|
|    F|     8-14|  363|
|    T|      31+|    1|
|    E|    14-31|  230|
|    D|       <8|   87|
|    C|      31+|  118|
+-----+---------+-----+



In [128]:
data.show(5)

+-----------+--------+------+------+----+-----+-----+----+-----+--------+--------+---------+
|PassengerId|Survived|Pclass|   Sex| Age|SibSp|Parch|Fare|Cabin|Embarked|AgeGroup|FareGroup|
+-----------+--------+------+------+----+-----+-----+----+-----+--------+--------+---------+
|          1|       0|     3|  male|22.0|    1|    0| 7.0|    D|       S|   18-60|       <8|
|          2|       1|     1|female|38.0|    1|    0|71.0|    C|       C|   18-60|      31+|
|          3|       1|     3|female|26.0|    0|    0| 8.0|    F|       S|   18-60|     8-14|
|          4|       1|     1|female|35.0|    1|    0|53.0|    C|       S|   18-60|      31+|
|          5|       0|     3|  male|35.0|    0|    0| 8.0|    F|       S|   18-60|     8-14|
+-----------+--------+------+------+----+-----+-----+----+-----+--------+--------+---------+
only showing top 5 rows



In [131]:
categorical_cols = ['Sex', 'Cabin', 'Embarked', 'AgeGroup', 'FareGroup']

In [133]:
from pyspark.ml.feature import StringIndexer, VectorAssembler
from pyspark.ml import Pipeline
string_indexers = [StringIndexer(inputCol=col, outputCol=f"{col}_in") for col in categorical_cols]

# Создаем объект StringIndexer для каждой категориальной колонки
string_indexers = [StringIndexer(inputCol=col, outputCol=f"{col}_in") for col in categorical_cols]

# Создаем объект VectorAssembler для объединения всех категориальных колонок в один вектор
assembler = VectorAssembler(inputCols=[f"{col}_in" for col in categorical_cols], outputCol="features")

# Объединяем StringIndexers и VectorAssembler в одну последовательность
indexer_and_assembler = string_indexers + [assembler]

# Обучаем последовательность на данных
model = Pipeline(stages=indexer_and_assembler).fit(data)

# Применяем последовательность преобразований для преобразования категориальных столбцов в числовые
indexed_data = model.transform(data)

# Выводим результат
indexed_data.show()

+-----------+--------+------+------+----+-----+-----+----+-----+--------+--------+---------+------+--------+-----------+-----------+------------+--------------------+
|PassengerId|Survived|Pclass|   Sex| Age|SibSp|Parch|Fare|Cabin|Embarked|AgeGroup|FareGroup|Sex_in|Cabin_in|Embarked_in|AgeGroup_in|FareGroup_in|            features|
+-----------+--------+------+------+----+-----+-----+----+-----+--------+--------+---------+------+--------+-----------+-----------+------------+--------------------+
|          1|       0|     3|  male|22.0|    1|    0| 7.0|    D|       S|   18-60|       <8|   0.0|     3.0|        0.0|        0.0|         3.0| (5,[1,4],[3.0,3.0])|
|          2|       1|     1|female|38.0|    1|    0|71.0|    C|       C|   18-60|      31+|   1.0|     2.0|        1.0|        0.0|         2.0|[1.0,2.0,1.0,0.0,...|
|          3|       1|     3|female|26.0|    0|    0| 8.0|    F|       S|   18-60|     8-14|   1.0|     0.0|        0.0|        0.0|         0.0|       (5,[0],[1.0])

In [134]:
from pyspark.ml.classification import RandomForestClassifier

In [137]:
(trainingData, testData) = indexed_data.randomSplit([0.7, 0.3])

In [139]:
trainingData.show(5)

+-----------+--------+------+------+----+-----+-----+----+-----+--------+--------+---------+------+--------+-----------+-----------+------------+--------------------+
|PassengerId|Survived|Pclass|   Sex| Age|SibSp|Parch|Fare|Cabin|Embarked|AgeGroup|FareGroup|Sex_in|Cabin_in|Embarked_in|AgeGroup_in|FareGroup_in|            features|
+-----------+--------+------+------+----+-----+-----+----+-----+--------+--------+---------+------+--------+-----------+-----------+------------+--------------------+
|          3|       1|     3|female|26.0|    0|    0| 8.0|    F|       S|   18-60|     8-14|   1.0|     0.0|        0.0|        0.0|         0.0|       (5,[0],[1.0])|
|          5|       0|     3|  male|35.0|    0|    0| 8.0|    F|       S|   18-60|     8-14|   0.0|     0.0|        0.0|        0.0|         0.0|           (5,[],[])|
|          7|       0|     1|  male|54.0|    0|    0|52.0|    E|       S|   18-60|      31+|   0.0|     1.0|        0.0|        0.0|         2.0| (5,[1,4],[1.0,2.0])

In [140]:
rf = RandomForestClassifier(labelCol="Survived", featuresCol="features", numTrees=10)

In [141]:
rfModel = rf.fit(trainingData)


In [142]:
predictions = rfModel.transform(testData)


In [146]:
predictions.show(5)

+-----------+--------+------+------+----+-----+-----+----+-----+--------+--------+---------+------+--------+-----------+-----------+------------+--------------------+--------------------+--------------------+----------+
|PassengerId|Survived|Pclass|   Sex| Age|SibSp|Parch|Fare|Cabin|Embarked|AgeGroup|FareGroup|Sex_in|Cabin_in|Embarked_in|AgeGroup_in|FareGroup_in|            features|       rawPrediction|         probability|prediction|
+-----------+--------+------+------+----+-----+-----+----+-----+--------+--------+---------+------+--------+-----------+-----------+------------+--------------------+--------------------+--------------------+----------+
|          1|       0|     3|  male|22.0|    1|    0| 7.0|    D|       S|   18-60|       <8|   0.0|     3.0|        0.0|        0.0|         3.0| (5,[1,4],[3.0,3.0])|[8.90466001958896...|[0.89046600195889...|       0.0|
|          2|       1|     1|female|38.0|    1|    0|71.0|    C|       C|   18-60|      31+|   1.0|     2.0|        1.0|

In [149]:
from pyspark.ml.evaluation import BinaryClassificationEvaluator

evaluator = BinaryClassificationEvaluator(labelCol="Survived", rawPredictionCol="rawPrediction", metricName="areaUnderROC")

areaUnderROC = evaluator.evaluate(predictions)
print(areaUnderROC)


0.8363221016561964
