In [1]:
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements.  See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License.  You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

# $example on$
from pyspark.ml.feature import Word2Vec, Word2VecModel
# $example off$
from pyspark.sql import SparkSession
from pyspark.sql.functions import format_number as fmt

In [2]:
spark = SparkSession.builder \
    .appName("word2vec-spark") \
    .getOrCreate()

In [3]:
#model = Word2VecModel.load(sc, 'w2v_model')
model = Word2VecModel.load('w2v_model2')

In [4]:
# 近いベクトルの単語を取得
model.findSynonymsArray('野球', 5)

[('プロ', 0.7207791209220886),
 ('サッカー', 0.7138990163803101),
 ('元', 0.6855593323707581),
 ('ゴルファー', 0.666304349899292),
 ('選手', 0.6264481544494629)]

In [5]:
# 近いベクトルの単語を取得
model.findSynonymsArray("東京", 5)

[('目黒', 0.7323737144470215),
 ('都', 0.6739898920059204),
 ('大阪', 0.6382745504379272),
 ('区', 0.6304832100868225),
 ('三鷹', 0.6196695566177368)]

In [6]:
# 近いベクトルの単語を取得
model.findSynonymsArray("日本", 5)

[('協会', 0.4811580777168274),
 ('元', 0.4514187276363373),
 ('家', 0.449057012796402),
 ('ロードレース', 0.44045063853263855),
 ('初', 0.431621253490448)]

In [7]:
model.findSynonyms('日本', 5).select("word", fmt("similarity", 5).alias("similarity")).show()

+------------+----------+
|        word|similarity|
+------------+----------+
|        協会|   0.48116|
|          元|   0.45142|
|          家|   0.44906|
|ロードレース|   0.44045|
|          初|   0.43162|
+------------+----------+



In [8]:
documentDF = spark.createDataFrame([
    ("日本 海".split(" "), ),
    ("王 男".split(" "), ),
], ["text"])

In [9]:
# ベクトルの取得
res = model.transform(documentDF)
for row in res.collect():
    text, vector = row
    print("Text: [%s] => \nVector: %s\n" % (", ".join(text), str(vector)))

Text: [日本, 海] => 
Vector: [0.051043495535850525,-0.2080628201365471,0.0913432314991951,0.10421495884656906,0.3484318181872368,0.01342928409576416,0.09427779167890549,-0.07429458387196064,0.07676511723548174,0.17561857402324677,0.004214957356452942,-0.17874804139137268,0.018272310495376587,-0.009507417678833008,0.22221802920103073,0.18262144830077887,0.13121276535093784,0.2770853042602539,0.15446455031633377,0.2568695805966854,-0.042400866746902466,0.11865832563489676,-0.04240161180496216,-0.06498439237475395,0.022332139313220978,0.09936200082302094,-0.24084429815411568,-0.1232358074048534,0.03752938250545412,-0.08315454982221127,-0.11612553149461746,-0.10686345025897026,-0.18035361357033253,-0.12141963839530945,0.004590824246406555,-0.04063940793275833,-0.007288217544555664,-0.11171035841107368,0.1061067245900631,-0.20631732791662216,0.253596194088459,0.27433183044195175,-0.058486893540248275,-0.09637655690312386,-0.0900473203510046,-0.07865951210260391,-0.07807149831205606,-0.06047749

In [10]:
spark.stop()