Demo ipynb for CTM (hyperparameters grid/random search)

Combined TM

In [1]:
import pandas as pd
import numpy as np


from contextualized_topic_models.models.ctm import CombinedTM
from contextualized_topic_models.utils.data_preparation import TopicModelDataPreparation
# from contextualized_topic_models.utils.preprocessing import WhiteSpacePreprocessingStopwords

import nltk

from pathlib import Path
import json
from datetime import datetime

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
dataset_path = Path('../../dataset/topic_modelling/top_10_games/00_Terraria.pkl')

dataset = pd.read_pickle(dataset_path)

dataset.info(verbose=True)

<class 'pandas.core.frame.DataFrame'>
Index: 75499 entries, 57735 to 133233
Data columns (total 6 columns):
 #   Column        Non-Null Count  Dtype 
---  ------        --------------  ----- 
 0   index         75499 non-null  int64 
 1   app_id        75499 non-null  int64 
 2   app_name      75499 non-null  object
 3   review_text   75499 non-null  object
 4   review_score  75499 non-null  int64 
 5   review_votes  75499 non-null  int64 
dtypes: int64(4), object(2)
memory usage: 4.0+ MB


In [3]:
%load_ext autoreload

In [4]:
# data preprocessing

import sys
sys.path.append('../../sa/')

%autoreload 2
import str_cleaning_functions

# copied from lda_demo_gridsearch.ipynb
def cleaning(df, review):
    df[review] = df[review].apply(lambda x: str_cleaning_functions.remove_links(x))
    df[review] = df[review].apply(lambda x: str_cleaning_functions.remove_links2(x))
    df[review] = df[review].apply(lambda x: str_cleaning_functions.clean(x))
    df[review] = df[review].apply(lambda x: str_cleaning_functions.deEmojify(x))
    df[review] = df[review].apply(lambda x: str_cleaning_functions.remove_non_letters(x))
    df[review] = df[review].apply(lambda x: x.lower())
    df[review] = df[review].apply(lambda x: str_cleaning_functions.unify_whitespaces(x))
    df[review] = df[review].apply(lambda x: str_cleaning_functions.remove_stopword(x))
    df[review] = df[review].apply(lambda x: str_cleaning_functions.unify_whitespaces(x))

# def cleaning_strlist(str_list):
#     str_list = list(map(lambda x: clean(x), str_list))
#     str_list = list(map(lambda x: deEmojify(x), str_list))

#     str_list = list(map(lambda x: x.lower(), str_list))
#     str_list = list(map(lambda x: remove_num(x), str_list))
#     str_list = list(map(lambda x: unify_whitespaces(x), str_list))

#     str_list = list(map(lambda x: _deaccent(x), str_list))
#     str_list = list(map(lambda x: remove_non_alphabets(x), str_list))
#     str_list = list(map(lambda x: remove_stopword(x), str_list))
#     return str_list

# copied from bert_demo_gridsearch.ipynb
def cleaning_little(df, review):
    df[review] = df[review].apply(lambda x: str_cleaning_functions.remove_links(x))
    df[review] = df[review].apply(lambda x: str_cleaning_functions.remove_links2(x))
    df[review] = df[review].apply(lambda x: str_cleaning_functions.clean(x))
    df[review] = df[review].apply(lambda x: str_cleaning_functions.deEmojify(x))
    df[review] = df[review].apply(lambda x: str_cleaning_functions.unify_whitespaces(x))


In [5]:
# create a copy of the dataset, as we need both untouched text and cleaned text

dataset_preprocessed = dataset.copy()

In [6]:
cleaning(dataset_preprocessed, 'review_text')


cleaning_little(dataset, 'review_text')

In [7]:
X_preprocessed = dataset_preprocessed['review_text'].values
X = dataset['review_text'].values

Apply lemmatizing to the preprocessed dataset as well (for BoW)

In [8]:
# do lemmatization, but not stemming (as part of speech is important in topic modelling)
# use nltk wordnet for lemmatization

from nltk.stem import WordNetLemmatizer
from nltk.corpus import wordnet

lemma = WordNetLemmatizer()

# from https://stackoverflow.com/questions/25534214/nltk-wordnet-lemmatizer-shouldnt-it-lemmatize-all-inflections-of-a-word

# from: https://www.cnblogs.com/jclian91/p/9898511.html
def get_wordnet_pos(tag):
    if tag.startswith('J'):
        return wordnet.ADJ
    elif tag.startswith('V'):
        return wordnet.VERB
    elif tag.startswith('N'):
        return wordnet.NOUN
    elif tag.startswith('R'):
        return wordnet.ADV
    else:
        return None     # if none -> created as noun by wordnet
    
def lemmatization(text):
   # use nltk to get PoS tag
    tagged = nltk.pos_tag(nltk.word_tokenize(text))

    # then we only need adj, adv, verb, noun
    # convert from nltk Penn Treebank tag to wordnet tag
    wn_tagged = list(map(lambda x: (x[0], get_wordnet_pos(x[1])), tagged))

    # lemmatize by the PoS
    lemmatized = list(map(lambda x: lemma.lemmatize(x[0], pos=x[1] if x[1] else wordnet.NOUN), wn_tagged))
    # lemma.lemmatize(wn_tagged[0], pos=wordnet.NOUN)

    return lemmatized

In [9]:
X_preprocessed = list(map(lambda x: lemmatization(x), X_preprocessed))
X_preprocessed = list(map(lambda x: ' '.join(x), X_preprocessed))

Training

In [24]:
# copy from: https://github.com/MilaNLProc/contextualized-topic-models/blob/master/contextualized_topic_models/utils/data_preparation.py#L44
# call bert_embeddings_from_list() to produce embeddings by ourself

import warnings
from sentence_transformers import SentenceTransformer
import torch
import platform


if platform.system() == 'Linux' or platform.system() == 'Windows':
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
else:
    device = torch.device('mps')        # m-series mac machine

print(device)

def bert_embeddings_from_list(
    texts, 
    model_name_or_path, 
    batch_size=32, 
    max_seq_length=None,            # 128 is the default valule in TopicModelDataPreparation() init. Passing none to use the default value of each model
    device='cpu'):
    """
    Creates SBERT Embeddings from a list
    """

    model = SentenceTransformer(model_name_or_path, device=device)

    if max_seq_length is not None:
        model.max_seq_length = max_seq_length
    else:
        max_seq_length = model.max_seq_length

    check_max_local_length(max_seq_length, texts)

    return np.array(model.encode(texts, batch_size=batch_size, show_progress_bar=True))


def check_max_local_length(max_seq_length, texts):
    max_local_length = np.max([len(t.split()) for t in texts])
    if max_local_length > max_seq_length:
        warnings.simplefilter("always", DeprecationWarning)
        warnings.warn(
            f"the longest document in your collection has {max_local_length} words, the model instead "
            f"truncates to {max_seq_length} tokens."
        )

mps


In [11]:
from gensim.models import CoherenceModel
from copy import deepcopy

from sklearn.model_selection import ParameterGrid, ParameterSampler

sys.path.append('../')

from eval_metrics import compute_inverted_rbo, compute_topic_diversity, compute_pairwise_jaccard_similarity, \
                        METRICS, SEARCH_BEHAVIOUR, COHERENCE_MODEL_METRICS

In [12]:
# init params

def _init_count_vectorizer_params(
        max_features=2000,
        ngram_range=(1,1)
):
    params_dict = {}
    params_dict['max_features'] = max_features
    params_dict['ngram_range'] = ngram_range

    return params_dict

def _init_sbert_params(
    model_name_or_path='all-mpnet-base-v2'
):
    params_dict = {}
    params_dict['model_name_or_path'] = model_name_or_path

    return params_dict

# params are copied from source code of CTM: https://github.com/MilaNLProc/contextualized-topic-models/blob/master/contextualized_topic_models/models/ctm.py#L131
# commented params are params that has no plan on fine-tuning them (not significant to our project)
def _init_ctm_params(
        # bow_size,
        # contextual_size,
        # inference_type="combined",
        n_components=10,
        # model_type="prodLDA",
        hidden_sizes=(100, 100),
        # activation="softplus",
        dropout=0.2,
        # learn_priors=True,
        # batch_size=64,
        lr=2e-3,
        momentum=0.99,
        solver="adam",
        # num_epochs=100,
        # reduce_on_plateau=False,
        # num_data_loader_workers=mp.cpu_count(),
        # label_size=0,
        # loss_weights=None
):
    params_dict = {}
    # params_dict['bow_size'] = bow_size                        # decided by the count vectorizer params (max_features)
    # params_dict['contextual_size'] = contextual_size          # decided by the sbert model
    # params_dict['inference_type'] = inference_type
    params_dict['n_components'] = n_components
    # params_dict['model_type'] = model_type
    params_dict['hidden_sizes'] = hidden_sizes
    # params_dict['activation'] = activation
    params_dict['dropout'] = dropout
    # params_dict['learn_priors'] = learn_priors
    # params_dict['batch_size'] = batch_size
    params_dict['lr'] = lr
    params_dict['momentum'] = momentum
    params_dict['solver'] = solver

    return params_dict

In [13]:
def _init_config_dict(config_path:Path, model_name:str, hyperparameters:dict, search_space_dict:dict, 
                      metrics:list[METRICS], monitor:METRICS,
                      search_behaviour:SEARCH_BEHAVIOUR, search_rs:int, search_n_iter:int):
    
    if not config_path.exists():
        config = {}

        sbert_params = _init_sbert_params(**hyperparameters['sbert_params'])
        countvect_params = _init_count_vectorizer_params(**hyperparameters['countvect_params'])
        ctm_params = _init_ctm_params(**hyperparameters['ctm_params'])

        config['model'] = model_name
        config['sbert_params'] = sbert_params
        config['countvect_params'] = countvect_params
        config['ctm_params'] = ctm_params

        if 'sbert_params' in search_space_dict:
            for k in search_space_dict['sbert_params'].keys():
                sbert_params.pop(k, '')     # add a default value to avoid key error
        if 'countvect_params' in search_space_dict:
            for k in search_space_dict['countvect_params'].keys():
                countvect_params.pop(k, '')
        if 'ctm_params' in search_space_dict:
            for k in search_space_dict['ctm_params'].keys():
                ctm_params.pop(k, '')

        config['search_space'] = search_space_dict
        config['metrics'] = list(map(lambda x: x.value, metrics))
        config['monitor'] = monitor.value

        config['search_behaviour'] = search_behaviour.value
        if search_behaviour == SEARCH_BEHAVIOUR.RANDOM_SEARCH:
            config['search_rs'] = search_rs
            config['search_n_iter'] = search_n_iter

        with open(config_path, 'w') as f:
            json.dump(config, f, indent=2)

        print('Created config file at {}'.format(config_path))
    else:
        with open(config_path, 'r') as f:
            config = json.load(f)

        # check whether the input params are consistent with the config file
        assert config['model'] == model_name, 'input model_name is not consistent with the config["model"]'
        assert config['metrics'] == list(map(lambda x: x.value, metrics)), 'input metrics is not consistent with config["metrics"]'
        assert config['monitor'] == monitor.value, 'input monitor is not consistent with config["monitor"]'
        assert config['search_behaviour'] == search_behaviour.value, 'input search_behaviour is not consistent with config["search_behaviour"]'
        if search_behaviour == SEARCH_BEHAVIOUR.RANDOM_SEARCH:
            assert config['search_rs'] == search_rs, 'input search_rs is not consistent with config["search_rs"]'
            assert config['search_n_iter'] == search_n_iter, 'input search_n_iter is not consistent with config["search_n_iter"]'

        # check whether the hyperparameters are consistent with the config file
        sbert_params = _init_sbert_params(**hyperparameters['sbert_params'])
        countvect_params = _init_count_vectorizer_params(**hyperparameters['countvect_params'])
        ctm_params = _init_ctm_params(**hyperparameters['ctm_params'])

        assert config['sbert_params'].keys() <= sbert_params.keys(), 'existing config["sbert_params"] contains additional hyperparameters'
        assert config['countvect_params'].keys() <= countvect_params.keys(), 'existing config["countvect_params"] contains additional hyperparameters'
        assert config['ctm_params'].keys() <= ctm_params.keys(), 'existing config["ctm_params"] contains additional hyperparameters'

        for key in sbert_params.keys() & config['sbert_params'].keys():
            assert sbert_params[key] == config['sbert_params'][key], 'existing config["sbert_params"] contains different hyperparameters'
        for key in countvect_params.keys() & config['countvect_params'].keys():
            assert countvect_params[key] == config['countvect_params'][key], 'existing config["countvect_params"] contains different hyperparameters'
        for key in ctm_params.keys() & config['ctm_params'].keys():
            assert ctm_params[key] == config['ctm_params'][key], 'existing config["ctm_params"] contains different hyperparameters'

        # check whether the search_space is consistent with the config file
        if 'sbert_params' in config['search_space']:
            assert config['search_space']['sbert_params'].keys() == search_space_dict['sbert_params'].keys(), 'input search_space_dict["sbert_params"] contains different hyperparameter keys than existing config["search_space"]["sbert_params"]'
            for k in search_space_dict['sbert_params'].keys():
                assert k in config['search_space']['sbert_params'], f'input search_space_dict["sbert_params"]["{key}"] contains value than existing config["search_space"]["sbert_params"]["{key}"]'
        if 'countvect_params' in config['search_space']:
            assert config['search_space']['countvect_params'].keys() == search_space_dict['countvect_params'].keys(), 'input search_space_dict["countvect_params"] contains different hyperparameter keys than existing config["search_space"]["countvect_params"]'
            for k in search_space_dict['countvect_params'].keys():
                assert k in config['search_space']['countvect_params'], f'input search_space_dict["countvect_params"]["{key}"] contains value than existing config["search_space"]["countvect_params"]["{key}"]'
        if 'ctm_params' in config['search_space']:
            assert config['search_space']['ctm_params'].keys() == search_space_dict['ctm_params'].keys(), 'input search_space_dict["ctm_params"] contains different hyperparameter keys than existing config["search_space"]["ctm_params"]'
            for k in search_space_dict['ctm_params'].keys():
                assert k in config['search_space']['ctm_params'], f'input search_space_dict["ctm_params"]["{key}"] contains value than existing config["search_space"]["ctm_params"]["{key}"]'
        
        print('Loaded existing config file from {}'.format(config_path))
        print('Hyperparameters and search space are consistent with the input parameters')

    return config


In [14]:
def _init_result_dict(result_path:Path, monitor_type:str):
    if not result_path.exists():
        result = {}

        result['best_metric'] = -float('inf')
        result['best_model_checkpoint'] = ""
        result['best_hyperparameters'] = dict()
        result["monitor_type"] = monitor_type
        result["log_history"] = list()

    else:
        with open(result_path, 'r') as f:
            result = json.load(f)

        assert result['monitor_type'] == monitor_type

        print('Loaded existing result file from {}'.format(result_path))
    
    return result

In [15]:
def _load_ctm_model(model_path:Path, ctm_params:dict):
    ctm = CombinedTM(**ctm_params)

    ctm.load(model_path)

    return ctm

In [16]:
def _get_topics(ctm, k=10):
    return ctm.get_topic_lists(k)

def _get_topic_word_metrix(ctm):
    return ctm.get_topic_word_distribution()

# ref: https://contextualized-topic-models.readthedocs.io/en/latest/readme.html (go to the section: Mono-Lingual Topic Modeling)
# testing_dataset = qt.transform(text_for_contextual=testing_text_for_contextual, text_for_bow=testing_text_for_bow)
# # n_sample how many times to sample the distribution (see the doc)
# ctm.get_doc_topic_distribution(testing_dataset, n_samples=20) # returns a (n_documents, n_topics) matrix with the topic distribution of each document
def _get_topic_document_metrix(ctm, dataset, n_samples=20):
    return ctm.get_doc_topic_distribution(dataset, n_samples=n_samples).T

In [26]:
from gensim import corpora
from sklearn.feature_extraction.text import CountVectorizer

def model_search(text_for_contextual, text_for_bow, hyperparameters:dict, search_space:dict, save_folder:Path,
                 metrics:list[METRICS]=[METRICS.C_NPMI], monitor:METRICS=METRICS.C_NPMI, 
                 save_each_models=True, run_from_checkpoints=False,
                 search_behaviour=SEARCH_BEHAVIOUR.GRID_SEARCH, search_rs=42, search_n_iter=10):
    
    config_json_path = save_folder.joinpath('config.json')
    result_json_path = save_folder.joinpath('result.json')

    if monitor not in metrics:
        raise Exception('monitor is not in metrics. Please modify the metrics passed in.')

    if run_from_checkpoints:
        if not save_folder.exists():
            print('Save folder:' + str(save_folder.resolve()) + ' does not exist. Function terminates.')
            raise Exception('No checkpoints found. Function terminates.')
        
        # check for existing configs
        if not config_json_path.exists():
            raise Exception('No config.json found. Function terminates.')
        
        # check for existing results
        if not result_json_path.exists():
            print('no result.json is found. Assuming no existing checkpoints.')
    else:
        if save_folder.exists():
            raise Exception('Checkpoints found. Please delete the checkpoints or set run_from_checkpoints=True. Function terminates.')

    if not save_folder.exists():
        save_folder.mkdir()

    config = _init_config_dict(config_json_path, 'ctm', hyperparameters, search_space,
                               metrics, monitor, search_behaviour, search_rs, search_n_iter)
    result = _init_result_dict(result_json_path, monitor.value)

    print('Search folder: {}'.format(save_folder))

    # init
    best_model_path = result['best_model_checkpoint']
    best_metric_score = result['best_metric']
    best_model = _load_ctm_model(Path(best_model_path)) if best_model_path != "" else None

    print(f'Best model checkpoint: {best_model_path}')
    print(f'Best metric score: {best_metric_score}')
    print(f'Best model: {best_model}')

    # search
    # like bertopic, we create a temp dict for initiating the search space
    # then we apply sklearn parameter sampler / parameter grid to get the search space
    temp_search_space = {}
    for k, v in search_space.items():
        for kk, vv in v.items():
            temp_search_space[k + '__' + kk] = vv

    if search_behaviour == SEARCH_BEHAVIOUR.RANDOM_SEARCH:
        search_iterator = ParameterSampler(temp_search_space, search_n_iter, random_state=search_rs)
    elif search_behaviour == SEARCH_BEHAVIOUR.GRID_SEARCH:
        search_iterator = ParameterGrid(temp_search_space)

    print('\n')

    for search_space_dict in search_iterator:

        # unwrap the search space dict

        model_name = ''

        _sbert_params = {}
        _countvect_params = {}
        _ctm_params = {}

        for k, v in search_space_dict.items():
            if k.startswith('sbert_params'):
                _sbert_params[k.split('__')[1]] = v
                model_name += 'sb_' + k.split('__')[1] + '_' + str(v) + '_'
            elif k.startswith('countvect_params'):
                _countvect_params[k.split('__')[1]] = v
                model_name += 'cvect_' + k.split('__')[1] + '_' + str(v) + '_'
            elif k.startswith('ctm_params'):
                _ctm_params[k.split('__')[1]] = v
                model_name += 'ctm_' + k.split('__')[1] + '_' + str(v) + '_'

        model_name = model_name[:-1]     # remove the last '_'

        model_path = save_folder.joinpath(config['model'] + model_name)

        # check whether the model exists
        if model_path.exists():
            print('Skipping current search space: {}'.format(search_space_dict))
            continue
    
        ##########
        # Training starts
        ##########

        print('Current search space: {}'.format(search_space_dict))

        sbert_params = deepcopy(config['sbert_params'])     # deepcopy just for safety (not messing up with the original config)
        countvect_params = deepcopy(config['countvect_params'])
        ctm_params = deepcopy(config['ctm_params'])

        sbert_params.update(_sbert_params)
        countvect_params.update(_countvect_params)
        ctm_params.update(_ctm_params)

        # create bow
        vectorizer = CountVectorizer(**countvect_params)
        vectorizer.fit_transform(text_for_bow)
        temp_vocabulary = set(vectorizer.get_feature_names_out())

        preprocessed_docs_tmp = [' '.join([w for w in doc.split() if w in temp_vocabulary])
                            for doc in text_for_bow]
        text_for_bow = preprocessed_docs_tmp
    
        # create sbert embeddings
        if platform.system() == 'Linux' or platform.system() == 'Windows':
            device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        else:
            device = torch.device('mps')        # m-series machine
        
        tp = TopicModelDataPreparation()

        # check existing embeddings
        # reuse them if found
        embeddings_path = save_folder.joinpath(f'embeddings_{sbert_params["model_name_or_path"]}.pkl')
        if embeddings_path.exists():
            with open(embeddings_path, 'rb') as f:
                embeddings = np.load(f)

            print(f'Found existing sbert embeddings at {embeddings_path}. Reusing them.')
        else:
            embeddings = bert_embeddings_from_list(text_for_contextual, **sbert_params, device=device)

            with open(embeddings_path, 'wb') as f:
                np.save(f, embeddings)
         
        training_dataset = tp.fit(text_for_contextual=text_for_contextual, text_for_bow=text_for_bow, custom_embeddings=embeddings)

        # ctm

        ctm_params['bow_size'] = len(tp.vocab)
        ctm_params['contextual_size'] = embeddings.shape[1]

        ctm = CombinedTM(**ctm_params)
        ctm.device = device
        ctm.fit(training_dataset, verbose=True)

        ##########
        # Training ends
        ##########

        ##########
        # Evaluation starts
        ##########

        # init data for gensim coherence model
        topic_words = _get_topics(ctm, k=10)
        topics = ctm.get_predicted_topics(training_dataset, n_samples=20)

        documents = pd.DataFrame({"Document": X,
                                "ID": range(len(X)),
                                "Topic": topics})
        
        docs_per_topic = documents.groupby(['Topic'], as_index=False).agg({'Document': ' '.join})
        texts = [doc.split() for doc in docs_per_topic.Document.values]
        
        dictionary = corpora.Dictionary(texts)
        corpus = [dictionary.doc2bow(text) for text in texts]

        # init octis format result for convenience
        result_octis = {}
        result_octis['topics'] = topic_words
        result_octis['topic-word-matrix'] = _get_topic_word_metrix(ctm)
        result_octis['topic-document-matrix'] = _get_topic_document_metrix(ctm, training_dataset, n_samples=20)

        print('Compute evaluation metrics')

        metrics_score = dict()

        for metric in metrics:
            if metric in COHERENCE_MODEL_METRICS:
                coherencemodel = CoherenceModel(topics=topic_words, texts=texts, corpus=corpus, dictionary=dictionary, topn=10, coherence=metric.value)
                score = coherencemodel.get_coherence()
            elif metric == METRICS.TOPIC_DIVERSITY:
                score = compute_topic_diversity(result_octis, topk=10)
            elif metric == METRICS.INVERTED_RBO:
                score = compute_inverted_rbo(result_octis, topk=10)
            elif metric == METRICS.PAIRWISE_JACCARD_SIMILARITY:
                score = compute_pairwise_jaccard_similarity(result_octis, topk=10)
            else:
                raise Exception('Unknown metric: {}'.format(metric.value))

            metrics_score[metric.value] = score

            print(f'Evaluation metric ({metric.value}): {score}')

        monitor_score = metrics_score[monitor.value]

        ##########
        # Evaluation ends
        ##########

        ##########
        # Save models
        ##########

        if not model_path.exists():
            model_path.mkdir()
        
        if save_each_models:
            ctm.save(models_dir=model_path)

        ##########
        # Save models ends
        ##########

        ###########
        # Update result dict and json file
        ###########
            
        model_hyperparameters = {
            'sbert_params': _sbert_params,
            'countvect_params': _countvect_params,
            'ctm_params': _ctm_params
        }

        if monitor_score > best_metric_score:
            best_metric_score = monitor_score
            best_model_path = model_path
            best_model = ctm
            best_hyperparameters = model_hyperparameters

        model_log_history = dict()
        model_log_history.update(metrics_score)
        model_log_history['model_name'] = model_name
        model_log_history['hyperparameters']  = model_hyperparameters

        result['best_metric'] = best_metric_score
        result['best_model_checkpoint'] = str(best_model_path)
        result['best_hyperparameters'] = best_hyperparameters
        result['log_history'].append(model_log_history)

        # save result
        with open(result_json_path, 'w') as f:
            json.dump(result, f, indent=2)
        
        print('Saved result.json at:', result_json_path)
        print('\n\n')
    
    print('Search ends')
    return best_model, best_model_path, best_hyperparameters


In [27]:
# grid search / random search

# hyperparameters
sbert_params = _init_sbert_params(model_name_or_path='all-mpnet-base-v2')
countvect_params = _init_count_vectorizer_params(max_features=2000, ngram_range=(1,1))
ctm_params = _init_ctm_params(n_components=10, hidden_sizes=(100, 100), dropout=0.2, lr=2e-3, momentum=0.99, solver="adam")

search_space_dict = {
    'sbert_params': {
        'model_name_or_path': ['all-mpnet-base-v2', 'all-roberta-large-v1']
    },
    'countvect_params': {
        'ngram_range': [(1, 1), (1,2)]
    },
    'ctm_params':{
        'n_components': [10, 20],
    }
}

search_behaviour = SEARCH_BEHAVIOUR.GRID_SEARCH
# search_behaviour = SEARCH_BEHAVIOUR.RANDOM_SEARCH

training_datetime = datetime.now()
# training_datetime = datetime(2024, 1)
training_folder = Path(f'ctm_{search_behaviour.value}_{training_datetime.strftime("%Y%m%d_%H%M%S")}')

best_model, best_model_path, best_hyperparameters = model_search(
    X,
    X_preprocessed,
    hyperparameters={
        'sbert_params': sbert_params,
        'countvect_params': countvect_params,
        'ctm_params': ctm_params
    },
    search_space=search_space_dict,
    save_folder=training_folder,
    metrics=[METRICS.C_NPMI, METRICS.C_V, METRICS.UMASS, METRICS.C_UCI, METRICS.TOPIC_DIVERSITY, METRICS.INVERTED_RBO, METRICS.PAIRWISE_JACCARD_SIMILARITY],
    monitor=METRICS.C_NPMI,
    save_each_models=True,
    run_from_checkpoints=False,
    search_behaviour=search_behaviour
)

Created config file at ctm_grid_search_20240122_134939/config.json
Search folder: ctm_grid_search_20240122_134939
Best model checkpoint: 
Best metric score: -inf
Best model: None


Current search space: {'countvect_params__ngram_range': (1, 1), 'ctm_params__n_components': 10, 'sbert_params__model_name_or_path': 'all-mpnet-base-v2'}


Batches: 100%|██████████| 2360/2360 [03:51<00:00, 10.18it/s]


Settings: 
                   N Components: 10
                   Topic Prior Mean: 0.0
                   Topic Prior Variance: 0.9
                   Model Type: prodLDA
                   Hidden Sizes: (100, 100)
                   Activation: softplus
                   Dropout: 0.2
                   Learn Priors: True
                   Learning Rate: 0.002
                   Momentum: 0.99
                   Reduce On Plateau: False
                   Save Dir: None


0it [00:00, ?it/s]huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid 

KeyboardInterrupt: 

In [None]:
# load the best model from the checkpoints

search_behaviour = SEARCH_BEHAVIOUR.GRID_SEARCH
training_datetime = datetime(2024, 1, 22, 13, 49, 39)
training_folder = Path(f'ctm_{search_behaviour.value}_{training_datetime.strftime("%Y%m%d_%H%M%S")}')

training_result_json_path = training_folder.joinpath('result.json')
with open(training_result_json_path, 'r') as f:
    training_result = json.load(f)

best_model_path = training_result['best_model_checkpoint']
ctm_hyperparameters = training_result['best_hyperparameters']['ctm_params']

best_model = _load_ctm_model(Path(best_model_path), ctm_hyperparameters)
best_model.get_topic_lists(k=10)