CTM

Notebook to load the result.json from a random search / grid search hyperparameter selection section

In [1]:
import pandas as pd
import numpy as np
import seaborn as sns

import json
from pathlib import Path
import sys
from datetime import datetime

sys.path.append('../')
from eval_metrics import compute_inverted_rbo, compute_topic_diversity, compute_pairwise_jaccard_similarity, \
                        METRICS, SEARCH_BEHAVIOUR, COHERENCE_MODEL_METRICS

In [2]:
import platform
import torch

if platform.system() == 'Linux' or platform.system() == 'Windows':
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
else:
    device = torch.device('mps')        # m-series machine

print(device)

cuda


In [3]:
# first load the best model

search_behaviour = SEARCH_BEHAVIOUR.RANDOM_SEARCH
# search_behaviour = SEARCH_BEHAVIOUR.RANDOM_SEARCH

training_datetime = datetime(2024, 1, 23, 0, 21, 11)
training_folder = Path(f'ctm_{search_behaviour.value}_{training_datetime.strftime("%Y%m%d_%H%M%S")}')

run_result_json_path = training_folder.joinpath('result.json')
run_config_json_path = training_folder.joinpath('config.json')

metrics_names = [m.value for m in METRICS]

with open(run_result_json_path) as f:
    run_result = json.load(f)

with open(run_config_json_path) as f:
    run_config = json.load(f)

run_result

{'best_metric': 0.014847039855361576,
 'best_model_checkpoint': 'ctm_random_search_20240123_002111/ctmsb_model_name_or_path_all-mpnet-base-v2_ctm_num_epochs_50_ctm_n_components_20_ctm_hidden_sizes_(200, 200, 200)_cvect_ngram_range_[1, 1]_cvect_max_features_2000',
 'best_hyperparameters': {'sbert_params': {'model_name_or_path': 'all-mpnet-base-v2'},
  'countvect_params': {'ngram_range': [1, 1], 'max_features': 2000},
  'ctm_params': {'dropout': 0.2,
   'lr': 0.002,
   'momentum': 0.99,
   'solver': 'adam',
   'num_epochs': 50,
   'n_components': 20,
   'hidden_sizes': [200, 200, 200],
   'bow_size': 967,
   'contextual_size': 768}},
 'monitor_type': 'c_npmi',
 'log_history': [{'c_npmi': -0.018172467372008744,
   'c_v': 0.45579005185159394,
   'u_mass': -0.05990959260352226,
   'c_uci': -1.3314675636930327,
   'topic_diversity': 0.675,
   'inverted_rbo': 0.9513249844553008,
   'pairwise_jaccard_similarity': 0.034047684703542594,
   'model_name': 'sb_model_name_or_path_all-roberta-large-v

In [6]:
# focus on the log history
# create a dataframe from the log history

log_history = run_result['log_history']

# from: https://www.freecodecamp.org/news/how-to-flatten-a-dictionary-in-python-in-4-different-ways/
from collections.abc import MutableMapping

def _flatten_dict_gen(d, parent_key, sep):
    for k, v in d.items():
        new_key = parent_key + sep + k if parent_key else k
        if isinstance(v, MutableMapping):
            yield from flatten_dict(v, new_key, sep=sep).items()
        else:
            yield new_key, v


def flatten_dict(d: MutableMapping, parent_key: str = '', sep: str = '.'):
    return dict(_flatten_dict_gen(d, parent_key, sep))

# for each dictionary in the log_history list
# convert them to a flattened dictionary
# then append to a list
log_history_flattened = [flatten_dict(log, sep='.') for log in log_history]
log_history_flattened


[{'c_npmi': -0.018172467372008744,
  'c_v': 0.45579005185159394,
  'u_mass': -0.05990959260352226,
  'c_uci': -1.3314675636930327,
  'topic_diversity': 0.675,
  'inverted_rbo': 0.9513249844553008,
  'pairwise_jaccard_similarity': 0.034047684703542594,
  'model_name': 'sb_model_name_or_path_all-roberta-large-v1_ctm_num_epochs_50_ctm_n_components_20_ctm_hidden_sizes_(200, 200, 200)_cvect_ngram_range_[1, 1]_cvect_max_features_1500',
  'hyperparameters.sbert_params.model_name_or_path': 'all-roberta-large-v1',
  'hyperparameters.countvect_params.ngram_range': [1, 1],
  'hyperparameters.countvect_params.max_features': 1500,
  'hyperparameters.ctm_params.dropout': 0.2,
  'hyperparameters.ctm_params.lr': 0.002,
  'hyperparameters.ctm_params.momentum': 0.99,
  'hyperparameters.ctm_params.solver': 'adam',
  'hyperparameters.ctm_params.num_epochs': 50,
  'hyperparameters.ctm_params.n_components': 20,
  'hyperparameters.ctm_params.hidden_sizes': [200, 200, 200],
  'hyperparameters.ctm_params.bow_s

In [7]:
# create a dataframe from the list of flattened dictionaries
log_history_df = pd.DataFrame(log_history_flattened)
log_history_df

Unnamed: 0,c_npmi,c_v,u_mass,c_uci,topic_diversity,inverted_rbo,pairwise_jaccard_similarity,model_name,hyperparameters.sbert_params.model_name_or_path,hyperparameters.countvect_params.ngram_range,hyperparameters.countvect_params.max_features,hyperparameters.ctm_params.dropout,hyperparameters.ctm_params.lr,hyperparameters.ctm_params.momentum,hyperparameters.ctm_params.solver,hyperparameters.ctm_params.num_epochs,hyperparameters.ctm_params.n_components,hyperparameters.ctm_params.hidden_sizes,hyperparameters.ctm_params.bow_size,hyperparameters.ctm_params.contextual_size
0,-0.018172,0.45579,-0.05991,-1.331468,0.675,0.951325,0.034048,sb_model_name_or_path_all-roberta-large-v1_ctm...,all-roberta-large-v1,"[1, 1]",1500,0.2,0.002,0.99,adam,50,20,"[200, 200, 200]",1500,1024
1,-0.031282,0.427796,-0.026496,-1.584586,0.86,0.970501,0.020393,sb_model_name_or_path_all-roberta-large-v1_ctm...,all-roberta-large-v1,"[1, 2]",2000,0.2,0.002,0.99,adam,20,10,"[200, 200]",1227,1024
2,-0.016341,0.443902,-0.019198,-1.180072,0.82,0.965105,0.026472,sb_model_name_or_path_all-mpnet-base-v2_ctm_nu...,all-mpnet-base-v2,"[1, 2]",2500,0.2,0.002,0.99,adam,100,10,"[200, 200]",1227,768
3,-0.022405,0.441994,-0.061465,-1.471422,0.65,0.943072,0.040139,sb_model_name_or_path_all-mpnet-base-v2_ctm_nu...,all-mpnet-base-v2,"[1, 1]",1500,0.2,0.002,0.99,adam,20,20,"[200, 200, 200]",1500,768
4,-0.043527,0.424232,-0.019026,-1.71528,0.78,0.948067,0.033509,sb_model_name_or_path_all-roberta-large-v1_ctm...,all-roberta-large-v1,"[1, 2]",2000,0.2,0.002,0.99,adam,20,10,"[200, 200, 200]",1207,1024
5,0.003896,0.459263,-0.051796,-0.829793,0.66,0.941869,0.041938,sb_model_name_or_path_all-roberta-large-v1_ctm...,all-roberta-large-v1,"[1, 1]",1500,0.2,0.002,0.99,adam,50,20,"[100, 100]",1207,1024
6,-0.001347,0.452369,-0.059107,-0.977359,0.695,0.95399,0.030547,sb_model_name_or_path_all-mpnet-base-v2_ctm_nu...,all-mpnet-base-v2,"[1, 1]",1500,0.2,0.002,0.99,adam,100,20,"[200, 200, 200]",1207,768
7,-0.019223,0.445077,-0.013658,-1.331428,0.78,0.956287,0.036043,sb_model_name_or_path_all-mpnet-base-v2_ctm_nu...,all-mpnet-base-v2,"[1, 2]",2000,0.2,0.002,0.99,adam,100,10,"[100, 100, 100]",1198,768
8,-0.001215,0.450839,-0.018812,-0.748345,0.84,0.966511,0.02232,sb_model_name_or_path_all-roberta-large-v1_ctm...,all-roberta-large-v1,"[1, 2]",1500,0.2,0.002,0.99,adam,100,10,"[100, 100, 100]",974,1024
9,0.006913,0.445368,-0.004576,-0.571444,0.87,0.972233,0.015724,sb_model_name_or_path_all-roberta-large-v1_ctm...,all-roberta-large-v1,"[1, 1]",2500,0.2,0.002,0.99,adam,50,10,"[100, 100]",974,1024


In [8]:
# sort by c_npmi
aaa = log_history_df.sort_values(by='c_npmi', ascending=False)
aaa

Unnamed: 0,c_npmi,c_v,u_mass,c_uci,topic_diversity,inverted_rbo,pairwise_jaccard_similarity,model_name,hyperparameters.sbert_params.model_name_or_path,hyperparameters.countvect_params.ngram_range,hyperparameters.countvect_params.max_features,hyperparameters.ctm_params.dropout,hyperparameters.ctm_params.lr,hyperparameters.ctm_params.momentum,hyperparameters.ctm_params.solver,hyperparameters.ctm_params.num_epochs,hyperparameters.ctm_params.n_components,hyperparameters.ctm_params.hidden_sizes,hyperparameters.ctm_params.bow_size,hyperparameters.ctm_params.contextual_size
26,0.014847,0.456032,-0.050658,-0.470859,0.67,0.941805,0.035003,sb_model_name_or_path_all-mpnet-base-v2_ctm_nu...,all-mpnet-base-v2,"[1, 1]",2000,0.2,0.002,0.99,adam,50,20,"[200, 200, 200]",967,768
14,0.011211,0.461966,-0.03668,-0.628559,0.73,0.949519,0.027837,sb_model_name_or_path_all-mpnet-base-v2_ctm_nu...,all-mpnet-base-v2,"[1, 2]",1500,0.2,0.002,0.99,adam,50,20,"[100, 100]",968,768
40,0.008884,0.459961,-0.046304,-0.671891,0.73,0.954127,0.027605,sb_model_name_or_path_all-roberta-large-v1_ctm...,all-roberta-large-v1,"[1, 2]",2000,0.2,0.002,0.99,adam,50,20,"[200, 200]",967,1024
9,0.006913,0.445368,-0.004576,-0.571444,0.87,0.972233,0.015724,sb_model_name_or_path_all-roberta-large-v1_ctm...,all-roberta-large-v1,"[1, 1]",2500,0.2,0.002,0.99,adam,50,10,"[100, 100]",974,1024
41,0.00658,0.446121,-0.01438,-0.664225,0.87,0.973393,0.017047,sb_model_name_or_path_all-mpnet-base-v2_ctm_nu...,all-mpnet-base-v2,"[1, 2]",2500,0.2,0.002,0.99,adam,50,10,"[200, 200]",967,768
18,0.005789,0.442994,-0.028211,-0.653505,0.695,0.951834,0.03114,sb_model_name_or_path_all-roberta-large-v1_ctm...,all-roberta-large-v1,"[1, 1]",2500,0.2,0.002,0.99,adam,50,20,"[200, 200, 200]",968,1024
38,0.004025,0.451972,-0.052269,-0.758101,0.695,0.950979,0.03025,sb_model_name_or_path_all-mpnet-base-v2_ctm_nu...,all-mpnet-base-v2,"[1, 1]",2500,0.2,0.002,0.99,adam,100,20,"[100, 100]",967,768
5,0.003896,0.459263,-0.051796,-0.829793,0.66,0.941869,0.041938,sb_model_name_or_path_all-roberta-large-v1_ctm...,all-roberta-large-v1,"[1, 1]",1500,0.2,0.002,0.99,adam,50,20,"[100, 100]",1207,1024
49,0.003363,0.464137,-0.046184,-0.926783,0.76,0.965914,0.020351,sb_model_name_or_path_all-roberta-large-v1_ctm...,all-roberta-large-v1,"[1, 2]",2000,0.2,0.002,0.99,adam,100,20,"[200, 200]",967,1024
17,0.003188,0.455758,-0.050997,-0.779757,0.695,0.947664,0.031798,sb_model_name_or_path_all-mpnet-base-v2_ctm_nu...,all-mpnet-base-v2,"[1, 1]",2500,0.2,0.002,0.99,adam,50,20,"[100, 100, 100]",968,768


In [12]:
aaa[aaa['hyperparameters.sbert_params.model_name_or_path'] == 'all-MiniLM-L12-v1']

Unnamed: 0,c_npmi,c_v,u_mass,c_uci,topic_diversity,inverted_rbo,pairwise_jaccard_similarity,model_name,hyperparameters.sbert_params.model_name_or_path,hyperparameters.vocab_tokenizer_params.ngram_range,...,hyperparameters.umap_params.low_memory,hyperparameters.umap_params.random_state,hyperparameters.hdbscan_params.metric,hyperparameters.hdbscan_params.prediction_data,hyperparameters.hdbscan_params.min_samples,hyperparameters.hdbscan_params.min_cluster_size,hyperparameters.bertopic_params.language,hyperparameters.bertopic_params.calculate_probabilities,hyperparameters.bertopic_params.top_n_words,hyperparameters.bertopic_params.nr_topics
4,0.137412,0.713508,-0.226645,1.23854,0.718421,0.961122,0.035787,sb_model_name_or_path_all-MiniLM-L12-v1_hs_min...,all-MiniLM-L12-v1,"[1, 2]",...,False,,euclidean,True,10,90,english,True,10,41
1,0.111151,0.744181,-0.152728,0.891481,0.605085,0.945372,0.060758,sb_model_name_or_path_all-MiniLM-L12-v1_hs_min...,all-MiniLM-L12-v1,"[1, 2]",...,False,,euclidean,True,40,180,english,True,20,61
10,0.110444,0.719416,-0.227431,0.941981,0.680769,0.977346,0.024298,sb_model_name_or_path_all-MiniLM-L12-v1_hs_min...,all-MiniLM-L12-v1,"[1, 2]",...,False,,euclidean,True,30,120,english,True,20,81
9,0.100998,0.761304,-0.16013,0.824367,0.674359,0.971981,0.031359,sb_model_name_or_path_all-MiniLM-L12-v1_hs_min...,all-MiniLM-L12-v1,"[1, 2]",...,False,,euclidean,True,50,180,english,True,30,81
23,0.100161,0.719013,-0.165695,0.755407,0.636735,0.953011,0.04874,sb_model_name_or_path_all-MiniLM-L12-v1_hs_min...,all-MiniLM-L12-v1,"[1, 2]",...,False,,euclidean,True,50,120,english,True,20,51
7,0.099775,0.740856,-0.177363,0.68602,0.575862,0.913845,0.091492,sb_model_name_or_path_all-MiniLM-L12-v1_hs_min...,all-MiniLM-L12-v1,"[1, 2]",...,False,,euclidean,True,20,200,english,True,20,31
22,0.092475,0.629878,-0.465147,0.468985,0.84,0.984671,0.013068,sb_model_name_or_path_all-MiniLM-L12-v1_hs_min...,all-MiniLM-L12-v1,"[1, 2]",...,False,,euclidean,True,10,30,english,True,10,31
0,0.092217,0.753588,-0.150548,0.645419,0.60625,0.951025,0.053354,sb_model_name_or_path_all-MiniLM-L12-v1_hs_min...,all-MiniLM-L12-v1,"[1, 2]",...,False,,euclidean,True,30,180,english,True,30,51
5,0.091976,0.753699,-0.143123,0.690079,0.657303,0.976477,0.025502,sb_model_name_or_path_all-MiniLM-L12-v1_hs_min...,all-MiniLM-L12-v1,"[1, 2]",...,False,,euclidean,True,50,150,english,True,30,91
20,0.091074,0.745403,-0.15265,0.678093,0.668116,0.968834,0.033474,sb_model_name_or_path_all-MiniLM-L12-v1_hs_min...,all-MiniLM-L12-v1,"[1, 2]",...,False,,euclidean,True,50,150,english,True,30,71


In [7]:
# only show the columns within the search_space
search_space = run_config['search_space']

# flatten the dictionary
search_space_flattened = flatten_dict({'hyperparameters': search_space}, sep='.')
search_space_flattened

{'hyperparameters.sbert_params.model_name_or_path': ['all-MiniLM-L6-v2',
  'all-mpnet-base-v2'],
 'hyperparameters.ctm_params.n_components': [200,
  190,
  180,
  170,
  160,
  150,
  140,
  130,
  120,
  110,
  100,
  90,
  80,
  70,
  60,
  50,
  40,
  30,
  20,
  10]}

In [9]:
bbb = log_history_df[['model_name'] + metrics_names + list(search_space_flattened.keys()) + ['hyperparameters.ctm_params.contextual_size']]
bbb

Unnamed: 0,model_name,u_mass,c_v,c_uci,c_npmi,topic_diversity,inverted_rbo,pairwise_jaccard_similarity,hyperparameters.sbert_params.model_name_or_path,hyperparameters.ctm_params.n_components,hyperparameters.ctm_params.contextual_size
0,ctm_n_components_200_sb_model_name_or_path_all...,-0.210015,0.441269,-0.593927,0.003249,0.224,0.92807,0.047736,all-MiniLM-L6-v2,200,384
1,ctm_n_components_200_sb_model_name_or_path_all...,-0.177767,0.44234,-0.551263,0.005855,0.2295,0.929084,0.047279,all-mpnet-base-v2,200,768
2,ctm_n_components_190_sb_model_name_or_path_all...,-0.227364,0.438791,-0.600537,0.002815,0.238947,0.927291,0.047078,all-MiniLM-L6-v2,190,384
3,ctm_n_components_190_sb_model_name_or_path_all...,-0.246302,0.446224,-0.64808,0.002696,0.235789,0.927232,0.049586,all-mpnet-base-v2,190,768
4,ctm_n_components_180_sb_model_name_or_path_all...,-0.190994,0.442974,-0.687146,-0.000269,0.244444,0.927682,0.048774,all-MiniLM-L6-v2,180,384
5,ctm_n_components_180_sb_model_name_or_path_all...,-0.172035,0.4517,-0.399257,0.012611,0.25,0.932978,0.043026,all-mpnet-base-v2,180,768
6,ctm_n_components_170_sb_model_name_or_path_all...,-0.277832,0.435596,-0.593923,0.003138,0.257059,0.931344,0.044774,all-MiniLM-L6-v2,170,384
7,ctm_n_components_170_sb_model_name_or_path_all...,-0.199715,0.445991,-0.48415,0.007661,0.242941,0.923395,0.050846,all-mpnet-base-v2,170,768
8,ctm_n_components_160_sb_model_name_or_path_all...,-0.146649,0.444822,-0.497336,0.006657,0.27125,0.931698,0.045666,all-MiniLM-L6-v2,160,384
9,ctm_n_components_160_sb_model_name_or_path_all...,-0.13662,0.451732,-0.579712,0.004409,0.265625,0.932727,0.045437,all-mpnet-base-v2,160,768


In [10]:
bbb.sort_values(by='c_npmi', ascending=False)

Unnamed: 0,model_name,u_mass,c_v,c_uci,c_npmi,topic_diversity,inverted_rbo,pairwise_jaccard_similarity,hyperparameters.sbert_params.model_name_or_path,hyperparameters.ctm_params.n_components,hyperparameters.ctm_params.contextual_size
33,ctm_n_components_40_sb_model_name_or_path_all-...,-0.04796,0.488311,-0.366861,0.021815,0.735,0.978712,0.014988,all-mpnet-base-v2,40,768
28,ctm_n_components_60_sb_model_name_or_path_all-...,-0.044538,0.479842,-0.265353,0.021325,0.558333,0.969323,0.021913,all-MiniLM-L6-v2,60,384
30,ctm_n_components_50_sb_model_name_or_path_all-...,-0.061448,0.488163,-0.345286,0.021288,0.65,0.97774,0.017008,all-MiniLM-L6-v2,50,384
31,ctm_n_components_50_sb_model_name_or_path_all-...,-0.060634,0.486787,-0.360129,0.020206,0.62,0.975749,0.018579,all-mpnet-base-v2,50,768
32,ctm_n_components_40_sb_model_name_or_path_all-...,-0.066008,0.492589,-0.416398,0.019232,0.7725,0.984964,0.011265,all-MiniLM-L6-v2,40,384
29,ctm_n_components_60_sb_model_name_or_path_all-...,-0.066133,0.478115,-0.359311,0.017579,0.568333,0.968922,0.023519,all-mpnet-base-v2,60,768
25,ctm_n_components_80_sb_model_name_or_path_all-...,-0.079816,0.464806,-0.348854,0.016602,0.4625,0.958413,0.029147,all-mpnet-base-v2,80,768
24,ctm_n_components_80_sb_model_name_or_path_all-...,-0.062991,0.47228,-0.358971,0.01569,0.4675,0.960692,0.02665,all-MiniLM-L6-v2,80,384
22,ctm_n_components_90_sb_model_name_or_path_all-...,-0.075025,0.462988,-0.369974,0.014788,0.422222,0.949589,0.033196,all-MiniLM-L6-v2,90,384
23,ctm_n_components_90_sb_model_name_or_path_all-...,-0.089386,0.462538,-0.408161,0.013249,0.412222,0.951559,0.032249,all-mpnet-base-v2,90,768


---

Read the topic keywords and representative sentences

Need to read the corpus b4 hand

In [8]:
import pandas as pd
import numpy as np


from contextualized_topic_models.models.ctm import CombinedTM
from contextualized_topic_models.utils.data_preparation import TopicModelDataPreparation
# from contextualized_topic_models.utils.preprocessing import WhiteSpacePreprocessingStopwords

import nltk
import os

from pathlib import Path
import json
import sys
from datetime import datetime

import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"          # disable huggingface warning

dataset_path = Path('../../dataset/topic_modelling/top_11_genres/01_Indie.pkl')

dataset = pd.read_pickle(dataset_path)

sys.path.append('../../sa/')
import str_cleaning_functions

# copied from lda_demo_gridsearch.ipynb
def cleaning(df, review):
    df[review] = df[review].apply(lambda x: str_cleaning_functions.remove_links(x))
    df[review] = df[review].apply(lambda x: str_cleaning_functions.remove_links2(x))
    df[review] = df[review].apply(lambda x: str_cleaning_functions.clean(x))
    df[review] = df[review].apply(lambda x: str_cleaning_functions.deEmojify(x))
    df[review] = df[review].apply(lambda x: str_cleaning_functions.remove_non_letters(x))
    df[review] = df[review].apply(lambda x: x.lower())
    df[review] = df[review].apply(lambda x: str_cleaning_functions.unify_whitespaces(x))
    df[review] = df[review].apply(lambda x: str_cleaning_functions.remove_stopword(x))
    df[review] = df[review].apply(lambda x: str_cleaning_functions.unify_whitespaces(x))

def cleaning_little(df, review):
    df[review] = df[review].apply(lambda x: str_cleaning_functions.remove_links(x))
    df[review] = df[review].apply(lambda x: str_cleaning_functions.remove_links2(x))
    df[review] = df[review].apply(lambda x: str_cleaning_functions.clean(x))
    df[review] = df[review].apply(lambda x: str_cleaning_functions.deEmojify(x))
    df[review] = df[review].apply(lambda x: str_cleaning_functions.unify_whitespaces(x))



dataset_preprocessed = dataset.copy()
cleaning(dataset_preprocessed, 'review_text')
cleaning_little(dataset, 'review_text')

X_preprocessed = dataset_preprocessed['review_text'].values
X = dataset['review_text'].values

from nltk.stem import WordNetLemmatizer
from nltk.corpus import wordnet

lemma = WordNetLemmatizer()

# from https://stackoverflow.com/questions/25534214/nltk-wordnet-lemmatizer-shouldnt-it-lemmatize-all-inflections-of-a-word

# from: https://www.cnblogs.com/jclian91/p/9898511.html
def get_wordnet_pos(tag):
    if tag.startswith('J'):
        return wordnet.ADJ
    elif tag.startswith('V'):
        return wordnet.VERB
    elif tag.startswith('N'):
        return wordnet.NOUN
    elif tag.startswith('R'):
        return wordnet.ADV
    else:
        return None     # if none -> created as noun by wordnet
    
def lemmatization(text):
   # use nltk to get PoS tag
    tagged = nltk.pos_tag(nltk.word_tokenize(text))

    # then we only need adj, adv, verb, noun
    # convert from nltk Penn Treebank tag to wordnet tag
    wn_tagged = list(map(lambda x: (x[0], get_wordnet_pos(x[1])), tagged))

    # lemmatize by the PoS
    lemmatized = list(map(lambda x: lemma.lemmatize(x[0], pos=x[1] if x[1] else wordnet.NOUN), wn_tagged))
    # lemma.lemmatize(wn_tagged[0], pos=wordnet.NOUN)

    return lemmatized

X_preprocessed = list(map(lambda x: lemmatization(x), X_preprocessed))
X_preprocessed = list(map(lambda x: ' '.join(x), X_preprocessed))

# prepare dataset for topic modeling

In [4]:
def _load_ctm_model(model_checkpoint:Path, ctm_params:dict):

    model_path = [p for p in model_checkpoint.iterdir() if p.is_dir()][-1]        # get the last dir (since there 's only one dir inside) -> get the only dir

    # get the first file in the dir
    epoch_file = [p for p in model_path.iterdir() if p.is_file()][0]
    epoch_num = int(epoch_file.stem.split('_')[-1])

    if 'hidden_sizes' in ctm_params:
        ctm_params['hidden_sizes'] = tuple(ctm_params['hidden_sizes'])

    ctm = CombinedTM(**ctm_params)

    ctm.load(model_path, epoch_num)

    return ctm

In [5]:
best_model_path = run_result['best_model_checkpoint']
best_model_hyperparameters = run_result['best_hyperparameters']

# load the best model
from contextualized_topic_models.models.ctm import CombinedTM

best_model_hyperparameters['ctm_params']['hidden_sizes'] = tuple(best_model_hyperparameters['ctm_params']['hidden_sizes'])

ctm = _load_ctm_model(Path(best_model_path), best_model_hyperparameters['ctm_params'])

  from .autonotebook import tqdm as notebook_tqdm


In [6]:
ctm.get_topics()

defaultdict(list,
            {0: ['love',
              'great',
              'recommend',
              'everyone',
              'anyone',
              'highly',
              'recomend',
              'reccomend',
              'amazing',
              'music'],
             1: ['fix',
              'please',
              'computer',
              'work',
              'mac',
              'play',
              'love',
              'im',
              'cant',
              'help'],
             2: ['addictive',
              'awesome',
              'sort',
              'pick',
              'maybe',
              'either',
              'simple',
              'nice',
              'begin',
              'whole'],
             3: ['game',
              'terrarium',
              'content',
              'update',
              'hour',
              'release',
              'developer',
              'one',
              'new',
              'time'],
             4: ['pretty',

In [7]:
# get representative docs

for i in range(ctm.num_topics):
    print(f'Topic {i}')
    print(ctm.get_topic_lists(i, 10))

TypeError: get_top_documents_per_topic_id() missing 3 required positional arguments: 'unpreprocessed_corpus', 'document_topic_distributions', and 'topic_id'