In [None]:
# grid_search.py
import itertools, time, csv, os
from database     import generate_database
from preferences  import generate_preferences
from training     import train_dpo
from comparison   import make_comparison

import random
random.seed(42)

# ------------------------------ #
# 1. spazi degli iper‑parametri  #
# ------------------------------ #
# ❶  dimensione del database → ampiezza del campo da esplorare
POINTS = [1_000, 2_000, 4_000, 8_000]          # 4 valori

# ❷  numero di coppie di preferenze → qualità del segnale
COUPLES = [25_000, 50_000, 100_000]            # 3 valori

# ❸  quota di coppie "difficili" → diversità delle preferenze
HARD_RATIO = [0.5, 0.7, 0.9]                   # 3 valori

W_PATH  = [0.05, 0.10, 0.20]      # penalità cammino lungo
W_WALL  = [0.1,  0.2,  0.3]       # bonus distanza dal muro
W_GOAL  = [0.2]                   # di solito fisso
W_DEG   = [1.0,  1.5,  2.0]       # vicoli ciechi / grado
W_GATE  = [0.0,  0.2,  0.4]       # distanza dal varco

import itertools
WEIGHTS_SETS = list(itertools.product(W_PATH, W_WALL,
                                      W_GOAL, W_DEG, W_GATE))

# ---------- 2. lista run (random search) ----------
POP  = list(itertools.product(POINTS, COUPLES, HARD_RATIO, WEIGHTS_SETS))
k    = min(100, len(POP))
RUNS = random.sample(POP, k=k)

CSV_PATH = "grid_results.csv"
FIELDNAMES = ["run_id", "points", "couples", "hard_ratio",
              "w_path","w_wall","w_goal","w_deg","w_gate",
              "iter_dpo","iter_base","dist_end_dpo","dist_end_base",
              "elapsed_sec"]

# crea il CSV con intestazioni se non esiste
if not os.path.exists(CSV_PATH):
    with open(CSV_PATH, "w", newline="") as f:
        csv.DictWriter(f, FIELDNAMES).writeheader()

# --------------------------------------------- #
# 2. grid‑search vero e proprio                 #
# --------------------------------------------- #
for run_id, (pts, cpl, hr, w) in enumerate(RUNS, start=1):

    t0 = time.time()
    print(f"\n▶ run {run_id}: points={pts}, couples={cpl}, hard_ratio={hr}, "
          f"weights={w}")

    # 2.1 database & preferenze
    generate_database(points_to_generate=pts)
    generate_preferences(total_couples=cpl,
                         weights_vec=list(w),
                         hard_ratio=hr)

    # 2.2 training
    train_dpo()

    # 2.3 confronto
    iter_dpo, iter_base, dist_dpo, dist_base = make_comparison()

    # 2.4 salva risultati convertendo tutto in tipi built‑in
    row = {
        "run_id"      : run_id,
        "points"      : pts,
        "couples"     : cpl,
        "hard_ratio"  : hr,
        "w_path"      : w[0],
        "w_wall"      : w[1],
        "w_goal"      : w[2],
        "w_deg"       : w[3],
        "w_gate"      : w[4],
        "iter_dpo"    : int(iter_dpo)  if hasattr(iter_dpo,  "item") else iter_dpo,
        "iter_base"   : int(iter_base) if hasattr(iter_base, "item") else iter_base,
        "dist_end_dpo": float(dist_dpo)  if hasattr(dist_dpo,  "item") else dist_dpo,
        "dist_end_base":float(dist_base) if hasattr(dist_base, "item") else dist_base,
        "elapsed_sec" : round(time.time()-t0, 1)
    }

    with open(CSV_PATH, "a", newline="") as f:
        csv.DictWriter(f, FIELDNAMES).writerow(row)

    print("  ↳ saved:", row)

print("\n✓ grid‑search terminato – risultati in", CSV_PATH)


Creating Path with step 3


Creating Path with step 3
Resetting environment. Previous state: [0.05 0.05], Counter: 0
Step called. Counter: 0, Horizon: 100
Step called. Counter: 1, Horizon: 100
Step called. Counter: 2, Horizon: 100
Step called. Counter: 3, Horizon: 100
Step called. Counter: 4, Horizon: 100
Step called. Counter: 5, Horizon: 100
Step called. Counter: 6, Horizon: 100
Step called. Counter: 7, Horizon: 100
Step called. Counter: 8, Horizon: 100
Step called. Counter: 9, Horizon: 100
Step called. Counter: 10, Horizon: 100
Step called. Counter: 11, Horizon: 100
Step called. Counter: 12, Horizon: 100
Step called. Counter: 13, Horizon: 100
Step called. Counter: 14, Horizon: 100
Step called. Counter: 15, Horizon: 100
Step called. Counter: 16, Horizon: 100
Step called. Counter: 17, Horizon: 100
Step called. Counter: 18, Horizon: 100
Step called. Counter: 19, Horizon: 100
Step called. Counter: 20, Horizon: 100
Step called. Counter: 21, Horizon: 100
Step called. Counter: 22, Horizon: 100
Step called. Counter: 23

In [2]:
iter_dpo, iter_base, dist_end_dpo, dist_end_base 

(-1, -1, np.float64(0.15526219015675266), np.float64(0.9027285751862661))