In [1]:
from pyspark.sql import SparkSession
from pyspark.ml.feature import Word2Vec

In [2]:
spark = SparkSession.builder.appName("SparkMLlib - Word2Vec").getOrCreate()

In [3]:
# Input data: Each row is a bag of words from a sentence or document.
documentDF = spark.createDataFrame([
    ("Hi I heard about Spark".split(" "), ),
    ("I wish Java could use case classes".split(" "), ),
    ("Hi I heard about Hadoop".split(" "), ),
    ("Logistic regression models are neat".split(" "), )
], ["text"])

In [9]:
documentDF.collect()

[Row(text=['Hi', 'I', 'heard', 'about', 'Spark']),
 Row(text=['I', 'wish', 'Java', 'could', 'use', 'case', 'classes']),
 Row(text=['Hi', 'I', 'heard', 'about', 'Hadoop']),
 Row(text=['Logistic', 'regression', 'models', 'are', 'neat'])]

In [10]:
# Learn a mapping from words to Vectors.
word2Vec = Word2Vec(vectorSize=3, minCount=0, inputCol="text", outputCol="result")
model = word2Vec.fit(documentDF)

In [11]:
result = model.transform(documentDF)
for row in result.collect():
    text, vector = row
    print("Text: [%s] => \nVector: %s\n" % (", ".join(text), str(vector)))

Text: [Hi, I, heard, about, Spark] => 
Vector: [0.09882356002926827,0.021327708754688503,0.02151499018073082]

Text: [I, wish, Java, could, use, case, classes] => 
Vector: [-0.04433649592101574,0.020363319665193558,-0.0248810313642025]

Text: [Hi, I, heard, about, Hadoop] => 
Vector: [0.052408272773027426,0.028904251847416164,0.04042706340551377]

Text: [Logistic, regression, models, are, neat] => 
Vector: [0.044334216415882116,-0.0414790615439415,-0.061293074395507574]



In [12]:
d1 = result.collect()[0][1]
d1

DenseVector([0.0988, 0.0213, 0.0215])

In [13]:
d2 = result.collect()[1][1]
d2

DenseVector([-0.0443, 0.0204, -0.0249])

In [14]:
d3 = result.collect()[2][1]
d3

DenseVector([0.0524, 0.0289, 0.0404])

In [15]:
d4 = result.collect()[3][1]
d4

DenseVector([0.0443, -0.0415, -0.0613])

In [16]:
from scipy.spatial.distance import cosine


In [17]:
1 - cosine(d1, d2)

-0.7918359386638196

In [18]:
1 - cosine(d1, d3)

0.8928464475797369

In [19]:
# Compare com seus próprios olhos :-)

In [20]:
result.collect()[0][0]

['Hi', 'I', 'heard', 'about', 'Spark']

In [21]:
result.collect()[2][0]

['Hi', 'I', 'heard', 'about', 'Hadoop']

In [22]:
1 - cosine(d1, d4)

0.2442315567729454