In [1]:
import os
os.chdir(os.path.pardir)
from dataset.dataset import Dataset
from evaluation_metrics.diversity_metrics import Topic_diversity
from evaluation_metrics.topic_significance_metrics import KL_uniform
from skopt import gp_minimize, forest_minimize, dummy_minimize
from optimization.optimizer import Optimizer
from skopt.space.space import Real, Integer
import multiprocessing as mp
from models import TorchETM
import torch
import numpy as np

In [2]:
dataset = Dataset()
dataset.load("preprocessed_datasets/newsgroup/newsgroup_lemmatized_10")

True

In [3]:
# Load model
model = TorchETM.ETM_Wrapper()

In [4]:
model.hyperparameters['epochs'] = 20
model.hyperparameters['enc_drop'] = 0.1

In [5]:
model.test_set(True)

In [6]:
model.train_model(dataset, model.hyperparameters, top_words= 5)

model: ETM(
  (t_drop): Dropout(p=0.1, inplace=False)
  (theta_act): ReLU()
  (rho): Linear(in_features=300, out_features=1268, bias=False)
  (alphas): Linear(in_features=300, out_features=10, bias=False)
  (q_theta): Sequential(
    (0): Linear(in_features=1268, out_features=800, bias=True)
    (1): ReLU()
    (2): Linear(in_features=800, out_features=800, bias=True)
    (3): ReLU()
  )
  (mu_q_theta): Linear(in_features=800, out_features=10, bias=True)
  (logsigma_q_theta): Linear(in_features=800, out_features=10, bias=True)
)
****************************************************************************************************
Epoch----->1 .. LR: 0.005 .. KL_theta: 0.06 .. Rec_loss: 515.16 .. NELBO: 515.22
****************************************************************************************************
****************************************************************************************************
Epoch----->2 .. LR: 0.005 .. KL_theta: 0.02 .. Rec_loss: 507.0 .. NELBO: 507.02
*

{'topics': [['be', 'moment', 'playoff', 'contain', 'represent'],
  ['volume', 'great', 'convince', 'better', 'encryption'],
  ['security', 'private', 'shipping', 'basically', 'role'],
  ['victim', 'encryption', 'thread', 'turn', 'be'],
  ['encryption', 'thread', 'be', 'turn', 'victim'],
  ['mostly', 'screw', 'son', 'compute', 'attempt'],
  ['victim', 'card', 'generation', 'method', 'past'],
  ['victim', 'turn', 'playoff', 'thread', 'be'],
  ['playoff', 'victim', 'be', 'license', 'school'],
  ['victim', 'license', 'card', 'playoff', 'school']],
 'topic-word-matrix': array([[3.5531379e-04, 9.3680067e-04, 1.5779298e-03, ..., 1.6453606e-03,
         9.1919978e-04, 1.3440547e-03],
        [4.6962657e-04, 8.3435618e-04, 4.8076722e-04, ..., 5.5757654e-04,
         6.7299034e-04, 6.8086328e-04],
        [4.7388798e-04, 1.2528567e-03, 1.8376140e-04, ..., 1.4141962e-04,
         8.2850049e-04, 4.2553875e-04],
        ...,
        [1.3802102e-04, 1.3298292e-03, 1.3895563e-04, ..., 8.9416128e-05,


In [8]:
model.inference()

{'topics': [['be', 'moment', 'playoff', 'contain', 'represent'],
  ['volume', 'great', 'convince', 'better', 'encryption'],
  ['security', 'private', 'shipping', 'basically', 'role'],
  ['victim', 'encryption', 'thread', 'turn', 'be'],
  ['encryption', 'thread', 'be', 'turn', 'victim'],
  ['mostly', 'screw', 'son', 'compute', 'attempt'],
  ['victim', 'card', 'generation', 'method', 'past'],
  ['victim', 'turn', 'playoff', 'thread', 'be'],
  ['playoff', 'victim', 'be', 'license', 'school'],
  ['victim', 'license', 'card', 'playoff', 'school']],
 'topic-word-matrix': array([[3.5531379e-04, 9.3680067e-04, 1.5779298e-03, ..., 1.6453606e-03,
         9.1919978e-04, 1.3440547e-03],
        [4.6962657e-04, 8.3435618e-04, 4.8076722e-04, ..., 5.5757654e-04,
         6.7299034e-04, 6.8086328e-04],
        [4.7388798e-04, 1.2528567e-03, 1.8376140e-04, ..., 1.4141962e-04,
         8.2850049e-04, 4.2553875e-04],
        ...,
        [1.3802102e-04, 1.3298292e-03, 1.3895563e-04, ..., 8.9416128e-05,
