In [28]:
from pyspark import SparkContext
from pyspark.sql import SparkSession
from pyspark.ml.feature import OneHotEncoder, StringIndexer

In [29]:
spark = SparkSession.Builder().getOrCreate()

In [30]:
catDF = spark.createDataFrame([
    (0, "Moscow"),
    (1, "New York"),
    (2, "Beijing"),
    (3, "New York"),
    (4, "Paris"),
    (5, "Paris"),
    (6, "New York"),
    (7, "Beijing")],
    ["row_id", "city"])

In [31]:
catDF.show()

+------+--------+
|row_id|    city|
+------+--------+
|     0|  Moscow|
|     1|New York|
|     2| Beijing|
|     3|New York|
|     4|   Paris|
|     5|   Paris|
|     6|New York|
|     7| Beijing|
+------+--------+



In [32]:
indexer = StringIndexer(inputCol="city", outputCol="index")
indexer = indexer.fit(catDF)
indexer = indexer.transform(catDF)

In [33]:
indexer.show()

+------+--------+-----+
|row_id|    city|index|
+------+--------+-----+
|     0|  Moscow|  3.0|
|     1|New York|  0.0|
|     2| Beijing|  1.0|
|     3|New York|  0.0|
|     4|   Paris|  2.0|
|     5|   Paris|  2.0|
|     6|New York|  0.0|
|     7| Beijing|  1.0|
+------+--------+-----+



In [34]:
encoder = OneHotEncoder(inputCol="index", outputCol="encoding")
encoder.setDropLast(False)
encoder = encoder.fit(indexer)
indexer = encoder.transform(indexer)

In [35]:
indexer.show()

+------+--------+-----+-------------+
|row_id|    city|index|     encoding|
+------+--------+-----+-------------+
|     0|  Moscow|  3.0|(4,[3],[1.0])|
|     1|New York|  0.0|(4,[0],[1.0])|
|     2| Beijing|  1.0|(4,[1],[1.0])|
|     3|New York|  0.0|(4,[0],[1.0])|
|     4|   Paris|  2.0|(4,[2],[1.0])|
|     5|   Paris|  2.0|(4,[2],[1.0])|
|     6|New York|  0.0|(4,[0],[1.0])|
|     7| Beijing|  1.0|(4,[1],[1.0])|
+------+--------+-----+-------------+

