In [1]:
from contextualized_topic_models.models.lmavitm import LMAVITM
from contextualized_topic_models.utils.data_preparation import to_bow
import os
import json
import numpy as np
import pickle
import torch
from contextualized_topic_models.datasets.dataset import LMTopicDataset

### Load The Data

In [3]:
vocab = os.path.join('../contextualized_topic_models/data/gnews', 'vocab.pkl')
vocab = json.load(open(vocab, 'r'))
idx2token = {v: k for (k, v) in vocab.items()}

vocab_size = len(vocab)
train = np.load(os.path.join('../contextualized_topic_models/data/gnews', 'train.txt.pkl'), encoding='latin1', allow_pickle=True)
train_bow = to_bow(train, vocab_size)

with open("../contextualized_topic_models/data/gnews/bert_embeddings_gnews", "rb") as filino:
    train_bert = pickle.load(filino)


In [4]:
training_data = LMTopicDataset(train_bow, train_bert, idx2token)


In [8]:
lmavitm = LMAVITM(input_size=vocab_size, bert_input_size=len(train_bert[0]),  inferencetype="contextual",
                n_components=50, model_type="prodLDA",
              hidden_sizes=(100, ), activation='softplus', dropout=0.2,
              learn_priors=True, batch_size=200, lr=2e-3, momentum=0.99,
              solver='adam', num_epochs=100, reduce_on_plateau=False)

lmavitm.fit(training_data)


Settings: 
               N Components: 50
               Topic Prior Mean: 0.0
               Topic Prior Variance: 0.98
               Model Type: prodLDA
               Hidden Sizes: (100,)
               Activation: softplus
               Dropout: 0.2
               Learn Priors: True
               Learning Rate: 0.002
               Momentum: 0.99
               Reduce On Plateau: False
               Save Dir: None
Epoch: [1/100]	Samples: [11108/1110800]	Train Loss: 92.68798731418235	Time: 0:00:01.647715
Epoch: [2/100]	Samples: [22216/1110800]	Train Loss: 80.98612715075846	Time: 0:00:01.620539
Epoch: [3/100]	Samples: [33324/1110800]	Train Loss: 75.86123634148812	Time: 0:00:01.620759
Epoch: [4/100]	Samples: [44432/1110800]	Train Loss: 72.60012332750046	Time: 0:00:01.622098
Epoch: [5/100]	Samples: [55540/1110800]	Train Loss: 70.16168988742629	Time: 0:00:01.615312
Epoch: [6/100]	Samples: [66648/1110800]	Train Loss: 68.47249703549807	Time: 0:00:01.712650
Epoch: [7/100]	Samples: [77

Epoch: [85/100]	Samples: [944180/1110800]	Train Loss: 60.555617994913575	Time: 0:00:01.657892
Epoch: [86/100]	Samples: [955288/1110800]	Train Loss: 60.31921703213619	Time: 0:00:01.651721
Epoch: [87/100]	Samples: [966396/1110800]	Train Loss: 60.590828840278405	Time: 0:00:01.632540
Epoch: [88/100]	Samples: [977504/1110800]	Train Loss: 60.28801789110326	Time: 0:00:01.656925
Epoch: [89/100]	Samples: [988612/1110800]	Train Loss: 60.28068448521055	Time: 0:00:01.663602
Epoch: [90/100]	Samples: [999720/1110800]	Train Loss: 60.29858714932369	Time: 0:00:01.667879
Epoch: [91/100]	Samples: [1010828/1110800]	Train Loss: 60.41148183600783	Time: 0:00:01.700559
Epoch: [92/100]	Samples: [1021936/1110800]	Train Loss: 60.532443229612106	Time: 0:00:01.733382
Epoch: [93/100]	Samples: [1033044/1110800]	Train Loss: 60.49505933927124	Time: 0:00:01.668005
Epoch: [94/100]	Samples: [1044152/1110800]	Train Loss: 60.51139240862442	Time: 0:00:01.753200
Epoch: [95/100]	Samples: [1055260/1110800]	Train Loss: 60.39587

In [9]:
lmavitm.get_topics(5)

defaultdict(list,
            {0: ['tax', 'irs', 'group', 'coalition', 'party'],
             1: ['nsa', 'porn', 'habit', 'online', 'drug'],
             2: ['light', 'hanukkah', 'rare', 'ison', 'comet'],
             3: ['independence',
              'scotland',
              'scottish',
              'independent',
              'salmond'],
             4: ['review', 'homefront', 'frozen', 'movie', 'disney'],
             5: ['baldwin', 'alec', 'change', 'msnbc', 'climate'],
             6: ['kobe', 'lakers', 'bryant', 'extension', 'contract'],
             7: ['oldboy', 'spike', 'lee', 'remake', 'icahn'],
             8: ['rise', 'risk', 'hewlett', 'packard', 'price'],
             9: ['nuclear', 'deal', 'security', 'project', 'israel'],
             10: ['china', 'air', 'zone', 'defense', 'sea'],
             11: ['attack', 'kill', 'killed', 'texas', 'nurse'],
             12: ['golden', 'won', 'frozen', 'globe', 'star'],
             13: ['star', 'win', 'duck', 'dancing', 'live'],