In [1]:
import os
import sys
# 如果当前代码文件运行测试需要加入修改路径，避免出现后导包问题
BASE_DIR = os.path.dirname(os.path.dirname(os.getcwd()))
sys.path.insert(0, os.path.join(BASE_DIR))
PYSPARK_PYTHON = "/opt/anaconda3/envs/rec_sys/bin/python"

# 当存在多个版本时，不指定很可能会导致出错
os.environ["PYSPARK_PYTHON"] = PYSPARK_PYTHON
os.environ["PYSPARK_DRIVER_PYTHON"] = PYSPARK_PYTHON


from offline import SparkSessionBase
from offline import SparkSessionBase
from pyspark.ml.feature import Word2Vec



class TrainWord2VecModel(SparkSessionBase):

    SPARK_APP_NAME = "Word2Vec"
    SPARK_URL = "yarn"
    
    ENABLE_HIVE_SUPPORT = True

    def __init__(self):
        self.spark = self._create_spark_session()


w2v = TrainWord2VecModel()

In [2]:
w2v.spark.sql("use article")

DataFrame[]

In [55]:
w2v.spark.conf.set("spark.kryoserializer.buffer.max",4)

In [57]:
w2v.spark.conf.get("spark.kryoserializer.buffer.max")

'4'

In [9]:
article = w2v.spark.sql("select * from article_data  limit 10")

In [10]:
import codecs
global stopwords_list
stopwords_path = "/home/data/etc/stop_words.txt"
stopwords_list = [i.strip() for i in codecs.open(stopwords_path).readlines()]

In [11]:
def segmentation(partition):
    import jieba
    import os
    import jieba.posseg as pseg
    import re
     
    
    # 分词
    def cut_sentence(sentence):
        """对切割之后的词语进行过滤，去除停用词，保留名词，英文和自定义词库中的词，长度大于2的词"""
        # print(sentence,"*"*100)
        # eg:[pair('今天', 't'), pair('有', 'd'), pair('雾', 'n'), pair('霾', 'g')]
        seg_list = pseg.lcut(sentence)
        seg_list = [i for i in seg_list if i.flag not in stopwords_list]
        filtered_words_list = []
        for seg in seg_list:
            # print(seg)
            if len(seg.word) <= 1:
                continue
            elif seg.flag == "eng":
                if len(seg.word) <= 2:
                    continue
                else:
                    filtered_words_list.append(seg.word)
            elif seg.flag.startswith("n"):
                filtered_words_list.append(seg.word)
            elif seg.flag in ["x", "eng"]:  # 是自定一个词语或者是英文单词
                filtered_words_list.append(seg.word)
        return filtered_words_list
    
    for row in partition:
        sentence = re.sub("<.*?>", "", row.sentence)    # 替换掉标签数据
        words = cut_sentence(sentence)
        yield row.article_id, row.channel_name, words

In [12]:
word_df = article.rdd.mapPartitions(segmentation).toDF(["article_id", "channle_name", "words"])

In [13]:
w2v_model = Word2Vec(vectorSize=100, inputCol="words", outputCol="model", minCount=3, windowSize=5)

In [14]:
new_w2v_model = w2v_model.fit(word_df)
new_w2v_model.save('hdfs://192.168.0.52:8020/models/w2v_test2.model')

In [74]:
from pyspark.ml.feature import Word2VecModel
wv_model = Word2VecModel.load(
                "hdfs://192.168.0.52:8020/models/w2v_test2.model")
vectors = wv_model.getVectors()

In [102]:
# 增量更新
from datetime import timedelta,datetime

_yester = datetime.today().replace(hour=0, minute=0, second=0, microsecond=0)
start = datetime.strftime(_yester+timedelta(days=-2, hours=0, minutes=0), "%Y-%m-%d %H:%M:%S")
end = datetime.strftime(_yester, "%Y-%m-%d %H:%M:%S")

Incremental_article_sql = "select article_id, channel_name, keywords, weights from article_profile LATERAL VIEW explode(keyword) AS keywords,weights where article_time >= '{}' and article_time < '{}' ".format(start, end)


In [103]:
Incremental_article_sql

"select article_id, channel_name, keywords, weights from article_profile LATERAL VIEW explode(keyword) AS keywords,weights where article_time >= '2019-12-10 00:00:00' and article_time < '2019-12-12 00:00:00' "

In [107]:
articleKeywordsWeights = w2v.spark.sql(Incremental_article_sql)

In [90]:
w2v.spark.sql("select * from article_profile limit 1")

DataFrame[article_id: string, channel_name: string, keyword: map<string,double>, topics: array<string>, article_time: string]

In [108]:
_article_profile = articleKeywordsWeights.join(vectors, vectors.word==articleKeywordsWeights.keywords, "inner")

In [109]:
articleKeywordVectors = _article_profile.rdd.map(lambda row: (row.article_id, row.channel_name, row.keywords, row.weights * row.vector)).toDF(["article_id", "channel_name", "keyword", "weightingVector"])

In [110]:
def avg(row):
    x = 0
    for v in row.vectors:
        x += v
    #  将平均向量作为article的向量
    return row.article_id, row.channel_name, x / len(row.vectors)

articleKeywordVectors.registerTempTable("tempTable")
articleVector = w2v.spark.sql(
    "select article_id, min(channel_name) channel_name, collect_set(weightingVector) vectors from tempTable group by article_id").rdd.map(
    avg).toDF(["article_id", "channel_name", "articleVector"])

In [111]:
def toArray(row):
    return row.article_id, row.channel_name, [float(i) for i in row.articleVector.toArray()]

articleVector = articleVector.rdd.map(toArray).toDF(['article_id', 'channel_name', 'articleVector'])

In [112]:
articleVector.write.insertInto("article_vector")

In [113]:
w2v.spark.sql("select * from article_vector limit 1").show()

+--------------------+------------+--------------------+
|          article_id|channel_name|       articlevector|
+--------------------+------------+--------------------+
|4ffe1e9b40e001aa1...|      砍柴网|[0.29118859694804...|
+--------------------+------------+--------------------+

