In [1]:
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements.  See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License.  You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

# $example on$
from pyspark.ml.feature import Word2Vec, Word2VecModel
# $example off$
from pyspark.sql import SparkSession
from pyspark.sql.functions import format_number as fmt

In [2]:
spark = SparkSession.builder \
    .appName("word2vec-spark") \
    .getOrCreate()

In [3]:
#model = Word2VecModel.load(sc, 'w2v_model')
model = Word2VecModel.load('w2v_model')

In [4]:
# 近いベクトルの単語を取得
model.findSynonymsArray('野球', 5)

[('姉妹', 0.8228563666343689),
 ('り', 0.8056354522705078),
 ('ん', 0.7986965775489807),
 ('男', 0.7821294069290161),
 ('秋', 0.7731783986091614)]

In [5]:
# 近いベクトルの単語を取得
model.findSynonymsArray("東京", 5)

[('バス', 0.9257544875144958),
 ('上野', 0.9251026511192322),
 ('都立', 0.9039059281349182),
 ('京成', 0.9008254408836365),
 ('清', 0.8926427364349365)]

In [6]:
# 近いベクトルの単語を取得
model.findSynonymsArray('日本', 5)

[('政府', 0.7264666557312012),
 ('ウィーン', 0.6764424443244934),
 ('クラブ', 0.6741369366645813),
 ('トルコ', 0.6721512079238892),
 ('都市', 0.6598173379898071)]

In [7]:
model.findSynonyms('日本', 5).select("word", fmt("similarity", 5).alias("similarity")).show()

+--------+----------+
|    word|similarity|
+--------+----------+
|    政府|   0.72647|
|ウィーン|   0.67644|
|  クラブ|   0.67414|
|  トルコ|   0.67215|
|    都市|   0.65982|
+--------+----------+



In [8]:
documentDF = spark.createDataFrame([
    ("日本 海".split(" "), ),
    ("王 男".split(" "), ),
], ["text"])

In [9]:
# ベクトルの取得
res = model.transform(documentDF)
for row in res.collect():
    text, vector = row
    print("Text: [%s] => \nVector: %s\n" % (", ".join(text), str(vector)))

Text: [日本, 海] => 
Vector: [-0.02831833530217409,0.04168903827667236,-0.022362001705914736,0.05916345398873091,0.11567089334130287,0.17672017007134855,-0.049605967476964,-0.02379211224615574,0.03758530877530575,0.1253017894923687,0.011594744399189949,0.01506025344133377,0.09179229289293289,0.0012611672282218933,0.00017327070236206055,-0.07945050485432148,0.050401024520397186,0.01208539493381977,0.0649831835180521,-0.03143859840929508,0.041660357266664505,0.14464688394218683,-0.05969642288982868,0.014144607819616795,-0.09923805948346853,0.015803582966327667,0.006146464496850967,0.054516881704330444,0.03930703899823129,0.015448017977178097,-0.05325884371995926,-0.04041178897023201,0.041766826529055834,0.004083635285496712,-0.10820183902978897,0.08505752868950367,-0.1263538971543312,-0.0252073984593153,0.1145687960088253,0.14198239520192146,-0.020240144804120064,-0.09661264345049858,-0.10040120407938957,0.01628229022026062,-0.1539141982793808,-0.04868007451295853,0.032655315939337015,-0.01

In [10]:
spark.stop()