In [None]:
# Stuff to run before this notebook

# local pc, make sure postgresql is running with the necessary data
# ssh -i ./fair_key -R 5432:localhost:5432 -L 8889:localhost:8889 -L 8887:localhost:8887 -L 8050:localhost:8050 name@host

# for conda env and vllm venv

# conda activate /raid/ggattiglia/magoh_ai/env
# source ./vllm-server/.venv/bin/activate

# LLM, needs conda and venv, replace cuda devices with gpus and hf token with your token (need to request permission to use model on huggingface first)
# CUDA_VISIBLE_DEVICES=0,1 HF_TOKEN=[hf_token] vllm serve --port 8001 google/gemma-3-27b-it --tensor-parallel-size 2 --gpu-memory-utilization 0.6

# mlflow, needs conda
# mlflow server --host 127.0.0.1 --port 8887

# NER server, needs conda and venv
# cd prompt_enhancing/models/custom-remote-models/src/magoh_ai_sup_server
# just run-server

In [None]:
# the functions and code in this notebook should be in the code itself

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os
os.chdir(os.path.join(os.getcwd(), '..', '..', 'src'))
print("Current working directory:", os.getcwd())

In [None]:
from archaeo_super_prompt.dataset import MagohDataset, SamplingParams
import archaeo_super_prompt.modeling.train as training
import archaeo_super_prompt.modeling.predict as infering
import mlflow
import pandas as pd
from archaeo_super_prompt.visualization import mlflow_logging as mmlflow
from archaeo_super_prompt import visualization as visualizator
from archaeo_super_prompt.config.env import getenv_or_throw
from sklearn.pipeline import Pipeline
from sklearn import set_config

from pathlib import Path
from sklearn.base import BaseEstimator, TransformerMixin
from archaeo_super_prompt.utils.cache import get_cache_dir_for

class LoadScans(BaseEstimator, TransformerMixin):
    def __init__(self, cache_csv: Path):
        self.cache_csv = Path(cache_csv)
    def fit(self, X, y=None):
        return self
    def transform(self, X):
        scans = pd.read_csv(self.cache_csv)
        scans = scans.drop_duplicates(subset=["id"])
        return X.merge(scans, on="id", how="inner")

# cache location
CACHE_CSV = get_cache_dir_for("interim", "miscel") / "scans.csv"
SCANS_DF = pd.read_csv(CACHE_CSV)


In [None]:
EXP_NAME = "Complete training test"
mlflow.set_tracking_uri(f"http://{getenv_or_throw('MLFLOW_HOST')}:{getenv_or_throw('MLFLOW_PORT')}")
mlflow.set_experiment(EXP_NAME)
mlflow.dspy.autolog(log_compiles=True, log_evals=True, log_traces_from_compile=True)
pd.set_option('display.max_columns', None)
set_config(display="diagram")

from archaeo_super_prompt.utils.cache import get_cache_dir_for
store_dir = get_model_store_dir()
Path(store_dir).mkdir(parents=True, exist_ok=True)
print("Model store:", store_dir)

In [None]:
# HARD RESET OF CACHES / ARTIFACTS run this when fresh mlflow rerun is needed
import os, shutil, pathlib, urllib.parse, uuid
import mlflow

# wipe compiled DSPy programs used by extractors
from archaeo_super_prompt.utils.result import get_model_store_dir
shutil.rmtree(get_model_store_dir(), ignore_errors=True)

# wipe joblib / skdag caches inside project cache dirs
from archaeo_super_prompt.utils.cache import get_cache_dir_for
for scope in ["external","internal","interim","miscel","thesaurus","raw"]:
    try:
        base = get_cache_dir_for(scope)
    except Exception:
        continue
    if not base.exists(): 
        continue
    for p in base.rglob("*"):
        if p.is_dir() and any(k in p.name for k in ("joblib","skdag","__joblib_cache__")):
            shutil.rmtree(p, ignore_errors=True)

fresh_dir = pathlib.Path.cwd() / f"mlruns_fresh_{uuid.uuid4().hex[:6]}"
fresh_dir.mkdir(parents=True, exist_ok=True)
mlflow.set_tracking_uri(f"http://{getenv_or_throw('MLFLOW_HOST')}:{getenv_or_throw('MLFLOW_PORT')}")
mlflow.set_experiment(f"fresh-{uuid.uuid4().hex[:6]}")
mlflow.dspy.autolog(log_compiles=True, log_evals=True, log_traces_from_compile=True)
pd.set_option('display.max_columns', None)
set_config(display="diagram")

store_dir = get_model_store_dir()
Path(store_dir).mkdir(parents=True, exist_ok=True)
print("Model store:", store_dir)

In [None]:
from archaeo_super_prompt.dataset import MagohDataset

selected_ids = set(map(int, SCANS_DF["id"].dropna().tolist()))
ds = MagohDataset(selected_ids)

inputs = ds.files.merge(SCANS_DF[["id"]].drop_duplicates(), on="id", how="inner")
train_inputs, eval_inputs = inputs.iloc[:10], inputs.iloc[10:]


In [None]:
import os, importlib

#delete openai in env if present, we have a local model
for k in ("OPENAI_BASE_URL", "OPENAI_API_BASE"):
    os.environ.pop(k, None)

#should be moved to env
os.environ["VLLM_SERVER_BASE_URL"] = "http://127.0.0.1:8001/v1"
#dspy seems to fall back to openai in case of errors, this is a dummy key
os.environ["OPENAI_API_KEY"] = "sk-local"

from archaeo_super_prompt.modeling.struct_extract import language_model as lm_provider_mod
from archaeo_super_prompt.modeling.struct_extract import field_extractor as fe
lm_provider_mod = importlib.reload(lm_provider_mod)
fe = importlib.reload(fe)

import pandas as pd
from pathlib import Path
from sklearn.base import BaseEstimator, TransformerMixin
from archaeo_super_prompt.utils.cache import get_cache_dir_for
from archaeo_super_prompt.modeling import pdf_to_text
import ast

CACHE_CSV = get_cache_dir_for("interim", "miscel") / "scans.csv"

def _as_list(v):
    if isinstance(v, list): return v
    if pd.isna(v): return []
    if isinstance(v, str):
        s=v.strip()
        if s and s[0] in "[{(":
            try:
                x=ast.literal_eval(s)
                return x if isinstance(x, list) else [x]
            except Exception:
                return [v]
        return [v]
    return [v]

def _as_str_list(v):
    return [str(x) for x in _as_list(v)]

def _as_int_list(v):
    out=[]
    for x in _as_list(v):
        try:
            out.append(int(x))
        except Exception:
            try:
                out.append(int(float(x)))
            except Exception:
                pass
    return out

class LoadScans(BaseEstimator, TransformerMixin):
    def __init__(self, cache_csv: str | Path):
        self.cache_csv = Path(cache_csv)
        self._df = None
    def fit(self, X, y=None):
        df = pd.read_csv(self.cache_csv)
        if "id" in df.columns:
            df["id"] = df["id"].astype("int64").astype(int)
        if "chunk_type" in df.columns:
            df["chunk_type"] = df["chunk_type"].apply(_as_str_list)
        if "chunk_page_position" in df.columns:
            df["chunk_page_position"] = df["chunk_page_position"].apply(_as_int_list)
        if "identified_thesaurus" in df.columns:
            df["identified_thesaurus"] = df["identified_thesaurus"].apply(_as_int_list)
        if "named_entities" in df.columns:
            df["named_entities"] = df["named_entities"].apply(_as_list)
        self._df = df
        return self
    def transform(self, X):
        X = X.copy()
        if "id" in X.columns:
            X["id"] = X["id"].astype(int)
        return X.merge(self._df, on="id", how="inner")

# replace vision lm with preprocessed scans
pdf_to_text.VLLM_Preprocessing = lambda **kw: LoadScans(CACHE_CSV)

training = importlib.reload(training)

In [None]:
_base_parts = training.get_training_dag()
from archaeo_super_prompt.modeling import predict as infering
expected_final_pipeline = infering.build_complete_inference_dag(_base_parts)
expected_final_pipeline


In [None]:
import re, datetime, pandas as pd
from archaeo_super_prompt.modeling.struct_extract.extractors.archiving_date import ArchivingDateProvider, ArchivingDateOutputSchema

def _predict_safe(self, X):
    def parse(dp):
        if dp is None: return datetime.date(1900,1,1)
        if isinstance(dp, (datetime.date, datetime.datetime, pd.Timestamp)):
            return dp.date() if hasattr(dp, "date") else dp
        s = str(dp).strip()
        if s == "" or s.lower() in ("none", "nan", "nat"): return datetime.date(1900,1,1)
        for fmt in ("%Y-%m-%d","%d-%m-%Y","%d/%m/%Y","%Y/%m/%d","%d.%m.%Y","%Y.%m.%d"):
            try: return datetime.datetime.strptime(s, fmt).date()
            except Exception: pass
        parts = re.split(r"[-/\. ]+", s)
        try:
            if len(parts) >= 3:
                d, m, y = map(int, parts[:3])
                if y < 100: y += 2000
                return datetime.date(y, m, d)
        except Exception: pass
        m = re.search(r"(\d{4})", s)
        if m: return datetime.date(int(m.group(1)), 1, 1)
        return datetime.date(1900,1,1)
    rows = [{"id": a.id, "data_protocollo": parse(getattr(a, "building__Data_Protocollo", None))}
            for a in self._mds.get_answers(set(X["id"].to_list()))]
    return ArchivingDateOutputSchema.validate(pd.DataFrame(rows).set_index("id"))

ArchivingDateProvider.predict = _predict_safe


In [None]:
import re, calendar, datetime
import archaeo_super_prompt.modeling.struct_extract.extractors.intervention_date as ide
from archaeo_super_prompt.modeling.struct_extract.extractors.intervention_date import InterventionStartExtractor, ITALIAN_MONTHS, Data, DataInterventoInputData
from archaeo_super_prompt.modeling.struct_extract.extractors.comune import ComuneExtractor, ComuneInputData, Comune
from archaeo_super_prompt.dataset.thesauri import comune_province as cp

def load_comune_aligned():
    df = pd.read_csv(cp._get_comune_file())
    df = df[df["nome"].notnull() & df["provincia"].notnull()]
    return [(id_com, nome) for _, id_com, nome in df[["id_com","nome"]].itertuples()]

cp.load_comune = load_comune_aligned


DATE_RE = re.compile(r"\b(\d{1,2}[\/\.-]\d{1,2}[\/\.-]\d{2,4}|gennaio|febbraio|marzo|aprile|maggio|giugno|luglio|agosto|settembre|ottobre|novembre|dicembre)\b", re.I)

# remove bad ocr output
def _strip_tables_and_noise(s):
    s=re.sub(r"^\s*\d+\s*,\s*\d+\s*=\s*.*$", "", s, flags=re.M|re.S)
    s=re.sub(r"`-+`\s*", "", s)
    return s

# avoids going over max context window for long (and badly read) documents
def _focus_and_truncate(s, max_chars=20000):
    s=_strip_tables_and_noise(s)
    if len(s)<=max_chars: return s
    lines=s.splitlines()
    hits=[ln for ln in lines if DATE_RE.search(ln)]
    head="\n".join(lines[:4000])
    tail="\n".join(lines[-2000:])
    middle="\n".join(hits)[:8000]
    out="\n".join([head,middle,tail])
    return out[:max_chars]

# feed trimmed "fragmenti relazione" and fix month (went out of bounds)
def _to_dspy_input_patched(self, x):
    d=x.data_protocollo
    return DataInterventoInputData(
        fragmenti_relazione=_focus_and_truncate(x.merged_chunks),
        data_di_archiviazone=Data(
            giorno=int(getattr(d,"day",d.day)),
            mese=ITALIAN_MONTHS[int(getattr(d,"month",d.month))-1],
            anno=int(getattr(d,"year",d.year)),
        ),
    )
# replace intervention start extractor to not crash when context is too large
InterventionStartExtractor._to_dspy_input = _to_dspy_input_patched

# fixes for invalid model outputs that broke the pipeline
def _as_int(x,d): 
    try: return int(x)
    except Exception: return d
def _clamp_month(m):
    m=_as_int(m,1)
    return 1 if m<1 else 12 if m>12 else m
def _clamp_day(y,m,d):
    y=_as_int(y,1900); m=_clamp_month(m); d=_as_int(d,1)
    last=calendar.monthrange(max(1,y),m)[1]
    return 1 if d<1 else last if d>last else d
def _get_min_date_safe(o):
    y=_as_int(getattr(o,"start_year",1900),1900)
    m=_clamp_month(getattr(o,"start_month",1))
    d=_clamp_day(y,m,getattr(o,"start_day",1))
    return datetime.date(y,m,d)
def _get_max_date_safe(o):
    y=getattr(o,"end_year",None)
    if y is None: y=getattr(o,"start_year",1900)
    y=_as_int(y,1900)
    m=_clamp_month(getattr(o,"end_month",12))
    d=_clamp_day(y,m,getattr(o,"end_day",28))
    return datetime.date(y,m,d)
ide._get_min_date = _get_min_date_safe
ide._get_max_date = _get_max_date_safe

# comune fixing section

def _read_comuni_tables_strict():
    comuni = pd.read_csv(cp._get_comune_file())[
        ["id_com", "nome", "provincia"]
    ].rename(columns={"id_com":"comune_id","nome":"name","provincia":"province_id"})
    comuni = comuni.dropna(subset=["name","province_id"]).copy()
    comuni["comune_id"] = comuni["comune_id"].astype(int)
    comuni["province_id"] = comuni["province_id"].astype(int)

    # drop duplicated comuni keeping the first occurrence
    dups = comuni["comune_id"].duplicated(keep="first").sum()
    if dups:
        print(f"[warn] dropping {dups} duplicated comuni by comune_id")
        comuni = comuni.drop_duplicates(subset=["comune_id"], keep="first")

    province = pd.read_csv(cp._get_provincie_file())[
        ["id_prov","nome","sigla"]
    ].rename(columns={"id_prov":"province_id","nome":"province_name"})
    province["province_id"] = province["province_id"].astype(int)
    province = province.drop_duplicates(subset=["province_id"], keep="first")

    # each comune refers to exactly one province
    merged = comuni.merge(
        province[["province_id","province_name","sigla"]],
        on="province_id",
        how="left",
        validate="m:1",   # raise if a province_id maps to multiple province rows
        copy=False,
    )

    comuni = comuni.set_index("comune_id").sort_index()
    merged = merged.set_index("comune_id").sort_index()
    return comuni, merged

_COMUNI, _MERGED = _read_comuni_tables_strict()
_POS2ID = _COMUNI.index.to_list()

_FIELDS = set(Comune.model_fields.keys())

def _mk_from_pos(pos: int):
    if not (0 <= pos < len(_POS2ID)):
        return None
    cid = _POS2ID[pos]
    if cid not in _MERGED.index:
        return None
    # scalar gets:
    name = str(_MERGED.at[cid, "name"])
    prov_name = str(_MERGED.at[cid, "province_name"])
    sigla = str(_MERGED.at[cid, "sigla"])
    kw = {}
    if "citta_nome" in _FIELDS:      kw["citta_nome"] = name
    if "provicia_nome" in _FIELDS:   kw["provicia_nome"] = prov_name
    if "provincia_sigla" in _FIELDS: kw["provincia_sigla"] = sigla
    if "id" in _FIELDS:              kw["id"] = int(cid)
    # not sure if name or nome is used
    if "nome" in _FIELDS and "citta_nome" not in _FIELDS: kw["nome"] = name
    if "name" in _FIELDS and "citta_nome" not in _FIELDS: kw["name"] = name
    return Comune(**kw)

def _comune_to_dspy_input(self, x):
    pos_list = getattr(x, "identified_thesaurus", None) or getattr(x, "possibili_comuni", None) or []
    if not isinstance(pos_list, list): pos_list = [pos_list]
    pos_list = [int(v) for v in pos_list if str(v).isdigit()]
    cands = [c for c in (_mk_from_pos(p) for p in pos_list) if c is not None]
    return ComuneInputData(
        fragmenti_relazione=getattr(x, "merged_chunks", ""),
        possibili_comuni=cands,
    )

# replace comune extractor code here because the original didn't work
ComuneExtractor._to_dspy_input = _comune_to_dspy_input

In [None]:
from archaeo_super_prompt.modeling import predict as infering
with mlflow.start_run():
    trained_dag_parts = training.train_from_scratch(train_inputs, ds)
    per_field_scores, detailed_results = infering.score_dag(trained_dag_parts, eval_inputs, ds)


In [None]:
visualizator.init_complete_vizualisation_engine(detailed_results)

In [None]:
visualizator.run_display_server()