In [7]:
from contextualized_topic_models.models.cotm import COTM
import os
import pickle
from contextualized_topic_models.utils.data_preparation import TextHandler
from contextualized_topic_models.datasets.dataset import COTMDataset

### Load The Data

In [8]:
handler = TextHandler("../contextualized_topic_models/data/gnews/GoogleNews.txt")
handler.prepare() # create vocabulary and training data

# load BERT data
with open("../contextualized_topic_models/data/gnews/bert_embeddings_gnews", "rb") as filino:
    training_bert = pickle.load(filino)
    
training_dataset = COTMDataset(handler.bow, training_bert, handler.idx2token)

cotm = COTM(input_size=len(handler.vocab), bert_input_size=768, num_epochs=5, inference_type="contextual", n_components=50)

cotm.fit(training_dataset) # run the model

Settings: 
               N Components: 50
               Topic Prior Mean: 0.0
               Topic Prior Variance: 0.98
               Model Type: prodLDA
               Hidden Sizes: (100, 100)
               Activation: softplus
               Dropout: 0.2
               Learn Priors: True
               Learning Rate: 0.002
               Momentum: 0.99
               Reduce On Plateau: False
               Save Dir: None
Epoch: [1/5]	Samples: [11108/55540]	Train Loss: 88.4506742307768	Time: 0:00:02.260274
Epoch: [2/5]	Samples: [22216/55540]	Train Loss: 78.05209167062871	Time: 0:00:02.212764
Epoch: [3/5]	Samples: [33324/55540]	Train Loss: 73.7296515338748	Time: 0:00:02.178382
Epoch: [4/5]	Samples: [44432/55540]	Train Loss: 71.4169747143458	Time: 0:00:02.192594
Epoch: [5/5]	Samples: [55540/55540]	Train Loss: 70.13589181924064	Time: 0:00:02.214744


In [9]:
cotm.get_topic_lists(5)[0:5]

[['hiv', 'test', 'love', 'researcher', 'secret'],
 ['pill', 'morning', 'woman', 'effective', 'hp'],
 ['jellyfish', 'salmond', 'flying', 'blueprint', 'independence'],
 ['prince', 'william', 'taylor', 'swift', 'bon'],
 ['san', 'grand', 'theft', 'andreas', 'auto']]

### Predict the topic distribution for the documents

In [13]:
cotm.get_thetas(training_dataset)[0] # topic distribution for the first document

[0.01770107075572014,
 0.0012614678125828505,
 0.04816910997033119,
 0.1519114375114441,
 0.004698761738836765,
 0.001173828961327672,
 0.0477045401930809,
 0.03476877883076668,
 0.0027616769075393677,
 0.0012197698233649135,
 0.07430833578109741,
 0.0031210030429065228,
 0.0044945646077394485,
 0.003755113808438182,
 0.052336834371089935,
 0.016557950526475906,
 0.007010035216808319,
 0.003168315626680851,
 0.0023144094739109278,
 0.015852492302656174,
 0.015327775850892067,
 0.004507662262767553,
 0.0024233420845121145,
 0.0019444096833467484,
 0.016046645119786263,
 0.0015794036444276571,
 0.0007641096599400043,
 0.001453230157494545,
 0.028085924685001373,
 0.023154620081186295,
 0.08526438474655151,
 0.010250079445540905,
 0.0029451046139001846,
 0.0023117822129279375,
 0.005968856625258923,
 0.01524386741220951,
 0.0017038424266502261,
 0.004526254255324602,
 0.008974992670118809,
 0.00682725990191102,
 0.011294402182102203,
 0.002833897713571787,
 0.0037057276349514723,
 0.00526

## Evaluate the Model

In [10]:
from contextualized_topic_models.evaluation.measures import TopicDiversity, CoherenceNPMI,\
    CoherenceWordEmbeddings, InvertedRBO

In [11]:
td = TopicDiversity(cotm.get_topic_lists(25))
td.score(topk=25)


0.5592

In [None]:
rbo = InvertedRBO(cotm.get_topic_lists(10))
rbo.score()

### Coherence measure based on Word Embeddings
evaluation of coherence on a word embedding space. if word2vec_file is specified, it retrieves the word embeddings file (in word2vec format) otherwise 'word2vec-google-news-300' is downloaded using gensim's APIs 

In [None]:

word2vec_path = "your\\path\\to\\word2vec.bin"
we_coh = CoherenceWordEmbeddings(word2vec_path=word2vec_path,
                                 topics=cotm.get_topic_lists(10),binary=True)
we_coh.score(topk=10)

In [None]:
with open(os.path.join('../contextualized_topic_models/data/gnews', 'GoogleNews.txt'),"r") as fr:
    texts = [doc.split() for doc in fr.read().splitlines()]
npmi = CoherenceNPMI(texts=texts, topics=cotm.get_topic_lists(10))
npmi.score()

