In [1]:
import contextualized_topic_models
from contextualized_topic_models.models.ctm import CTM
from contextualized_topic_models.utils.data_preparation import TextHandler
from contextualized_topic_models.utils.data_preparation import bert_embeddings_from_file
from contextualized_topic_models.datasets.dataset import CTMDataset

# tqdm allows you to create progress bars to track how long your code is taking to process
from tqdm.notebook import tqdm as tqdm

# pprint is to make our topics formatted a little nicer when we take a look
from pprint import pprint

In [2]:
handler = TextHandler("documents.txt")
handler.prepare()

  self.vocab_dict[x], y.split()))), data)))


In [3]:

# generate BERT data
training_bert = bert_embeddings_from_file("documents.txt", "distiluse-base-multilingual-cased")


training_dataset = CTMDataset(handler.bow, training_bert, handler.idx2token)

ctm = CTM(input_size=len(handler.vocab), bert_input_size=512, inference_type="combined", n_components=50)

ctm.fit(training_dataset) # run the model



Settings: 
               N Components: 50
               Topic Prior Mean: 0.0
               Topic Prior Variance: 0.98
               Model Type: prodLDA
               Hidden Sizes: (100, 100)
               Activation: softplus
               Dropout: 0.2
               Learn Priors: True
               Learning Rate: 0.002
               Momentum: 0.99
               Reduce On Plateau: False
               Save Dir: None
Epoch: [1/100]	Samples: [35/3500]	Train Loss: 212.7273158482143	Time: 0:00:00.163821
Epoch: [2/100]	Samples: [70/3500]	Train Loss: 214.52566964285714	Time: 0:00:00.078270
Epoch: [3/100]	Samples: [105/3500]	Train Loss: 211.35712890625	Time: 0:00:00.133631
Epoch: [4/100]	Samples: [140/3500]	Train Loss: 209.90057198660713	Time: 0:00:00.075811
Epoch: [5/100]	Samples: [175/3500]	Train Loss: 210.4786411830357	Time: 0:00:00.094209
Epoch: [6/100]	Samples: [210/3500]	Train Loss: 211.3962890625	Time: 0:00:00.071848
Epoch: [7/100]	Samples: [245/3500]	Train Loss: 211.3770926

Epoch: [90/100]	Samples: [3150/3500]	Train Loss: 190.3243443080357	Time: 0:00:00.086061
Epoch: [91/100]	Samples: [3185/3500]	Train Loss: 190.6351283482143	Time: 0:00:00.082367
Epoch: [92/100]	Samples: [3220/3500]	Train Loss: 190.91953125	Time: 0:00:00.087674
Epoch: [93/100]	Samples: [3255/3500]	Train Loss: 187.64323381696428	Time: 0:00:00.085840
Epoch: [94/100]	Samples: [3290/3500]	Train Loss: 189.34839564732144	Time: 0:00:00.078068
Epoch: [95/100]	Samples: [3325/3500]	Train Loss: 191.9134486607143	Time: 0:00:00.076335
Epoch: [96/100]	Samples: [3360/3500]	Train Loss: 187.87861328125	Time: 0:00:00.087560
Epoch: [97/100]	Samples: [3395/3500]	Train Loss: 186.5501953125	Time: 0:00:00.073873
Epoch: [98/100]	Samples: [3430/3500]	Train Loss: 188.78759765625	Time: 0:00:00.069172
Epoch: [99/100]	Samples: [3465/3500]	Train Loss: 190.42066127232144	Time: 0:00:00.071778
Epoch: [100/100]	Samples: [3500/3500]	Train Loss: 188.47459542410715	Time: 0:00:00.082685


In [7]:
print("Get topic list")
ctm.get_topic_lists(5)[0:6]

Get topic list


[['fundraiser', 'in', 'relatively', 'aiming', 'will'],
 ['rate', 'The', 'To', 'contract', 'keep'],
 ['couriers', 'televised', 'and', 'rules', 'Already,'],
 ['stepped', 'coronavirus,', 'million', 'far', 'Some'],
 ['under', 'the', 'last', 'economy', 'women'],
 ['New', 'laboratory-confirmed', 'tube,', 'arriving', 'travel.']]

In [10]:
!head -n 10 documents.txt

The World Health Organization reported a record increase in global coronavirus infections yesterday, with the total rising by 183,020 in a 24-hour period.

The biggest increase, of more than 116,000, was from North and South America, it said in a daily report. Total global cases have passed 8.7 million with more than 461,000 deaths, the WHO added.

Germany's coronavirus reproduction rate jumped to 2.88 yesterday, up from 1.79 a day earlier, health authorities said, a rate showing infections are rising above the level needed to contain the disease over the longer term.

The rise brings with it the possibility of renewed restrictions on activity in Europe's largest economy - a blow to a country that so far had widely been seen as successful in curbing the coronavirus spread and keeping the death toll relatively low.

To keep the pandemic under control, Germany needs the reproduction rate to drop below one. The rate of 2.88, published by the Robert Koch Institute (RKI) for public 

In [9]:
import numpy as np

distribution = ctm.get_thetas(training_dataset)[8] # topic distribution for the first document

print(distribution)

topic = np.argmax(distribution)

ctm.get_topic_lists(5)[topic]

[0.003234884236007929, 0.07325280457735062, 0.007807305548340082, 0.24313412606716156, 0.006643406115472317, 0.002581328619271517, 0.004334053490310907, 0.0006474740221165121, 0.00915495865046978, 0.016229234635829926, 0.004887876100838184, 0.010647004470229149, 0.0018588844686746597, 0.030883969739079475, 0.003372527426108718, 0.03695008158683777, 0.0032030942384153605, 0.0039533269591629505, 0.039643656462430954, 0.002057476667687297, 0.0029870006255805492, 0.005823671817779541, 0.0062237088568508625, 0.0017570609925314784, 0.012800307944417, 0.005516372621059418, 0.004783101845532656, 0.02290784753859043, 0.02422214485704899, 0.008392039686441422, 0.0024587088264524937, 0.017790278419852257, 0.014246385544538498, 0.008215066976845264, 0.09104691445827484, 0.039396271109580994, 0.038242027163505554, 0.00047126191202551126, 0.006299561820924282, 0.0029759102035313845, 0.022552059963345528, 0.005243094637989998, 0.0348799005150795, 0.001906018704175949, 0.03045138530433178, 0.037276644

['stepped', 'coronavirus,', 'million', 'far', 'Some']

In [13]:
#Evaluate the Model

from contextualized_topic_models.evaluation.measures import TopicDiversity, CoherenceNPMI,\
    CoherenceWordEmbeddings, InvertedRBO



In [None]:
td = TopicDiversity(ctm.get_topic_lists(25))
td.score(topk=25)


In [None]:
rbo = InvertedRBO(ctm.get_topic_lists(10))
rbo.score()

In [17]:
td=TopicDiversity(ctm.get_topic_lists(5))

In [24]:
#Coherence measure based on Word Embeddings
#evaluation of coherence on a word embedding space. if word2vec_file is specified, 
#it retrieves the word embeddings file (in word2vec format) otherwise 'word2vec-google-news-300' is downloaded 
#using gensim's APIs

# td.topics
ctm.get_topic_lists(5)

[['fundraiser', 'in', 'relatively', 'aiming', 'will'],
 ['rate', 'The', 'To', 'contract', 'keep'],
 ['couriers', 'televised', 'and', 'rules', 'Already,'],
 ['stepped', 'coronavirus,', 'million', 'far', 'Some'],
 ['under', 'the', 'last', 'economy', 'women'],
 ['New', 'laboratory-confirmed', 'tube,', 'arriving', 'travel.'],
 ['to', 'which', 'said', 'world', 'and'],
 ['reporters.', 'tube,', 'country', 'activity', 'toll'],
 ['that', 'couriers', 'expected', 'was', 'today,'],
 ['1.79', 'testing', 'New', 'following', 'tightened'],
 ['the', 'called', 'allowed', 'around', 'number'],
 ['the', 'to', 'Allgemeine', 'Bundesbank', 'newspaper'],
 ['current', 'and', 'Officials', 'news', 'Beijing'],
 ['and', 'to', 'mitigate', 'linked', 'stepped'],
 ['in', 'ban,', 'days.', 'but', 'nine.'],
 ['contract', 'the', 'health,', 'by', 'keep'],
 ['quarantine', 'As', 'multiple', 'week', 'Some'],
 ['parcel', 'outbreak,', 'Chinese', 'workers', 'virus'],
 ['daily', '461,000', 'South', '2.88', 'from'],
 ['meat', 'comp