In [1]:
# Установить pyspark
!pip install pyspark



In [2]:
# После установки мы можем создать сеанс Spark и проверить его информацию.
# Импорт SparkSession
from pyspark.sql import SparkSession

In [3]:
# Создание сеанса Spark
spark = SparkSession.builder.master("local[*]").getOrCreate()
sc = spark.sparkContext

In [4]:
# Проверьте информацию о сеансе Spark
spark

In [5]:
# Мы также можем протестировать установку, импортировав библиотеку Spark.
from pyspark.sql.functions import col

In [None]:
list_ = ['Функция', 'документ', 'входной', 'пар', 'набор', 'превращает', 'в', 'map', 'Функция']

In [None]:
rdd_list = sc.parallelize(list_)

def mapper(world):
  return(world, 1)

transform = rdd_list.map(mapper)
transform.collect()

[('Функция', 1),
 ('документ', 1),
 ('входной', 1),
 ('пар', 1),
 ('набор', 1),
 ('превращает', 1),
 ('в', 1),
 ('map', 1),
 ('Функция', 1)]

In [None]:
def reducer(a, b):
  return a + b

transform = transform.reduceByKey(reducer)
transform.collect()

[('Функция', 2),
 ('документ', 1),
 ('входной', 1),
 ('пар', 1),
 ('набор', 1),
 ('превращает', 1),
 ('в', 1),
 ('map', 1)]

In [None]:
salary = spark.read.csv('Анализ ДОХОДА.csv', inferSchema=True, header=True) # inferSchema=True - Типы данных будут определены. Если false - тип данных у всех будет str. header=True - в файле есть заголовки
salary.select('YOKB').take(5)

[Row(YOKB=107277.47),
 Row(YOKB=189254.9),
 Row(YOKB=105225.84),
 Row(YOKB=154250.0),
 Row(YOKB=142204.32)]

In [None]:
salary['YOKB']

Column<'YOKB'>

In [None]:
salary.count()

60

In [None]:
salary.dtypes

[('YOKB', 'double'), ('RJD', 'double'), ('Full', 'double')]

In [None]:
salary.take(5)

[Row(YOKB=107277.47, RJD=35354.26, Full=142631.73),
 Row(YOKB=189254.9, RJD=35733.88, Full=224988.78),
 Row(YOKB=105225.84, RJD=36584.07, Full=141809.91),
 Row(YOKB=154250.0, RJD=39406.7, Full=193656.7),
 Row(YOKB=142204.32, RJD=36380.26, Full=178584.58)]

In [None]:
salary.show()

+---------+--------+---------+
|     YOKB|     RJD|     Full|
+---------+--------+---------+
|107277.47|35354.26|142631.73|
| 189254.9|35733.88|224988.78|
|105225.84|36584.07|141809.91|
| 154250.0| 39406.7| 193656.7|
|142204.32|36380.26|178584.58|
|141192.39|39277.41| 180469.8|
| 140745.4|37824.04|178569.44|
|190600.94|47710.37|238311.31|
| 105799.1|24965.21|130764.31|
|125798.49|38290.84|164089.33|
|123906.52|37822.22|161728.74|
|197554.29|38315.87|235870.16|
| 83178.24|31147.68|114325.92|
|156549.72|58075.57|214625.29|
| 68411.13|18513.63| 86924.76|
|116971.95|25765.97|142737.92|
|212480.58| 42796.6|255277.18|
| 51320.48|14814.26| 66134.74|
|134311.09|47192.62|181503.71|
|161593.17|38935.73| 200528.9|
+---------+--------+---------+
only showing top 20 rows



In [None]:
salary.describe().show()

+-------+------------------+-----------------+------------------+
|summary|              YOKB|              RJD|              Full|
+-------+------------------+-----------------+------------------+
|  count|                60|               60|                60|
|   mean|141049.01016666667|40308.85016666667|181357.86033333334|
| stddev| 43854.88962988396|14372.12343064859| 54399.23002668247|
|    min|          51320.48|         14515.03|          66134.74|
|    max|         241690.96|         89284.94|         313444.38|
+-------+------------------+-----------------+------------------+



In [None]:
# сохранение датафрейма в файл
salary.write.format('com.databricks.spark.csv').option('header', 'true').save('salaty_spark.csv')

In [None]:
# Вычесления
salary = salary.withColumn('payment', salary['Full']*0.87)
salary.show()

In [None]:
power = spark.read.csv('power.csv', header=True, inferSchema=True)
power.show()

+-------+----+--------+--------+
|country|year|quantity|category|
+-------+----+--------+--------+
|Austria|1996|     5.0|       1|
|Austria|1995|    17.0|       1|
|Belgium|2014|     0.0|       1|
|Belgium|2013|     0.0|       1|
|Belgium|2012|    35.0|       1|
|Belgium|2011|    25.0|       1|
|Belgium|2010|    22.0|       1|
|Belgium|2009|    45.0|       1|
|Czechia|1998|     1.0|       1|
|Czechia|1995|     7.0|       1|
|Finland|2010|     9.0|       1|
|Finland|2009|    13.0|       1|
|Finland|2008|    39.0|       1|
|Finland|2007|    21.0|       1|
|Finland|2006|     0.0|       1|
|Finland|2005|     0.0|       1|
|Finland|2004|     0.0|       1|
|Finland|2003|     0.0|       1|
|Finland|2002|     0.0|       1|
|Finland|2001|     0.0|       1|
+-------+----+--------+--------+
only showing top 20 rows



In [None]:
power_country = power.groupBy(['country']).sum('quantity')
power_country.show()


+------------------+--------------------+
|           country|       sum(quantity)|
+------------------+--------------------+
|     Côte d'Ivoire| 2.815485732456253E7|
|              Chad|  3796498.7491319943|
|          Paraguay|     1.23209483765E7|
|          Anguilla|   20529.34999999997|
|             Yemen|1.8178937740390217E8|
|State of Palestine|  1318668.0123446316|
|           Senegal|   6944395.348079733|
|            Sweden|1.3456236759933385E8|
|        Cabo Verde|   88130.27080000004|
|          Kiribati|   6450.091429000002|
|            Guyana|   772150.6722661877|
|       Philippines|  8.45277094530091E7|
|           Eritrea|   918454.1476713057|
|            Jersey|  142744.73085845588|
|             Tonga|  16350.450516472933|
|          Djibouti|  130946.11799999996|
|         Singapore| 4.701454062703839E7|
|          Malaysia| 8.356959770425016E8|
|              Fiji|  400739.80509911076|
|            Turkey| 3.500108256564667E8|
+------------------+--------------

In [None]:
# сводная таблица
power_pivot = power.groupby('country').pivot('year').sum('quantity')
power_pivot.show()

+------------------+--------------------+------------------+------------------+--------------------+------------------+--------------------+------------------+--------------------+------------------+------------------+------------------+------------------+------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+-------------------+--------------------+--------------------+--------------------+
|           country|                1990|              1991|              1992|                1993|              1994|                1995|              1996|                1997|              1998|              1999|              2000|              2001|              2002|                2003|                2004|                2005|                2006|                2007|                2008|                2009|                2010|               2011|          

In [None]:
power_max = power.groupby('country').max('quantity')# min
power_max.show()

+------------------+-------------+
|           country|max(quantity)|
+------------------+-------------+
|     Côte d'Ivoire|    1638882.0|
|              Chad|     222000.0|
|          Paraguay|     200000.0|
|          Anguilla|        520.0|
|             Yemen|  2.1656655E7|
|State of Palestine|       5370.4|
|           Senegal|     429231.0|
|            Sweden|     875000.0|
|        Cabo Verde|       408.81|
|          Kiribati|         24.3|
|            Guyana|      81000.0|
|       Philippines|    3902100.0|
|           Eritrea|       4969.0|
|            Jersey|        950.0|
|             Tonga|        54.66|
|          Djibouti|        402.0|
|         Singapore|     431960.5|
|          Malaysia|   9.677208E7|
|              Fiji|       3000.0|
|            Turkey|     1.2828E7|
+------------------+-------------+
only showing top 20 rows



In [None]:
# фильтр
salary['Full'] >= 250000

Column<'(Full >= 250000)'>

In [None]:
salary_filter = salary.where((salary['Full'] >= 250000) & (salary['Full'] < 300000))
salary_filter.show()

+---------+--------+---------+
|     YOKB|     RJD|     Full|
+---------+--------+---------+
|212480.58| 42796.6|255277.18|
|217355.29|39058.34|256413.63|
+---------+--------+---------+



In [None]:
#SQL команда с возможностью работы как в sql. Cоздание или Замена Временного Представления
power.createOrReplaceTempView('power') # команда с возможностью работы как в sql. Cоздание или Замена Временного Представления


In [None]:
spark.sql('select * from power').show()

+-------+----+--------+--------+
|country|year|quantity|category|
+-------+----+--------+--------+
|Austria|1996|     5.0|       1|
|Austria|1995|    17.0|       1|
|Belgium|2014|     0.0|       1|
|Belgium|2013|     0.0|       1|
|Belgium|2012|    35.0|       1|
|Belgium|2011|    25.0|       1|
|Belgium|2010|    22.0|       1|
|Belgium|2009|    45.0|       1|
|Czechia|1998|     1.0|       1|
|Czechia|1995|     7.0|       1|
|Finland|2010|     9.0|       1|
|Finland|2009|    13.0|       1|
|Finland|2008|    39.0|       1|
|Finland|2007|    21.0|       1|
|Finland|2006|     0.0|       1|
|Finland|2005|     0.0|       1|
|Finland|2004|     0.0|       1|
|Finland|2003|     0.0|       1|
|Finland|2002|     0.0|       1|
|Finland|2001|     0.0|       1|
+-------+----+--------+--------+
only showing top 20 rows



In [None]:
df_spark = spark.sql('select country, sum(quantity) from power group by country order by sum(quantity) desc limit 10')
df_spark.show()

+--------------------+--------------------+
|             country|       sum(quantity)|
+--------------------+--------------------+
|  Russian Federation|4.335060328741451E10|
|       United States|4.323338745244536E10|
|               China|1.901974183705489E10|
|       USSR (former)|  1.4141328256288E10|
|           Australia|1.319571762662552...|
|Iran (Islamic Rep...|1.092931727870884...|
|               Qatar| 6.078649233527217E9|
|              Canada| 5.571032256333699E9|
|               India| 4.310527452210909E9|
|        Saudi Arabia|3.7398750382031584E9|
+--------------------+--------------------+



In [None]:
# из Spark в pandas
df_pandas_power = df_spark.toPandas()
df_pandas_power

Unnamed: 0,country,sum(quantity)
0,Russian Federation,43350600000.0
1,United States,43233390000.0
2,China,19019740000.0
3,USSR (former),14141330000.0
4,Australia,13195720000.0
5,Iran (Islamic Rep. of),10929320000.0
6,Qatar,6078649000.0
7,Canada,5571032000.0
8,India,4310527000.0
9,Saudi Arabia,3739875000.0


**ТИТАНИК**

In [5]:
titanic = spark.read.csv('titanic.csv', inferSchema=True, header=True)
titanic.show()

+-----------+--------+------+--------------------+------+----+-----+-----+----------------+-------+-----+--------+
|PassengerId|Survived|Pclass|                Name|   Sex| Age|SibSp|Parch|          Ticket|   Fare|Cabin|Embarked|
+-----------+--------+------+--------------------+------+----+-----+-----+----------------+-------+-----+--------+
|          1|       0|     3|Braund, Mr. Owen ...|  male|22.0|    1|    0|       A/5 21171|   7.25| NULL|       S|
|          2|       1|     1|Cumings, Mrs. Joh...|female|38.0|    1|    0|        PC 17599|71.2833|  C85|       C|
|          3|       1|     3|Heikkinen, Miss. ...|female|26.0|    0|    0|STON/O2. 3101282|  7.925| NULL|       S|
|          4|       1|     1|Futrelle, Mrs. Ja...|female|35.0|    1|    0|          113803|   53.1| C123|       S|
|          5|       0|     3|Allen, Mr. Willia...|  male|35.0|    0|    0|          373450|   8.05| NULL|       S|
|          6|       0|     3|    Moran, Mr. James|  male|NULL|    0|    0|      

In [7]:
col = titanic.columns
col

['PassengerId',
 'Survived',
 'Pclass',
 'Name',
 'Sex',
 'Age',
 'SibSp',
 'Parch',
 'Ticket',
 'Fare',
 'Cabin',
 'Embarked']

In [8]:
titanic.groupBy('Embarked').count().show()

+--------+-----+
|Embarked|count|
+--------+-----+
|       Q|   77|
|    NULL|    2|
|       C|  168|
|       S|  644|
+--------+-----+



In [9]:
titanic.agg({'Age': 'mean', 'Fare': 'mean'}).show()

+-----------------+----------------+
|         avg(Age)|       avg(Fare)|
+-----------------+----------------+
|29.69911764705882|32.2042079685746|
+-----------------+----------------+



In [10]:
titanic_filter = titanic.select('Survived', 'Pclass', 'Name', 'Sex', 'Age', 'Fare', 'Embarked')
titanic_filter.show(3)

+--------+------+--------------------+------+----+-------+--------+
|Survived|Pclass|                Name|   Sex| Age|   Fare|Embarked|
+--------+------+--------------------+------+----+-------+--------+
|       0|     3|Braund, Mr. Owen ...|  male|22.0|   7.25|       S|
|       1|     1|Cumings, Mrs. Joh...|female|38.0|71.2833|       C|
|       1|     3|Heikkinen, Miss. ...|female|26.0|  7.925|       S|
+--------+------+--------------------+------+----+-------+--------+
only showing top 3 rows



In [11]:
titanic_filter.describe().show()

+-------+-------------------+------------------+--------------------+------+------------------+-----------------+--------+
|summary|           Survived|            Pclass|                Name|   Sex|               Age|             Fare|Embarked|
+-------+-------------------+------------------+--------------------+------+------------------+-----------------+--------+
|  count|                891|               891|                 891|   891|               714|              891|     889|
|   mean| 0.3838383838383838| 2.308641975308642|                NULL|  NULL| 29.69911764705882| 32.2042079685746|    NULL|
| stddev|0.48659245426485753|0.8360712409770491|                NULL|  NULL|14.526497332334035|49.69342859718089|    NULL|
|    min|                  0|                 1|"Andersson, Mr. A...|female|              0.42|              0.0|       C|
|    max|                  1|                 3|van Melkebeke, Mr...|  male|              80.0|         512.3292|       S|
+-------+-------

**Заполнение пустот**

In [12]:
# Заполнение пустот
titanic_filter = titanic_filter.na.fill({'Age': 29.7, 'Embarked': 'S'})
titanic_filter.describe().show()

+-------+-------------------+------------------+--------------------+------+------------------+-----------------+--------+
|summary|           Survived|            Pclass|                Name|   Sex|               Age|             Fare|Embarked|
+-------+-------------------+------------------+--------------------+------+------------------+-----------------+--------+
|  count|                891|               891|                 891|   891|               891|              891|     891|
|   mean| 0.3838383838383838| 2.308641975308642|                NULL|  NULL| 29.69929292929302| 32.2042079685746|    NULL|
| stddev|0.48659245426485753|0.8360712409770491|                NULL|  NULL|13.002015230774303|49.69342859718089|    NULL|
|    min|                  0|                 1|"Andersson, Mr. A...|female|              0.42|              0.0|       C|
|    max|                  1|                 3|van Melkebeke, Mr...|  male|              80.0|         512.3292|       S|
+-------+-------

**Библиотеки для замены строковых значений на числовые данные**

In [13]:
# Библиотеки для замены строковых значений на числовые данные
from pyspark.ml.feature import StringIndexer, OneHotEncoder

In [68]:
# Метод StringIndexer
# Замена male на 0.0, female на 1.0

stringIndex = StringIndexer(inputCol='Sex', outputCol='Sex_id')
stringTrained = stringIndex.fit(titanic_filter)
titanic_new = stringTrained.transform(titanic_filter)
titanic_new.show(5)

+--------+------+--------------------+------+----+-------+--------+------+
|Survived|Pclass|                Name|   Sex| Age|   Fare|Embarked|Sex_id|
+--------+------+--------------------+------+----+-------+--------+------+
|       0|     3|Braund, Mr. Owen ...|  male|22.0|   7.25|       S|   0.0|
|       1|     1|Cumings, Mrs. Joh...|female|38.0|71.2833|       C|   1.0|
|       1|     3|Heikkinen, Miss. ...|female|26.0|  7.925|       S|   1.0|
|       1|     1|Futrelle, Mrs. Ja...|female|35.0|   53.1|       S|   1.0|
|       0|     3|Allen, Mr. Willia...|  male|35.0|   8.05|       S|   0.0|
+--------+------+--------------------+------+----+-------+--------+------+
only showing top 5 rows



In [69]:
stringIndex_embarked = StringIndexer(inputCol='Embarked', outputCol='Embarked_id')
stringTrained_embarked = stringIndex_embarked.fit(titanic_new)
titanic_new = stringTrained_embarked.transform(titanic_new)
titanic_new.show(5)

+--------+------+--------------------+------+----+-------+--------+------+-----------+
|Survived|Pclass|                Name|   Sex| Age|   Fare|Embarked|Sex_id|Embarked_id|
+--------+------+--------------------+------+----+-------+--------+------+-----------+
|       0|     3|Braund, Mr. Owen ...|  male|22.0|   7.25|       S|   0.0|        0.0|
|       1|     1|Cumings, Mrs. Joh...|female|38.0|71.2833|       C|   1.0|        1.0|
|       1|     3|Heikkinen, Miss. ...|female|26.0|  7.925|       S|   1.0|        0.0|
|       1|     1|Futrelle, Mrs. Ja...|female|35.0|   53.1|       S|   1.0|        0.0|
|       0|     3|Allen, Mr. Willia...|  male|35.0|   8.05|       S|   0.0|        0.0|
+--------+------+--------------------+------+----+-------+--------+------+-----------+
only showing top 5 rows



In [70]:
# OneHotEncoder (столбец должен содержать тип numeric)
onehot = OneHotEncoder(inputCol='Embarked_id', outputCol='Embarked_OneHot')
onehot_Trained = onehot.fit(titanic_new) # все же это обязательное действие
titanic_new = onehot_Trained .transform(titanic_new)

In [54]:
# (2,[0],[1.0]) 2 - это длина вектора("столбцов"), [0] - это позиция где стоит еденичка (как бы первый столбец). Если 2,[],[] - значит еденицы нигде нет - нули
titanic_new.show(10)

+--------+------+--------------------+------+----+-------+--------+------+-----------+---------------+
|Survived|Pclass|                Name|   Sex| Age|   Fare|Embarked|Sex_id|Embarked_id|Embarked_OneHot|
+--------+------+--------------------+------+----+-------+--------+------+-----------+---------------+
|       0|     3|Braund, Mr. Owen ...|  male|22.0|   7.25|       S|   0.0|        0.0|  (2,[0],[1.0])|
|       1|     1|Cumings, Mrs. Joh...|female|38.0|71.2833|       C|   1.0|        1.0|  (2,[1],[1.0])|
|       1|     3|Heikkinen, Miss. ...|female|26.0|  7.925|       S|   1.0|        0.0|  (2,[0],[1.0])|
|       1|     1|Futrelle, Mrs. Ja...|female|35.0|   53.1|       S|   1.0|        0.0|  (2,[0],[1.0])|
|       0|     3|Allen, Mr. Willia...|  male|35.0|   8.05|       S|   0.0|        0.0|  (2,[0],[1.0])|
|       0|     3|    Moran, Mr. James|  male|29.7| 8.4583|       Q|   0.0|        2.0|      (2,[],[])|
|       0|     1|McCarthy, Mr. Tim...|  male|54.0|51.8625|       S|   0.0

In [55]:
# СОБИРАЕМ ВСЕ ПРИЗНАКИ И СОЕДИНЯЕМ В ОДИН ВЕКТОР ДЛЯ МАШИННОГО ОБУЧЕНИЯ!
from pyspark.ml.feature import VectorAssembler

In [57]:
x_list = ['Pclass', 'Age', 'Fare', 'Sex_id', 'Embarked_OneHot']

In [58]:
vec_Assembler = VectorAssembler(inputCols=x_list, outputCol='vec_attribute')

In [59]:
titanic_new = vec_Assembler.transform(titanic_new)

In [65]:
titanic_new.show(5)

+--------+------+--------------------+------+----+-------+--------+------+-----------+---------------+--------------------+
|Survived|Pclass|                Name|   Sex| Age|   Fare|Embarked|Sex_id|Embarked_id|Embarked_OneHot|       vec_attribute|
+--------+------+--------------------+------+----+-------+--------+------+-----------+---------------+--------------------+
|       0|     3|Braund, Mr. Owen ...|  male|22.0|   7.25|       S|   0.0|        0.0|  (2,[0],[1.0])|[3.0,22.0,7.25,0....|
|       1|     1|Cumings, Mrs. Joh...|female|38.0|71.2833|       C|   1.0|        1.0|  (2,[1],[1.0])|[1.0,38.0,71.2833...|
|       1|     3|Heikkinen, Miss. ...|female|26.0|  7.925|       S|   1.0|        0.0|  (2,[0],[1.0])|[3.0,26.0,7.925,1...|
|       1|     1|Futrelle, Mrs. Ja...|female|35.0|   53.1|       S|   1.0|        0.0|  (2,[0],[1.0])|[1.0,35.0,53.1,1....|
|       0|     3|Allen, Mr. Willia...|  male|35.0|   8.05|       S|   0.0|        0.0|  (2,[0],[1.0])|[3.0,35.0,8.05,0....|
+-------

In [64]:
titanic_new.select('vec_attribute').take(5)

[Row(vec_attribute=DenseVector([3.0, 22.0, 7.25, 0.0, 1.0, 0.0])),
 Row(vec_attribute=DenseVector([1.0, 38.0, 71.2833, 1.0, 0.0, 1.0])),
 Row(vec_attribute=DenseVector([3.0, 26.0, 7.925, 1.0, 1.0, 0.0])),
 Row(vec_attribute=DenseVector([1.0, 35.0, 53.1, 1.0, 1.0, 0.0])),
 Row(vec_attribute=DenseVector([3.0, 35.0, 8.05, 0.0, 1.0, 0.0]))]

**ПОВТОРЯЕМОСТЬ**

In [79]:
# ПОВТОРЯЕМОСТЬ
# Можно выполнить все операции через библиотеку Pipeline
from pyspark.ml import Pipeline

In [75]:
pipeline = Pipeline(stages=
                  [StringIndexer(inputCol='Sex', outputCol='Sex_id'),
                   StringIndexer(inputCol='Embarked', outputCol='Embarked_id'),
                   OneHotEncoder(inputCol='Embarked_id', outputCol='Embarked_OneHot'),
                   VectorAssembler(inputCols=x_list, outputCol='vec_attribute')]
                    )


In [77]:
pipeline_Trained = pipeline.fit(titanic_filter)
titanic_new_1 = pipeline_Trained.transform(titanic_filter)

In [78]:
titanic_new_1.show(5)

+--------+------+--------------------+------+----+-------+--------+------+-----------+---------------+--------------------+
|Survived|Pclass|                Name|   Sex| Age|   Fare|Embarked|Sex_id|Embarked_id|Embarked_OneHot|       vec_attribute|
+--------+------+--------------------+------+----+-------+--------+------+-----------+---------------+--------------------+
|       0|     3|Braund, Mr. Owen ...|  male|22.0|   7.25|       S|   0.0|        0.0|  (2,[0],[1.0])|[3.0,22.0,7.25,0....|
|       1|     1|Cumings, Mrs. Joh...|female|38.0|71.2833|       C|   1.0|        1.0|  (2,[1],[1.0])|[1.0,38.0,71.2833...|
|       1|     3|Heikkinen, Miss. ...|female|26.0|  7.925|       S|   1.0|        0.0|  (2,[0],[1.0])|[3.0,26.0,7.925,1...|
|       1|     1|Futrelle, Mrs. Ja...|female|35.0|   53.1|       S|   1.0|        0.0|  (2,[0],[1.0])|[1.0,35.0,53.1,1....|
|       0|     3|Allen, Mr. Willia...|  male|35.0|   8.05|       S|   0.0|        0.0|  (2,[0],[1.0])|[3.0,35.0,8.05,0....|
+-------

**ОБУЧЕНИЕ. ML**

In [None]:
#  Разбиваем данные на данные для обучения и проверки.
# Аналог from sklearn.model_selection import train_test_split.
# X_train, X_test, y_train, y_test = train_test_split(X, y, train_size=0.2)

In [82]:
train, test = titanic_new_1.randomSplit([0.8, 0.2], seed=12345)

In [None]:
# Создаем модель
# Аналог from sklearn.linear_model import LinearRegression
# model_LinReg = LinearRegression()

In [97]:
from pyspark.ml.classification import LogisticRegression
# featuresCol = Колонка с векторными данными (признаками)
# labelCol = колонка предсказания
model_LogRegression = LogisticRegression(featuresCol='vec_attribute', labelCol='Survived')

In [None]:
# Обучам модель
# Аналог model_LinReg.fit(X_train, y_train)

In [98]:
model_titanic = model_LogRegression.fit(train)

In [101]:
# Предсказание. Тестирование выборок
test_res = model_titanic.transform(test)
train_res = model_titanic.transform(train)

In [104]:
test_res.show(5)

+--------+------+--------------------+----+----+----+--------+------+-----------+---------------+--------------------+--------------------+--------------------+----------+
|Survived|Pclass|                Name| Sex| Age|Fare|Embarked|Sex_id|Embarked_id|Embarked_OneHot|       vec_attribute|       rawPrediction|         probability|prediction|
+--------+------+--------------------+----+----+----+--------+------+-----------+---------------+--------------------+--------------------+--------------------+----------+
|       0|     1|Blackwell, Mr. St...|male|45.0|35.5|       S|   0.0|        0.0|  (2,[0],[1.0])|[1.0,45.0,35.5,0....|[0.74264022126565...|[0.67757292992927...|       0.0|
|       0|     1|Futrelle, Mr. Jac...|male|37.0|53.1|       S|   0.0|        0.0|  (2,[0],[1.0])|[1.0,37.0,53.1,0....|[0.58467742281615...|[0.64214296961147...|       0.0|
|       0|     1|   Gee, Mr. Arthur H|male|47.0|38.5|       S|   0.0|        0.0|  (2,[0],[1.0])|[1.0,47.0,38.5,0....|[0.81607622570566...|[

In [105]:
train_res.show(5)

+--------+------+--------------------+------+----+-------+--------+------+-----------+---------------+--------------------+--------------------+--------------------+----------+
|Survived|Pclass|                Name|   Sex| Age|   Fare|Embarked|Sex_id|Embarked_id|Embarked_OneHot|       vec_attribute|       rawPrediction|         probability|prediction|
+--------+------+--------------------+------+----+-------+--------+------+-----------+---------------+--------------------+--------------------+--------------------+----------+
|       0|     1|Allison, Miss. He...|female| 2.0| 151.55|       S|   1.0|        0.0|  (2,[0],[1.0])|[1.0,2.0,151.55,1...|[-2.7701781445074...|[0.05895712901966...|       1.0|
|       0|     1|Allison, Mrs. Hud...|female|25.0| 151.55|       S|   1.0|        0.0|  (2,[0],[1.0])|[1.0,25.0,151.55,...|[-2.0839226091977...|[0.11066930718476...|       1.0|
|       0|     1|Andrews, Mr. Thom...|  male|39.0|    0.0|       S|   0.0|        0.0|  (2,[0],[1.0])|[1.0,39.0,0.0

**Оценка качества модели**

In [106]:
from pyspark.ml.evaluation import BinaryClassificationEvaluator # для оценки если всего два классификатора
from pyspark.ml.evaluation import MulticlassClassificationEvaluator # для оценки если более двух классификатора

In [111]:
bin_eva = BinaryClassificationEvaluator(labelCol='Survived')
bin_eva.evaluate(train_res)

0.858296048952198

In [112]:
bin_eva.evaluate(test_res)

0.8241773760768252

In [114]:
# Модель дерева решений
from pyspark.ml.classification import DecisionTreeClassifier # Классификатор решений