In [1]:
import logging
from logging import config as log_config
from pathlib import Path

import mlflow
import mlflow.sklearn
import optuna

from helpers.cli_options import get_cli_options_hpo
from helpers.mlflow_helpers import create_experiment, mlflow_callback
from helpers.optuna_helpers import objective, logging_callback
from helpers.pipeline import Anonymizer
from helpers.preprocessing import load_data

In [2]:
log_config.fileConfig(r"./log.conf")
logger = logging.getLogger(__name__)


base_folder = Path("../data/processed")
mlflow.tracking.set_tracking_uri("http://localhost:5000")


In [3]:
experiment = create_experiment(base_name="Sentiment")

In [9]:
def train_model(
    sample_size: int, workers: int, experiment: mlflow.entities.experiment.Experiment,
) -> None:
    logger.info("Load IMDB reviews")
    df_train, _ = load_data(folder=base_folder, sample_size=sample_size)

    # Anonymize data before pipeline, since this step is slow and constant
    logger.info("Preprocess reviews with spaCy. This may take a while..")
    anonymized_reviews = Anonymizer().transform(df_train.review)

    logger.info("Explore search space")
    study = optuna.create_study(direction="maximize")
    study.set_user_attr(key='experiment', value=experiment)

    study.optimize(
        lambda trial: objective(
            trial,
            X=anonymized_reviews[:10],
            y=df_train.sentiment[:10],
            workers=workers,
        ),
        n_trials=20,
        callbacks=[logging_callback, mlflow_callback],
    )

    return study

In [10]:
s = train_model(sample_size=10_000, workers=2, experiment=experiment)

INFO:: Load IMDB reviews
INFO:: Preprocess reviews with spaCy. This may take a while..
INFO:: Explore search space
INFO:: Trial number: 0, ngram_range: 11, clf: MultinomialNB, alpha: 0.026068140154119183
INFO:: Accuracy: 0.3611111111111111
INFO:: Trial number: 1, ngram_range: 12, clf: RandomForestClassifier, max_depth: 2
INFO:: Accuracy: 0.27777777777777773
INFO:: Trial number: 2, ngram_range: 12, clf: RandomForestClassifier, max_depth: 2
INFO:: Accuracy: 0.38888888888888884
INFO:: Trial number: 3, ngram_range: 11, clf: SVC, C: 0.16847788581886447
INFO:: Accuracy: 0.3055555555555555
INFO:: Trial number: 4, ngram_range: 11, clf: MultinomialNB, alpha: 0.02365706569357498
INFO:: Accuracy: 0.3611111111111111
INFO:: Trial number: 5, ngram_range: 11, clf: SVC, C: 0.15135894230972707
INFO:: Accuracy: 0.3055555555555555
INFO:: Trial number: 6, ngram_range: 12, clf: SVC, C: 0.1433212999490459
INFO:: Accuracy: 0.3055555555555555
INFO:: Trial number: 7, ngram_range: 11, clf: SVC, C: 0.17589144970

In [14]:
s.best_trial.number

2

In [17]:
run_id = mlflow.search_runs(
        experiment_ids=experiment.experiment_id)

In [18]:
run_id

Unnamed: 0,run_id,experiment_id,status,artifact_uri,start_time,end_time,metrics.f1,metrics.accuracy,metrics.average_precision,params.max_depth,params.clf,params.ngram_range,params.alpha,params.C,tags.mlflow.log-model.history,tags.mlflow.source.type,tags.mlflow.user,tags.mlflow.source.name
0,753a0e926ed74dc0810386f554fc10ed,54,FINISHED,../data/artifacts/54/753a0e926ed74dc0810386f55...,2020-07-02 16:34:29.903000+00:00,2020-07-02 16:34:30.147000+00:00,0.388889,0.388889,0.444444,3.0,RandomForestClassifier,12,,,"[{""run_id"": ""753a0e926ed74dc0810386f554fc10ed""...",LOCAL,apancham002,/Users/apancham002/Applications/anaconda3/envs...
1,fca40f4f11d34b7a97774e648dc9aa28,54,FINISHED,../data/artifacts/54/fca40f4f11d34b7a97774e648...,2020-07-02 16:34:29.320000+00:00,2020-07-02 16:34:29.559000+00:00,0.222222,0.277778,0.444444,2.0,RandomForestClassifier,12,,,"[{""run_id"": ""fca40f4f11d34b7a97774e648dc9aa28""...",LOCAL,apancham002,/Users/apancham002/Applications/anaconda3/envs...
2,acc5d9d9ac4c4dea89e9d4c1533192c0,54,FINISHED,../data/artifacts/54/acc5d9d9ac4c4dea89e9d4c15...,2020-07-02 16:34:28.728000+00:00,2020-07-02 16:34:28.973000+00:00,0.166667,0.277778,0.555556,,MultinomialNB,12,0.0992045609572619,,"[{""run_id"": ""acc5d9d9ac4c4dea89e9d4c1533192c0""...",LOCAL,apancham002,/Users/apancham002/Applications/anaconda3/envs...
3,91bd001e2a2e4a53a76d897b3a26c634,54,FINISHED,../data/artifacts/54/91bd001e2a2e4a53a76d897b3...,2020-07-02 16:34:28.352000+00:00,2020-07-02 16:34:28.613000+00:00,0.388889,0.388889,0.444444,4.0,RandomForestClassifier,12,,,"[{""run_id"": ""91bd001e2a2e4a53a76d897b3a26c634""...",LOCAL,apancham002,/Users/apancham002/Applications/anaconda3/envs...
4,c9c286ff1d484e6d9fe0573cf07d318c,54,FINISHED,../data/artifacts/54/c9c286ff1d484e6d9fe0573cf...,2020-07-02 16:34:27.735000+00:00,2020-07-02 16:34:27.975000+00:00,0.222222,0.388889,0.5,2.0,RandomForestClassifier,12,,,"[{""run_id"": ""c9c286ff1d484e6d9fe0573cf07d318c""...",LOCAL,apancham002,/Users/apancham002/Applications/anaconda3/envs...
5,900681ccdf0b4750bdbbd96d8279d541,54,FINISHED,../data/artifacts/54/900681ccdf0b4750bdbbd96d8...,2020-07-02 16:34:27.121000+00:00,2020-07-02 16:34:27.387000+00:00,0.222222,0.277778,0.5,3.0,RandomForestClassifier,12,,,"[{""run_id"": ""900681ccdf0b4750bdbbd96d8279d541""...",LOCAL,apancham002,/Users/apancham002/Applications/anaconda3/envs...
6,bc4b00a5e56947728aaa610f30151bee,54,FINISHED,../data/artifacts/54/bc4b00a5e56947728aaa610f3...,2020-07-02 16:34:26.480000+00:00,2020-07-02 16:34:26.751000+00:00,0.133333,0.194444,0.444444,3.0,RandomForestClassifier,12,,,"[{""run_id"": ""bc4b00a5e56947728aaa610f30151bee""...",LOCAL,apancham002,/Users/apancham002/Applications/anaconda3/envs...
7,55b0790849b6469ca04a081a068ded81,54,FINISHED,../data/artifacts/54/55b0790849b6469ca04a081a0...,2020-07-02 16:34:25.905000+00:00,2020-07-02 16:34:26.132000+00:00,0.222222,0.277778,0.472222,4.0,RandomForestClassifier,12,,,"[{""run_id"": ""55b0790849b6469ca04a081a068ded81""...",LOCAL,apancham002,/Users/apancham002/Applications/anaconda3/envs...
8,0144359cdcb94b77b6fa8349b425efd2,54,FINISHED,../data/artifacts/54/0144359cdcb94b77b6fa8349b...,2020-07-02 16:34:25.298000+00:00,2020-07-02 16:34:25.543000+00:00,0.222222,0.277778,0.444444,4.0,RandomForestClassifier,12,,,"[{""run_id"": ""0144359cdcb94b77b6fa8349b425efd2""...",LOCAL,apancham002,/Users/apancham002/Applications/anaconda3/envs...
9,cd5a0960e38d49d6bdd8dca0400b8f60,54,FINISHED,../data/artifacts/54/cd5a0960e38d49d6bdd8dca04...,2020-07-02 16:34:24.709000+00:00,2020-07-02 16:34:24.961000+00:00,0.222222,0.388889,0.444444,4.0,RandomForestClassifier,12,,,"[{""run_id"": ""cd5a0960e38d49d6bdd8dca0400b8f60""...",LOCAL,apancham002,/Users/apancham002/Applications/anaconda3/envs...
