In [1]:
from contextualized_topic_models.models.cotm import COTM
from contextualized_topic_models.utils.data_preparation import to_bow
import os
import json
import numpy as np
import pickle
import torch
from contextualized_topic_models.datasets.dataset import LMTopicDataset

### Load The Data

In [2]:
vocab = os.path.join('../contextualized_topic_models/data/gnews', 'vocab.pkl')
vocab = json.load(open(vocab, 'r'))
idx2token = {v: k for (k, v) in vocab.items()}

vocab_size = len(vocab)
train = np.load(os.path.join('../contextualized_topic_models/data/gnews', 'train.txt.pkl'), encoding='latin1', allow_pickle=True)
train_bow = to_bow(train, vocab_size)

with open("../contextualized_topic_models/data/gnews/bert_embeddings_gnews", "rb") as filino:
    train_bert = pickle.load(filino)


In [3]:
training_data = LMTopicDataset(train_bow, train_bert, idx2token)


In [4]:
cotm = COTM(input_size=vocab_size, bert_input_size=len(train_bert[0]),  inferencetype="contextual",
                n_components=50, model_type="prodLDA",
              hidden_sizes=(100, ), activation='softplus', dropout=0.2,
              learn_priors=True, batch_size=200, lr=2e-3, momentum=0.99,
              solver='adam', num_epochs=2, reduce_on_plateau=False, 
                  num_data_loader_workers=0)

cotm.fit(training_data)


Settings: 
               N Components: 50
               Topic Prior Mean: 0.0
               Topic Prior Variance: 0.98
               Model Type: prodLDA
               Hidden Sizes: (100,)
               Activation: softplus
               Dropout: 0.2
               Learn Priors: True
               Learning Rate: 0.002
               Momentum: 0.99
               Reduce On Plateau: False
               Save Dir: None
Epoch: [1/2]	Samples: [11108/22216]	Train Loss: 92.30986263067722	Time: 0:00:03.105151
Epoch: [2/2]	Samples: [22216/22216]	Train Loss: 81.17217109853259	Time: 0:00:02.737545


In [5]:
cotm.get_topic_lists(5)

[['andreas', 'san', 'flying', 'station', 'mobile'],
 ['black', 'friday', 'deal', 'lumia', 'best'],
 ['seth', 'kanye', 'west', 'james', 'franco'],
 ['men', 'wearhouse', 'jos', 'bank', 'bid'],
 ['murder', 'child', 'daughter', 'guilty', 'hospital'],
 ['beltran', 'group', 'steubenville', 'rape', 'mourinho'],
 ['xbox', 'seek', 'held', 'console', 'spacex'],
 ['storm', 'winter', 'travel', 'lumia', 'weather'],
 ['jennifer', 'change', 'pope', 'rate', 'warns'],
 ['troop', 'central', 'republic', 'france', 'pope'],
 ['heart', 'woman', 'secretly', 'fed', 'treasury'],
 ['lumia', 'nokia', 'ice', 'microsoft', 'motorola'],
 ['december', 'coming', 'andreas', 'grand', 'theft'],
 ['heart', 'envoy', 'author', 'joseph', 'pain'],
 ['salmond', 'seahorse', 'independence', 'arrested', 'whopping'],
 ['week', 'bronco', 'peyton', 'raider', 'loss'],
 ['xbox', 'auto', 'theft', 'microsoft', 'android'],
 ['kardashian', 'german', 'kim', 'reaction', 'spoof'],
 ['kim', 'kanye', 'west', 'kardashian', 'bon'],
 ['china', 'z

### Evaluate the Model

In [6]:
from contextualized_topic_models.evaluation.measures import TopicDiversity, CoherenceNPMI,\
    CoherenceWordEmbeddings,RBO

In [7]:
td = TopicDiversity(cotm.get_topic_lists(25))
td.score(topk=25)


0.5968

In [8]:
rbo = RBO(cotm.get_topic_lists(10))
rbo.score()

0.015902484681825074

[['google', 'nokia', 'glass', 'lumia', 'chromebooks'],
 ['theft', 'andreas', 'grand', 'mobile', 'san'],
 ['white', 'scotland', 'aid', 'church', 'raf'],
 ['nokia', 'window', 'google', 'microsoft', 'lumia'],
 ['birth', 'girl', 'hospital', 'texas', 'nurse'],
 ['andreas', 'san', 'mobile', 'theft', 'grand'],
 ['kanye', 'kim', 'west', 'james', 'kardashian'],
 ['friday', 'black', 'jennifer', 'free', 'launch'],
 ['bronco', 'welker', 'wes', 'cowboy', 'week'],
 ['nigella', 'irs', 'woman', 'child', 'guilty'],
 ['nokia', 'lumia', 'stock', 'att', 'app'],
 ['skyline', 'bridging', 'cosmetic', 'bribe', 'curve'],
 ['storm', 'air', 'typhoon', 'seahawks', 'east'],
 ['berlusconi', 'originally', 'merkel', 'scottish', 'patent'],
 ['jovi', 'independence', 'republic', 'unveils', 'plan'],
 ['minute', 'seahawks', 'storm', 'arizona', 'syria'],
 ['microsoft', 'morning', 'woman', 'nokia', 'effective'],
 ['decimate', 'discontinue', 'seized', 'expletive', 'bridging'],
 ['xbox', 'microsoft', 'disc', 'nokia', 'amber']