In [1]:
from contextualized_topic_models.models.ctm import CTM
import os
import pickle
from contextualized_topic_models.utils.data_preparation import TextHandler
from contextualized_topic_models.datasets.dataset import CTMDataset

### Load The Data

In [2]:
handler = TextHandler("../contextualized_topic_models/data/gnews/GoogleNews.txt")
handler.prepare() # create vocabulary and training data

# load BERT data
with open("../contextualized_topic_models/data/gnews/bert_embeddings_gnews", "rb") as filino:
    training_bert = pickle.load(filino)
    
training_dataset = CTMDataset(handler.bow, training_bert, handler.idx2token)

ctm = CTM(input_size=len(handler.vocab), bert_input_size=768, num_epochs=5, inference_type="combined", n_components=50)

ctm.fit(training_dataset) # run the model

Settings: 
               N Components: 50
               Topic Prior Mean: 0.0
               Topic Prior Variance: 0.98
               Model Type: prodLDA
               Hidden Sizes: (100, 100)
               Activation: softplus
               Dropout: 0.2
               Learn Priors: True
               Learning Rate: 0.002
               Momentum: 0.99
               Reduce On Plateau: False
               Save Dir: None
Epoch: [1/5]	Samples: [11108/55540]	Train Loss: 83.68554634750855	Time: 0:00:04.193274
Epoch: [2/5]	Samples: [22216/55540]	Train Loss: 74.56768099901957	Time: 0:00:04.221488
Epoch: [3/5]	Samples: [33324/55540]	Train Loss: 71.33945797792131	Time: 0:00:04.225701
Epoch: [4/5]	Samples: [44432/55540]	Train Loss: 69.63589208298636	Time: 0:00:04.240847
Epoch: [5/5]	Samples: [55540/55540]	Train Loss: 68.70253333042626	Time: 0:00:04.133649


In [3]:
ctm.get_topic_lists(5)[0:5]

[['xbox', 'lumia', 'nokia', 'tablet', 'microsoft'],
 ['jos', 'bank', 'men', 'money', 'wearhouse'],
 ['swift', 'jovi', 'bon', 'taylor', 'william'],
 ['disney', 'free', 'lumia', 'frozen', 'beat'],
 ['kobe', 'basel', 'lakers', 'bryant', 'league']]

### Predict the topic distribution for the documents
Which is the topic of the document *nokia lumia launch*?

In [14]:
!head -n 10 ../contextualized_topic_models/data/gnews/GoogleNews.txt

centrepoint winter white gala london
mourinho seek killer instinct
roundup golden globe won seduced johansson voice
travel disruption mount storm cold air sweep south florida
wes welker blame costly turnover
psalm book fetch record ny auction ktvn channel reno
surface review comparison window powered tablet pitted
scientist unreported fish trap space
nokia lumia launch
edward snowden latest leak nsa monitored online porn habit radicalizers


In [32]:
import numpy as np

distribution = ctm.get_thetas(training_dataset)[8] # topic distribution for the first document

print(distribution)

topic = np.argmax(distribution)

ctm.get_topic_lists(5)[topic]


[0.04631870985031128, 0.0030940896831452847, 0.012179974466562271, 0.10562609136104584, 0.00029982352862134576, 0.016247086226940155, 0.0026096573565155268, 0.0024016175884753466, 0.017862707376480103, 0.02637273631989956, 0.06362489610910416, 0.0037376857362687588, 0.027107330039143562, 0.05662212520837784, 0.002137327566742897, 0.0013186399592086673, 0.004643352702260017, 0.008647769689559937, 0.0011787913972511888, 0.006648227572441101, 0.007791334297508001, 0.00045980652794241905, 0.275553822517395, 0.004696805961430073, 0.0015463449526578188, 0.0004970636800862849, 0.0042398953810334206, 0.00293867033906281, 0.023034632205963135, 0.09038814902305603, 0.02389369159936905, 0.0037447575014084578, 0.03660003840923309, 0.00037390936631709337, 0.026152431964874268, 0.016770798712968826, 0.00270906207151711, 0.0025828303769230843, 0.008109727874398232, 0.0009836236713454127, 0.00779800396412611, 0.03024321049451828, 0.0005514333606697619, 0.004347378853708506, 0.004108520690351725, 0.000

['moto', 'xbox', 'camera', 'surface', 'review']

## Evaluate the Model

In [33]:
from contextualized_topic_models.evaluation.measures import TopicDiversity, CoherenceNPMI,\
    CoherenceWordEmbeddings, InvertedRBO

In [34]:
td = TopicDiversity(ctm.get_topic_lists(25))
td.score(topk=25)


0.5176

In [35]:
rbo = InvertedRBO(ctm.get_topic_lists(10))
rbo.score()

0.9802813618807522

### Coherence measure based on Word Embeddings
evaluation of coherence on a word embedding space. if word2vec_file is specified, it retrieves the word embeddings file (in word2vec format) otherwise 'word2vec-google-news-300' is downloaded using gensim's APIs 

In [None]:

word2vec_path = "your\\path\\to\\word2vec.bin"
we_coh = CoherenceWordEmbeddings(word2vec_path=word2vec_path,
                                 topics=ctm.get_topic_lists(10),binary=True)
we_coh.score(topk=10)

In [36]:
with open(os.path.join('../contextualized_topic_models/data/gnews', 'GoogleNews.txt'),"r") as fr:
    texts = [doc.split() for doc in fr.read().splitlines()]
npmi = CoherenceNPMI(texts=texts, topics=ctm.get_topic_lists(10))
npmi.score()



-0.12413081907269717