In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from archaeo_super_prompt.dataset import MagohDataset
import archaeo_super_prompt.modeling.predict as modeling
import mlflow
import pandas as pd
from archaeo_super_prompt.visualization import mlflow_logging as mmlflow
from archaeo_super_prompt.config.env import getenv_or_throw
from sklearn.pipeline import Pipeline

In [None]:
EXP_NAME = "Comune dspy training - working"
mlflow.set_tracking_uri(f"http://{getenv_or_throw("MLFLOW_HOST")}:{getenv_or_throw("MLFLOW_PORT")}")
mlflow.set_experiment(EXP_NAME)
mlflow.dspy.autolog(log_compiles=True, log_evals=True, log_traces_from_compile=True)
pd.set_option('display.max_columns', None)

In [None]:
pipe = modeling.get_pipeline()

In [None]:
_digitally_born_documents = [
    # very good
    33799, 34439, 38005, 36837, 36937, 37614, 37026, 37971, 36846, 36304, 34423, 36052,
    37043, 36554, 989, 37007, 30897, 36351, 36308, 38013, 36011, 33828, 1221,
    38039, 35429, 37065, 37116, 34452, 33441, 33062, 34939, 35918, 33689, 34508, 31035,
    38220, 38092, 36979, 36854, 36207, 34915, 35688, 36359,
    # not that good
    31164, 32600, 33760, 32714, 31208, 30712,
    ]
ok_scanned_pdfs = {
    32666, 31298, 33548, 35189, 35399, 30925, 37040, 37379, 33589, 34769,
    33858, 34329, 5193, 37706, 30647, 37702, 33540, 36042, 33357, 34959,
    30646, 33547, 32581, 30878, 37302, 33560, 35881, 31031, 37381, 242, 34869,
    33841, 36465, 33499, 36095, 36068, 33594, 33904, 33644, 33553, 35052,
    33630, 34426, 31090, 30716, 31059, 35849, 33813, 34666, 36119, 830,
    36187, 31977, 34787, 33749, 35447, 33555, 33846, 34093, 33508,  # 33710
}
dirty_pdfs = {
    36648, 32433, 35131, 33383, 30657, 31312, 30399, 33331, 31234, 30548,
    34685, 34237, 35114, 30821, 33708, 33668, 34932, 30697, 38241, 33443,
    37305, 33535, 31815, 35203, 33576, 32053, 33761, 37910, 35983, 31314,
    37400, 36457, 33582, 31903, 32494, 33184, 36070, 31804, 30861
}

selected_ids = set(_digitally_born_documents).union(ok_scanned_pdfs, dirty_pdfs)
ds = MagohDataset(selected_ids)
inputs = ds.files.sample(10)

In [None]:
import logging
with mlflow.start_run():
    results = pipe.fit(inputs, ds)
    comune_extractor = pipe.named_steps["extract-comune"]
    mmlflow.save_models(comune_extractor)