In [1]:
import os
from pyspark.sql import SparkSession

import pyspark.sql.functions as F
import pyspark.sql.types as T

import smart_open

from scipy.sparse import csr_matrix
import scipy

import numpy as np

from simple_lda import SimpleLDA
from itertools import islice

In [2]:
spark = (
    SparkSession.builder
     .master("local[*]")
     .appName("SparkTest")
     .getOrCreate()
     )

In [3]:
word2id_path = "s3://onai-ml-dev-eu-west-1/company2vec/common"
data_path = "s3://onai-ml-dev-eu-west-1/company2vec/data_desc_only"

In [4]:
public_labels = [
    2444485, 9429145, 93196, 9870522, 7914436,
    645782, 380011, 392154, 5523392,
    875173, 237655379, 931146, 418171,
    813607, 100231, 357076,
    380011, 357076, 324490, 93339,
    882889, 127202, 5478606, 1025390, 645782,
    46329052, 189096, 915379, 46895276, 877008, 325290136, 20524024, 271958947, 21852987, 26363560, 110104150,
    106335, 319676, 377732, 61206100,
    877769, 874042, 780678, 953488, 883809, 875295, 874186, 874119,
    34049884, 882155, 30428758, 315394, 23037669, 27561,
    254287477, 883300, 30614595,
    5600108, 285880, 5433540, 878697, 35650, 688262, 226852452,
    876031, 410182, 874470, 874191, 879732, 5395336,
    883752, 880697, 65340486, 26320074, 883327, 1034090, 257501324,
    5920885, 1494039, 268074105, 34534627, 20385800, 23000545, 124640,
    628413, 272054403, 91192, 309779, 140283, 138644,
    364040, 381388, 184945, 874170, 42751952, 874183, 314896, 5126590, 841504,
    257501324, 35000, 47320264, 253748612, 85076655, 32053, 12144785, 8186273, 9934160, 557267859,
    695204, 35303, 274561, 683719, 370857, 561001, 874022, 387473, 394038, 8274485, 12188205,
    30614595, 883300, 254287477, 9956099, 380011, 27868703, 2386697, 126857,
    28224119, 26824144, 35023689, 386639, 393661,
    32449506, 875260, 27169270, 5629762, 26014489, 286119,
    233324810, 874864, 159230, 27860587, 35806, 876981,
    879554, 5487000, 236715563, 412090459, 875192, 278679, 180871, 22516334, 30274893, 5478907,
    5580060, 118474533, 1779941, 265154, 10405454,
    23335317, 7885406, 277444, 278933, 8983678, 874143, 409119,
    381865672, 874842, 410366, 873649, 275789, 882473,
    937352, 876758, 879422, 128861678, 6461781, 1859063,
    874119, 881803, 875849, 231533, 877769, 780678, 953488, 875295, 874042, 775001, 874186,
    680934, 135398, 882299, 668578, 4481676, 32012,
    2248076, 141249, 4975204, 98876, 21828553,
    3606442, 882547, 4509042, 20703565, 7435035, 94799, 288033, 359868,
    877235, 295170, 175265, 874520, 410366, 873649, 874977, 167945, 8090046,
    84148802, 275789, 30339992, 5533238, 5718736,
    5523392, 645782, 11809880, 1353107, 962864,
    413744, 409932, 875491, 109303666, 91638,
    314896, 330589, 34768, 184945,
    5126590, 874855, 631781, 364040, 831357, 874170,
    377732, 319676, 106772, 106335, 704634, 320105, 874828, 873861, 1519242, 533853947,
    874652, 377732, 319676, 106772, 704634, 312375, 278933, 874828, 4863668
]

In [5]:
companies_raw = spark.read.load("s3://ai-data-lake-dev-eu-west-1/business/capiq/company_denormalized")

In [6]:
target_industries = [el[0] for el in
                     companies_raw
                     .filter(F.col("company_id").isin(public_labels))
                     .select("sic_code")
                     .filter(F.col("sic_code") != "")
                     .distinct()
                     .collect()
                    ]

In [7]:
companies_subset = {el.company_id: el.sic_code_desc
                    for el in companies_raw.filter(F.col("sic_code").isin(target_industries)).collect()
                   }

In [8]:
industries = sorted(list(set(companies_subset.values())))
industry2index = {}
for i,el in enumerate(industries):
    industry2index[el] = i

In [9]:
word2id = {}
with smart_open.open(f"{word2id_path}/bow/word2id.csv", "r") as f:
    for line in f:
        word,idd = line.strip().split(",")
        word2id[word] = int(idd)

In [10]:
industry_udf = F.udf(lambda c_id: companies_subset[c_id], T.StringType())

In [11]:
companies_df = (spark.read.load(f"{data_path}/raw_company_features")
                .filter(F.col("company_id").isin(list(companies_subset.keys())))
                .select("merged_description", industry_udf("company_id").alias("industry"))
               )

In [12]:
n_docs = companies_df.count()
n_words = len(word2id)

In [13]:
rows = []
cols = []
data = []
industry = []
for i, row in enumerate(companies_df.select("merged_description", "industry").collect()):
    industry.append(industry2index[row.industry])
    for k,v in row.merged_description.items():
        if k not in word2id:
            continue
        rows.append(i)
        cols.append(word2id[k])
        data.append(v)

In [14]:
X = csr_matrix((data, (rows, cols)), shape=(n_docs, n_words))

In [15]:
n_topics = len(industries) + 1

In [16]:
X_metadata_rdd = spark.sparkContext.parallelize(list(zip(X,industry)), 10)

In [17]:
model = SimpleLDA(word2id, industries, print_every=25)

In [18]:
model.train_distributed(X_metadata_rdd, n_iter=50)

0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49


In [19]:
def plot_topics(model):
    phi = model.phi
    id2word = model.id2word
    
    for i,topic in enumerate(phi):
        topic_name = model.industries[i] if i < len(model.industries) else "background"
        print("#"*20, topic_name, "#"*20)
        for prob,i in islice(sorted([(prob,i) for i,prob in enumerate(topic)], reverse=True), 20):
            print(f"{id2word[i]}:{prob} ", end="")
        print()
        print("#"*80)

In [20]:
plot_topics(model)

#################### Communication services ####################
servic:0.019999180741328507 mobil:0.015352320206042436 data:0.015138444583581767 service:0.014727373838233412 network:0.013648499963047744 internet:0.012061404879951606 commun:0.011678684115441353 satellit:0.00805323891381802 provid:0.0072223487820527975 oper:0.0069992414865325675 telecommun:0.00625196397146434 broadband:0.006112035958372362 platform:0.005869617799356605 inc:0.00586123445328233 wireless:0.005669523496679094 voic:0.005236730668915918 compani:0.005080463666779935 nth:0.004982129858406043 phone:0.0046627950666368 limit:0.004601520523599553 
################################################################################
#################### Engineering services ####################
engin:0.0292305871928453 project:0.018670523107140707 design:0.01667645970875136 system:0.016287203659791652 construct:0.015816715307217 service:0.014589622754559464 servic:0.011507315519148995 australia:0.010558116071737539 ltd:0