In [None]:
import warnings
warnings.filterwarnings("ignore")
from copy import deepcopy
from datetime import datetime
from pathlib import Path

import numpy as np
import pandas as pd
import sys

sys.path.append("/code")

from tqdm import tqdm
import torch
# device = torch.device('cpu')
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
# import gym
# import recogym

import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
import torch.optim as optim

torch.backends.cudnn.benchmark = torch.cuda.is_available()
if torch.cuda.is_available():
    torch.set_float32_matmul_precision("high")  # TF32 = big speedup on Ada


from sklearn.utils import check_random_state

# implementing OPE of the IPWLearner using synthetic bandit data
from sklearn.linear_model import LogisticRegression

import matplotlib.pyplot as plt

from scipy.special import softmax
import optuna
# from memory_profiler import profile


from estimators import (
    DirectMethod as DM
)

from simulation_utils import (
    eval_policy,
    generate_dataset,
    create_simulation_data_from_pi,
    get_train_data,
    get_opl_results_dict,
    CustomCFDataset,
    calc_reward,
)

from models import (    
    LinearCFModel,
    NeighborhoodModel,
    BPRModel, 
    RegressionModel
)

from training_utils import (
    train,
    validation_loop, 
    cv_score_model
 )

from custom_losses import (
    SNDRPolicyLoss,
    IPWPolicyLoss
    )

random_state=12345
random_ = check_random_state(random_state)

pd.options.display.float_format = '{:,.8f}'.format

Using device: cpu
Using device: cpu
Using device: cpu


In [2]:
def get_trial_results(
    our_x, 
    our_a, 
    emb_x, 
    emb_a, 
    original_x, 
    original_a, 
    dataset, 
    val_data, 
    original_policy_prob, 
    neighberhoodmodel, 
    regression_model, 
    dm
):
    policy = np.expand_dims(softmax(our_x @ our_a.T, axis=1), -1)
    policy_reward = calc_reward(dataset, policy)
    eval_metrics = eval_policy(neighberhoodmodel, val_data, original_policy_prob, policy)
    action_diff_to_real = np.sqrt(np.mean((emb_a - our_a) ** 2))
    action_delta = np.sqrt(np.mean((original_a - our_a) ** 2))
    context_diff_to_real = np.sqrt(np.mean((emb_x - our_x) ** 2))
    context_delta = np.sqrt(np.mean((original_x - our_x) ** 2))

    row = np.concatenate([
        np.atleast_1d(policy_reward),
        np.atleast_1d(eval_metrics),
        np.atleast_1d(action_diff_to_real),
        np.atleast_1d(action_delta),
        np.atleast_1d(context_diff_to_real),
        np.atleast_1d(context_delta)
    ])
    reg_dm = dm.estimate_policy_value(policy[val_data['x_idx']], regression_model.predict(val_data['x']))
    reg_results = np.array([reg_dm])
    conv_results = np.array([row])
    return get_opl_results_dict(reg_results, conv_results)

## `trainer_trial` Function

This function runs policy learning experiments using offline bandit data and evaluates various estimators.

### Parameters
- **num_runs** (int): Number of experimental runs per training size
- **num_neighbors** (int): Number of neighbors to consider in the neighborhood model
- **num_rounds_list** (list): List of training set sizes to evaluate
- **dataset** (dict): Contains dataset information including embeddings, action probabilities, and reward probabilities
- **batch_size** (int): Batch size for training the policy model
- **num_epochs** (int): Number of training epochs for each experiment
- **lr** (float, default=0.001): Learning rate for the optimizer

### Process Flow
1. Initializes result structures and retrieval models
2. For each training size in `num_rounds_list`:
   - Creates a uniform logging policy and simulates data
   - Generates training data for offline learning
   - Fits regression and neighborhood models for reward estimation
   - Initializes and trains a counterfactual policy model
   - Evaluates policy performance using various estimators
   - Collects metrics on policy reward and embedding quality

### Returns
- **DataFrame**: Results table with rows indexed by training size and columns for various metrics:
  - `policy_rewards`: True expected reward of the learned policy
  - Various estimator errors (`ipw`, `reg_dm`, `conv_dm`, `conv_dr`, `conv_sndr`)
  - Variance metrics for each estimator
  - Embedding quality metrics comparing learned representations to ground truth

### Implementation Notes
- Uses uniform random logging policy for collecting offline data
- Employs Self-Normalized Doubly Robust (SNDR) policy learning
- Measures embedding quality via RMSE to original/ground truth embeddings

In [None]:
def trainer_trial(
    num_runs,
    num_neighbors,
    train_sizes,
    dataset,
    batch_size,
    val_size=2000,
    n_trials=10,    
    prev_best_params=None
):

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    torch.backends.cudnn.benchmark = torch.cuda.is_available()
    if torch.cuda.is_available():
        torch.set_float32_matmul_precision("high")

    dm = DM()
    results = {}

    our_x, our_a = dataset["our_x"], dataset["our_a"]
    emb_x, emb_a = dataset["emb_x"], dataset["emb_a"]

    original_x, original_a = dataset["original_x"], dataset["original_a"]
    n_users, n_actions, emb_dim = dataset["n_users"], dataset["n_actions"], dataset["emb_dim"]

    all_user_indices = np.arange(n_users, dtype=np.int64)

    def T(x):
        return torch.as_tensor(x, device=device, dtype=torch.float32)

    def _mean_dict(dicts):
        """
        Robust mean over a list of dicts with numeric/scalar/1D-array values.
        Returns a single dict with elementwise means.
        """
        if not dicts:
            return {}
        keys = dicts[0].keys()
        out = {}
        for k in keys:
            vals = [d[k] for d in dicts if k in d]
            # try to convert each to np.array and average
            arrs = [np.asarray(v) for v in vals]
            # broadcast to same shape if scalars/1D
            stacked = np.stack(arrs, axis=0)
            out[k] = np.mean(stacked, axis=0)
        return out

    # ===== unpack dataset (keep originals safe) =====
    our_x_orig, our_a_orig = our_x, our_a
    emb_x, emb_a = emb_x, emb_a
    original_x, original_a = original_x, original_a
    n_users, n_actions, emb_dim = n_users, n_actions, emb_dim
    all_user_indices = np.arange(n_users, dtype=np.int64)

    dm = DM()
    results = {}
    best_hyperparams_by_size = {}
    last_best_params = prev_best_params if prev_best_params is not None else None

    # ===== baseline (sample size = 0) using get_trial_results =====
    pi_0 = softmax(our_x_orig @ our_a_orig.T, axis=1)
    original_policy_prob = np.expand_dims(pi_0, -1)

    simulation_data = create_simulation_data_from_pi(
        dataset, pi_0, val_size, random_state=0
    )

    # use same data for train/val just to generate the baseline row
    train_data = get_train_data(n_actions, val_size, simulation_data, np.arange(val_size), our_x_orig)
    val_data   = get_train_data(n_actions, val_size, simulation_data, np.arange(val_size), our_x_orig)

    regression_model = RegressionModel(
        n_actions=n_actions, action_context=our_x_orig,
        base_model=LogisticRegression(random_state=12345)
    )

    regression_model.fit(train_data['x'], train_data['a'], train_data['r'])

    neighberhoodmodel = NeighborhoodModel(
        train_data['x_idx'], train_data['a'],
        our_a_orig, our_x_orig, train_data['r'],
        num_neighbors=num_neighbors
    )

    # baseline row produced via get_trial_results
    results[0] = get_trial_results(
        our_x_orig, our_a_orig, emb_x, emb_a, original_x, original_a,
        dataset, val_data, original_policy_prob,
        neighberhoodmodel, regression_model, dm
    )

    # ===== main loop over training sizes =====
    for train_size in train_sizes:

        # we’ll collect per-run trial dicts generated by get_trial_results
        trial_dicts_this_size = []
        best_hyperparams_by_size[train_size] = {}

        # --- prepare a resampling for Optuna’s objective (shared loaders built per-run inside objective) ---
        # We’ll do Optuna per-run (fresh resample + search), then final fit with best params, then get_trial_results.

        for run in range(num_runs):

            # --- resample for this run ---
            pi_0 = softmax(our_x_orig @ our_a_orig.T, axis=1)
            original_policy_prob = np.expand_dims(pi_0, -1)

            simulation_data = create_simulation_data_from_pi(
                dataset, pi_0, train_size + val_size,
                random_state=(run + 1) * (train_size + 17)
            )

            idx_train = np.arange(train_size)
            train_data = get_train_data(n_actions, train_size, simulation_data, idx_train, our_x_orig)
            val_idx   = np.arange(val_size) + train_size
            val_data  = get_train_data(n_actions, val_size, simulation_data, val_idx, our_x_orig)

            num_workers = 4 if torch.cuda.is_available() else 0

            cf_dataset = CustomCFDataset(
                train_data['x_idx'], train_data['a'], train_data['r'], original_policy_prob
            )

            val_dataset = CustomCFDataset(
                val_data['x_idx'], val_data['a'], val_data['r'], original_policy_prob
            )

            # val_loader = DataLoader(
            #     val_dataset, batch_size=val_size, shuffle=False,
            #     pin_memory=torch.cuda.is_available(),
            #     num_workers=num_workers, persistent_workers=bool(num_workers)
            # )


            # --- Optuna objective bound to this run's data ---
            def objective(trial):
                lr = trial.suggest_float("lr", 1e-4, 1e-2, log=True)
                epochs = trial.suggest_int("num_epochs", 1, 10)
                trial_batch_size = trial.suggest_categorical("batch_size", [64, 128, 256, 512])
                trial_num_neighbors = trial.suggest_int("num_neighbors", 3, 15)
                lr_decay = trial.suggest_float("lr_decay", 0.8, 1.0)

                trial_neigh_model = NeighborhoodModel(
                    train_data['x_idx'], train_data['a'],
                    our_a_orig, our_x_orig, train_data['r'],
                    num_neighbors=trial_num_neighbors
                )

                trial_scores_all = torch.as_tensor(
                    trial_neigh_model.predict(all_user_indices),
                    device=device, dtype=torch.float32
                )

                trial_model = LinearCFModel(
                    n_users, n_actions, emb_dim,
                    initial_user_embeddings=T(our_x_orig),
                    initial_actions_embeddings=T(our_a_orig)
                ).to(device)

                assert (not torch.cuda.is_available()) or next(trial_model.parameters()).is_cuda

                final_train_loader = DataLoader(
                    cf_dataset, batch_size=trial_batch_size, shuffle=True,
                    pin_memory=torch.cuda.is_available(),
                    num_workers=num_workers, persistent_workers=bool(num_workers)
                )

                current_lr = lr
                for epoch in range(epochs):
                    if epoch > 0:
                        current_lr *= lr_decay
                        
                    train(
                        trial_model, final_train_loader, trial_scores_all,
                        criterion=IPWPolicyLoss(), num_epochs=1, lr=current_lr, device=str(device)
                    )

                trial_x, trial_a = trial_model.get_params()
                trial_x = trial_x.detach().cpu().numpy()
                trial_a = trial_a.detach().cpu().numpy()

                pi_i = softmax(trial_x @ trial_a.T, axis=1)

                # print(get_weights_info(pi_i, original_policy_prob))
                # validation reward for selection
                return cv_score_model(val_data, trial_scores_all, pi_i)


            # --- run Optuna for this run ---
            study = optuna.create_study(direction="maximize")
            
            if last_best_params is not None:
                study.enqueue_trial(last_best_params)

            study.optimize(objective, n_trials=n_trials, show_progress_bar=True)

            best_params = study.best_params
            last_best_params = best_params  # optional warm-start to next run
            best_hyperparams_by_size[train_size][run] = {
                "params": best_params,
                "reward": study.best_value
            }


            # --- final training with best params on this run’s data ---
            regression_model = RegressionModel(
                n_actions=n_actions, action_context=our_x_orig,
                base_model=LogisticRegression(random_state=12345)
            )
            regression_model.fit(
                train_data['x'], train_data['a'], train_data['r'],
                original_policy_prob[train_data['x_idx'], train_data['a']].squeeze()
            )

            neighberhoodmodel = NeighborhoodModel(
                train_data['x_idx'], train_data['a'],
                our_a_orig, our_x_orig, train_data['r'],
                num_neighbors=best_params['num_neighbors']
            )
            scores_all = torch.as_tensor(
                neighberhoodmodel.predict(all_user_indices),
                device=device, dtype=torch.float32
            )

            model = LinearCFModel(
                n_users, n_actions, emb_dim,
                initial_user_embeddings=T(our_x_orig),
                initial_actions_embeddings=T(our_a_orig)
            ).to(device)
            assert (not torch.cuda.is_available()) or next(model.parameters()).is_cuda

            train_loader = DataLoader(
                cf_dataset, batch_size=batch_size, shuffle=True,
                pin_memory=torch.cuda.is_available(),
                num_workers=num_workers, persistent_workers=bool(num_workers)
            )

            current_lr = best_params['lr']
            for epoch in range(best_params['num_epochs']):
                if epoch > 0:
                    current_lr *= best_params['lr_decay']
                train(
                    model, train_loader, scores_all,
                    criterion=IPWPolicyLoss(), num_epochs=1, lr=current_lr, device=str(device)
                )

            # learned embeddings (do NOT overwrite originals)
            learned_x_t, learned_a_t = model.get_params()
            learned_x = learned_x_t.detach().cpu().numpy()
            learned_a = learned_a_t.detach().cpu().numpy()

            # --- produce the per-run result via get_trial_results ---
            trial_res = get_trial_results(
                learned_x, learned_a,          # learned (policy) embeddings
                emb_x, emb_a,                  # ground-truth embedding refs
                original_x, original_a,        # original clean refs
                dataset,
                val_data,                      # use this run's val split
                original_policy_prob,
                neighberhoodmodel,
                regression_model,
                dm
            )

            trial_dicts_this_size.append(trial_res)

            # memory hygiene
            torch.cuda.empty_cache()

        # === aggregate per-run results (mean) and store under this train_size ===
        results[train_size] = _mean_dict(trial_dicts_this_size)

    return pd.DataFrame.from_dict(results, orient='index'), best_hyperparams_by_size

## Learning

We will run several simulations on a generated dataset, the dataset is generated like this:
$$ \text{We have users U and actions A } u_i \sim N(0, I_{emb_dim}) \ a_i \sim N(0, I_{emb_dim})$$
$$ p_{ij} = 1 / (5 + e^{-(u_i.T a_j)}) $$
$$r_{ij} \sim Bin(p_{ij})$$

We have a policy $\pi$
and it's ground truth reward is calculated by
$$R_{gt} = \sum_{i}{\sum_{j}{\pi_{ij} * p_{ij}}} $$

Our parameters for the dataset will be
$$EmbDim = 5$$
$$NumActions= 150$$
$$NumUsers = 150$$
$$NeighborhoodSize = 6$$

to learn a new policy from $\pi$ we will sample from:
$$\pi_{start} = (1-\epsilon)*\pi + \epsilon * \pi_{random}$$

In [4]:
dataset_params = dict(
                    n_actions= 500,
                    n_users = 500,
                    emb_dim = 16,
                    # sigma = 0.1,
                    eps = 0.6, # this is the epsilon for the noise in the ground truth policy representation
                    ctr = 0.1
                    )

train_dataset = generate_dataset(dataset_params)

Random Item CTR: 0.07066414727263938
Optimal greedy CTR: 0.09999926940951757
Optimal Stochastic CTR: 0.09995326955796031
Our Initial CTR: 0.08610747363354625


In [5]:
num_runs = 1
batch_size = 200
num_neighbors = 6
n_trials_for_optuna = 20
num_rounds_list = [500, 1000, 2000, 10000, 20000]
# num_rounds_list = [500, 1000, 2000]
# num_rounds_list = [40000]


# Manually define your best parameters
best_params_to_use = {
    "lr": 0.0095,  # Learning rate
    "num_epochs": 5,  # Number of training epochs
    "batch_size": 64,  # Batch size for training
    "num_neighbors": 8,  # Number of neighbors for neighborhood model
    "lr_decay": 0.85  # Learning rate decay factor
}

### 1

$$emb = 0.7 * gt + 0.3 * noise$$
$$lr = 0.005$$
$$n_{epochs} = 1$$
$$BatchSize=50$$

In [6]:
print("Value of num_rounds_list:", num_rounds_list)

# Run the optimization
df4, best_hyperparams_by_size = trainer_trial(num_runs, num_neighbors, num_rounds_list, train_dataset, batch_size, val_size=10000, n_trials=n_trials_for_optuna, prev_best_params=best_params_to_use)

# # Print best hyperparameters for each training size
# print("\n=== BEST HYPERPARAMETERS BY TRAINING SIZE ===")
# for train_size, params in best_hyperparams_by_size.items():
#     print(f"\nTraining Size: {train_size}")
#     # print(f"Best Reward: {params['reward']:.6f}")
#     print("Parameters:")
#     for param_name, value in params['params'].items():
#         print(f"  {param_name}: {value}")
# print("===========================\n")

# Show the performance metrics
df4[['policy_rewards', 'ipw', 'reg_dm', 'conv_dm', 'conv_dr', 'conv_sndr', 'action_diff_to_real', 'action_delta', 'context_diff_to_real', 'context_delta']]

Value of num_rounds_list: [500, 1000, 2000, 10000, 20000]
{'gini': np.float64(0.46533227392070686), 'ess': np.float64(4313.650641077747), 'max_wi': np.float64(22.173174058915244), 'min_wi': np.float64(0.007342366144116349)}


[I 2025-10-16 17:04:32,513] A new study created in memory with name: no-name-57858b99-1c8f-4338-b5fb-3d339da78ac8
Best trial: 0. Best value: 0.0879964:   5%|▌         | 1/20 [00:02<00:43,  2.31s/it]

Cross-validated error: 0.009510311148654926
[I 2025-10-16 17:04:34,825] Trial 0 finished with value: 0.08799642209798801 and parameters: {'lr': 0.0095, 'num_epochs': 5, 'batch_size': 64, 'num_neighbors': 8, 'lr_decay': 0.85}. Best is trial 0 with value: 0.08799642209798801.


Best trial: 1. Best value: 0.0929651:  10%|█         | 2/20 [00:03<00:34,  1.89s/it]

Cross-validated error: 0.010768681093061845
[I 2025-10-16 17:04:36,418] Trial 1 finished with value: 0.09296512125984926 and parameters: {'lr': 0.0005995912930806353, 'num_epochs': 4, 'batch_size': 64, 'num_neighbors': 8, 'lr_decay': 0.8192591901212108}. Best is trial 1 with value: 0.09296512125984926.


Best trial: 2. Best value: 0.0955179:  15%|█▌        | 3/20 [00:05<00:30,  1.77s/it]

Cross-validated error: 0.011439012523728207
[I 2025-10-16 17:04:38,038] Trial 2 finished with value: 0.09551793180622016 and parameters: {'lr': 0.000998175652165881, 'num_epochs': 7, 'batch_size': 256, 'num_neighbors': 6, 'lr_decay': 0.8354512698988809}. Best is trial 2 with value: 0.09551793180622016.


Best trial: 2. Best value: 0.0955179:  20%|██        | 4/20 [00:07<00:27,  1.72s/it]

Cross-validated error: 0.010253281021046893
[I 2025-10-16 17:04:39,683] Trial 3 finished with value: 0.09098333489003863 and parameters: {'lr': 0.001496590005134734, 'num_epochs': 9, 'batch_size': 512, 'num_neighbors': 11, 'lr_decay': 0.9307885505252045}. Best is trial 2 with value: 0.09551793180622016.


Best trial: 2. Best value: 0.0955179:  25%|██▌       | 5/20 [00:08<00:25,  1.72s/it]

Cross-validated error: 0.009720537617242939
[I 2025-10-16 17:04:41,412] Trial 4 finished with value: 0.08889175624936957 and parameters: {'lr': 0.0031672877389910966, 'num_epochs': 9, 'batch_size': 64, 'num_neighbors': 11, 'lr_decay': 0.9122827737015912}. Best is trial 2 with value: 0.09551793180622016.


Best trial: 2. Best value: 0.0955179:  30%|███       | 6/20 [00:10<00:24,  1.72s/it]

Cross-validated error: 0.010902879466087199
[I 2025-10-16 17:04:43,130] Trial 5 finished with value: 0.09348887141761844 and parameters: {'lr': 0.0020116685373161285, 'num_epochs': 7, 'batch_size': 128, 'num_neighbors': 7, 'lr_decay': 0.9633130325732253}. Best is trial 2 with value: 0.09551793180622016.


Best trial: 2. Best value: 0.0955179:  35%|███▌      | 7/20 [00:12<00:22,  1.72s/it]

Cross-validated error: 0.010388169053668576
[I 2025-10-16 17:04:44,839] Trial 6 finished with value: 0.09151260521824384 and parameters: {'lr': 0.000649413791850496, 'num_epochs': 4, 'batch_size': 128, 'num_neighbors': 10, 'lr_decay': 0.8392313250026459}. Best is trial 2 with value: 0.09551793180622016.


Best trial: 2. Best value: 0.0955179:  40%|████      | 8/20 [00:14<00:20,  1.72s/it]

Cross-validated error: 0.010097085455586206
[I 2025-10-16 17:04:46,577] Trial 7 finished with value: 0.09036936935064221 and parameters: {'lr': 0.00014018686368492215, 'num_epochs': 8, 'batch_size': 256, 'num_neighbors': 12, 'lr_decay': 0.8784656357274324}. Best is trial 2 with value: 0.09551793180622016.


Best trial: 2. Best value: 0.0955179:  45%|████▌     | 9/20 [00:15<00:18,  1.72s/it]

Cross-validated error: 0.010094858619085779
[I 2025-10-16 17:04:48,304] Trial 8 finished with value: 0.09036233784561847 and parameters: {'lr': 0.0009977976227738795, 'num_epochs': 6, 'batch_size': 64, 'num_neighbors': 11, 'lr_decay': 0.9121103506821937}. Best is trial 2 with value: 0.09551793180622016.


Best trial: 2. Best value: 0.0955179:  50%|█████     | 10/20 [00:17<00:17,  1.72s/it]

Cross-validated error: 0.009938402072972434
[I 2025-10-16 17:04:50,000] Trial 9 finished with value: 0.08973106702143344 and parameters: {'lr': 0.0005145320426193893, 'num_epochs': 5, 'batch_size': 128, 'num_neighbors': 13, 'lr_decay': 0.8623361726239775}. Best is trial 2 with value: 0.09551793180622016.


Best trial: 10. Best value: 0.101517:  55%|█████▌    | 11/20 [00:19<00:15,  1.69s/it]

Cross-validated error: 0.013161661609252378
[I 2025-10-16 17:04:51,641] Trial 10 finished with value: 0.10151694304912517 and parameters: {'lr': 0.00013295012499051913, 'num_epochs': 1, 'batch_size': 256, 'num_neighbors': 3, 'lr_decay': 0.8115173109082136}. Best is trial 10 with value: 0.10151694304912517.


Best trial: 11. Best value: 0.101536:  60%|██████    | 12/20 [00:20<00:13,  1.69s/it]

Cross-validated error: 0.013162256065375312
[I 2025-10-16 17:04:53,320] Trial 11 finished with value: 0.10153643747995664 and parameters: {'lr': 0.00011073705126541502, 'num_epochs': 1, 'batch_size': 256, 'num_neighbors': 3, 'lr_decay': 0.8025261031904265}. Best is trial 11 with value: 0.10153643747995664.


Best trial: 11. Best value: 0.101536:  65%|██████▌   | 13/20 [00:22<00:11,  1.66s/it]

Cross-validated error: 0.013163483461559456
[I 2025-10-16 17:04:54,926] Trial 12 finished with value: 0.10152667925193096 and parameters: {'lr': 0.00010966705960366183, 'num_epochs': 1, 'batch_size': 256, 'num_neighbors': 3, 'lr_decay': 0.8062463696815886}. Best is trial 11 with value: 0.10153643747995664.


Best trial: 13. Best value: 0.101537:  70%|███████   | 14/20 [00:24<00:09,  1.65s/it]

Cross-validated error: 0.013152658323577553
[I 2025-10-16 17:04:56,529] Trial 13 finished with value: 0.10153707011273463 and parameters: {'lr': 0.00025976955978605694, 'num_epochs': 1, 'batch_size': 256, 'num_neighbors': 3, 'lr_decay': 0.8057803501951822}. Best is trial 13 with value: 0.10153707011273463.


Best trial: 13. Best value: 0.101537:  75%|███████▌  | 15/20 [00:25<00:08,  1.66s/it]

Cross-validated error: 0.011988543663041432
[I 2025-10-16 17:04:58,227] Trial 14 finished with value: 0.09747820395350681 and parameters: {'lr': 0.0002663449189269235, 'num_epochs': 2, 'batch_size': 256, 'num_neighbors': 5, 'lr_decay': 0.9916206981260058}. Best is trial 13 with value: 0.10153707011273463.


Best trial: 13. Best value: 0.101537:  80%|████████  | 16/20 [00:27<00:06,  1.68s/it]

Cross-validated error: 0.009804656175740195
[I 2025-10-16 17:04:59,936] Trial 15 finished with value: 0.08921321697785925 and parameters: {'lr': 0.00026946647737301286, 'num_epochs': 3, 'batch_size': 512, 'num_neighbors': 15, 'lr_decay': 0.8006589784852907}. Best is trial 13 with value: 0.10153707011273463.


Best trial: 13. Best value: 0.101537:  85%|████████▌ | 17/20 [00:29<00:04,  1.66s/it]

Cross-validated error: 0.011989264881590696
[I 2025-10-16 17:05:01,563] Trial 16 finished with value: 0.0974901987221716 and parameters: {'lr': 0.0002511722710824395, 'num_epochs': 2, 'batch_size': 256, 'num_neighbors': 5, 'lr_decay': 0.8805553666580285}. Best is trial 13 with value: 0.10153707011273463.


Best trial: 13. Best value: 0.101537:  90%|█████████ | 18/20 [00:30<00:03,  1.66s/it]

Cross-validated error: 0.012744674617740422
[I 2025-10-16 17:05:03,215] Trial 17 finished with value: 0.1001446059758383 and parameters: {'lr': 0.00020074026826942694, 'num_epochs': 3, 'batch_size': 256, 'num_neighbors': 4, 'lr_decay': 0.8292313335530229}. Best is trial 13 with value: 0.10153707011273463.


Best trial: 13. Best value: 0.101537:  95%|█████████▌| 19/20 [00:32<00:01,  1.64s/it]

Cross-validated error: 0.013155506058793924
[I 2025-10-16 17:05:04,825] Trial 18 finished with value: 0.10153387391665998 and parameters: {'lr': 0.0004894888598534817, 'num_epochs': 1, 'batch_size': 256, 'num_neighbors': 3, 'lr_decay': 0.8639663742076548}. Best is trial 13 with value: 0.10153707011273463.


Best trial: 13. Best value: 0.101537: 100%|██████████| 20/20 [00:33<00:00,  1.70s/it]

Cross-validated error: 0.011988062402715941
[I 2025-10-16 17:05:06,470] Trial 19 finished with value: 0.09749421459847735 and parameters: {'lr': 0.0001758184052092875, 'num_epochs': 2, 'batch_size': 512, 'num_neighbors': 5, 'lr_decay': 0.9385427545852538}. Best is trial 13 with value: 0.10153707011273463.





{'gini': np.float64(0.46088596120278164), 'ess': np.float64(4449.640756608241), 'max_wi': np.float64(33.15015508403525), 'min_wi': np.float64(0.02655396316711536)}


[I 2025-10-16 17:05:08,875] A new study created in memory with name: no-name-b6350b47-5731-45b0-b97c-097fa4649657
Best trial: 0. Best value: 0.0749437:   5%|▌         | 1/20 [00:02<00:47,  2.49s/it]

Cross-validated error: 0.006660832877500933
[I 2025-10-16 17:05:11,362] Trial 0 finished with value: 0.07494366219272666 and parameters: {'lr': 0.00025976955978605694, 'num_epochs': 1, 'batch_size': 256, 'num_neighbors': 3, 'lr_decay': 0.8057803501951822}. Best is trial 0 with value: 0.07494366219272666.


Best trial: 1. Best value: 0.076151:  10%|█         | 2/20 [00:05<00:49,  2.73s/it] 

Cross-validated error: 0.006895260942475331
[I 2025-10-16 17:05:14,262] Trial 1 finished with value: 0.07615097033164392 and parameters: {'lr': 0.001530810213257148, 'num_epochs': 5, 'batch_size': 64, 'num_neighbors': 11, 'lr_decay': 0.8586215198058325}. Best is trial 1 with value: 0.07615097033164392.


Best trial: 1. Best value: 0.076151:  15%|█▌        | 3/20 [00:08<00:47,  2.77s/it]

Cross-validated error: 0.0068646202000036535
[I 2025-10-16 17:05:17,088] Trial 2 finished with value: 0.07597443747733179 and parameters: {'lr': 0.0003076774925680022, 'num_epochs': 3, 'batch_size': 64, 'num_neighbors': 14, 'lr_decay': 0.9719571571262107}. Best is trial 1 with value: 0.07615097033164392.


Best trial: 1. Best value: 0.076151:  20%|██        | 4/20 [00:11<00:45,  2.82s/it]

Cross-validated error: 0.006412889004976751
[I 2025-10-16 17:05:19,983] Trial 3 finished with value: 0.07362886826893965 and parameters: {'lr': 0.00012224144085258904, 'num_epochs': 6, 'batch_size': 128, 'num_neighbors': 6, 'lr_decay': 0.9079673779103115}. Best is trial 1 with value: 0.07615097033164392.


Best trial: 1. Best value: 0.076151:  25%|██▌       | 5/20 [00:14<00:42,  2.86s/it]

Cross-validated error: 0.006596068366521396
[I 2025-10-16 17:05:22,907] Trial 4 finished with value: 0.07458404456195805 and parameters: {'lr': 0.0016363625863142896, 'num_epochs': 9, 'batch_size': 256, 'num_neighbors': 4, 'lr_decay': 0.9015445049596268}. Best is trial 1 with value: 0.07615097033164392.


Best trial: 1. Best value: 0.076151:  30%|███       | 6/20 [00:17<00:40,  2.90s/it]

Cross-validated error: 0.0068496618839197256
[I 2025-10-16 17:05:25,888] Trial 5 finished with value: 0.07590485776185855 and parameters: {'lr': 0.00019283985763070817, 'num_epochs': 7, 'batch_size': 64, 'num_neighbors': 15, 'lr_decay': 0.8857434571993493}. Best is trial 1 with value: 0.07615097033164392.


Best trial: 1. Best value: 0.076151:  35%|███▌      | 7/20 [00:19<00:37,  2.86s/it]

Cross-validated error: 0.006573756834207933
[I 2025-10-16 17:05:28,650] Trial 6 finished with value: 0.07451585843115993 and parameters: {'lr': 0.004538419133745602, 'num_epochs': 2, 'batch_size': 128, 'num_neighbors': 3, 'lr_decay': 0.889466473373138}. Best is trial 1 with value: 0.07615097033164392.


Best trial: 1. Best value: 0.076151:  40%|████      | 8/20 [00:22<00:35,  2.93s/it]

Cross-validated error: 0.006769220599065302
[I 2025-10-16 17:05:31,736] Trial 7 finished with value: 0.07551024418252382 and parameters: {'lr': 0.0006602598660542721, 'num_epochs': 10, 'batch_size': 256, 'num_neighbors': 11, 'lr_decay': 0.8927176965442517}. Best is trial 1 with value: 0.07615097033164392.


Best trial: 1. Best value: 0.076151:  45%|████▌     | 9/20 [00:25<00:32,  2.93s/it]

Cross-validated error: 0.006658262171508681
[I 2025-10-16 17:05:34,656] Trial 8 finished with value: 0.07491080558168037 and parameters: {'lr': 0.00015382624263048003, 'num_epochs': 9, 'batch_size': 256, 'num_neighbors': 3, 'lr_decay': 0.8181411245325578}. Best is trial 1 with value: 0.07615097033164392.


Best trial: 9. Best value: 0.0771746:  50%|█████     | 10/20 [00:28<00:29,  2.96s/it]

Cross-validated error: 0.007104105137918493
[I 2025-10-16 17:05:37,679] Trial 9 finished with value: 0.07717455080346448 and parameters: {'lr': 0.00645035763141304, 'num_epochs': 7, 'batch_size': 128, 'num_neighbors': 12, 'lr_decay': 0.8753865246344537}. Best is trial 9 with value: 0.07717455080346448.


Best trial: 9. Best value: 0.0771746:  55%|█████▌    | 11/20 [00:31<00:26,  2.94s/it]

Cross-validated error: 0.006456554419242222
[I 2025-10-16 17:05:40,569] Trial 10 finished with value: 0.07389146081974327 and parameters: {'lr': 0.009022055480726479, 'num_epochs': 4, 'batch_size': 512, 'num_neighbors': 8, 'lr_decay': 0.9734249302121355}. Best is trial 9 with value: 0.07717455080346448.


Best trial: 9. Best value: 0.0771746:  60%|██████    | 12/20 [00:34<00:23,  2.94s/it]

Cross-validated error: 0.006780297362268142
[I 2025-10-16 17:05:43,515] Trial 11 finished with value: 0.07556619874894836 and parameters: {'lr': 0.0021002756036394856, 'num_epochs': 5, 'batch_size': 128, 'num_neighbors': 11, 'lr_decay': 0.8480756750746413}. Best is trial 9 with value: 0.07717455080346448.


Best trial: 12. Best value: 0.0786479:  65%|██████▌   | 13/20 [00:37<00:20,  2.97s/it]

Cross-validated error: 0.007404480568070046
[I 2025-10-16 17:05:46,566] Trial 12 finished with value: 0.07864786631632549 and parameters: {'lr': 0.0039459077200945224, 'num_epochs': 7, 'batch_size': 64, 'num_neighbors': 12, 'lr_decay': 0.850874323422532}. Best is trial 12 with value: 0.07864786631632549.


Best trial: 12. Best value: 0.0786479:  70%|███████   | 14/20 [00:40<00:17,  2.98s/it]

Cross-validated error: 0.006697615846818993
[I 2025-10-16 17:05:49,572] Trial 13 finished with value: 0.07511343192517397 and parameters: {'lr': 0.009877088966523114, 'num_epochs': 7, 'batch_size': 512, 'num_neighbors': 13, 'lr_decay': 0.943737953996505}. Best is trial 12 with value: 0.07864786631632549.


Best trial: 12. Best value: 0.0786479:  75%|███████▌  | 15/20 [00:43<00:14,  2.98s/it]

Cross-validated error: 0.006641725476218371
[I 2025-10-16 17:05:52,542] Trial 14 finished with value: 0.07484733264976592 and parameters: {'lr': 0.0033891543343062412, 'num_epochs': 8, 'batch_size': 128, 'num_neighbors': 9, 'lr_decay': 0.8440052319284073}. Best is trial 12 with value: 0.07864786631632549.


Best trial: 15. Best value: 0.0787028:  80%|████████  | 16/20 [00:46<00:11,  2.98s/it]

Cross-validated error: 0.007417532711707614
[I 2025-10-16 17:05:55,540] Trial 15 finished with value: 0.07870283331620737 and parameters: {'lr': 0.004802661027545419, 'num_epochs': 7, 'batch_size': 64, 'num_neighbors': 12, 'lr_decay': 0.9307668267288444}. Best is trial 15 with value: 0.07870283331620737.


Best trial: 15. Best value: 0.0787028:  85%|████████▌ | 17/20 [00:49<00:08,  2.98s/it]

Cross-validated error: 0.007241183391739006
[I 2025-10-16 17:05:58,497] Trial 16 finished with value: 0.0778586351300149 and parameters: {'lr': 0.0032846425902398947, 'num_epochs': 6, 'batch_size': 64, 'num_neighbors': 9, 'lr_decay': 0.9315689955625682}. Best is trial 15 with value: 0.07870283331620737.


Best trial: 15. Best value: 0.0787028:  90%|█████████ | 18/20 [00:52<00:05,  2.99s/it]

Cross-validated error: 0.007005063275300157
[I 2025-10-16 17:06:01,534] Trial 17 finished with value: 0.07668312303340609 and parameters: {'lr': 0.0009036746067634838, 'num_epochs': 8, 'batch_size': 64, 'num_neighbors': 13, 'lr_decay': 0.9938017084280583}. Best is trial 15 with value: 0.07870283331620737.


Best trial: 15. Best value: 0.0787028:  95%|█████████▌| 19/20 [00:55<00:02,  2.95s/it]

Cross-validated error: 0.007069530365579816
[I 2025-10-16 17:06:04,370] Trial 18 finished with value: 0.07698513741743684 and parameters: {'lr': 0.004849397377782079, 'num_epochs': 4, 'batch_size': 64, 'num_neighbors': 7, 'lr_decay': 0.9315842350247223}. Best is trial 15 with value: 0.07870283331620737.


Best trial: 15. Best value: 0.0787028: 100%|██████████| 20/20 [00:58<00:00,  2.93s/it]

Cross-validated error: 0.007131931242761227
[I 2025-10-16 17:06:07,467] Trial 19 finished with value: 0.07730044049407693 and parameters: {'lr': 0.00272669800504152, 'num_epochs': 10, 'batch_size': 64, 'num_neighbors': 15, 'lr_decay': 0.8238240252018805}. Best is trial 15 with value: 0.07870283331620737.





{'gini': np.float64(0.5432471254412999), 'ess': np.float64(3112.585846918791), 'max_wi': np.float64(56.264163337141575), 'min_wi': np.float64(0.009158557338997134)}


[I 2025-10-16 17:06:11,176] A new study created in memory with name: no-name-f91c84a2-c3ab-4b4e-bd64-6d62eda89aa0
Best trial: 0. Best value: 0.0845076:   5%|▌         | 1/20 [00:06<02:02,  6.44s/it]

Cross-validated error: 0.00867720239495928
[I 2025-10-16 17:06:17,610] Trial 0 finished with value: 0.08450757079954678 and parameters: {'lr': 0.004802661027545419, 'num_epochs': 7, 'batch_size': 64, 'num_neighbors': 12, 'lr_decay': 0.9307668267288444}. Best is trial 0 with value: 0.08450757079954678.


Best trial: 1. Best value: 0.0939595:  10%|█         | 2/20 [00:12<01:57,  6.51s/it]

Cross-validated error: 0.01102853134116372
[I 2025-10-16 17:06:24,171] Trial 1 finished with value: 0.09395949879742266 and parameters: {'lr': 0.003875683816825103, 'num_epochs': 9, 'batch_size': 256, 'num_neighbors': 6, 'lr_decay': 0.9248082223936485}. Best is trial 1 with value: 0.09395949879742266.


Best trial: 1. Best value: 0.0939595:  15%|█▌        | 3/20 [00:18<01:45,  6.22s/it]

Cross-validated error: 0.010799526007503803
[I 2025-10-16 17:06:30,040] Trial 2 finished with value: 0.09313629289273094 and parameters: {'lr': 0.0021855030856577583, 'num_epochs': 3, 'batch_size': 256, 'num_neighbors': 5, 'lr_decay': 0.9150524748614448}. Best is trial 1 with value: 0.09395949879742266.


Best trial: 1. Best value: 0.0939595:  20%|██        | 4/20 [00:25<01:41,  6.32s/it]

Cross-validated error: 0.010585216217405494
[I 2025-10-16 17:06:36,515] Trial 3 finished with value: 0.09227768192646636 and parameters: {'lr': 0.00016849674669940746, 'num_epochs': 9, 'batch_size': 64, 'num_neighbors': 13, 'lr_decay': 0.9434256602068696}. Best is trial 1 with value: 0.09395949879742266.


Best trial: 1. Best value: 0.0939595:  25%|██▌       | 5/20 [00:31<01:34,  6.28s/it]

Cross-validated error: 0.010654090585898494
[I 2025-10-16 17:06:42,722] Trial 4 finished with value: 0.09256937518812156 and parameters: {'lr': 0.003397928301396086, 'num_epochs': 3, 'batch_size': 256, 'num_neighbors': 10, 'lr_decay': 0.9537859883796826}. Best is trial 1 with value: 0.09395949879742266.


Best trial: 1. Best value: 0.0939595:  30%|███       | 6/20 [00:36<01:23,  5.99s/it]

Cross-validated error: 0.010927432976723507
[I 2025-10-16 17:06:48,156] Trial 5 finished with value: 0.09358020873722957 and parameters: {'lr': 0.0027837786843711147, 'num_epochs': 1, 'batch_size': 512, 'num_neighbors': 8, 'lr_decay': 0.8948892660899135}. Best is trial 1 with value: 0.09395949879742266.


Best trial: 1. Best value: 0.0939595:  35%|███▌      | 7/20 [00:42<01:16,  5.89s/it]

Cross-validated error: 0.010783272845199768
[I 2025-10-16 17:06:53,837] Trial 6 finished with value: 0.0930344994049413 and parameters: {'lr': 0.0012911311311009216, 'num_epochs': 9, 'batch_size': 256, 'num_neighbors': 8, 'lr_decay': 0.8508598390720506}. Best is trial 1 with value: 0.09395949879742266.


Best trial: 1. Best value: 0.0939595:  40%|████      | 8/20 [00:48<01:10,  5.84s/it]

Cross-validated error: 0.009962233852027148
[I 2025-10-16 17:06:59,579] Trial 7 finished with value: 0.08982619288911885 and parameters: {'lr': 0.005206694903545054, 'num_epochs': 10, 'batch_size': 256, 'num_neighbors': 12, 'lr_decay': 0.9711564718839419}. Best is trial 1 with value: 0.09395949879742266.


Best trial: 1. Best value: 0.0939595:  45%|████▌     | 9/20 [00:54<01:05,  5.98s/it]

Cross-validated error: 0.010888943357687898
[I 2025-10-16 17:07:05,860] Trial 8 finished with value: 0.0934593785018975 and parameters: {'lr': 0.00021627705551389734, 'num_epochs': 7, 'batch_size': 256, 'num_neighbors': 6, 'lr_decay': 0.8976274458664016}. Best is trial 1 with value: 0.09395949879742266.


Best trial: 1. Best value: 0.0939595:  50%|█████     | 10/20 [01:00<00:59,  5.96s/it]

Cross-validated error: 0.01091208847706827
[I 2025-10-16 17:07:11,770] Trial 9 finished with value: 0.09355320420144217 and parameters: {'lr': 0.0005659260945697457, 'num_epochs': 3, 'batch_size': 256, 'num_neighbors': 8, 'lr_decay': 0.8068749566742087}. Best is trial 1 with value: 0.09395949879742266.


Best trial: 1. Best value: 0.0939595:  55%|█████▌    | 11/20 [01:06<00:53,  6.00s/it]

Cross-validated error: 0.009505637258464945
[I 2025-10-16 17:07:17,863] Trial 10 finished with value: 0.08798581456130682 and parameters: {'lr': 0.008190430438806753, 'num_epochs': 6, 'batch_size': 128, 'num_neighbors': 15, 'lr_decay': 0.9996743390505459}. Best is trial 1 with value: 0.09395949879742266.


Best trial: 1. Best value: 0.0939595:  60%|██████    | 12/20 [01:12<00:47,  5.93s/it]

Cross-validated error: 0.010863629603771345
[I 2025-10-16 17:07:23,630] Trial 11 finished with value: 0.09335426884895193 and parameters: {'lr': 0.0012328467051262184, 'num_epochs': 1, 'batch_size': 512, 'num_neighbors': 5, 'lr_decay': 0.8763655593131847}. Best is trial 1 with value: 0.09395949879742266.


Best trial: 1. Best value: 0.0939595:  65%|██████▌   | 13/20 [01:18<00:41,  5.86s/it]

Cross-validated error: 0.010609840324794065
[I 2025-10-16 17:07:29,334] Trial 12 finished with value: 0.09238305810521633 and parameters: {'lr': 0.002332997884040731, 'num_epochs': 1, 'batch_size': 512, 'num_neighbors': 3, 'lr_decay': 0.875403661221258}. Best is trial 1 with value: 0.09395949879742266.


Best trial: 1. Best value: 0.0939595:  70%|███████   | 14/20 [01:24<00:35,  5.87s/it]

Cross-validated error: 0.010763672190570353
[I 2025-10-16 17:07:35,230] Trial 13 finished with value: 0.09294060401320588 and parameters: {'lr': 0.009670715768200175, 'num_epochs': 5, 'batch_size': 512, 'num_neighbors': 7, 'lr_decay': 0.8396882621972929}. Best is trial 1 with value: 0.09395949879742266.


Best trial: 1. Best value: 0.0939595:  75%|███████▌  | 15/20 [01:30<00:29,  5.90s/it]

Cross-validated error: 0.010813077541262399
[I 2025-10-16 17:07:41,184] Trial 14 finished with value: 0.0931506976909659 and parameters: {'lr': 0.0004761523097434946, 'num_epochs': 5, 'batch_size': 128, 'num_neighbors': 10, 'lr_decay': 0.8956614024889251}. Best is trial 1 with value: 0.09395949879742266.


Best trial: 1. Best value: 0.0939595:  80%|████████  | 16/20 [01:36<00:23,  5.93s/it]

Cross-validated error: 0.010452067348046198
[I 2025-10-16 17:07:47,194] Trial 15 finished with value: 0.09176955431319211 and parameters: {'lr': 0.002249476816990836, 'num_epochs': 8, 'batch_size': 512, 'num_neighbors': 3, 'lr_decay': 0.9208502277941439}. Best is trial 1 with value: 0.09395949879742266.


Best trial: 1. Best value: 0.0939595:  85%|████████▌ | 17/20 [01:41<00:17,  5.93s/it]

Cross-validated error: 0.010943715364529047
[I 2025-10-16 17:07:53,139] Trial 16 finished with value: 0.09363926528752421 and parameters: {'lr': 0.0007044775507599226, 'num_epochs': 4, 'batch_size': 512, 'num_neighbors': 9, 'lr_decay': 0.9632321509451643}. Best is trial 1 with value: 0.09395949879742266.


Best trial: 1. Best value: 0.0939595:  90%|█████████ | 18/20 [01:48<00:11,  5.97s/it]

Cross-validated error: 0.010745320900562245
[I 2025-10-16 17:07:59,189] Trial 17 finished with value: 0.09289823407597186 and parameters: {'lr': 0.0006241708023400658, 'num_epochs': 4, 'batch_size': 128, 'num_neighbors': 10, 'lr_decay': 0.9799413544529334}. Best is trial 1 with value: 0.09395949879742266.


Best trial: 1. Best value: 0.0939595:  95%|█████████▌| 19/20 [01:54<00:06,  6.04s/it]

Cross-validated error: 0.010677348784391407
[I 2025-10-16 17:08:05,382] Trial 18 finished with value: 0.09267838394170018 and parameters: {'lr': 0.0003097195262357522, 'num_epochs': 6, 'batch_size': 64, 'num_neighbors': 5, 'lr_decay': 0.9676514983757013}. Best is trial 1 with value: 0.09395949879742266.


Best trial: 1. Best value: 0.0939595: 100%|██████████| 20/20 [02:00<00:00,  6.02s/it]

Cross-validated error: 0.01083483255138326
[I 2025-10-16 17:08:11,640] Trial 19 finished with value: 0.09324683745850322 and parameters: {'lr': 0.0009138723824906459, 'num_epochs': 10, 'batch_size': 512, 'num_neighbors': 6, 'lr_decay': 0.9931585626036461}. Best is trial 1 with value: 0.09395949879742266.





{'gini': np.float64(0.6526563208778949), 'ess': np.float64(1909.7452846800886), 'max_wi': np.float64(70.86693977832296), 'min_wi': np.float64(0.005222348795357526)}


[I 2025-10-16 17:08:18,828] A new study created in memory with name: no-name-8f7b5498-30ef-47bd-ae1d-9eca89b9ee9f
Best trial: 0. Best value: 0.0996762:   5%|▌         | 1/20 [00:30<09:40, 30.54s/it]

Cross-validated error: 0.012632933641719595
[I 2025-10-16 17:08:49,370] Trial 0 finished with value: 0.09967624187799504 and parameters: {'lr': 0.003875683816825103, 'num_epochs': 9, 'batch_size': 256, 'num_neighbors': 6, 'lr_decay': 0.9248082223936485}. Best is trial 0 with value: 0.09967624187799504.


Best trial: 0. Best value: 0.0996762:  10%|█         | 2/20 [00:59<08:51, 29.54s/it]

Cross-validated error: 0.007527955525383548
[I 2025-10-16 17:09:18,202] Trial 1 finished with value: 0.0792368290642977 and parameters: {'lr': 0.0009509239081206644, 'num_epochs': 2, 'batch_size': 256, 'num_neighbors': 14, 'lr_decay': 0.9996485895058593}. Best is trial 0 with value: 0.09967624187799504.


Best trial: 0. Best value: 0.0996762:  15%|█▌        | 3/20 [01:28<08:21, 29.50s/it]

Cross-validated error: 0.006712556371092115
[I 2025-10-16 17:09:47,659] Trial 2 finished with value: 0.07519139316584827 and parameters: {'lr': 0.006821780205659142, 'num_epochs': 6, 'batch_size': 64, 'num_neighbors': 12, 'lr_decay': 0.8443935859142442}. Best is trial 0 with value: 0.09967624187799504.


Best trial: 0. Best value: 0.0996762:  20%|██        | 4/20 [02:00<08:05, 30.32s/it]

Cross-validated error: 0.008368368165779972
[I 2025-10-16 17:10:19,230] Trial 3 finished with value: 0.08308227404155503 and parameters: {'lr': 0.001826663913806818, 'num_epochs': 9, 'batch_size': 64, 'num_neighbors': 11, 'lr_decay': 0.9844240227584448}. Best is trial 0 with value: 0.09967624187799504.


Best trial: 0. Best value: 0.0996762:  25%|██▌       | 5/20 [02:29<07:30, 30.04s/it]

Cross-validated error: 0.010512766967316962
[I 2025-10-16 17:10:48,789] Trial 4 finished with value: 0.09202951770679843 and parameters: {'lr': 0.007966060544349863, 'num_epochs': 6, 'batch_size': 512, 'num_neighbors': 7, 'lr_decay': 0.8394869419831281}. Best is trial 0 with value: 0.09967624187799504.


Best trial: 0. Best value: 0.0996762:  30%|███       | 6/20 [02:59<07:00, 30.01s/it]

Cross-validated error: 0.007544737745963051
[I 2025-10-16 17:11:18,743] Trial 5 finished with value: 0.07930393028795989 and parameters: {'lr': 0.00017048261663612978, 'num_epochs': 8, 'batch_size': 128, 'num_neighbors': 9, 'lr_decay': 0.970484776644635}. Best is trial 0 with value: 0.09967624187799504.


Best trial: 0. Best value: 0.0996762:  35%|███▌      | 7/20 [03:28<06:22, 29.45s/it]

Cross-validated error: 0.007505093564921077
[I 2025-10-16 17:11:47,043] Trial 6 finished with value: 0.0791364255818348 and parameters: {'lr': 0.00020005711078602904, 'num_epochs': 4, 'batch_size': 256, 'num_neighbors': 5, 'lr_decay': 0.9312429643151606}. Best is trial 0 with value: 0.09967624187799504.


Best trial: 0. Best value: 0.0996762:  40%|████      | 8/20 [03:56<05:47, 29.00s/it]

Cross-validated error: 0.008589959938695779
[I 2025-10-16 17:12:15,061] Trial 7 finished with value: 0.08406614169613996 and parameters: {'lr': 0.009116009967335173, 'num_epochs': 4, 'batch_size': 512, 'num_neighbors': 7, 'lr_decay': 0.9985245943082409}. Best is trial 0 with value: 0.09967624187799504.


Best trial: 0. Best value: 0.0996762:  45%|████▌     | 9/20 [04:25<05:19, 29.03s/it]

Cross-validated error: 0.008187831296400879
[I 2025-10-16 17:12:44,177] Trial 8 finished with value: 0.08226020567843188 and parameters: {'lr': 0.007546533178143631, 'num_epochs': 8, 'batch_size': 256, 'num_neighbors': 9, 'lr_decay': 0.8423364020156331}. Best is trial 0 with value: 0.09967624187799504.


Best trial: 0. Best value: 0.0996762:  50%|█████     | 10/20 [04:54<04:51, 29.13s/it]

Cross-validated error: 0.00799676262216814
[I 2025-10-16 17:13:13,518] Trial 9 finished with value: 0.08145562339372343 and parameters: {'lr': 0.0018982723980748284, 'num_epochs': 5, 'batch_size': 64, 'num_neighbors': 14, 'lr_decay': 0.8092361904264403}. Best is trial 0 with value: 0.09967624187799504.


Best trial: 0. Best value: 0.0996762:  55%|█████▌    | 11/20 [05:24<04:25, 29.48s/it]

Cross-validated error: 0.008162626590416597
[I 2025-10-16 17:13:43,804] Trial 10 finished with value: 0.08217402448522942 and parameters: {'lr': 0.0005676073392218347, 'num_epochs': 10, 'batch_size': 128, 'num_neighbors': 3, 'lr_decay': 0.9035205660026691}. Best is trial 0 with value: 0.09967624187799504.


Best trial: 0. Best value: 0.0996762:  60%|██████    | 12/20 [05:54<03:55, 29.48s/it]

Cross-validated error: 0.008432154927201507
[I 2025-10-16 17:14:13,263] Trial 11 finished with value: 0.0833903886170925 and parameters: {'lr': 0.003324436474686724, 'num_epochs': 7, 'batch_size': 512, 'num_neighbors': 6, 'lr_decay': 0.8802633132438644}. Best is trial 0 with value: 0.09967624187799504.


Best trial: 0. Best value: 0.0996762:  65%|██████▌   | 13/20 [06:24<03:27, 29.62s/it]

Cross-validated error: 0.008647918700489908
[I 2025-10-16 17:14:43,218] Trial 12 finished with value: 0.0842848471636279 and parameters: {'lr': 0.003619483164475682, 'num_epochs': 10, 'batch_size': 512, 'num_neighbors': 4, 'lr_decay': 0.9222172884247815}. Best is trial 0 with value: 0.09967624187799504.


Best trial: 0. Best value: 0.0996762:  70%|███████   | 14/20 [06:52<02:54, 29.02s/it]

Cross-validated error: 0.007668173112740878
[I 2025-10-16 17:15:10,847] Trial 13 finished with value: 0.07991328582243003 and parameters: {'lr': 0.0038862173999790707, 'num_epochs': 1, 'batch_size': 512, 'num_neighbors': 7, 'lr_decay': 0.8761656391482968}. Best is trial 0 with value: 0.09967624187799504.


Best trial: 0. Best value: 0.0996762:  75%|███████▌  | 15/20 [07:21<02:26, 29.20s/it]

Cross-validated error: 0.007663581440423834
[I 2025-10-16 17:15:40,473] Trial 14 finished with value: 0.0798927180162575 and parameters: {'lr': 0.0004762459345350074, 'num_epochs': 6, 'batch_size': 256, 'num_neighbors': 8, 'lr_decay': 0.9502371737354407}. Best is trial 0 with value: 0.09967624187799504.


Best trial: 0. Best value: 0.0996762:  80%|████████  | 16/20 [07:52<01:58, 29.71s/it]

Cross-validated error: 0.007975817497490746
[I 2025-10-16 17:16:11,347] Trial 15 finished with value: 0.08133071567468056 and parameters: {'lr': 0.0017328453270741072, 'num_epochs': 8, 'batch_size': 512, 'num_neighbors': 11, 'lr_decay': 0.8107793711525745}. Best is trial 0 with value: 0.09967624187799504.


Best trial: 0. Best value: 0.0996762:  85%|████████▌ | 17/20 [08:20<01:27, 29.24s/it]

Cross-validated error: 0.01024256703951846
[I 2025-10-16 17:16:39,513] Trial 16 finished with value: 0.09091966781974208 and parameters: {'lr': 0.004889684505800363, 'num_epochs': 3, 'batch_size': 256, 'num_neighbors': 5, 'lr_decay': 0.8582406631649373}. Best is trial 0 with value: 0.09967624187799504.


Best trial: 0. Best value: 0.0996762:  90%|█████████ | 18/20 [08:50<00:58, 29.36s/it]

Cross-validated error: 0.00731201329697008
[I 2025-10-16 17:17:09,132] Trial 17 finished with value: 0.07813494514706532 and parameters: {'lr': 0.009948470894561983, 'num_epochs': 7, 'batch_size': 128, 'num_neighbors': 3, 'lr_decay': 0.9049422944812956}. Best is trial 0 with value: 0.09967624187799504.


Best trial: 0. Best value: 0.0996762:  95%|█████████▌| 19/20 [09:20<00:29, 29.46s/it]

Cross-validated error: 0.009329398718995276
[I 2025-10-16 17:17:38,837] Trial 18 finished with value: 0.08728159712257799 and parameters: {'lr': 0.0024876503538617716, 'num_epochs': 9, 'batch_size': 512, 'num_neighbors': 7, 'lr_decay': 0.9477811258696589}. Best is trial 0 with value: 0.09967624187799504.


Best trial: 0. Best value: 0.0996762: 100%|██████████| 20/20 [09:49<00:00, 29.49s/it]

Cross-validated error: 0.00800593878718246
[I 2025-10-16 17:18:08,623] Trial 19 finished with value: 0.08147830000328403 and parameters: {'lr': 0.0011241262385970154, 'num_epochs': 5, 'batch_size': 256, 'num_neighbors': 5, 'lr_decay': 0.880146177851786}. Best is trial 0 with value: 0.09967624187799504.





{'gini': np.float64(0.9897484519610177), 'ess': np.float64(91.92746578546355), 'max_wi': np.float64(414.91493458804086), 'min_wi': np.float64(1.9253317247576076e-19)}


[I 2025-10-16 17:18:41,577] A new study created in memory with name: no-name-4ab782a1-3603-41b8-aa65-c9a6dec5a73c
Best trial: 0. Best value: 0.0982203:   5%|▌         | 1/20 [01:02<19:41, 62.20s/it]

Cross-validated error: 0.012205247860988424
[I 2025-10-16 17:19:43,773] Trial 0 finished with value: 0.09822026625400156 and parameters: {'lr': 0.003875683816825103, 'num_epochs': 9, 'batch_size': 256, 'num_neighbors': 6, 'lr_decay': 0.9248082223936485}. Best is trial 0 with value: 0.09822026625400156.


Best trial: 0. Best value: 0.0982203:  10%|█         | 2/20 [01:58<17:35, 58.63s/it]

Cross-validated error: 0.007275528202115583
[I 2025-10-16 17:20:39,900] Trial 1 finished with value: 0.0780084851866848 and parameters: {'lr': 0.003988203035153607, 'num_epochs': 3, 'batch_size': 512, 'num_neighbors': 11, 'lr_decay': 0.8911707391431788}. Best is trial 0 with value: 0.09822026625400156.


Best trial: 0. Best value: 0.0982203:  15%|█▌        | 3/20 [02:56<16:33, 58.46s/it]

Cross-validated error: 0.00737371438120997
[I 2025-10-16 17:21:38,158] Trial 2 finished with value: 0.07843021339387583 and parameters: {'lr': 0.004154131873993655, 'num_epochs': 2, 'batch_size': 256, 'num_neighbors': 5, 'lr_decay': 0.8465708932264654}. Best is trial 0 with value: 0.09822026625400156.


Best trial: 0. Best value: 0.0982203:  20%|██        | 4/20 [03:54<15:32, 58.28s/it]

Cross-validated error: 0.008008567437752127
[I 2025-10-16 17:22:36,159] Trial 3 finished with value: 0.08147587822189922 and parameters: {'lr': 0.000101547403038134, 'num_epochs': 7, 'batch_size': 256, 'num_neighbors': 4, 'lr_decay': 0.8476336543263958}. Best is trial 0 with value: 0.09822026625400156.


Best trial: 0. Best value: 0.0982203:  25%|██▌       | 5/20 [04:52<14:33, 58.22s/it]

Cross-validated error: 0.008081883987705063
[I 2025-10-16 17:23:34,269] Trial 4 finished with value: 0.08181351542935927 and parameters: {'lr': 0.00022412242161083155, 'num_epochs': 5, 'batch_size': 64, 'num_neighbors': 3, 'lr_decay': 0.996310611330101}. Best is trial 0 with value: 0.09822026625400156.


Best trial: 0. Best value: 0.0982203:  30%|███       | 6/20 [05:50<13:31, 57.94s/it]

Cross-validated error: 0.00758284542278853
[I 2025-10-16 17:24:31,661] Trial 5 finished with value: 0.07950131658364695 and parameters: {'lr': 0.0008457940524388892, 'num_epochs': 6, 'batch_size': 256, 'num_neighbors': 7, 'lr_decay': 0.9797225477802021}. Best is trial 0 with value: 0.09822026625400156.


Best trial: 0. Best value: 0.0982203:  35%|███▌      | 7/20 [06:45<12:23, 57.22s/it]

Cross-validated error: 0.008072421118690274
[I 2025-10-16 17:25:27,392] Trial 6 finished with value: 0.08178024827089249 and parameters: {'lr': 0.0007094884807033015, 'num_epochs': 1, 'batch_size': 128, 'num_neighbors': 6, 'lr_decay': 0.9587672640223927}. Best is trial 0 with value: 0.09822026625400156.


Best trial: 0. Best value: 0.0982203:  40%|████      | 8/20 [07:46<11:41, 58.42s/it]

Cross-validated error: 0.0055764698208383345
[I 2025-10-16 17:26:28,392] Trial 7 finished with value: 0.06902558347268986 and parameters: {'lr': 0.0006885537224663438, 'num_epochs': 8, 'batch_size': 64, 'num_neighbors': 7, 'lr_decay': 0.9658852638219069}. Best is trial 0 with value: 0.09822026625400156.


Best trial: 0. Best value: 0.0982203:  45%|████▌     | 9/20 [08:41<10:29, 57.24s/it]

Cross-validated error: 0.005869079903171023
[I 2025-10-16 17:27:23,047] Trial 8 finished with value: 0.07064855053574398 and parameters: {'lr': 0.007131238040326657, 'num_epochs': 2, 'batch_size': 512, 'num_neighbors': 3, 'lr_decay': 0.8115609170823834}. Best is trial 0 with value: 0.09822026625400156.


Best trial: 0. Best value: 0.0982203:  50%|█████     | 10/20 [09:35<09:22, 56.23s/it]

Cross-validated error: 0.008006044007694871
[I 2025-10-16 17:28:17,017] Trial 9 finished with value: 0.08147598184233444 and parameters: {'lr': 0.001444035032346878, 'num_epochs': 2, 'batch_size': 512, 'num_neighbors': 11, 'lr_decay': 0.9207146901518635}. Best is trial 0 with value: 0.09822026625400156.


Best trial: 0. Best value: 0.0982203:  55%|█████▌    | 11/20 [10:35<08:36, 57.42s/it]

Cross-validated error: 0.0034421718366943995
[I 2025-10-16 17:29:17,138] Trial 10 finished with value: 0.05516033956712753 and parameters: {'lr': 0.002442432471597029, 'num_epochs': 10, 'batch_size': 128, 'num_neighbors': 14, 'lr_decay': 0.9172802350527366}. Best is trial 0 with value: 0.09822026625400156.


Best trial: 0. Best value: 0.0982203:  60%|██████    | 12/20 [11:31<07:35, 56.98s/it]

Cross-validated error: 0.008075807437255268
[I 2025-10-16 17:30:13,121] Trial 11 finished with value: 0.08177888150007857 and parameters: {'lr': 0.00020695308068370297, 'num_epochs': 4, 'batch_size': 64, 'num_neighbors': 3, 'lr_decay': 0.9890886392871586}. Best is trial 0 with value: 0.09822026625400156.


Best trial: 0. Best value: 0.0982203:  65%|██████▌   | 13/20 [12:34<06:51, 58.82s/it]

Cross-validated error: 0.006780893737039556
[I 2025-10-16 17:31:16,151] Trial 12 finished with value: 0.07559640517800946 and parameters: {'lr': 0.00033986387713497867, 'num_epochs': 9, 'batch_size': 64, 'num_neighbors': 9, 'lr_decay': 0.9418486754134497}. Best is trial 0 with value: 0.09822026625400156.


Best trial: 0. Best value: 0.0982203:  70%|███████   | 14/20 [13:31<05:50, 58.39s/it]

Cross-validated error: 0.0026516131825224747
[I 2025-10-16 17:32:13,573] Trial 13 finished with value: 0.04879196732103506 and parameters: {'lr': 0.009167169796048508, 'num_epochs': 5, 'batch_size': 64, 'num_neighbors': 8, 'lr_decay': 0.8833730125632304}. Best is trial 0 with value: 0.09822026625400156.


Best trial: 0. Best value: 0.0982203:  75%|███████▌  | 15/20 [14:30<04:52, 58.43s/it]

Cross-validated error: 0.008002888506138492
[I 2025-10-16 17:33:12,099] Trial 14 finished with value: 0.08144977544493913 and parameters: {'lr': 0.00036439999860197564, 'num_epochs': 6, 'batch_size': 256, 'num_neighbors': 5, 'lr_decay': 0.99774022605034}. Best is trial 0 with value: 0.09822026625400156.


Best trial: 0. Best value: 0.0982203:  80%|████████  | 16/20 [15:32<03:57, 59.39s/it]

Cross-validated error: 0.008004351908104178
[I 2025-10-16 17:34:13,696] Trial 15 finished with value: 0.08146146487175185 and parameters: {'lr': 0.0001168087065683276, 'num_epochs': 8, 'batch_size': 64, 'num_neighbors': 5, 'lr_decay': 0.9412776469215947}. Best is trial 0 with value: 0.09822026625400156.


Best trial: 0. Best value: 0.0982203:  85%|████████▌ | 17/20 [16:33<02:59, 59.94s/it]

Cross-validated error: 0.005505982452271491
[I 2025-10-16 17:35:14,908] Trial 16 finished with value: 0.0686875817698573 and parameters: {'lr': 0.0016106060436547748, 'num_epochs': 10, 'batch_size': 256, 'num_neighbors': 10, 'lr_decay': 0.8579265306085876}. Best is trial 0 with value: 0.09822026625400156.


Best trial: 0. Best value: 0.0982203:  90%|█████████ | 18/20 [17:31<01:58, 59.32s/it]

Cross-validated error: 0.008099456625972766
[I 2025-10-16 17:36:12,797] Trial 17 finished with value: 0.08188513555356025 and parameters: {'lr': 0.00039670437691385153, 'num_epochs': 5, 'batch_size': 128, 'num_neighbors': 3, 'lr_decay': 0.9177429717913402}. Best is trial 0 with value: 0.09822026625400156.


Best trial: 0. Best value: 0.0982203:  95%|█████████▌| 19/20 [18:30<00:59, 59.25s/it]

Cross-validated error: 0.004019543274036053
[I 2025-10-16 17:37:11,896] Trial 18 finished with value: 0.05940811293530281 and parameters: {'lr': 0.0029723822333769267, 'num_epochs': 7, 'batch_size': 128, 'num_neighbors': 15, 'lr_decay': 0.9213278346525696}. Best is trial 0 with value: 0.09822026625400156.


Best trial: 0. Best value: 0.0982203: 100%|██████████| 20/20 [19:27<00:00, 58.38s/it]

Cross-validated error: 0.008033223167107561
[I 2025-10-16 17:38:09,233] Trial 19 finished with value: 0.08159551450993047 and parameters: {'lr': 0.0004342580251017464, 'num_epochs': 4, 'batch_size': 128, 'num_neighbors': 6, 'lr_decay': 0.8776398389590361}. Best is trial 0 with value: 0.09822026625400156.





{'gini': np.float64(0.9941550718974186), 'ess': np.float64(62.94008576478465), 'max_wi': np.float64(390.83934708898147), 'min_wi': np.float64(2.120327915211682e-24)}


Unnamed: 0,policy_rewards,ipw,reg_dm,conv_dm,conv_dr,conv_sndr,action_diff_to_real,action_delta,context_diff_to_real,context_delta
0,0.08610747,0.0841,0.08415214,0.08978912,0.08653484,0.08653484,0.7569287,0.0,0.87627132,0.0
500,0.08610881,0.08769545,0.10020986,0.11468845,0.08882124,0.08882153,0.75691521,0.00063361,0.87626985,0.00031709
1000,0.08644733,0.08604138,0.08545258,0.08254102,0.08668264,0.08669657,0.75982833,0.09204378,0.87816584,0.04237359
2000,0.08694974,0.08625748,0.09133763,0.1030264,0.08723947,0.08708389,0.76900623,0.14652035,0.88062576,0.06360372
10000,0.07158604,0.04788352,0.06588676,0.05610879,0.05627977,0.05629828,0.88200747,0.38703708,1.06824968,0.32475498
20000,0.07569903,0.09758643,0.09521935,0.13363498,0.10557451,0.10549701,0.8831732,0.38378032,1.14827889,0.41905272


### Policy with delta function

In [7]:
df4[['policy_rewards', 'ipw', 'reg_dm', 'conv_dm', 'conv_dr', 'conv_sndr', 'action_diff_to_real', 'action_delta', 'context_diff_to_real', 'context_delta']]

Unnamed: 0,policy_rewards,ipw,reg_dm,conv_dm,conv_dr,conv_sndr,action_diff_to_real,action_delta,context_diff_to_real,context_delta
0,0.08610747,0.0841,0.08415214,0.08978912,0.08653484,0.08653484,0.7569287,0.0,0.87627132,0.0
500,0.08610881,0.08769545,0.10020986,0.11468845,0.08882124,0.08882153,0.75691521,0.00063361,0.87626985,0.00031709
1000,0.08644733,0.08604138,0.08545258,0.08254102,0.08668264,0.08669657,0.75982833,0.09204378,0.87816584,0.04237359
2000,0.08694974,0.08625748,0.09133763,0.1030264,0.08723947,0.08708389,0.76900623,0.14652035,0.88062576,0.06360372
10000,0.07158604,0.04788352,0.06588676,0.05610879,0.05627977,0.05629828,0.88200747,0.38703708,1.06824968,0.32475498
20000,0.07569903,0.09758643,0.09521935,0.13363498,0.10557451,0.10549701,0.8831732,0.38378032,1.14827889,0.41905272


### Poicy Via argmax(r_hat - error_hat) through cross validation

In [8]:
df4[['policy_rewards', 'ipw', 'reg_dm', 'conv_dm', 'conv_dr', 'conv_sndr', 'action_diff_to_real', 'action_delta', 'context_diff_to_real', 'context_delta']]

Unnamed: 0,policy_rewards,ipw,reg_dm,conv_dm,conv_dr,conv_sndr,action_diff_to_real,action_delta,context_diff_to_real,context_delta
0,0.08610747,0.0841,0.08415214,0.08978912,0.08653484,0.08653484,0.7569287,0.0,0.87627132,0.0
500,0.08610881,0.08769545,0.10020986,0.11468845,0.08882124,0.08882153,0.75691521,0.00063361,0.87626985,0.00031709
1000,0.08644733,0.08604138,0.08545258,0.08254102,0.08668264,0.08669657,0.75982833,0.09204378,0.87816584,0.04237359
2000,0.08694974,0.08625748,0.09133763,0.1030264,0.08723947,0.08708389,0.76900623,0.14652035,0.88062576,0.06360372
10000,0.07158604,0.04788352,0.06588676,0.05610879,0.05627977,0.05629828,0.88200747,0.38703708,1.06824968,0.32475498
20000,0.07569903,0.09758643,0.09521935,0.13363498,0.10557451,0.10549701,0.8831732,0.38378032,1.14827889,0.41905272


### Policy Via using actual policy value

In [9]:
# Show the performance metrics
df4[['policy_rewards', 'ipw', 'reg_dm', 'conv_dm', 'conv_dr', 'conv_sndr', 'action_diff_to_real', 'action_delta', 'context_diff_to_real', 'context_delta']]


Unnamed: 0,policy_rewards,ipw,reg_dm,conv_dm,conv_dr,conv_sndr,action_diff_to_real,action_delta,context_diff_to_real,context_delta
0,0.08610747,0.0841,0.08415214,0.08978912,0.08653484,0.08653484,0.7569287,0.0,0.87627132,0.0
500,0.08610881,0.08769545,0.10020986,0.11468845,0.08882124,0.08882153,0.75691521,0.00063361,0.87626985,0.00031709
1000,0.08644733,0.08604138,0.08545258,0.08254102,0.08668264,0.08669657,0.75982833,0.09204378,0.87816584,0.04237359
2000,0.08694974,0.08625748,0.09133763,0.1030264,0.08723947,0.08708389,0.76900623,0.14652035,0.88062576,0.06360372
10000,0.07158604,0.04788352,0.06588676,0.05610879,0.05627977,0.05629828,0.88200747,0.38703708,1.06824968,0.32475498
20000,0.07569903,0.09758643,0.09521935,0.13363498,0.10557451,0.10549701,0.8831732,0.38378032,1.14827889,0.41905272
