Skip to content

Commit

Permalink
add base_en example, set default num_iterations to 50, remove tests/data
Browse files Browse the repository at this point in the history
  • Loading branch information
fonhorst committed Jun 8, 2024
1 parent ba72491 commit 5e5cc51
Show file tree
Hide file tree
Showing 16 changed files with 1,046 additions and 35,736 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ def run_algorithm(
exp_id: Union[int, str],
topic_count: int,
num_individuals: int = 11,
num_iterations: int = 400,
num_iterations: int = 50,
num_fitness_evaluations: int = None,
mutation_type: str = "psm",
crossover_type: str = "blend_crossover",
Expand Down
4 changes: 2 additions & 2 deletions autotm/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ def fit(self, dataset: Union[pd.DataFrame, pd.Series], processed_dataset_path: O
exp_id=self.exp_id or "0",
topic_count=self.topic_count,
log_file=self.log_file_path,
**self.alg_params
**(self.alg_params or dict())
)
else:
# TODO: refactor this function
Expand All @@ -172,7 +172,7 @@ def fit(self, dataset: Union[pd.DataFrame, pd.Series], processed_dataset_path: O
topic_count=self.topic_count,
log_file=self.log_file_path,
exp_id=self.exp_id or "0",
**self.alg_params
**(self.alg_params or dict())
)

self._model = best_topic_model.model
Expand Down
1,001 changes: 1,001 additions & 0 deletions data/sample_corpora/imdb_1000.csv

Large diffs are not rendered by default.

37 changes: 29 additions & 8 deletions examples/examples_autotm_fit_predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import os
import uuid
from typing import Dict, Any
from typing import Dict, Any, Optional

import pandas as pd
from sklearn.model_selection import train_test_split
Expand All @@ -22,6 +22,17 @@
"num_individuals": 2,
"use_pipeline": True
},
"base_en": {
"alg_name": "ga",
"dataset": {
"lang": "en",
"dataset_path": "data/sample_corpora/imdb_1000.csv",
"dataset_name": "imdb_1000"
},
"num_iterations": 2,
"num_individuals": 2,
"use_pipeline": True
},
"static_chromosome": {
"alg_name": "ga",
"num_iterations": 2,
Expand Down Expand Up @@ -49,10 +60,15 @@
}


def run(alg_name: str, alg_params: Dict[str, Any]):
path_to_dataset = "data/sample_corpora/sample_dataset_lenta.csv"
def run(alg_name: str, alg_params: Dict[str, Any], dataset: Optional[Dict[str, Any]] = None):
if not dataset:
dataset = {
"lang": "ru",
"dataset_path": "data/sample_corpora/sample_dataset_lenta.csv",
"dataset_name": "lenta_ru"
}

df = pd.read_csv(path_to_dataset)
df = pd.read_csv(dataset['dataset_path'])
train_df, test_df = train_test_split(df, test_size=0.1)

working_dir_path = f"./autotm_workdir_{uuid.uuid4()}"
Expand All @@ -61,13 +77,13 @@ def run(alg_name: str, alg_params: Dict[str, Any]):
autotm = AutoTM(
topic_count=20,
preprocessing_params={
"lang": "ru",
"lang": dataset['lang'],
"min_tokens_count": 3
},
alg_name=alg_name,
alg_params=alg_params,
working_dir_path=working_dir_path,
exp_dataset_name="lenta_ru"
exp_dataset_name=dataset["dataset_name"]
)
mixtures = autotm.fit_predict(train_df)

Expand All @@ -93,8 +109,13 @@ def main(conf_name: str = "base"):
alg_name = conf['alg_name']
del conf['alg_name']

run(alg_name=alg_name, alg_params=conf)
dataset = None
if 'dataset' in conf:
dataset = conf['dataset']
del conf['dataset']

run(alg_name=alg_name, alg_params=conf, dataset=dataset)


if __name__ == "__main__":
main()
main(conf_name="base_en")
Loading

0 comments on commit 5e5cc51

Please sign in to comment.