In [1]:
import warnings
warnings.filterwarnings("ignore")
from copy import deepcopy
from datetime import datetime
from pathlib import Path

import numpy as np
import pandas as pd
import sys

sys.path.append("/code")

from tqdm import tqdm
import torch
# device = torch.device('cpu')
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
# import gym
# import recogym

import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
import torch.optim as optim

torch.backends.cudnn.benchmark = torch.cuda.is_available()
if torch.cuda.is_available():
    torch.set_float32_matmul_precision("high")  # TF32 = big speedup on Ada


from sklearn.utils import check_random_state

# implementing OPE of the IPWLearner using synthetic bandit data
from sklearn.linear_model import LogisticRegression

import matplotlib.pyplot as plt

from scipy.special import softmax
import optuna
# from memory_profiler import profile


from estimators import (
    DirectMethod as DM
)

from simulation_utils import (
    eval_policy,
    generate_dataset,
    create_simulation_data_from_pi,
    get_train_data,
    get_opl_results_dict,
    CustomCFDataset,
    calc_reward,
    get_weights_info
)

from models import (    
    LinearCFModel,
    NeighborhoodModel,
    BPRModel, 
    RegressionModel
)

from training_utils import (
    train,
    validation_loop, 
    cv_score_model
 )

from custom_losses import (
    SNDRPolicyLoss,
    IPWPolicyLoss
    )

random_state=12345
random_ = check_random_state(random_state)

pd.options.display.float_format = '{:,.8f}'.format

Using device: cpu
Using device: cpu
Using device: cpu


In [2]:
def get_trial_results(
    our_x, 
    our_a, 
    emb_x, 
    emb_a, 
    original_x, 
    original_a, 
    dataset, 
    val_data, 
    original_policy_prob, 
    neighberhoodmodel, 
    regression_model, 
    dm
):
    policy = np.expand_dims(softmax(our_x @ our_a.T, axis=1), -1)
    policy_reward = calc_reward(dataset, policy)
    eval_metrics = eval_policy(neighberhoodmodel, val_data, original_policy_prob, policy)
    action_diff_to_real = np.sqrt(np.mean((emb_a - our_a) ** 2))
    action_delta = np.sqrt(np.mean((original_a - our_a) ** 2))
    context_diff_to_real = np.sqrt(np.mean((emb_x - our_x) ** 2))
    context_delta = np.sqrt(np.mean((original_x - our_x) ** 2))

    row = np.concatenate([
        np.atleast_1d(policy_reward),
        np.atleast_1d(eval_metrics),
        np.atleast_1d(action_diff_to_real),
        np.atleast_1d(action_delta),
        np.atleast_1d(context_diff_to_real),
        np.atleast_1d(context_delta)
    ])
    reg_dm = dm.estimate_policy_value(policy[val_data['x_idx']], regression_model.predict(val_data['x']))
    reg_results = np.array([reg_dm])
    conv_results = np.array([row])
    return get_opl_results_dict(reg_results, conv_results)

## `trainer_trial` Function

This function runs policy learning experiments using offline bandit data and evaluates various estimators.

### Parameters
- **num_runs** (int): Number of experimental runs per training size
- **num_neighbors** (int): Number of neighbors to consider in the neighborhood model
- **num_rounds_list** (list): List of training set sizes to evaluate
- **dataset** (dict): Contains dataset information including embeddings, action probabilities, and reward probabilities
- **batch_size** (int): Batch size for training the policy model
- **num_epochs** (int): Number of training epochs for each experiment
- **lr** (float, default=0.001): Learning rate for the optimizer

### Process Flow
1. Initializes result structures and retrieval models
2. For each training size in `num_rounds_list`:
   - Creates a uniform logging policy and simulates data
   - Generates training data for offline learning
   - Fits regression and neighborhood models for reward estimation
   - Initializes and trains a counterfactual policy model
   - Evaluates policy performance using various estimators
   - Collects metrics on policy reward and embedding quality

### Returns
- **DataFrame**: Results table with rows indexed by training size and columns for various metrics:
  - `policy_rewards`: True expected reward of the learned policy
  - Various estimator errors (`ipw`, `reg_dm`, `conv_dm`, `conv_dr`, `conv_sndr`)
  - Variance metrics for each estimator
  - Embedding quality metrics comparing learned representations to ground truth

### Implementation Notes
- Uses uniform random logging policy for collecting offline data
- Employs Self-Normalized Doubly Robust (SNDR) policy learning
- Measures embedding quality via RMSE to original/ground truth embeddings

In [3]:
def trainer_trial(
    num_runs,
    num_neighbors,
    train_sizes,
    dataset,
    batch_size,
    val_size=2000,
    n_trials=10,    
    prev_best_params=None
):

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    torch.backends.cudnn.benchmark = torch.cuda.is_available()
    if torch.cuda.is_available():
        torch.set_float32_matmul_precision("high")

    dm = DM()
    results = {}

    our_x, our_a = dataset["our_x"], dataset["our_a"]
    emb_x, emb_a = dataset["emb_x"], dataset["emb_a"]

    original_x, original_a = dataset["original_x"], dataset["original_a"]
    n_users, n_actions, emb_dim = dataset["n_users"], dataset["n_actions"], dataset["emb_dim"]

    all_user_indices = np.arange(n_users, dtype=np.int64)

    def T(x):
        return torch.as_tensor(x, device=device, dtype=torch.float32)

    def _mean_dict(dicts):
        """
        Robust mean over a list of dicts with numeric/scalar/1D-array values.
        Returns a single dict with elementwise means.
        """
        if not dicts:
            return {}
        keys = dicts[0].keys()
        out = {}
        for k in keys:
            vals = [d[k] for d in dicts if k in d]
            # try to convert each to np.array and average
            arrs = [np.asarray(v) for v in vals]
            # broadcast to same shape if scalars/1D
            stacked = np.stack(arrs, axis=0)
            out[k] = np.mean(stacked, axis=0)
        return out

    # ===== unpack dataset (keep originals safe) =====
    our_x_orig, our_a_orig = our_x, our_a
    emb_x, emb_a = emb_x, emb_a
    original_x, original_a = original_x, original_a
    n_users, n_actions, emb_dim = n_users, n_actions, emb_dim
    all_user_indices = np.arange(n_users, dtype=np.int64)

    dm = DM()
    results = {}
    best_hyperparams_by_size = {}
    last_best_params = prev_best_params if prev_best_params is not None else None

    # ===== baseline (sample size = 0) using get_trial_results =====
    pi_0 = softmax(our_x_orig @ our_a_orig.T, axis=1)
    original_policy_prob = np.expand_dims(pi_0, -1)

    simulation_data = create_simulation_data_from_pi(
        dataset, pi_0, val_size, random_state=0
    )

    # use same data for train/val just to generate the baseline row
    train_data = get_train_data(n_actions, val_size, simulation_data, np.arange(val_size), our_x_orig)
    val_data   = get_train_data(n_actions, val_size, simulation_data, np.arange(val_size), our_x_orig)

    regression_model = RegressionModel(
        n_actions=n_actions, action_context=our_x_orig,
        base_model=LogisticRegression(random_state=12345)
    )

    regression_model.fit(train_data['x'], train_data['a'], train_data['r'])

    neighberhoodmodel = NeighborhoodModel(
        train_data['x_idx'], train_data['a'],
        our_a_orig, our_x_orig, train_data['r'],
        num_neighbors=num_neighbors
    )

    # baseline row produced via get_trial_results
    results[0] = get_trial_results(
        our_x_orig, our_a_orig, emb_x, emb_a, original_x, original_a,
        dataset, val_data, original_policy_prob,
        neighberhoodmodel, regression_model, dm
    )

    # ===== main loop over training sizes =====
    for train_size in train_sizes:

        # we’ll collect per-run trial dicts generated by get_trial_results
        trial_dicts_this_size = []
        best_hyperparams_by_size[train_size] = {}

        # --- prepare a resampling for Optuna’s objective (shared loaders built per-run inside objective) ---
        # We’ll do Optuna per-run (fresh resample + search), then final fit with best params, then get_trial_results.

        for run in range(num_runs):

            # --- resample for this run ---
            pi_0 = softmax(our_x_orig @ our_a_orig.T, axis=1)
            original_policy_prob = np.expand_dims(pi_0, -1)

            simulation_data = create_simulation_data_from_pi(
                dataset, pi_0, train_size + val_size,
                random_state=(run + 1) * (train_size + 17)
            )

            idx_train = np.arange(train_size)
            train_data = get_train_data(n_actions, train_size, simulation_data, idx_train, our_x_orig)
            val_idx   = np.arange(val_size) + train_size
            val_data  = get_train_data(n_actions, val_size, simulation_data, val_idx, our_x_orig)

            num_workers = 4 if torch.cuda.is_available() else 0

            cf_dataset = CustomCFDataset(
                train_data['x_idx'], train_data['a'], train_data['r'], original_policy_prob
            )

            # val_loader = DataLoader(
            #     val_dataset, batch_size=val_size, shuffle=False,
            #     pin_memory=torch.cuda.is_available(),
            #     num_workers=num_workers, persistent_workers=bool(num_workers)
            # )


            # --- Optuna objective bound to this run's data ---
            def objective(trial):
                lr = trial.suggest_float("lr", 1e-4, 1e-1, log=True)
                epochs = trial.suggest_int("num_epochs", 1, 10)
                trial_batch_size = trial.suggest_categorical("batch_size", [64, 128, 256, 512])
                trial_num_neighbors = trial.suggest_int("num_neighbors", 3, 15)
                lr_decay = trial.suggest_float("lr_decay", 0.8, 1.0)

                trial_neigh_model = NeighborhoodModel(
                    train_data['x_idx'], train_data['a'],
                    our_a_orig, our_x_orig, train_data['r'],
                    num_neighbors=trial_num_neighbors
                )

                trial_scores_all = torch.as_tensor(
                    trial_neigh_model.predict(all_user_indices),
                    device=device, dtype=torch.float32
                )

                trial_model = LinearCFModel(
                    n_users, n_actions, emb_dim,
                    initial_user_embeddings=T(our_x_orig),
                    initial_actions_embeddings=T(our_a_orig)
                ).to(device)

                assert (not torch.cuda.is_available()) or next(trial_model.parameters()).is_cuda

                final_train_loader = DataLoader(
                    cf_dataset, batch_size=trial_batch_size, shuffle=True,
                    pin_memory=torch.cuda.is_available(),
                    num_workers=num_workers, persistent_workers=bool(num_workers)
                )

                current_lr = lr
                for epoch in range(epochs):
                    if epoch > 0:
                        current_lr *= lr_decay
                        
                    train(
                        trial_model, final_train_loader, trial_scores_all,
                        criterion=SNDRPolicyLoss(), num_epochs=1, lr=current_lr, device=str(device)
                    )

                trial_x, trial_a = trial_model.get_params()
                trial_x = trial_x.detach().cpu().numpy()
                trial_a = trial_a.detach().cpu().numpy()

                pi_i = softmax(trial_x @ trial_a.T, axis=1)
                train_actions = train_data['a']
                train_users = train_data['x_idx']

                print("Train wi info: {}".format(get_weights_info(pi_i[train_users, train_actions], original_policy_prob[train_users, train_actions])))
                print(f"actual reward: {calc_reward(dataset, np.expand_dims(pi_i, -1))}")

                # print(get_weights_info(pi_i, original_policy_prob))
                # validation reward for selection
                return cv_score_model(val_data, trial_scores_all, pi_i)


            # --- run Optuna for this run ---
            study = optuna.create_study(direction="maximize")
            
            if last_best_params is not None:
                study.enqueue_trial(last_best_params)

            study.optimize(objective, n_trials=n_trials, show_progress_bar=True)

            best_params = study.best_params
            last_best_params = best_params  # optional warm-start to next run
            best_hyperparams_by_size[train_size][run] = {
                "params": best_params,
                "reward": study.best_value
            }


            # --- final training with best params on this run’s data ---
            regression_model = RegressionModel(
                n_actions=n_actions, action_context=our_x_orig,
                base_model=LogisticRegression(random_state=12345)
            )
            regression_model.fit(
                train_data['x'], train_data['a'], train_data['r'],
                original_policy_prob[train_data['x_idx'], train_data['a']].squeeze()
            )

            neighberhoodmodel = NeighborhoodModel(
                train_data['x_idx'], train_data['a'],
                our_a_orig, our_x_orig, train_data['r'],
                num_neighbors=best_params['num_neighbors']
            )
            scores_all = torch.as_tensor(
                neighberhoodmodel.predict(all_user_indices),
                device=device, dtype=torch.float32
            )

            model = LinearCFModel(
                n_users, n_actions, emb_dim,
                initial_user_embeddings=T(our_x_orig),
                initial_actions_embeddings=T(our_a_orig)
            ).to(device)
            assert (not torch.cuda.is_available()) or next(model.parameters()).is_cuda

            train_loader = DataLoader(
                cf_dataset, batch_size=batch_size, shuffle=True,
                pin_memory=torch.cuda.is_available(),
                num_workers=num_workers, persistent_workers=bool(num_workers)
            )

            current_lr = best_params['lr']
            for epoch in range(best_params['num_epochs']):
                if epoch > 0:
                    current_lr *= best_params['lr_decay']
                train(
                    model, train_loader, scores_all,
                    criterion=SNDRPolicyLoss(), num_epochs=1, lr=current_lr, device=str(device)
                )

            # learned embeddings (do NOT overwrite originals)
            learned_x_t, learned_a_t = model.get_params()
            learned_x = learned_x_t.detach().cpu().numpy()
            learned_a = learned_a_t.detach().cpu().numpy()

            # --- produce the per-run result via get_trial_results ---
            trial_res = get_trial_results(
                learned_x, learned_a,          # learned (policy) embeddings
                emb_x, emb_a,                  # ground-truth embedding refs
                original_x, original_a,        # original clean refs
                dataset,
                val_data,                      # use this run's val split
                original_policy_prob,
                neighberhoodmodel,
                regression_model,
                dm
            )

            trial_dicts_this_size.append(trial_res)

            # memory hygiene
            torch.cuda.empty_cache()

        # === aggregate per-run results (mean) and store under this train_size ===
        results[train_size] = _mean_dict(trial_dicts_this_size)

    return pd.DataFrame.from_dict(results, orient='index'), best_hyperparams_by_size

## Learning

We will run several simulations on a generated dataset, the dataset is generated like this:
$$ \text{We have users U and actions A } u_i \sim N(0, I_{emb_dim}) \ a_i \sim N(0, I_{emb_dim})$$
$$ p_{ij} = 1 / (5 + e^{-(u_i.T a_j)}) $$
$$r_{ij} \sim Bin(p_{ij})$$

We have a policy $\pi$
and it's ground truth reward is calculated by
$$R_{gt} = \sum_{i}{\sum_{j}{\pi_{ij} * p_{ij}}} $$

Our parameters for the dataset will be
$$EmbDim = 5$$
$$NumActions= 150$$
$$NumUsers = 150$$
$$NeighborhoodSize = 6$$

to learn a new policy from $\pi$ we will sample from:
$$\pi_{start} = (1-\epsilon)*\pi + \epsilon * \pi_{random}$$

In [4]:
dataset_params = dict(
                    n_actions= 500,
                    n_users = 500,
                    emb_dim = 16,
                    # sigma = 0.1,
                    eps = 0.6, # this is the epsilon for the noise in the ground truth policy representation
                    ctr = 0.1
                    )

train_dataset = generate_dataset(dataset_params)

Random Item CTR: 0.07066414727263938
Optimal greedy CTR: 0.09999926940951757
Optimal Stochastic CTR: 0.09995326955796031
Our Initial CTR: 0.08610747363354625


In [5]:
num_runs = 1
batch_size = 200
num_neighbors = 6
n_trials_for_optuna = 20
# num_rounds_list = [500, 1000, 2000, 10000, 20000]
# num_rounds_list = [500, 1000, 2000]
num_rounds_list = [15000]


# Manually define your best parameters
best_params_to_use = {
    "lr": 0.096,  # Learning rate
    "num_epochs": 5,  # Number of training epochs
    "batch_size": 64,  # Batch size for training
    "num_neighbors": 8,  # Number of neighbors for neighborhood model
    "lr_decay": 0.85  # Learning rate decay factor
}

### 1

$$emb = 0.7 * gt + 0.3 * noise$$
$$lr = 0.005$$
$$n_{epochs} = 1$$
$$BatchSize=50$$

In [6]:
print("Value of num_rounds_list:", num_rounds_list)

# Run the optimization
df4, best_hyperparams_by_size = trainer_trial(num_runs, num_neighbors, num_rounds_list, train_dataset, batch_size, val_size=10000, n_trials=n_trials_for_optuna, prev_best_params=best_params_to_use)

# # Print best hyperparameters for each training size
# print("\n=== BEST HYPERPARAMETERS BY TRAINING SIZE ===")
# for train_size, params in best_hyperparams_by_size.items():
#     print(f"\nTraining Size: {train_size}")
#     # print(f"Best Reward: {params['reward']:.6f}")
#     print("Parameters:")
#     for param_name, value in params['params'].items():
#         print(f"  {param_name}: {value}")
# print("===========================\n")

# Show the performance metrics
df4[['policy_rewards', 'ipw', 'reg_dm', 'conv_dm', 'conv_dr', 'conv_sndr', 'action_diff_to_real', 'action_delta', 'context_diff_to_real', 'context_delta']]

Value of num_rounds_list: [15000]
{'gini': np.float64(0.47165337067628044), 'ess': np.float64(4299.836390949015), 'max_wi': np.float64(39.02207562454071), 'min_wi': np.float64(0.008541729602724986)}


[I 2025-10-27 23:20:50,001] A new study created in memory with name: no-name-f4656db7-3a4d-4f33-9869-96b29d82a165
Best trial: 0. Best value: 0.0719627:   5%|▌         | 1/20 [00:44<14:11, 44.82s/it]

Train wi info: {'gini': np.float64(0.999646163854872), 'ess': np.float64(5.634283506987809), 'max_wi': np.float64(4346.13005127459), 'min_wi': np.float64(0.0)}
actual reward: [0.07844657]
{'gini': np.float64(0.9998023537765777), 'ess': np.float64(2.012897323932942), 'max_wi': np.float64(2038.673415495716), 'min_wi': np.float64(0.0)}
Cross-validated error: 0.006261881785689511
[I 2025-10-27 23:21:34,824] Trial 0 finished with value: 0.07196269215221371 and parameters: {'lr': 0.096, 'num_epochs': 5, 'batch_size': 64, 'num_neighbors': 8, 'lr_decay': 0.85}. Best is trial 0 with value: 0.07196269215221371.


Best trial: 1. Best value: 0.0736587:  10%|█         | 2/20 [01:25<12:41, 42.30s/it]

Train wi info: {'gini': np.float64(0.028013777796080327), 'ess': np.float64(14942.23837489102), 'max_wi': np.float64(1.489931280462149), 'min_wi': np.float64(0.8713135516788827)}
actual reward: [0.08612014]
{'gini': np.float64(0.02875486793061097), 'ess': np.float64(9960.921018362127), 'max_wi': np.float64(1.497317319184339), 'min_wi': np.float64(0.8880170187116933)}
Cross-validated error: 0.008057314749798879
[I 2025-10-27 23:22:15,352] Trial 1 finished with value: 0.07365868194019859 and parameters: {'lr': 0.0001784325070080417, 'num_epochs': 1, 'batch_size': 64, 'num_neighbors': 10, 'lr_decay': 0.9664030861479147}. Best is trial 1 with value: 0.07365868194019859.


Best trial: 1. Best value: 0.0736587:  15%|█▌        | 3/20 [02:11<12:29, 44.10s/it]

Train wi info: {'gini': np.float64(0.9982718078647653), 'ess': np.float64(16.406924896301945), 'max_wi': np.float64(2043.126811124678), 'min_wi': np.float64(2.624658209574201e-20)}
actual reward: [0.08426919]
{'gini': np.float64(0.9974472388423422), 'ess': np.float64(13.442121639275893), 'max_wi': np.float64(2615.782486156559), 'min_wi': np.float64(2.624658209574201e-20)}
Cross-validated error: 0.0068137255589685326
[I 2025-10-27 23:23:01,599] Trial 2 finished with value: 0.06880207932853438 and parameters: {'lr': 0.0030300466585572376, 'num_epochs': 10, 'batch_size': 64, 'num_neighbors': 11, 'lr_decay': 0.9178755706819265}. Best is trial 1 with value: 0.07365868194019859.


Best trial: 1. Best value: 0.0736587:  20%|██        | 4/20 [02:53<11:30, 43.14s/it]

Train wi info: {'gini': np.float64(0.9516007405871872), 'ess': np.float64(764.2710529577054), 'max_wi': np.float64(67.50171657595634), 'min_wi': np.float64(1.1640488525496517e-06)}
actual reward: [0.08488388]
{'gini': np.float64(0.9457891597526437), 'ess': np.float64(569.0095902279055), 'max_wi': np.float64(65.1363097860034), 'min_wi': np.float64(1.1640488525496517e-06)}
Cross-validated error: 0.0066917742696349595
[I 2025-10-27 23:23:43,273] Trial 3 finished with value: 0.06841183046344562 and parameters: {'lr': 0.0011604593022526127, 'num_epochs': 4, 'batch_size': 64, 'num_neighbors': 5, 'lr_decay': 0.8001066714399241}. Best is trial 1 with value: 0.07365868194019859.


Best trial: 1. Best value: 0.0736587:  25%|██▌       | 5/20 [03:37<10:51, 43.44s/it]

Train wi info: {'gini': np.float64(0.999716757225525), 'ess': np.float64(4.6525141298387345), 'max_wi': np.float64(3112.272076835479), 'min_wi': np.float64(0.0)}
actual reward: [0.07306385]
{'gini': np.float64(0.9994972563128569), 'ess': np.float64(5.5170300969073605), 'max_wi': np.float64(3718.904760495078), 'min_wi': np.float64(0.0)}
Cross-validated error: 0.02091266838024902
[I 2025-10-27 23:24:27,251] Trial 4 finished with value: 0.07307017718618636 and parameters: {'lr': 0.08283322579563636, 'num_epochs': 7, 'batch_size': 64, 'num_neighbors': 4, 'lr_decay': 0.8329708193443458}. Best is trial 1 with value: 0.07365868194019859.


Best trial: 1. Best value: 0.0736587:  30%|███       | 6/20 [04:20<10:08, 43.48s/it]

Train wi info: {'gini': np.float64(0.9988930002133581), 'ess': np.float64(4.6025273392822195), 'max_wi': np.float64(3394.41307443081), 'min_wi': np.float64(1.7586501483582938e-23)}
actual reward: [0.08090076]
{'gini': np.float64(0.9987147128790252), 'ess': np.float64(7.112884605194537), 'max_wi': np.float64(3200.458763877934), 'min_wi': np.float64(1.7586501483582938e-23)}
Cross-validated error: 0.00725089982778735
[I 2025-10-27 23:25:10,792] Trial 5 finished with value: 0.07050256984842869 and parameters: {'lr': 0.004774968667111028, 'num_epochs': 8, 'batch_size': 128, 'num_neighbors': 13, 'lr_decay': 0.8839700642855196}. Best is trial 1 with value: 0.07365868194019859.


Best trial: 1. Best value: 0.0736587:  35%|███▌      | 7/20 [05:01<09:12, 42.50s/it]

Train wi info: {'gini': np.float64(0.008625747589182133), 'ess': np.float64(14996.042262887267), 'max_wi': np.float64(1.0666170708342868), 'min_wi': np.float64(0.9558198127172892)}
actual reward: [0.08610849]
{'gini': np.float64(0.008658193562141299), 'ess': np.float64(9997.394113933127), 'max_wi': np.float64(1.0631425401195993), 'min_wi': np.float64(0.9618070189884889)}
Cross-validated error: 0.008055792011055829
[I 2025-10-27 23:25:51,296] Trial 6 finished with value: 0.07363396745792353 and parameters: {'lr': 0.0002288722987805932, 'num_epochs': 1, 'batch_size': 512, 'num_neighbors': 10, 'lr_decay': 0.9185230051782117}. Best is trial 1 with value: 0.07365868194019859.


Best trial: 1. Best value: 0.0736587:  40%|████      | 8/20 [05:44<08:32, 42.67s/it]

Train wi info: {'gini': np.float64(0.42191704375776184), 'ess': np.float64(5954.405834959825), 'max_wi': np.float64(14.811794247744334), 'min_wi': np.float64(0.10385820006986456)}
actual reward: [0.08621009]
{'gini': np.float64(0.42682005695900527), 'ess': np.float64(4095.2844838021956), 'max_wi': np.float64(14.594650744578928), 'min_wi': np.float64(0.11993940608197849)}
Cross-validated error: 0.007967497873624609
[I 2025-10-27 23:26:34,327] Trial 7 finished with value: 0.0733420237420281 and parameters: {'lr': 0.00039454433834270377, 'num_epochs': 4, 'batch_size': 64, 'num_neighbors': 12, 'lr_decay': 0.8909306427942534}. Best is trial 1 with value: 0.07365868194019859.


Best trial: 1. Best value: 0.0736587:  45%|████▌     | 9/20 [06:25<07:44, 42.20s/it]

Train wi info: {'gini': np.float64(0.8320464294885157), 'ess': np.float64(1454.5074192646216), 'max_wi': np.float64(32.34096746243747), 'min_wi': np.float64(0.0020207409128657366)}
actual reward: [0.08592418]
{'gini': np.float64(0.824332780733891), 'ess': np.float64(1081.7060590049923), 'max_wi': np.float64(31.19681233816564), 'min_wi': np.float64(0.002222469582100586)}
Cross-validated error: 0.00760512203283906
[I 2025-10-27 23:27:15,486] Trial 8 finished with value: 0.07200508627191557 and parameters: {'lr': 0.002020744913068975, 'num_epochs': 2, 'batch_size': 128, 'num_neighbors': 4, 'lr_decay': 0.8681905479087392}. Best is trial 1 with value: 0.07365868194019859.


Best trial: 1. Best value: 0.0736587:  50%|█████     | 10/20 [07:08<07:04, 42.45s/it]

Train wi info: {'gini': np.float64(0.050382877366140756), 'ess': np.float64(14856.049443782935), 'max_wi': np.float64(1.426215308806622), 'min_wi': np.float64(0.7608962249275689)}
actual reward: [0.08611805]
{'gini': np.float64(0.05058027601070978), 'ess': np.float64(9906.116198348738), 'max_wi': np.float64(1.4116994320031295), 'min_wi': np.float64(0.7966706989077212)}
Cross-validated error: 0.008016779263998335
[I 2025-10-27 23:27:58,497] Trial 9 finished with value: 0.07349679620665248 and parameters: {'lr': 0.000287027535270829, 'num_epochs': 7, 'batch_size': 512, 'num_neighbors': 12, 'lr_decay': 0.8437038421879548}. Best is trial 1 with value: 0.07365868194019859.


Best trial: 1. Best value: 0.0736587:  55%|█████▌    | 11/20 [07:49<06:17, 41.94s/it]

Train wi info: {'gini': np.float64(0.9804375321131571), 'ess': np.float64(269.4141167847223), 'max_wi': np.float64(161.27297418120907), 'min_wi': np.float64(4.723632745745147e-22)}
actual reward: [0.07948651]
{'gini': np.float64(0.9792946188704468), 'ess': np.float64(136.18961858197446), 'max_wi': np.float64(534.2573745959935), 'min_wi': np.float64(4.723632745745147e-22)}
Cross-validated error: 0.007267622673981876
[I 2025-10-27 23:28:39,280] Trial 10 finished with value: 0.07072033709818717 and parameters: {'lr': 0.015787716907415966, 'num_epochs': 2, 'batch_size': 256, 'num_neighbors': 15, 'lr_decay': 0.9922384298699587}. Best is trial 1 with value: 0.07365868194019859.


Best trial: 1. Best value: 0.0736587:  60%|██████    | 12/20 [08:28<05:29, 41.18s/it]

Train wi info: {'gini': np.float64(0.003861764146141051), 'ess': np.float64(14999.199446225231), 'max_wi': np.float64(1.0311485790848745), 'min_wi': np.float64(0.9816246192289989)}
actual reward: [0.08610767]
{'gini': np.float64(0.003885033097592153), 'ess': np.float64(9999.470519999039), 'max_wi': np.float64(1.028472049228845), 'min_wi': np.float64(0.9818373550511265)}
Cross-validated error: 0.008034956715162191
[I 2025-10-27 23:29:18,734] Trial 11 finished with value: 0.07356906723141071 and parameters: {'lr': 0.00010413038608694127, 'num_epochs': 1, 'batch_size': 512, 'num_neighbors': 8, 'lr_decay': 0.9432538417740688}. Best is trial 1 with value: 0.07365868194019859.


Best trial: 1. Best value: 0.0736587:  65%|██████▌   | 13/20 [09:08<04:44, 40.68s/it]

Train wi info: {'gini': np.float64(0.003943755823395173), 'ess': np.float64(14999.175671081253), 'max_wi': np.float64(1.0288176543502097), 'min_wi': np.float64(0.9814096247427344)}
actual reward: [0.08610831]
{'gini': np.float64(0.003933924586927), 'ess': np.float64(9999.462113939735), 'max_wi': np.float64(1.0280366559586702), 'min_wi': np.float64(0.9831282577441992)}
Cross-validated error: 0.008050995242648094
[I 2025-10-27 23:29:58,261] Trial 12 finished with value: 0.07361979572589737 and parameters: {'lr': 0.00010443214681943732, 'num_epochs': 1, 'batch_size': 512, 'num_neighbors': 10, 'lr_decay': 0.9662089324614493}. Best is trial 1 with value: 0.07365868194019859.


Best trial: 13. Best value: 0.0738197:  70%|███████   | 14/20 [09:49<04:04, 40.77s/it]

Train wi info: {'gini': np.float64(0.08687797910312388), 'ess': np.float64(14530.057585560227), 'max_wi': np.float64(1.9685525263726993), 'min_wi': np.float64(0.646847316660086)}
actual reward: [0.08613653]
{'gini': np.float64(0.08743652585687996), 'ess': np.float64(9692.375373122963), 'max_wi': np.float64(1.9405222747170632), 'min_wi': np.float64(0.6697625066677111)}
Cross-validated error: 0.008108325177560459
[I 2025-10-27 23:30:39,228] Trial 13 finished with value: 0.07381969635252891 and parameters: {'lr': 0.00044999507782196666, 'num_epochs': 3, 'batch_size': 256, 'num_neighbors': 9, 'lr_decay': 0.9291422050111389}. Best is trial 13 with value: 0.07381969635252891.


Best trial: 13. Best value: 0.0738197:  75%|███████▌  | 15/20 [10:29<03:23, 40.74s/it]

Train wi info: {'gini': np.float64(0.131514635878677), 'ess': np.float64(13895.153475629153), 'max_wi': np.float64(2.508628018096653), 'min_wi': np.float64(0.48064181109723175)}
actual reward: [0.08614504]
{'gini': np.float64(0.13232202671224724), 'ess': np.float64(9280.7984290195), 'max_wi': np.float64(2.508628018096653), 'min_wi': np.float64(0.5334081269120884)}
Cross-validated error: 0.008046147227154559
[I 2025-10-27 23:31:19,899] Trial 14 finished with value: 0.07359114598960631 and parameters: {'lr': 0.0006053695193668276, 'num_epochs': 3, 'batch_size': 256, 'num_neighbors': 7, 'lr_decay': 0.958791402686086}. Best is trial 13 with value: 0.07381969635252891.


Best trial: 15. Best value: 0.074562:  80%|████████  | 16/20 [11:11<02:43, 40.90s/it] 

Train wi info: {'gini': np.float64(0.27260366635406064), 'ess': np.float64(10784.2099582877), 'max_wi': np.float64(4.4115403260528625), 'min_wi': np.float64(0.2422405471789508)}
actual reward: [0.08623391]
{'gini': np.float64(0.27507181871558195), 'ess': np.float64(7267.107401070962), 'max_wi': np.float64(4.4115403260528625), 'min_wi': np.float64(0.2539300337663526)}
Cross-validated error: 0.008323903735741515
[I 2025-10-27 23:32:01,187] Trial 15 finished with value: 0.07456198457555607 and parameters: {'lr': 0.0007524887036981202, 'num_epochs': 4, 'batch_size': 256, 'num_neighbors': 6, 'lr_decay': 0.9931112364834909}. Best is trial 15 with value: 0.07456198457555607.


Best trial: 16. Best value: 0.0752005:  85%|████████▌ | 17/20 [11:53<02:04, 41.36s/it]

Train wi info: {'gini': np.float64(0.5193168785119469), 'ess': np.float64(5420.646841082747), 'max_wi': np.float64(8.916186543051793), 'min_wi': np.float64(0.04364732457792955)}
actual reward: [0.08637723]
{'gini': np.float64(0.5159864481189825), 'ess': np.float64(3830.1461271376056), 'max_wi': np.float64(8.916186543051793), 'min_wi': np.float64(0.052527761771887005)}
Cross-validated error: 0.008506141899339736
[I 2025-10-27 23:32:43,604] Trial 16 finished with value: 0.07520054960867502 and parameters: {'lr': 0.000932010598449493, 'num_epochs': 5, 'batch_size': 256, 'num_neighbors': 6, 'lr_decay': 0.9997528517215528}. Best is trial 16 with value: 0.07520054960867502.


Best trial: 16. Best value: 0.0752005:  90%|█████████ | 18/20 [12:36<01:23, 41.88s/it]

Train wi info: {'gini': np.float64(0.9881483003564567), 'ess': np.float64(71.02398274739765), 'max_wi': np.float64(986.1077784422832), 'min_wi': np.float64(4.92661646304291e-25)}
actual reward: [0.08345529]
{'gini': np.float64(0.9820638694037278), 'ess': np.float64(83.10618708743004), 'max_wi': np.float64(802.6598463546668), 'min_wi': np.float64(4.92661646304291e-25)}
Cross-validated error: 0.007755821795799234
[I 2025-10-27 23:33:26,704] Trial 17 finished with value: 0.07242416093471817 and parameters: {'lr': 0.008068231278109579, 'num_epochs': 6, 'batch_size': 256, 'num_neighbors': 6, 'lr_decay': 0.9998847477183437}. Best is trial 16 with value: 0.07520054960867502.


Best trial: 18. Best value: 0.0752474:  95%|█████████▌| 19/20 [13:17<00:41, 41.66s/it]

Train wi info: {'gini': np.float64(0.5318939189491736), 'ess': np.float64(5139.410859868155), 'max_wi': np.float64(9.686144553157186), 'min_wi': np.float64(0.03906758043679743)}
actual reward: [0.08639666]
{'gini': np.float64(0.5291076997182486), 'ess': np.float64(3665.6947415644804), 'max_wi': np.float64(9.49802626945328), 'min_wi': np.float64(0.05051829570017257)}
Cross-validated error: 0.008508793611584187
[I 2025-10-27 23:34:07,858] Trial 18 finished with value: 0.0752473977356801 and parameters: {'lr': 0.0009746252643148638, 'num_epochs': 5, 'batch_size': 256, 'num_neighbors': 6, 'lr_decay': 0.9855017085376683}. Best is trial 18 with value: 0.0752473977356801.


Best trial: 19. Best value: 0.0775719: 100%|██████████| 20/20 [13:59<00:00, 41.97s/it]

Train wi info: {'gini': np.float64(0.8921142883907963), 'ess': np.float64(1376.812637891251), 'max_wi': np.float64(24.63087778020697), 'min_wi': np.float64(0.0002894263923630555)}
actual reward: [0.08591988]
{'gini': np.float64(0.882510728704491), 'ess': np.float64(1019.9244806439935), 'max_wi': np.float64(24.63087778020697), 'min_wi': np.float64(0.00028341453096947624)}
Cross-validated error: 0.009247601709633726
[I 2025-10-27 23:34:49,382] Trial 19 finished with value: 0.07757193181161719 and parameters: {'lr': 0.001462337473591599, 'num_epochs': 6, 'batch_size': 256, 'num_neighbors': 6, 'lr_decay': 0.9802496749954691}. Best is trial 19 with value: 0.07757193181161719.





{'gini': np.float64(0.9512420704631571), 'ess': np.float64(331.3872075421497), 'max_wi': np.float64(221.85069782106882), 'min_wi': np.float64(4.279189116368622e-06)}


Unnamed: 0,policy_rewards,ipw,reg_dm,conv_dm,conv_dr,conv_sndr,action_diff_to_real,action_delta,context_diff_to_real,context_delta
0,0.08610747,0.0879,0.08802531,0.09188019,0.08858942,0.08858942,0.7569287,0.0,0.87627132,0.0
15000,0.08516441,0.08017263,0.08906158,0.09463238,0.0795643,0.07993675,0.82602002,0.24147136,0.93809648,0.1280313


### Policy with delta function

In [7]:
dataset_params = dict(
                    n_actions= 500,
                    n_users = 500,
                    emb_dim = 16,
                    # sigma = 0.1,
                    eps = 0.6, # this is the epsilon for the noise in the ground truth policy representation
                    ctr = 0.1
                    )

train_dataset = generate_dataset(dataset_params, seed=10000)

Random Item CTR: 0.07083863592474163
Optimal greedy CTR: 0.09999916436977967
Optimal Stochastic CTR: 0.0999493542444427
Our Initial CTR: 0.08557719469284641


In [8]:
# Run the optimization
df5, best_hyperparams_by_size = trainer_trial(num_runs, num_neighbors, num_rounds_list, train_dataset, batch_size, val_size=10000, n_trials=n_trials_for_optuna, prev_best_params=best_params_to_use)

# Show the performance metrics
df5[['policy_rewards', 'ipw', 'reg_dm', 'conv_dm', 'conv_dr', 'conv_sndr', 'action_diff_to_real', 'action_delta', 'context_diff_to_real', 'context_delta']]

{'gini': np.float64(0.44923948239866207), 'ess': np.float64(4832.412848990686), 'max_wi': np.float64(21.27427255219379), 'min_wi': np.float64(0.014521519317996934)}


[I 2025-10-27 23:36:00,165] A new study created in memory with name: no-name-f0d13181-70b5-4258-b7cb-4aa8f1d8a78a
Best trial: 0. Best value: 0.0880204:   5%|▌         | 1/20 [00:41<13:00, 41.08s/it]

Train wi info: {'gini': np.float64(0.9996097069643727), 'ess': np.float64(5.830156890001412), 'max_wi': np.float64(3630.058135095379), 'min_wi': np.float64(0.0)}
actual reward: [0.07890449]
{'gini': np.float64(0.9994245666139087), 'ess': np.float64(6.375448159220158), 'max_wi': np.float64(1614.386753291461), 'min_wi': np.float64(0.0)}
Cross-validated error: 0.014929009492012396
[I 2025-10-27 23:36:41,239] Trial 0 finished with value: 0.08802041750421653 and parameters: {'lr': 0.096, 'num_epochs': 5, 'batch_size': 64, 'num_neighbors': 8, 'lr_decay': 0.85}. Best is trial 0 with value: 0.08802041750421653.


Best trial: 0. Best value: 0.0880204:  10%|█         | 2/20 [01:25<12:51, 42.86s/it]

Train wi info: {'gini': np.float64(0.9998221742259609), 'ess': np.float64(2.0554121861970804), 'max_wi': np.float64(8030.121123700266), 'min_wi': np.float64(1.935631442710909e-34)}
actual reward: [0.07658656]
{'gini': np.float64(0.9995960337073696), 'ess': np.float64(4.427802787389191), 'max_wi': np.float64(2116.049297312971), 'min_wi': np.float64(1.935631442710909e-34)}
Cross-validated error: 0.007664841123861136
[I 2025-10-27 23:37:25,351] Trial 1 finished with value: 0.07221835978638602 and parameters: {'lr': 0.01101759311492387, 'num_epochs': 10, 'batch_size': 64, 'num_neighbors': 8, 'lr_decay': 0.9727821458933623}. Best is trial 0 with value: 0.08802041750421653.


Best trial: 0. Best value: 0.0880204:  15%|█▌        | 3/20 [02:06<11:56, 42.17s/it]

Train wi info: {'gini': np.float64(0.9997789227114159), 'ess': np.float64(3.0725850325272037), 'max_wi': np.float64(6271.406574953424), 'min_wi': np.float64(9.153146478277448e-30)}
actual reward: [0.08406508]
{'gini': np.float64(0.9995891981730308), 'ess': np.float64(3.639147899284846), 'max_wi': np.float64(1850.6168897024766), 'min_wi': np.float64(9.153146478277448e-30)}
Cross-validated error: 0.0035633574131534156
[I 2025-10-27 23:38:06,690] Trial 2 finished with value: 0.05249527124745243 and parameters: {'lr': 0.01775729460301181, 'num_epochs': 8, 'batch_size': 256, 'num_neighbors': 14, 'lr_decay': 0.9183107984557436}. Best is trial 0 with value: 0.08802041750421653.


Best trial: 0. Best value: 0.0880204:  20%|██        | 4/20 [02:48<11:13, 42.07s/it]

Train wi info: {'gini': np.float64(0.9997981757582663), 'ess': np.float64(3.2572956630973233), 'max_wi': np.float64(7429.480689325076), 'min_wi': np.float64(1.3132053705355725e-27)}
actual reward: [0.07092805]
{'gini': np.float64(0.9998256015025703), 'ess': np.float64(1.8574713162099612), 'max_wi': np.float64(5734.713833880798), 'min_wi': np.float64(1.3132053705355725e-27)}
Cross-validated error: 0.004352277173409262
[I 2025-10-27 23:38:48,617] Trial 3 finished with value: 0.05720583119560163 and parameters: {'lr': 0.01585000095438304, 'num_epochs': 9, 'batch_size': 128, 'num_neighbors': 7, 'lr_decay': 0.9330634550689125}. Best is trial 0 with value: 0.08802041750421653.


Best trial: 0. Best value: 0.0880204:  25%|██▌       | 5/20 [03:27<10:16, 41.11s/it]

Train wi info: {'gini': np.float64(0.9998061903268869), 'ess': np.float64(2.452260421085143), 'max_wi': np.float64(8048.086841086151), 'min_wi': np.float64(0.0)}
actual reward: [0.07381496]
{'gini': np.float64(0.9995749176942778), 'ess': np.float64(4.318108320572713), 'max_wi': np.float64(2116.3323633481064), 'min_wi': np.float64(0.0)}
Cross-validated error: 0.006302089308507625
[I 2025-10-27 23:39:28,029] Trial 4 finished with value: 0.0671561833127756 and parameters: {'lr': 0.05977229735685543, 'num_epochs': 1, 'batch_size': 128, 'num_neighbors': 13, 'lr_decay': 0.9779748714424004}. Best is trial 0 with value: 0.08802041750421653.


Best trial: 0. Best value: 0.0880204:  30%|███       | 6/20 [04:08<09:34, 41.07s/it]

Train wi info: {'gini': np.float64(0.9997673202549525), 'ess': np.float64(3.7395928240634007), 'max_wi': np.float64(7465.851959773306), 'min_wi': np.float64(6.288761929120283e-39)}
actual reward: [0.06878844]
{'gini': np.float64(0.9998423603386505), 'ess': np.float64(1.6956905279022847), 'max_wi': np.float64(5982.93917962597), 'min_wi': np.float64(6.288761929120283e-39)}
Cross-validated error: 0.005754404940973547
[I 2025-10-27 23:40:09,018] Trial 5 finished with value: 0.06431713836081229 and parameters: {'lr': 0.02069766588674227, 'num_epochs': 4, 'batch_size': 64, 'num_neighbors': 11, 'lr_decay': 0.85749520391534}. Best is trial 0 with value: 0.08802041750421653.


Best trial: 0. Best value: 0.0880204:  35%|███▌      | 7/20 [04:48<08:48, 40.69s/it]

Train wi info: {'gini': np.float64(0.03873758610406583), 'ess': np.float64(14913.96076519477), 'max_wi': np.float64(1.284862896017546), 'min_wi': np.float64(0.8326211012094943)}
actual reward: [0.0855554]
{'gini': np.float64(0.039183420056355174), 'ess': np.float64(9943.751199318018), 'max_wi': np.float64(1.272258734590905), 'min_wi': np.float64(0.834552014236116)}
Cross-validated error: 0.00874183997902258
[I 2025-10-27 23:40:48,925] Trial 6 finished with value: 0.07600103571259906 and parameters: {'lr': 0.00037915720322300253, 'num_epochs': 2, 'batch_size': 512, 'num_neighbors': 8, 'lr_decay': 0.817520201847658}. Best is trial 0 with value: 0.08802041750421653.


Best trial: 0. Best value: 0.0880204:  40%|████      | 8/20 [05:29<08:08, 40.68s/it]

Train wi info: {'gini': np.float64(0.7247455852478126), 'ess': np.float64(3348.5575704395196), 'max_wi': np.float64(9.36648914283479), 'min_wi': np.float64(0.005153897639446858)}
actual reward: [0.08454799]
{'gini': np.float64(0.7060554608188887), 'ess': np.float64(2537.945598765795), 'max_wi': np.float64(9.959890071372758), 'min_wi': np.float64(0.005153897639446858)}
Cross-validated error: 0.008087418252725458
[I 2025-10-27 23:41:29,593] Trial 7 finished with value: 0.07375801680655454 and parameters: {'lr': 0.0021917862323487837, 'num_epochs': 4, 'batch_size': 512, 'num_neighbors': 7, 'lr_decay': 0.9750703736983615}. Best is trial 0 with value: 0.08802041750421653.


Best trial: 0. Best value: 0.0880204:  45%|████▌     | 9/20 [06:10<07:27, 40.67s/it]

Train wi info: {'gini': np.float64(0.9997885896366039), 'ess': np.float64(3.4779037635869576), 'max_wi': np.float64(6105.136662999787), 'min_wi': np.float64(2.5713355035456896e-22)}
actual reward: [0.08261077]
{'gini': np.float64(0.9998316551056639), 'ess': np.float64(1.8169423060075776), 'max_wi': np.float64(3052.6969699757096), 'min_wi': np.float64(4.868996511109384e-22)}
Cross-validated error: 0.0037497256542983167
[I 2025-10-27 23:42:10,227] Trial 8 finished with value: 0.053750366411064024 and parameters: {'lr': 0.006385506647510808, 'num_epochs': 4, 'batch_size': 64, 'num_neighbors': 15, 'lr_decay': 0.9801545344527838}. Best is trial 0 with value: 0.08802041750421653.


Best trial: 0. Best value: 0.0880204:  50%|█████     | 10/20 [06:49<06:43, 40.40s/it]

Train wi info: {'gini': np.float64(0.05261443771611639), 'ess': np.float64(14823.010738964289), 'max_wi': np.float64(1.4642686054685368), 'min_wi': np.float64(0.7689821991979146)}
actual reward: [0.08554556]
{'gini': np.float64(0.0538319301506782), 'ess': np.float64(9882.997588774599), 'max_wi': np.float64(1.4234424152433447), 'min_wi': np.float64(0.7867379671912955)}
Cross-validated error: 0.008486120205847495
[I 2025-10-27 23:42:50,018] Trial 9 finished with value: 0.07515376202377755 and parameters: {'lr': 0.00011350959687385877, 'num_epochs': 5, 'batch_size': 256, 'num_neighbors': 4, 'lr_decay': 0.9875671245545623}. Best is trial 0 with value: 0.08802041750421653.


Best trial: 0. Best value: 0.0880204:  55%|█████▌    | 11/20 [07:31<06:06, 40.69s/it]

Train wi info: {'gini': np.float64(0.9996894960934605), 'ess': np.float64(5.178527330550152), 'max_wi': np.float64(2450.672305660804), 'min_wi': np.float64(0.0)}
actual reward: [0.07614208]
{'gini': np.float64(0.9996097689387308), 'ess': np.float64(4.294471983567222), 'max_wi': np.float64(2500.9523165520104), 'min_wi': np.float64(0.0)}
Cross-validated error: 0.0027783953064609477
[I 2025-10-27 23:43:31,367] Trial 10 finished with value: 0.043901549692790955 and parameters: {'lr': 0.09172581510554555, 'num_epochs': 7, 'batch_size': 64, 'num_neighbors': 3, 'lr_decay': 0.8665171592778175}. Best is trial 0 with value: 0.08802041750421653.


Best trial: 0. Best value: 0.0880204:  60%|██████    | 12/20 [08:09<05:19, 39.90s/it]

Train wi info: {'gini': np.float64(0.05704472291800429), 'ess': np.float64(14814.861004597562), 'max_wi': np.float64(1.4022091086426696), 'min_wi': np.float64(0.7455716899357528)}
actual reward: [0.085543]
{'gini': np.float64(0.05789002215649096), 'ess': np.float64(9878.665995633575), 'max_wi': np.float64(1.3921201741356428), 'min_wi': np.float64(0.7455716899357528)}
Cross-validated error: 0.008641416240486935
[I 2025-10-27 23:44:09,473] Trial 11 finished with value: 0.07566946523124692 and parameters: {'lr': 0.0009361932470617366, 'num_epochs': 1, 'batch_size': 512, 'num_neighbors': 10, 'lr_decay': 0.8075558082281605}. Best is trial 0 with value: 0.08802041750421653.


Best trial: 0. Best value: 0.0880204:  65%|██████▌   | 13/20 [08:47<04:35, 39.37s/it]

Train wi info: {'gini': np.float64(0.028196249938825396), 'ess': np.float64(14954.533664930635), 'max_wi': np.float64(1.2057877735326674), 'min_wi': np.float64(0.8710695835499722)}
actual reward: [0.08556373]
{'gini': np.float64(0.028734656311319884), 'ess': np.float64(9969.938179387425), 'max_wi': np.float64(1.196003214679918), 'min_wi': np.float64(0.8889555510299361)}
Cross-validated error: 0.008439206473134127
[I 2025-10-27 23:44:47,631] Trial 12 finished with value: 0.07499377429277274 and parameters: {'lr': 0.000283284567065278, 'num_epochs': 2, 'batch_size': 512, 'num_neighbors': 5, 'lr_decay': 0.8005118061757704}. Best is trial 0 with value: 0.08802041750421653.


Best trial: 0. Best value: 0.0880204:  70%|███████   | 14/20 [09:27<03:57, 39.55s/it]

Train wi info: {'gini': np.float64(0.23946528830607874), 'ess': np.float64(12004.058434125738), 'max_wi': np.float64(3.135413632519627), 'min_wi': np.float64(0.29403539657464745)}
actual reward: [0.08543704]
{'gini': np.float64(0.239607865084295), 'ess': np.float64(8115.574015057899), 'max_wi': np.float64(2.953828325307407), 'min_wi': np.float64(0.29843985636030884)}
Cross-validated error: 0.00853092972192915
[I 2025-10-27 23:45:27,580] Trial 13 finished with value: 0.07529502266932368 and parameters: {'lr': 0.0007926021217960742, 'num_epochs': 6, 'batch_size': 512, 'num_neighbors': 11, 'lr_decay': 0.8398856336619419}. Best is trial 0 with value: 0.08802041750421653.


Best trial: 0. Best value: 0.0880204:  75%|███████▌  | 15/20 [10:07<03:17, 39.59s/it]

Train wi info: {'gini': np.float64(0.9578629621498209), 'ess': np.float64(580.8117489172511), 'max_wi': np.float64(199.70371818394543), 'min_wi': np.float64(9.912183762454372e-14)}
actual reward: [0.07709441]
{'gini': np.float64(0.9551114061596994), 'ess': np.float64(480.85985219797936), 'max_wi': np.float64(96.50221800902499), 'min_wi': np.float64(9.912183762454372e-14)}
Cross-validated error: 0.007274827483736512
[I 2025-10-27 23:46:07,255] Trial 14 finished with value: 0.07073273484761589 and parameters: {'lr': 0.0025648516344611767, 'num_epochs': 3, 'batch_size': 64, 'num_neighbors': 6, 'lr_decay': 0.8346031769761114}. Best is trial 0 with value: 0.08802041750421653.


Best trial: 0. Best value: 0.0880204:  80%|████████  | 16/20 [10:47<02:39, 39.87s/it]

Train wi info: {'gini': np.float64(0.12015631798584171), 'ess': np.float64(14167.194929351595), 'max_wi': np.float64(1.972309390631449), 'min_wi': np.float64(0.5469914849751779)}
actual reward: [0.08550308]
{'gini': np.float64(0.1215056717219775), 'ess': np.float64(9464.608714531183), 'max_wi': np.float64(1.8753515539863703), 'min_wi': np.float64(0.573490616720202)}
Cross-validated error: 0.008632155190170998
[I 2025-10-27 23:46:47,797] Trial 15 finished with value: 0.07564603778080586 and parameters: {'lr': 0.0004086088930482397, 'num_epochs': 6, 'batch_size': 512, 'num_neighbors': 9, 'lr_decay': 0.8877672012403806}. Best is trial 0 with value: 0.08802041750421653.


Best trial: 0. Best value: 0.0880204:  85%|████████▌ | 17/20 [11:26<01:58, 39.66s/it]

Train wi info: {'gini': np.float64(0.017460788073472004), 'ess': np.float64(14980.768657322667), 'max_wi': np.float64(1.1475887860201517), 'min_wi': np.float64(0.9082488058489442)}
actual reward: [0.08556565]
{'gini': np.float64(0.017932111447683672), 'ess': np.float64(9987.132620243545), 'max_wi': np.float64(1.135400465651514), 'min_wi': np.float64(0.9201269759291669)}
Cross-validated error: 0.00866964629243465
[I 2025-10-27 23:47:26,962] Trial 16 finished with value: 0.07576904456450605 and parameters: {'lr': 0.00010956490665697717, 'num_epochs': 2, 'batch_size': 256, 'num_neighbors': 9, 'lr_decay': 0.8295690962867562}. Best is trial 0 with value: 0.08802041750421653.


Best trial: 0. Best value: 0.0880204:  90%|█████████ | 18/20 [12:06<01:19, 39.56s/it]

Train wi info: {'gini': np.float64(0.9997600384730521), 'ess': np.float64(3.8738355791970456), 'max_wi': np.float64(2750.585804093999), 'min_wi': np.float64(0.0)}
actual reward: [0.07399289]
{'gini': np.float64(0.9997873517180412), 'ess': np.float64(2.271322604868116), 'max_wi': np.float64(2986.880633359096), 'min_wi': np.float64(0.0)}
Cross-validated error: 0.004453766543925651
[I 2025-10-27 23:48:06,290] Trial 17 finished with value: 0.058040625717275154 and parameters: {'lr': 0.04105245418168806, 'num_epochs': 3, 'batch_size': 128, 'num_neighbors': 12, 'lr_decay': 0.8821993832410265}. Best is trial 0 with value: 0.08802041750421653.


Best trial: 0. Best value: 0.0880204:  95%|█████████▌| 19/20 [12:47<00:40, 40.22s/it]

Train wi info: {'gini': np.float64(0.99979825330962), 'ess': np.float64(3.254814397571667), 'max_wi': np.float64(7462.3827316229335), 'min_wi': np.float64(5.825886907571239e-22)}
actual reward: [0.06863549]
{'gini': np.float64(0.9998413283055675), 'ess': np.float64(1.6998012443673802), 'max_wi': np.float64(5981.249174477704), 'min_wi': np.float64(6.470961990118268e-22)}
Cross-validated error: 0.0037177673982042455
[I 2025-10-27 23:48:48,029] Trial 18 finished with value: 0.05353168241433952 and parameters: {'lr': 0.004753717233942495, 'num_epochs': 7, 'batch_size': 64, 'num_neighbors': 6, 'lr_decay': 0.8172081970678516}. Best is trial 0 with value: 0.08802041750421653.


Best trial: 0. Best value: 0.0880204: 100%|██████████| 20/20 [13:27<00:00, 40.38s/it]

Train wi info: {'gini': np.float64(0.4280820350415511), 'ess': np.float64(7907.601050502945), 'max_wi': np.float64(5.327344295792663), 'min_wi': np.float64(0.08180966994520852)}
actual reward: [0.08524725]
{'gini': np.float64(0.42059953857286125), 'ess': np.float64(5569.96297740022), 'max_wi': np.float64(4.974736160629101), 'min_wi': np.float64(0.08180966994520852)}
Cross-validated error: 0.008630729749377779
[I 2025-10-27 23:49:27,734] Trial 19 finished with value: 0.07562763067245258 and parameters: {'lr': 0.0013445150172097717, 'num_epochs': 5, 'batch_size': 512, 'num_neighbors': 8, 'lr_decay': 0.8559553230590546}. Best is trial 0 with value: 0.08802041750421653.





{'gini': np.float64(0.9997772353395038), 'ess': np.float64(2.2784141545297385), 'max_wi': np.float64(5667.301413335771), 'min_wi': np.float64(0.0)}


Unnamed: 0,policy_rewards,ipw,reg_dm,conv_dm,conv_dr,conv_sndr,action_diff_to_real,action_delta,context_diff_to_real,context_delta
0,0.08557719,0.0856,0.08549812,0.08946431,0.08491724,0.08491724,0.82618217,0.0,0.99950468,0.0
15000,0.07726057,0.0,0.09523702,0.05551238,0.02801009,-0.04978863,1.84970676,1.68611378,1.44748668,0.75767119


In [9]:
dataset_params = dict(
                    n_actions= 500,
                    n_users = 500,
                    emb_dim = 16,
                    # sigma = 0.1,
                    eps = 0.6, # this is the epsilon for the noise in the ground truth policy representation
                    ctr = 0.1
                    )

train_dataset = generate_dataset(dataset_params, seed=20000)

Random Item CTR: 0.07042251854546815
Optimal greedy CTR: 0.09999934264692525
Optimal Stochastic CTR: 0.09996075464321043
Our Initial CTR: 0.08647580588501355


In [10]:
# Run the optimization
df6, best_hyperparams_by_size = trainer_trial(num_runs, num_neighbors, num_rounds_list, train_dataset, batch_size, val_size=10000, n_trials=n_trials_for_optuna, prev_best_params=best_params_to_use)

# Show the performance metrics
df6[['policy_rewards', 'ipw', 'reg_dm', 'conv_dm', 'conv_dr', 'conv_sndr', 'action_diff_to_real', 'action_delta', 'context_diff_to_real', 'context_delta']]

{'gini': np.float64(0.46252337713872943), 'ess': np.float64(4199.2626855415365), 'max_wi': np.float64(35.00371895533248), 'min_wi': np.float64(0.015165627151326161)}


[I 2025-10-27 23:50:35,917] A new study created in memory with name: no-name-e0614b86-a2cd-4ea9-aed5-a9cdae968b27
Best trial: 0. Best value: 0.0712219:   5%|▌         | 1/20 [00:41<13:08, 41.50s/it]

Train wi info: {'gini': np.float64(0.9994312030295652), 'ess': np.float64(9.401063447168319), 'max_wi': np.float64(2411.9660013041316), 'min_wi': np.float64(0.0)}
actual reward: [0.08270539]
{'gini': np.float64(0.998834130777689), 'ess': np.float64(12.449745870924417), 'max_wi': np.float64(2563.59741867418), 'min_wi': np.float64(0.0)}
Cross-validated error: 0.008157016065257488
[I 2025-10-27 23:51:17,412] Trial 0 finished with value: 0.07122186095641528 and parameters: {'lr': 0.096, 'num_epochs': 5, 'batch_size': 64, 'num_neighbors': 8, 'lr_decay': 0.85}. Best is trial 0 with value: 0.07122186095641528.


Best trial: 1. Best value: 0.0714719:  10%|█         | 2/20 [01:21<12:11, 40.64s/it]

Train wi info: {'gini': np.float64(0.1257473653317675), 'ess': np.float64(14086.590198829628), 'max_wi': np.float64(2.56468539917588), 'min_wi': np.float64(0.4508305572470686)}
actual reward: [0.0865818]
{'gini': np.float64(0.12299934376143441), 'ess': np.float64(9451.312340616245), 'max_wi': np.float64(2.7402104156322737), 'min_wi': np.float64(0.49551307094742053)}
Cross-validated error: 0.007471196584155263
[I 2025-10-27 23:51:57,449] Trial 1 finished with value: 0.07147194404105459 and parameters: {'lr': 0.0012321998583674654, 'num_epochs': 5, 'batch_size': 512, 'num_neighbors': 8, 'lr_decay': 0.8045851102528513}. Best is trial 1 with value: 0.07147194404105459.


Best trial: 2. Best value: 0.0721559:  15%|█▌        | 3/20 [02:00<11:17, 39.88s/it]

Train wi info: {'gini': np.float64(0.9998897255310978), 'ess': np.float64(1.784996942936559), 'max_wi': np.float64(4713.1922051970805), 'min_wi': np.float64(0.0)}
actual reward: [0.07134034]
{'gini': np.float64(0.9998531058375044), 'ess': np.float64(1.5431857802015325), 'max_wi': np.float64(2151.0355095567597), 'min_wi': np.float64(0.0)}
Cross-validated error: 0.007367156212944903
[I 2025-10-27 23:52:36,430] Trial 2 finished with value: 0.07215585635299072 and parameters: {'lr': 0.03186263736915529, 'num_epochs': 2, 'batch_size': 128, 'num_neighbors': 5, 'lr_decay': 0.980537629282467}. Best is trial 2 with value: 0.07215585635299072.


Best trial: 3. Best value: 0.0799763:  20%|██        | 4/20 [02:40<10:39, 40.00s/it]

Train wi info: {'gini': np.float64(0.9723199765259745), 'ess': np.float64(436.81931650599546), 'max_wi': np.float64(89.9910013675083), 'min_wi': np.float64(3.7801791159491655e-11)}
actual reward: [0.08554836]
{'gini': np.float64(0.971823944939169), 'ess': np.float64(288.9381598532702), 'max_wi': np.float64(91.05508387997153), 'min_wi': np.float64(9.652565908187843e-11)}
Cross-validated error: 0.010024432730146665
[I 2025-10-27 23:53:16,601] Trial 3 finished with value: 0.07997627282423043 and parameters: {'lr': 0.004180320407798399, 'num_epochs': 5, 'batch_size': 256, 'num_neighbors': 9, 'lr_decay': 0.9520486268792834}. Best is trial 3 with value: 0.07997627282423043.


Best trial: 3. Best value: 0.0799763:  25%|██▌       | 5/20 [03:20<09:56, 39.75s/it]

Train wi info: {'gini': np.float64(0.9998840705381117), 'ess': np.float64(1.8229741670460327), 'max_wi': np.float64(4713.1922051970805), 'min_wi': np.float64(0.0)}
actual reward: [0.06899468]
{'gini': np.float64(0.9998339259639992), 'ess': np.float64(1.570710725925184), 'max_wi': np.float64(2153.1000065135268), 'min_wi': np.float64(0.0)}
Cross-validated error: 0.008956169606269094
[I 2025-10-27 23:53:55,927] Trial 4 finished with value: 0.07124049315972639 and parameters: {'lr': 0.07376657360828262, 'num_epochs': 2, 'batch_size': 256, 'num_neighbors': 13, 'lr_decay': 0.8742722537585211}. Best is trial 3 with value: 0.07997627282423043.


Best trial: 3. Best value: 0.0799763:  30%|███       | 6/20 [04:00<09:22, 40.16s/it]

Train wi info: {'gini': np.float64(0.9996682581712187), 'ess': np.float64(4.280353317602677), 'max_wi': np.float64(5398.927372068734), 'min_wi': np.float64(4.192665647013834e-24)}
actual reward: [0.07448656]
{'gini': np.float64(0.9991328402916303), 'ess': np.float64(8.658601688925252), 'max_wi': np.float64(4085.7126623403), 'min_wi': np.float64(5.8014288890421896e-24)}
Cross-validated error: 0.0071840609182296025
[I 2025-10-27 23:54:36,871] Trial 5 finished with value: 0.07035279338479433 and parameters: {'lr': 0.006896778798893424, 'num_epochs': 4, 'batch_size': 64, 'num_neighbors': 3, 'lr_decay': 0.8734280711497683}. Best is trial 3 with value: 0.07997627282423043.


Best trial: 3. Best value: 0.0799763:  35%|███▌      | 7/20 [04:42<08:48, 40.68s/it]

Train wi info: {'gini': np.float64(0.9998873249208453), 'ess': np.float64(1.8185727333925004), 'max_wi': np.float64(4712.615740639131), 'min_wi': np.float64(6.337903934544197e-26)}
actual reward: [0.06942375]
{'gini': np.float64(0.9998166969409036), 'ess': np.float64(1.636537457605633), 'max_wi': np.float64(2152.148040462651), 'min_wi': np.float64(6.337903934544197e-26)}
Cross-validated error: 0.005097850261697264
[I 2025-10-27 23:55:18,627] Trial 6 finished with value: 0.061119789146213424 and parameters: {'lr': 0.00576296502976337, 'num_epochs': 6, 'batch_size': 64, 'num_neighbors': 3, 'lr_decay': 0.9245479486988929}. Best is trial 3 with value: 0.07997627282423043.


Best trial: 3. Best value: 0.0799763:  40%|████      | 8/20 [05:24<08:11, 40.99s/it]

Train wi info: {'gini': np.float64(0.9782151619589452), 'ess': np.float64(249.0119139387142), 'max_wi': np.float64(317.39458150308775), 'min_wi': np.float64(7.3558501543718506e-09)}
actual reward: [0.0859805]
{'gini': np.float64(0.9772040314630567), 'ess': np.float64(168.09321484106277), 'max_wi': np.float64(320.7897940807504), 'min_wi': np.float64(9.429362577500794e-09)}
Cross-validated error: 0.009103153279232166
[I 2025-10-27 23:56:00,288] Trial 7 finished with value: 0.07720215637691151 and parameters: {'lr': 0.0017505128934260118, 'num_epochs': 3, 'batch_size': 64, 'num_neighbors': 8, 'lr_decay': 0.9895626384522236}. Best is trial 3 with value: 0.07997627282423043.


Best trial: 3. Best value: 0.0799763:  45%|████▌     | 9/20 [06:05<07:30, 40.94s/it]

Train wi info: {'gini': np.float64(0.4341191349913053), 'ess': np.float64(7047.228666110702), 'max_wi': np.float64(8.914198014875167), 'min_wi': np.float64(0.042166971920432436)}
actual reward: [0.0871369]
{'gini': np.float64(0.41548404130501115), 'ess': np.float64(5264.315616449051), 'max_wi': np.float64(8.914198014875167), 'min_wi': np.float64(0.04846650307637128)}
Cross-validated error: 0.007743433216059856
[I 2025-10-27 23:56:41,113] Trial 8 finished with value: 0.07249307422023661 and parameters: {'lr': 0.0015182818154452432, 'num_epochs': 9, 'batch_size': 512, 'num_neighbors': 9, 'lr_decay': 0.9224753460761134}. Best is trial 3 with value: 0.07997627282423043.


Best trial: 3. Best value: 0.0799763:  50%|█████     | 10/20 [06:46<06:49, 40.94s/it]

Train wi info: {'gini': np.float64(0.9949961501724518), 'ess': np.float64(51.998950360881565), 'max_wi': np.float64(590.4434852049756), 'min_wi': np.float64(2.126770287840466e-24)}
actual reward: [0.08299494]
{'gini': np.float64(0.9961919181949402), 'ess': np.float64(15.475549044987037), 'max_wi': np.float64(1801.4966000530208), 'min_wi': np.float64(2.126770287840466e-24)}
Cross-validated error: 0.006916033514540874
[I 2025-10-27 23:57:22,056] Trial 9 finished with value: 0.06932449638026764 and parameters: {'lr': 0.008866060949015455, 'num_epochs': 7, 'batch_size': 256, 'num_neighbors': 13, 'lr_decay': 0.9441152032282301}. Best is trial 3 with value: 0.07997627282423043.


Best trial: 3. Best value: 0.0799763:  55%|█████▌    | 11/20 [07:28<06:12, 41.41s/it]

Train wi info: {'gini': np.float64(0.06045407483517182), 'ess': np.float64(14773.033484700745), 'max_wi': np.float64(1.8089062558695121), 'min_wi': np.float64(0.6670071950840364)}
actual reward: [0.08652699]
{'gini': np.float64(0.06033892483976576), 'ess': np.float64(9856.716538028833), 'max_wi': np.float64(2.0353713262296247), 'min_wi': np.float64(0.6926181510744893)}
Cross-validated error: 0.007438610198748873
[I 2025-10-27 23:58:04,516] Trial 10 finished with value: 0.07133559296484679 and parameters: {'lr': 0.00018238310446421963, 'num_epochs': 10, 'batch_size': 256, 'num_neighbors': 11, 'lr_decay': 0.959563054622239}. Best is trial 3 with value: 0.07997627282423043.


Best trial: 3. Best value: 0.0799763:  60%|██████    | 12/20 [08:08<05:27, 40.98s/it]

Train wi info: {'gini': np.float64(0.09869441534023364), 'ess': np.float64(14319.490809894438), 'max_wi': np.float64(2.747827052121196), 'min_wi': np.float64(0.5208272037204034)}
actual reward: [0.08656602]
{'gini': np.float64(0.09951345215425232), 'ess': np.float64(9552.25484068056), 'max_wi': np.float64(3.8801428282907096), 'min_wi': np.float64(0.5735624960376801)}
Cross-validated error: 0.007469054051861297
[I 2025-10-27 23:58:44,508] Trial 11 finished with value: 0.07145447062999762 and parameters: {'lr': 0.0004776672868830939, 'num_epochs': 3, 'batch_size': 128, 'num_neighbors': 6, 'lr_decay': 0.9955137758812541}. Best is trial 3 with value: 0.07997627282423043.


Best trial: 3. Best value: 0.0799763:  65%|██████▌   | 13/20 [08:47<04:43, 40.44s/it]

Train wi info: {'gini': np.float64(0.0940541901635803), 'ess': np.float64(14463.812764054715), 'max_wi': np.float64(2.365638389040993), 'min_wi': np.float64(0.5882662104688227)}
actual reward: [0.08654444]
{'gini': np.float64(0.09423679307003224), 'ess': np.float64(9653.094616072975), 'max_wi': np.float64(2.8193306623930248), 'min_wi': np.float64(0.5337236080078261)}
Cross-validated error: 0.007482764496077845
[I 2025-10-27 23:59:23,708] Trial 12 finished with value: 0.07150951948255968 and parameters: {'lr': 0.0020801733802212454, 'num_epochs': 1, 'batch_size': 256, 'num_neighbors': 10, 'lr_decay': 0.9689832255326974}. Best is trial 3 with value: 0.07997627282423043.


Best trial: 3. Best value: 0.0799763:  70%|███████   | 14/20 [09:29<04:04, 40.78s/it]

Train wi info: {'gini': np.float64(0.9219113727355062), 'ess': np.float64(714.141956959728), 'max_wi': np.float64(64.5903447222483), 'min_wi': np.float64(8.931249317539597e-06)}
actual reward: [0.0870457]
{'gini': np.float64(0.9180389400369363), 'ess': np.float64(491.53343816590564), 'max_wi': np.float64(94.63718236786679), 'min_wi': np.float64(1.4435878583848543e-05)}
Cross-validated error: 0.009534079943280066
[I 2025-10-28 00:00:05,272] Trial 13 finished with value: 0.07859119515671811 and parameters: {'lr': 0.0005530241613627311, 'num_epochs': 7, 'batch_size': 64, 'num_neighbors': 15, 'lr_decay': 0.9997661999141829}. Best is trial 3 with value: 0.07997627282423043.


Best trial: 3. Best value: 0.0799763:  75%|███████▌  | 15/20 [10:12<03:26, 41.39s/it]

Train wi info: {'gini': np.float64(0.07036992451368007), 'ess': np.float64(14644.915018165124), 'max_wi': np.float64(2.510392786706382), 'min_wi': np.float64(0.6286191802798744)}
actual reward: [0.08653635]
{'gini': np.float64(0.07194203134020247), 'ess': np.float64(9751.175815363535), 'max_wi': np.float64(3.6299771099245284), 'min_wi': np.float64(0.6562707775882048)}
Cross-validated error: 0.007490666436347652
[I 2025-10-28 00:00:48,075] Trial 14 finished with value: 0.07155229508289289 and parameters: {'lr': 0.00010932465846522318, 'num_epochs': 8, 'batch_size': 64, 'num_neighbors': 15, 'lr_decay': 0.9420661106193113}. Best is trial 3 with value: 0.07997627282423043.


Best trial: 3. Best value: 0.0799763:  80%|████████  | 16/20 [10:53<02:45, 41.36s/it]

Train wi info: {'gini': np.float64(0.1280748679898056), 'ess': np.float64(13943.88316489547), 'max_wi': np.float64(3.0272529799469465), 'min_wi': np.float64(0.435958828196972)}
actual reward: [0.08659879]
{'gini': np.float64(0.12638763459128594), 'ess': np.float64(9352.183554798403), 'max_wi': np.float64(3.472206138105926), 'min_wi': np.float64(0.46272204682456697)}
Cross-validated error: 0.007494202349019805
[I 2025-10-28 00:01:29,367] Trial 15 finished with value: 0.07158195248394045 and parameters: {'lr': 0.0005007149443198234, 'num_epochs': 7, 'batch_size': 256, 'num_neighbors': 12, 'lr_decay': 0.9163806632879883}. Best is trial 3 with value: 0.07997627282423043.


Best trial: 3. Best value: 0.0799763:  85%|████████▌ | 17/20 [11:35<02:04, 41.43s/it]

Train wi info: {'gini': np.float64(0.08717199964663595), 'ess': np.float64(14553.024212897313), 'max_wi': np.float64(2.0720230764024334), 'min_wi': np.float64(0.5673201228396793)}
actual reward: [0.08654087]
{'gini': np.float64(0.0855788619123784), 'ess': np.float64(9728.121302037114), 'max_wi': np.float64(2.1880728944133447), 'min_wi': np.float64(0.5979451261029534)}
Cross-validated error: 0.007478507102752922
[I 2025-10-28 00:02:10,975] Trial 16 finished with value: 0.07151649828212726 and parameters: {'lr': 0.0005315547965164072, 'num_epochs': 6, 'batch_size': 512, 'num_neighbors': 15, 'lr_decay': 0.9985840816831022}. Best is trial 3 with value: 0.07997627282423043.


Best trial: 3. Best value: 0.0799763:  90%|█████████ | 18/20 [12:16<01:22, 41.29s/it]

Train wi info: {'gini': np.float64(0.9996994476318798), 'ess': np.float64(4.621538324525221), 'max_wi': np.float64(4103.574754746176), 'min_wi': np.float64(8.638699267976359e-37)}
actual reward: [0.07171578]
{'gini': np.float64(0.9996223464381472), 'ess': np.float64(3.676154840837321), 'max_wi': np.float64(2021.7679663492509), 'min_wi': np.float64(5.448562431611212e-35)}
Cross-validated error: 0.008434144382532708
[I 2025-10-28 00:02:51,930] Trial 17 finished with value: 0.07496664289159372 and parameters: {'lr': 0.020240645874223934, 'num_epochs': 7, 'batch_size': 128, 'num_neighbors': 6, 'lr_decay': 0.9641179964607101}. Best is trial 3 with value: 0.07997627282423043.


Best trial: 3. Best value: 0.0799763:  95%|█████████▌| 19/20 [12:58<00:41, 41.76s/it]

Train wi info: {'gini': np.float64(0.9993640529649561), 'ess': np.float64(10.537726252969954), 'max_wi': np.float64(1168.645500627307), 'min_wi': np.float64(6.016430614358391e-21)}
actual reward: [0.08051569]
{'gini': np.float64(0.9988934204763997), 'ess': np.float64(11.046908740763518), 'max_wi': np.float64(1424.1011766179781), 'min_wi': np.float64(7.081151830313337e-21)}
Cross-validated error: 0.007531752584274878
[I 2025-10-28 00:03:34,776] Trial 18 finished with value: 0.07174998021052512 and parameters: {'lr': 0.003444957483594971, 'num_epochs': 9, 'batch_size': 64, 'num_neighbors': 10, 'lr_decay': 0.8968488577467927}. Best is trial 3 with value: 0.07997627282423043.


Best trial: 3. Best value: 0.0799763: 100%|██████████| 20/20 [13:38<00:00, 40.93s/it]

Train wi info: {'gini': np.float64(0.13375982388872948), 'ess': np.float64(13827.172150936609), 'max_wi': np.float64(3.2117052021458963), 'min_wi': np.float64(0.40741735250198163)}
actual reward: [0.08660467]
{'gini': np.float64(0.132022604738014), 'ess': np.float64(9276.472909661627), 'max_wi': np.float64(3.5697834348430355), 'min_wi': np.float64(0.447558851582022)}
Cross-validated error: 0.007522171641492355
[I 2025-10-28 00:04:14,469] Trial 19 finished with value: 0.07167480223314454 and parameters: {'lr': 0.0007491708504922513, 'num_epochs': 4, 'batch_size': 256, 'num_neighbors': 13, 'lr_decay': 0.950573235034753}. Best is trial 3 with value: 0.07997627282423043.





{'gini': np.float64(0.9878041884816501), 'ess': np.float64(93.75247677853959), 'max_wi': np.float64(262.31639900636765), 'min_wi': np.float64(6.034154546512548e-14)}


Unnamed: 0,policy_rewards,ipw,reg_dm,conv_dm,conv_dr,conv_sndr,action_diff_to_real,action_delta,context_diff_to_real,context_delta
0,0.08647581,0.0883,0.08831187,0.08828296,0.08685068,0.08685068,0.88083979,0.0,0.74725465,0.0
15000,0.0834197,0.09461783,0.09015125,0.10100693,0.09713463,0.0968437,1.02117958,0.37188423,0.92463774,0.32452709


In [11]:
dataset_params = dict(
                    n_actions= 500,
                    n_users = 500,
                    emb_dim = 16,
                    # sigma = 0.1,
                    eps = 0.6, # this is the epsilon for the noise in the ground truth policy representation
                    ctr = 0.1
                    )

train_dataset = generate_dataset(dataset_params, seed=30000)

Random Item CTR: 0.07069350185865088
Optimal greedy CTR: 0.09999918303816259
Optimal Stochastic CTR: 0.0999509448932121
Our Initial CTR: 0.08653966603258505


In [12]:
# Run the optimization
df7, best_hyperparams_by_size = trainer_trial(num_runs, num_neighbors, num_rounds_list, train_dataset, batch_size, val_size=10000, n_trials=n_trials_for_optuna, prev_best_params=best_params_to_use)

# Show the performance metrics
df7[['policy_rewards', 'ipw', 'reg_dm', 'conv_dm', 'conv_dr', 'conv_sndr', 'action_diff_to_real', 'action_delta', 'context_diff_to_real', 'context_delta']]

{'gini': np.float64(0.46691230806930045), 'ess': np.float64(4435.710165485847), 'max_wi': np.float64(25.67323418791716), 'min_wi': np.float64(0.015833285225929968)}


[I 2025-10-28 00:05:22,516] A new study created in memory with name: no-name-5585ab91-e63c-4883-b5c9-c58aa40f79fe
Best trial: 0. Best value: 0.0684667:   5%|▌         | 1/20 [00:41<13:05, 41.36s/it]

Train wi info: {'gini': np.float64(0.9992690163389645), 'ess': np.float64(11.54661790260458), 'max_wi': np.float64(2110.4958967186294), 'min_wi': np.float64(0.0)}
actual reward: [0.0765924]
{'gini': np.float64(0.9992193085635824), 'ess': np.float64(8.583189658067637), 'max_wi': np.float64(2520.606918191326), 'min_wi': np.float64(0.0)}
Cross-validated error: 0.007327511236574311
[I 2025-10-28 00:06:03,871] Trial 0 finished with value: 0.06846671806879495 and parameters: {'lr': 0.096, 'num_epochs': 5, 'batch_size': 64, 'num_neighbors': 8, 'lr_decay': 0.85}. Best is trial 0 with value: 0.06846671806879495.


Best trial: 1. Best value: 0.073837:  10%|█         | 2/20 [01:20<12:02, 40.16s/it] 

Train wi info: {'gini': np.float64(0.027310923686200736), 'ess': np.float64(14954.46466633772), 'max_wi': np.float64(1.3390055342022473), 'min_wi': np.float64(0.8402084858569301)}
actual reward: [0.0865726]
{'gini': np.float64(0.026764246775215227), 'ess': np.float64(9971.683839404637), 'max_wi': np.float64(1.3299329635273056), 'min_wi': np.float64(0.8579692861930686)}
Cross-validated error: 0.008115159540973046
[I 2025-10-28 00:06:43,195] Trial 1 finished with value: 0.07383702561624221 and parameters: {'lr': 0.00045238085386773347, 'num_epochs': 2, 'batch_size': 256, 'num_neighbors': 15, 'lr_decay': 0.8670904426471295}. Best is trial 1 with value: 0.07383702561624221.


Best trial: 1. Best value: 0.073837:  15%|█▌        | 3/20 [01:59<11:15, 39.75s/it]

Train wi info: {'gini': np.float64(0.9474460754362192), 'ess': np.float64(528.4531321418755), 'max_wi': np.float64(91.85959842165985), 'min_wi': np.float64(6.039190111254048e-08)}
actual reward: [0.08799793]
{'gini': np.float64(0.94021626257596), 'ess': np.float64(405.554825416619), 'max_wi': np.float64(74.69058564386741), 'min_wi': np.float64(9.495784414906563e-08)}
Cross-validated error: 0.005262281019290073
[I 2025-10-28 00:07:22,463] Trial 2 finished with value: 0.06197332163619088 and parameters: {'lr': 0.0077515682641924955, 'num_epochs': 2, 'batch_size': 256, 'num_neighbors': 6, 'lr_decay': 0.864169820348822}. Best is trial 1 with value: 0.07383702561624221.


Best trial: 1. Best value: 0.073837:  20%|██        | 4/20 [02:38<10:31, 39.45s/it]

Train wi info: {'gini': np.float64(0.11222290000669159), 'ess': np.float64(13809.260717894582), 'max_wi': np.float64(6.173982296697973), 'min_wi': np.float64(0.4813821530445688)}
actual reward: [0.08673424]
{'gini': np.float64(0.11128012904714825), 'ess': np.float64(9272.845889941982), 'max_wi': np.float64(3.9832543093685215), 'min_wi': np.float64(0.5101921075565431)}
Cross-validated error: 0.008102189177194977
[I 2025-10-28 00:08:01,452] Trial 3 finished with value: 0.0738213251404952 and parameters: {'lr': 0.0010721229319430666, 'num_epochs': 1, 'batch_size': 64, 'num_neighbors': 6, 'lr_decay': 0.8927638786319713}. Best is trial 1 with value: 0.07383702561624221.


Best trial: 4. Best value: 0.0782062:  25%|██▌       | 5/20 [03:19<09:57, 39.81s/it]

Train wi info: {'gini': np.float64(0.9997320198721401), 'ess': np.float64(3.496316664872391), 'max_wi': np.float64(15529.094480478867), 'min_wi': np.float64(5.8212397599592684e-39)}
actual reward: [0.07840653]
{'gini': np.float64(0.9996878819877187), 'ess': np.float64(3.2119782691758894), 'max_wi': np.float64(2559.9834309000344), 'min_wi': np.float64(5.8212397599592684e-39)}
Cross-validated error: 0.009426375299844313
[I 2025-10-28 00:08:41,893] Trial 4 finished with value: 0.07820624229404924 and parameters: {'lr': 0.029924897556198544, 'num_epochs': 6, 'batch_size': 256, 'num_neighbors': 14, 'lr_decay': 0.9208132031616365}. Best is trial 4 with value: 0.07820624229404924.


Best trial: 4. Best value: 0.0782062:  30%|███       | 6/20 [04:01<09:30, 40.74s/it]

Train wi info: {'gini': np.float64(0.02543093103837844), 'ess': np.float64(14959.855694055059), 'max_wi': np.float64(1.3265614755669657), 'min_wi': np.float64(0.8507571906720639)}
actual reward: [0.08657282]
{'gini': np.float64(0.024983553401649573), 'ess': np.float64(9974.89955123999), 'max_wi': np.float64(1.316608878679879), 'min_wi': np.float64(0.8631208591964237)}
Cross-validated error: 0.008095121458796425
[I 2025-10-28 00:09:24,438] Trial 5 finished with value: 0.0737645871370998 and parameters: {'lr': 0.00016311569206806345, 'num_epochs': 10, 'batch_size': 256, 'num_neighbors': 5, 'lr_decay': 0.8342550361712783}. Best is trial 4 with value: 0.07820624229404924.


Best trial: 4. Best value: 0.0782062:  35%|███▌      | 7/20 [04:43<08:54, 41.11s/it]

Train wi info: {'gini': np.float64(0.9944095670450906), 'ess': np.float64(58.44747054591682), 'max_wi': np.float64(1035.446987043114), 'min_wi': np.float64(1.3468893102844023e-16)}
actual reward: [0.07355942]
{'gini': np.float64(0.9928390791124279), 'ess': np.float64(56.419378812617026), 'max_wi': np.float64(657.511376037165), 'min_wi': np.float64(1.3948096587215273e-16)}
Cross-validated error: 0.004235254490456356
[I 2025-10-28 00:10:06,319] Trial 6 finished with value: 0.056505296437310934 and parameters: {'lr': 0.00391158924413331, 'num_epochs': 8, 'batch_size': 128, 'num_neighbors': 8, 'lr_decay': 0.8295691374098305}. Best is trial 4 with value: 0.07820624229404924.


Best trial: 4. Best value: 0.0782062:  40%|████      | 8/20 [05:24<08:09, 40.83s/it]

Train wi info: {'gini': np.float64(0.7799637685146116), 'ess': np.float64(1361.302328799074), 'max_wi': np.float64(41.51601430584438), 'min_wi': np.float64(0.0002558068748296458)}
actual reward: [0.08852757]
{'gini': np.float64(0.7534647040664563), 'ess': np.float64(1144.450306292141), 'max_wi': np.float64(41.51601430584438), 'min_wi': np.float64(0.0002558068748296458)}
Cross-validated error: 0.006797435361663404
[I 2025-10-28 00:10:46,547] Trial 7 finished with value: 0.06883930996898531 and parameters: {'lr': 0.002464645207526292, 'num_epochs': 5, 'batch_size': 256, 'num_neighbors': 4, 'lr_decay': 0.8995589122763292}. Best is trial 4 with value: 0.07820624229404924.


Best trial: 4. Best value: 0.0782062:  45%|████▌     | 9/20 [06:07<07:39, 41.73s/it]

Train wi info: {'gini': np.float64(0.5650737596663375), 'ess': np.float64(2565.2499935005712), 'max_wi': np.float64(29.19791683359657), 'min_wi': np.float64(0.00599906055889564)}
actual reward: [0.08782621]
{'gini': np.float64(0.5445526842878923), 'ess': np.float64(1984.0833031166212), 'max_wi': np.float64(29.094395304694633), 'min_wi': np.float64(0.006579774166346388)}
Cross-validated error: 0.007499427987284007
[I 2025-10-28 00:11:30,246] Trial 8 finished with value: 0.07159819601806491 and parameters: {'lr': 0.0003300125067608769, 'num_epochs': 9, 'batch_size': 64, 'num_neighbors': 7, 'lr_decay': 0.9839200287312162}. Best is trial 4 with value: 0.07820624229404924.


Best trial: 4. Best value: 0.0782062:  50%|█████     | 10/20 [06:48<06:55, 41.54s/it]

Train wi info: {'gini': np.float64(0.0586862899360742), 'ess': np.float64(14795.119020472506), 'max_wi': np.float64(1.7302949811575652), 'min_wi': np.float64(0.6911764677881116)}
actual reward: [0.08660945]
{'gini': np.float64(0.05722343671938731), 'ess': np.float64(9874.710151011324), 'max_wi': np.float64(1.7302949811575652), 'min_wi': np.float64(0.697304260535411)}
Cross-validated error: 0.008074250246931486
[I 2025-10-28 00:12:11,363] Trial 9 finished with value: 0.07371392619379041 and parameters: {'lr': 0.0008346598603800893, 'num_epochs': 4, 'batch_size': 512, 'num_neighbors': 5, 'lr_decay': 0.8345595150733487}. Best is trial 4 with value: 0.07820624229404924.


Best trial: 10. Best value: 0.0799275:  55%|█████▌    | 11/20 [07:29<06:11, 41.28s/it]

Train wi info: {'gini': np.float64(0.9992761342901892), 'ess': np.float64(3.942888649533661), 'max_wi': np.float64(4178.2650319121285), 'min_wi': np.float64(1.8763531260310622e-39)}
actual reward: [0.07207853]
{'gini': np.float64(0.9990450321801098), 'ess': np.float64(4.91786996271081), 'max_wi': np.float64(1115.7506640461731), 'min_wi': np.float64(1.8763531260310622e-39)}
Cross-validated error: 0.010004083794711174
[I 2025-10-28 00:12:52,041] Trial 10 finished with value: 0.07992748936708931 and parameters: {'lr': 0.036288809853074526, 'num_epochs': 7, 'batch_size': 512, 'num_neighbors': 13, 'lr_decay': 0.9591244003386906}. Best is trial 10 with value: 0.07992748936708931.


Best trial: 10. Best value: 0.0799275:  60%|██████    | 12/20 [08:10<05:28, 41.10s/it]

Train wi info: {'gini': np.float64(0.9997257670369128), 'ess': np.float64(4.194039731472434), 'max_wi': np.float64(15535.616281146782), 'min_wi': np.float64(1.7869352533911768e-43)}
actual reward: [0.06818725]
{'gini': np.float64(0.9998403772451552), 'ess': np.float64(1.531587806290671), 'max_wi': np.float64(9447.898543612175), 'min_wi': np.float64(1.888685315384629e-43)}
Cross-validated error: 0.007701183770952358
[I 2025-10-28 00:13:32,752] Trial 11 finished with value: 0.07230401995180344 and parameters: {'lr': 0.03800737242051641, 'num_epochs': 7, 'batch_size': 512, 'num_neighbors': 13, 'lr_decay': 0.9589875531467104}. Best is trial 10 with value: 0.07992748936708931.


Best trial: 10. Best value: 0.0799275:  65%|██████▌   | 13/20 [08:51<04:48, 41.18s/it]

Train wi info: {'gini': np.float64(0.9965783390433004), 'ess': np.float64(22.77880374800158), 'max_wi': np.float64(2202.411211893296), 'min_wi': np.float64(2.1050973517174294e-34)}
actual reward: [0.07403761]
{'gini': np.float64(0.9943270424648278), 'ess': np.float64(50.71396791635804), 'max_wi': np.float64(643.4347885544047), 'min_wi': np.float64(2.1050973517174294e-34)}
Cross-validated error: 0.004440870617698927
[I 2025-10-28 00:14:14,107] Trial 12 finished with value: 0.05764004271291516 and parameters: {'lr': 0.02727720314144876, 'num_epochs': 7, 'batch_size': 512, 'num_neighbors': 12, 'lr_decay': 0.9417297618473212}. Best is trial 10 with value: 0.07992748936708931.


Best trial: 10. Best value: 0.0799275:  70%|███████   | 14/20 [09:33<04:07, 41.31s/it]

Train wi info: {'gini': np.float64(0.9993132516991144), 'ess': np.float64(9.037293290123811), 'max_wi': np.float64(1520.740015263155), 'min_wi': np.float64(3.276049061880296e-34)}
actual reward: [0.06953641]
{'gini': np.float64(0.999012793936479), 'ess': np.float64(7.083579835977368), 'max_wi': np.float64(2117.318306700297), 'min_wi': np.float64(1.4873498637755967e-33)}
Cross-validated error: 0.00837947913663741
[I 2025-10-28 00:14:55,734] Trial 13 finished with value: 0.07476046743268103 and parameters: {'lr': 0.016615270370834934, 'num_epochs': 7, 'batch_size': 128, 'num_neighbors': 11, 'lr_decay': 0.9323286478195265}. Best is trial 10 with value: 0.07992748936708931.


Best trial: 10. Best value: 0.0799275:  75%|███████▌  | 15/20 [10:13<03:24, 40.89s/it]

Train wi info: {'gini': np.float64(0.9994821549406316), 'ess': np.float64(7.600705805029452), 'max_wi': np.float64(2477.7395493599283), 'min_wi': np.float64(0.0)}
actual reward: [0.07391446]
{'gini': np.float64(0.9988374410651654), 'ess': np.float64(8.262840609045366), 'max_wi': np.float64(2494.77001812824), 'min_wi': np.float64(0.0)}
Cross-validated error: 0.008470011844035855
[I 2025-10-28 00:15:35,627] Trial 14 finished with value: 0.0766831482741004 and parameters: {'lr': 0.08789465107010729, 'num_epochs': 4, 'batch_size': 512, 'num_neighbors': 15, 'lr_decay': 0.9996246610203103}. Best is trial 10 with value: 0.07992748936708931.


Best trial: 10. Best value: 0.0799275:  80%|████████  | 16/20 [10:54<02:44, 41.13s/it]

Train wi info: {'gini': np.float64(0.9962413396600024), 'ess': np.float64(46.65828814630374), 'max_wi': np.float64(1039.0800374447429), 'min_wi': np.float64(7.759367546418227e-25)}
actual reward: [0.0648827]
{'gini': np.float64(0.9948450605571616), 'ess': np.float64(45.90164223503791), 'max_wi': np.float64(676.7250085551485), 'min_wi': np.float64(7.759367546418227e-25)}
Cross-validated error: 0.0034359043809684435
[I 2025-10-28 00:16:17,337] Trial 15 finished with value: 0.05169150426105719 and parameters: {'lr': 0.01115938313419852, 'num_epochs': 6, 'batch_size': 256, 'num_neighbors': 10, 'lr_decay': 0.9289697219359209}. Best is trial 10 with value: 0.07992748936708931.


Best trial: 10. Best value: 0.0799275:  85%|████████▌ | 17/20 [11:38<02:05, 41.89s/it]

Train wi info: {'gini': np.float64(0.9982027490471936), 'ess': np.float64(5.718983833465686), 'max_wi': np.float64(4291.561162429713), 'min_wi': np.float64(6.14769729891255e-41)}
actual reward: [0.07637492]
{'gini': np.float64(0.9965762891022667), 'ess': np.float64(32.01538338278205), 'max_wi': np.float64(359.90708007592053), 'min_wi': np.float64(6.14769729891255e-41)}
Cross-validated error: 0.006782589839005816
[I 2025-10-28 00:17:00,973] Trial 16 finished with value: 0.06871034604761105 and parameters: {'lr': 0.04218759998759707, 'num_epochs': 9, 'batch_size': 512, 'num_neighbors': 13, 'lr_decay': 0.9591290028352266}. Best is trial 10 with value: 0.07992748936708931.


Best trial: 10. Best value: 0.0799275:  90%|█████████ | 18/20 [12:21<01:24, 42.19s/it]

Train wi info: {'gini': np.float64(0.9165753338485441), 'ess': np.float64(756.861967274632), 'max_wi': np.float64(49.65180623886492), 'min_wi': np.float64(1.0443135384374695e-06)}
actual reward: [0.08900447]
{'gini': np.float64(0.9006933790850102), 'ess': np.float64(622.9869581516742), 'max_wi': np.float64(47.52296144141636), 'min_wi': np.float64(1.1909631155811792e-06)}
Cross-validated error: 0.006960188138021279
[I 2025-10-28 00:17:43,868] Trial 17 finished with value: 0.06943649530840798 and parameters: {'lr': 0.006190061536108574, 'num_epochs': 6, 'batch_size': 512, 'num_neighbors': 14, 'lr_decay': 0.8041436856399324}. Best is trial 10 with value: 0.07992748936708931.


Best trial: 18. Best value: 0.0805536:  95%|█████████▌| 19/20 [13:04<00:42, 42.34s/it]

Train wi info: {'gini': np.float64(0.9995841280888977), 'ess': np.float64(3.9353100648258694), 'max_wi': np.float64(8284.861406059881), 'min_wi': np.float64(1.6654705220828945e-33)}
actual reward: [0.07983486]
{'gini': np.float64(0.9989287115028471), 'ess': np.float64(11.120027323340231), 'max_wi': np.float64(1335.3525780386265), 'min_wi': np.float64(1.6654705220828945e-33)}
Cross-validated error: 0.010218239179396859
[I 2025-10-28 00:18:26,574] Trial 18 finished with value: 0.08055358904798596 and parameters: {'lr': 0.021481255407433664, 'num_epochs': 4, 'batch_size': 128, 'num_neighbors': 10, 'lr_decay': 0.9133768116739862}. Best is trial 18 with value: 0.08055358904798596.


Best trial: 18. Best value: 0.0805536: 100%|██████████| 20/20 [13:44<00:00, 41.23s/it]

Train wi info: {'gini': np.float64(0.9943218377615222), 'ess': np.float64(53.892139439341456), 'max_wi': np.float64(1514.6693686189371), 'min_wi': np.float64(5.601599163940932e-29)}
actual reward: [0.07374614]
{'gini': np.float64(0.9955133023881012), 'ess': np.float64(26.46244506776238), 'max_wi': np.float64(1294.5221231175915), 'min_wi': np.float64(5.601599163940932e-29)}
Cross-validated error: 0.006241128218861144
[I 2025-10-28 00:19:07,026] Trial 19 finished with value: 0.06655811771799952 and parameters: {'lr': 0.017472763331314176, 'num_epochs': 3, 'batch_size': 128, 'num_neighbors': 10, 'lr_decay': 0.9699543382075912}. Best is trial 18 with value: 0.08055358904798596.





{'gini': np.float64(0.9973610174551267), 'ess': np.float64(15.827750037516562), 'max_wi': np.float64(1978.5128751704929), 'min_wi': np.float64(1.2503375789455513e-33)}


Unnamed: 0,policy_rewards,ipw,reg_dm,conv_dm,conv_dr,conv_sndr,action_diff_to_real,action_delta,context_diff_to_real,context_delta
0,0.08653967,0.0873,0.08727474,0.08935798,0.08737408,0.08737408,0.82469903,0.0,0.72168239,0.0
15000,0.07620169,0.02471816,0.08501919,0.08134475,0.01775477,0.02246849,1.02339782,0.55808987,1.055258,0.57137736


In [13]:
dataset_params = dict(
                    n_actions= 500,
                    n_users = 500,
                    emb_dim = 16,
                    # sigma = 0.1,
                    eps = 0.6, # this is the epsilon for the noise in the ground truth policy representation
                    ctr = 0.1
                    )

train_dataset = generate_dataset(dataset_params, seed=40000)

Random Item CTR: 0.07053370144999074
Optimal greedy CTR: 0.09999936716169436
Optimal Stochastic CTR: 0.09995563088920843
Our Initial CTR: 0.08622184481781218


In [14]:
# Run the optimization
df8, best_hyperparams_by_size = trainer_trial(num_runs, num_neighbors, num_rounds_list, train_dataset, batch_size, val_size=10000, n_trials=n_trials_for_optuna, prev_best_params=best_params_to_use)

# Show the performance metrics
df8[['policy_rewards', 'ipw', 'reg_dm', 'conv_dm', 'conv_dr', 'conv_sndr', 'action_diff_to_real', 'action_delta', 'context_diff_to_real', 'context_delta']]

{'gini': np.float64(0.4765880277297409), 'ess': np.float64(4297.400342742042), 'max_wi': np.float64(23.98424839857256), 'min_wi': np.float64(0.008677272580760257)}


[I 2025-10-28 00:20:15,279] A new study created in memory with name: no-name-cd3a4713-a607-496b-8326-60b559875358
Best trial: 0. Best value: 0.071175:   5%|▌         | 1/20 [00:41<13:16, 41.94s/it]

Train wi info: {'gini': np.float64(0.9994425534628644), 'ess': np.float64(8.39211110682246), 'max_wi': np.float64(6455.883621370989), 'min_wi': np.float64(0.0)}
actual reward: [0.0781076]
{'gini': np.float64(0.9995481682247254), 'ess': np.float64(4.851856347715597), 'max_wi': np.float64(3550.1441402641103), 'min_wi': np.float64(0.0)}
Cross-validated error: 0.006006053180833118
[I 2025-10-28 00:20:57,218] Trial 0 finished with value: 0.07117500293424119 and parameters: {'lr': 0.096, 'num_epochs': 5, 'batch_size': 64, 'num_neighbors': 8, 'lr_decay': 0.85}. Best is trial 0 with value: 0.07117500293424119.


Best trial: 0. Best value: 0.071175:  10%|█         | 2/20 [01:22<12:20, 41.12s/it]

Train wi info: {'gini': np.float64(0.9955753189931694), 'ess': np.float64(20.342612678521512), 'max_wi': np.float64(626.6510327448123), 'min_wi': np.float64(3.424673251662064e-33)}
actual reward: [0.07547815]
{'gini': np.float64(0.9978620604201811), 'ess': np.float64(9.298761577717519), 'max_wi': np.float64(1952.6189195954032), 'min_wi': np.float64(3.0333617777215167e-33)}
Cross-validated error: 0.006764925460367622
[I 2025-10-28 00:21:37,769] Trial 1 finished with value: 0.0686914786145342 and parameters: {'lr': 0.03051639043783386, 'num_epochs': 3, 'batch_size': 256, 'num_neighbors': 11, 'lr_decay': 0.887472700875917}. Best is trial 0 with value: 0.07117500293424119.


Best trial: 2. Best value: 0.0760413:  15%|█▌        | 3/20 [02:03<11:34, 40.85s/it]

Train wi info: {'gini': np.float64(0.9871695223903892), 'ess': np.float64(35.856202354495366), 'max_wi': np.float64(1782.8301374330197), 'min_wi': np.float64(2.5399950527293326e-22)}
actual reward: [0.08493199]
{'gini': np.float64(0.9862827674934029), 'ess': np.float64(23.778456994889908), 'max_wi': np.float64(1907.9912273723523), 'min_wi': np.float64(2.5399950527293326e-22)}
Cross-validated error: 0.008763525556187222
[I 2025-10-28 00:22:18,293] Trial 2 finished with value: 0.07604127541798236 and parameters: {'lr': 0.014974336061328282, 'num_epochs': 7, 'batch_size': 512, 'num_neighbors': 10, 'lr_decay': 0.9594686178032676}. Best is trial 2 with value: 0.07604127541798236.


Best trial: 2. Best value: 0.0760413:  20%|██        | 4/20 [02:44<10:58, 41.15s/it]

Train wi info: {'gini': np.float64(0.21741910774338888), 'ess': np.float64(12288.399160354662), 'max_wi': np.float64(3.624952306671342), 'min_wi': np.float64(0.22677112843123468)}
actual reward: [0.08654338]
{'gini': np.float64(0.21519009703554656), 'ess': np.float64(8315.176270405775), 'max_wi': np.float64(3.624952306671342), 'min_wi': np.float64(0.22677112843123468)}
Cross-validated error: 0.007890715293343372
[I 2025-10-28 00:22:59,892] Trial 3 finished with value: 0.07303365058563507 and parameters: {'lr': 0.0010530891258288843, 'num_epochs': 8, 'batch_size': 512, 'num_neighbors': 8, 'lr_decay': 0.8352623280428114}. Best is trial 2 with value: 0.07604127541798236.


Best trial: 2. Best value: 0.0760413:  25%|██▌       | 5/20 [03:24<10:08, 40.58s/it]

Train wi info: {'gini': np.float64(0.033448313223498054), 'ess': np.float64(14936.644013458357), 'max_wi': np.float64(1.3047719748253894), 'min_wi': np.float64(0.803865888786562)}
actual reward: [0.08625035]
{'gini': np.float64(0.03338281108558325), 'ess': np.float64(9959.222498956364), 'max_wi': np.float64(1.3047719748253894), 'min_wi': np.float64(0.803865888786562)}
Cross-validated error: 0.007795237636985456
[I 2025-10-28 00:23:39,474] Trial 4 finished with value: 0.07269956124701002 and parameters: {'lr': 0.0005137241565248642, 'num_epochs': 2, 'batch_size': 512, 'num_neighbors': 6, 'lr_decay': 0.9233498047899026}. Best is trial 2 with value: 0.07604127541798236.


Best trial: 2. Best value: 0.0760413:  30%|███       | 6/20 [04:04<09:28, 40.59s/it]

Train wi info: {'gini': np.float64(0.975666356606742), 'ess': np.float64(292.2055773119329), 'max_wi': np.float64(453.47800743005064), 'min_wi': np.float64(3.19601010052063e-12)}
actual reward: [0.0802507]
{'gini': np.float64(0.9752181875333669), 'ess': np.float64(245.55353072312974), 'max_wi': np.float64(163.54502962688156), 'min_wi': np.float64(6.605152956246455e-12)}
Cross-validated error: 0.00808926898441874
[I 2025-10-28 00:24:20,087] Trial 5 finished with value: 0.0735622049923843 and parameters: {'lr': 0.005772220478586617, 'num_epochs': 4, 'batch_size': 256, 'num_neighbors': 4, 'lr_decay': 0.8702137258761482}. Best is trial 2 with value: 0.07604127541798236.


Best trial: 6. Best value: 0.0774287:  35%|███▌      | 7/20 [04:45<08:47, 40.59s/it]

Train wi info: {'gini': np.float64(0.9825139935546585), 'ess': np.float64(136.24021653868155), 'max_wi': np.float64(947.9695538963481), 'min_wi': np.float64(3.7800492988630633e-10)}
actual reward: [0.0769999]
{'gini': np.float64(0.9815664116984298), 'ess': np.float64(124.78728313404994), 'max_wi': np.float64(668.1082520219301), 'min_wi': np.float64(6.746346715238784e-10)}
Cross-validated error: 0.009183412800275952
[I 2025-10-28 00:25:00,674] Trial 6 finished with value: 0.07742868970585519 and parameters: {'lr': 0.0018864664953513158, 'num_epochs': 3, 'batch_size': 64, 'num_neighbors': 8, 'lr_decay': 0.8221397902117843}. Best is trial 6 with value: 0.07742868970585519.


Best trial: 6. Best value: 0.0774287:  40%|████      | 8/20 [05:26<08:09, 40.82s/it]

Train wi info: {'gini': np.float64(0.9982512938696502), 'ess': np.float64(8.130389606619913), 'max_wi': np.float64(3003.7103938381338), 'min_wi': np.float64(5.8977869200724205e-27)}
actual reward: [0.08594521]
{'gini': np.float64(0.9970742531184991), 'ess': np.float64(10.460926432067302), 'max_wi': np.float64(902.910652165782), 'min_wi': np.float64(5.8977869200724205e-27)}
Cross-validated error: 0.0073046788202481205
[I 2025-10-28 00:25:41,981] Trial 7 finished with value: 0.07085546921686042 and parameters: {'lr': 0.017998010351319642, 'num_epochs': 6, 'batch_size': 256, 'num_neighbors': 5, 'lr_decay': 0.8965563614200232}. Best is trial 6 with value: 0.07742868970585519.


Best trial: 6. Best value: 0.0774287:  45%|████▌     | 9/20 [06:10<07:38, 41.72s/it]

Train wi info: {'gini': np.float64(0.9840400946641389), 'ess': np.float64(109.93688736719507), 'max_wi': np.float64(1101.7365285653684), 'min_wi': np.float64(1.0646139783010482e-15)}
actual reward: [0.07037061]
{'gini': np.float64(0.9830528872094991), 'ess': np.float64(116.74601906931122), 'max_wi': np.float64(679.2863204421514), 'min_wi': np.float64(7.803287461980857e-16)}
Cross-validated error: 0.009166041820004227
[I 2025-10-28 00:26:25,687] Trial 8 finished with value: 0.07738473545646835 and parameters: {'lr': 0.0009287890522710303, 'num_epochs': 10, 'batch_size': 64, 'num_neighbors': 12, 'lr_decay': 0.990032907559895}. Best is trial 6 with value: 0.07742868970585519.


Best trial: 6. Best value: 0.0774287:  50%|█████     | 10/20 [06:49<06:50, 41.04s/it]

Train wi info: {'gini': np.float64(0.45332985582045066), 'ess': np.float64(6822.2825156643685), 'max_wi': np.float64(8.091962167882542), 'min_wi': np.float64(0.03934782694329021)}
actual reward: [0.08709364]
{'gini': np.float64(0.44377928557988383), 'ess': np.float64(4904.365457635314), 'max_wi': np.float64(8.091962167882542), 'min_wi': np.float64(0.04024241594771544)}
Cross-validated error: 0.00783676765751243
[I 2025-10-28 00:27:05,209] Trial 9 finished with value: 0.07285199420142685 and parameters: {'lr': 0.0023184591311556004, 'num_epochs': 4, 'batch_size': 512, 'num_neighbors': 10, 'lr_decay': 0.9121542607576805}. Best is trial 6 with value: 0.07742868970585519.


Best trial: 6. Best value: 0.0774287:  55%|█████▌    | 11/20 [07:30<06:07, 40.84s/it]

Train wi info: {'gini': np.float64(0.021784638781765498), 'ess': np.float64(14965.541872172038), 'max_wi': np.float64(1.3111475621107727), 'min_wi': np.float64(0.8727075899380028)}
actual reward: [0.08624736]
{'gini': np.float64(0.022244028667267927), 'ess': np.float64(9976.773837289296), 'max_wi': np.float64(1.3111475621107727), 'min_wi': np.float64(0.8880543857985002)}
Cross-validated error: 0.007887999752697475
[I 2025-10-28 00:27:45,577] Trial 10 finished with value: 0.07302669151807153 and parameters: {'lr': 0.00015073946725868579, 'num_epochs': 2, 'batch_size': 128, 'num_neighbors': 15, 'lr_decay': 0.8007847317408237}. Best is trial 6 with value: 0.07742868970585519.


Best trial: 6. Best value: 0.0774287:  60%|██████    | 12/20 [08:13<05:32, 41.53s/it]

Train wi info: {'gini': np.float64(0.9380419437160525), 'ess': np.float64(625.6512990974966), 'max_wi': np.float64(106.44650399020908), 'min_wi': np.float64(5.3632109331226075e-06)}
actual reward: [0.0841741]
{'gini': np.float64(0.934116211069278), 'ess': np.float64(434.90258971044693), 'max_wi': np.float64(114.98638250822034), 'min_wi': np.float64(6.929959278215726e-06)}
Cross-validated error: 0.008731296879739284
[I 2025-10-28 00:28:28,695] Trial 11 finished with value: 0.07596877149566506 and parameters: {'lr': 0.0003363262078545383, 'num_epochs': 10, 'batch_size': 64, 'num_neighbors': 13, 'lr_decay': 0.9973363595813158}. Best is trial 6 with value: 0.07742868970585519.


Best trial: 6. Best value: 0.0774287:  65%|██████▌   | 13/20 [08:56<04:54, 42.04s/it]

Train wi info: {'gini': np.float64(0.9863958940471246), 'ess': np.float64(55.448007807243776), 'max_wi': np.float64(1270.752661776107), 'min_wi': np.float64(4.1578044883265e-17)}
actual reward: [0.07925253]
{'gini': np.float64(0.9853530381205897), 'ess': np.float64(66.22336466599945), 'max_wi': np.float64(728.5643274898047), 'min_wi': np.float64(9.970104997503642e-17)}
Cross-validated error: 0.009107771408791362
[I 2025-10-28 00:29:11,893] Trial 12 finished with value: 0.07713064306567254 and parameters: {'lr': 0.0020854457620178716, 'num_epochs': 10, 'batch_size': 64, 'num_neighbors': 13, 'lr_decay': 0.9530771879119968}. Best is trial 6 with value: 0.07742868970585519.


Best trial: 6. Best value: 0.0774287:  70%|███████   | 14/20 [09:35<04:05, 40.97s/it]

Train wi info: {'gini': np.float64(0.9899228161528969), 'ess': np.float64(31.264148255729822), 'max_wi': np.float64(2072.359043999418), 'min_wi': np.float64(1.6960842618502037e-11)}
actual reward: [0.09115361]
{'gini': np.float64(0.9909905002510123), 'ess': np.float64(14.970648290656886), 'max_wi': np.float64(2634.63057851196), 'min_wi': np.float64(1.6960842618502037e-11)}
Cross-validated error: 0.008151221888131879
[I 2025-10-28 00:29:50,405] Trial 13 finished with value: 0.0739927214898265 and parameters: {'lr': 0.005853340890965105, 'num_epochs': 1, 'batch_size': 64, 'num_neighbors': 7, 'lr_decay': 0.8021313156946275}. Best is trial 6 with value: 0.07742868970585519.


Best trial: 14. Best value: 0.0788566:  75%|███████▌  | 15/20 [10:17<03:26, 41.26s/it]

Train wi info: {'gini': np.float64(0.9822490676914172), 'ess': np.float64(151.65365188031504), 'max_wi': np.float64(850.789365854257), 'min_wi': np.float64(1.715343754881661e-11)}
actual reward: [0.07548321]
{'gini': np.float64(0.9806399188868394), 'ess': np.float64(139.11648706711154), 'max_wi': np.float64(579.3327679184392), 'min_wi': np.float64(1.715343754881661e-11)}
Cross-validated error: 0.009638572112980722
[I 2025-10-28 00:30:32,323] Trial 14 finished with value: 0.07885655538590086 and parameters: {'lr': 0.0007970021421639396, 'num_epochs': 8, 'batch_size': 64, 'num_neighbors': 12, 'lr_decay': 0.9689550977791144}. Best is trial 14 with value: 0.07885655538590086.


Best trial: 14. Best value: 0.0788566:  80%|████████  | 16/20 [10:58<02:44, 41.21s/it]

Train wi info: {'gini': np.float64(0.06164298041629719), 'ess': np.float64(14709.84416801597), 'max_wi': np.float64(1.9730747409416263), 'min_wi': np.float64(0.6658536037343763)}
actual reward: [0.08629691]
{'gini': np.float64(0.06294046404941396), 'ess': np.float64(9806.930567404952), 'max_wi': np.float64(1.9730747409416263), 'min_wi': np.float64(0.7004529773845232)}
Cross-validated error: 0.007892249449280143
[I 2025-10-28 00:31:13,437] Trial 15 finished with value: 0.0730635733496829 and parameters: {'lr': 0.00010304975682759053, 'num_epochs': 8, 'batch_size': 128, 'num_neighbors': 15, 'lr_decay': 0.9435667788864572}. Best is trial 14 with value: 0.07885655538590086.


Best trial: 14. Best value: 0.0788566:  85%|████████▌ | 17/20 [11:40<02:04, 41.59s/it]

Train wi info: {'gini': np.float64(0.42909662164327744), 'ess': np.float64(5034.863880551787), 'max_wi': np.float64(13.915501774005696), 'min_wi': np.float64(0.04262779398241533)}
actual reward: [0.08676491]
{'gini': np.float64(0.4294596345060618), 'ess': np.float64(3551.810319863416), 'max_wi': np.float64(13.915501774005696), 'min_wi': np.float64(0.04854554928142906)}
Cross-validated error: 0.008050997280615514
[I 2025-10-28 00:31:55,903] Trial 16 finished with value: 0.07358891834158178 and parameters: {'lr': 0.00031919211584449696, 'num_epochs': 8, 'batch_size': 64, 'num_neighbors': 9, 'lr_decay': 0.8501025851581533}. Best is trial 14 with value: 0.07885655538590086.


Best trial: 17. Best value: 0.0805166:  90%|█████████ | 18/20 [12:22<01:23, 41.53s/it]

Train wi info: {'gini': np.float64(0.9785621579937278), 'ess': np.float64(299.57061352652903), 'max_wi': np.float64(197.22478512546363), 'min_wi': np.float64(1.5673116256421677e-08)}
actual reward: [0.07823777]
{'gini': np.float64(0.9776411689727987), 'ess': np.float64(194.05410215638526), 'max_wi': np.float64(322.52416032020477), 'min_wi': np.float64(2.14092926723229e-08)}
Cross-validated error: 0.010211172076958965
[I 2025-10-28 00:32:37,284] Trial 17 finished with value: 0.08051659731920405 and parameters: {'lr': 0.0010626919949003994, 'num_epochs': 6, 'batch_size': 64, 'num_neighbors': 3, 'lr_decay': 0.8260328370610548}. Best is trial 17 with value: 0.08051659731920405.


Best trial: 17. Best value: 0.0805166:  95%|█████████▌| 19/20 [13:03<00:41, 41.57s/it]

Train wi info: {'gini': np.float64(0.9760187543450577), 'ess': np.float64(286.2071625090533), 'max_wi': np.float64(401.27750549169195), 'min_wi': np.float64(1.108212178095789e-08)}
actual reward: [0.07910612]
{'gini': np.float64(0.9753833718344049), 'ess': np.float64(184.1559652007926), 'max_wi': np.float64(416.1593594126596), 'min_wi': np.float64(2.005611101825991e-08)}
Cross-validated error: 0.008771668768240292
[I 2025-10-28 00:33:18,965] Trial 18 finished with value: 0.07605118699761479 and parameters: {'lr': 0.0007838291241515988, 'num_epochs': 6, 'batch_size': 64, 'num_neighbors': 5, 'lr_decay': 0.969057498988737}. Best is trial 17 with value: 0.08051659731920405.


Best trial: 17. Best value: 0.0805166: 100%|██████████| 20/20 [13:45<00:00, 41.26s/it]

Train wi info: {'gini': np.float64(0.13459718692174996), 'ess': np.float64(13562.396871452263), 'max_wi': np.float64(3.4197225701756357), 'min_wi': np.float64(0.4012525586365764)}
actual reward: [0.08640222]
{'gini': np.float64(0.13669370559506883), 'ess': np.float64(9062.284363197758), 'max_wi': np.float64(3.4197225701756357), 'min_wi': np.float64(0.4503985062471784)}
Cross-validated error: 0.00810147130095046
[I 2025-10-28 00:34:00,413] Trial 19 finished with value: 0.07376464151615597 and parameters: {'lr': 0.00022140417097077505, 'num_epochs': 7, 'batch_size': 128, 'num_neighbors': 3, 'lr_decay': 0.9296482058187899}. Best is trial 17 with value: 0.08051659731920405.





{'gini': np.float64(0.6875407660566112), 'ess': np.float64(866.7547908092289), 'max_wi': np.float64(125.17520581288665), 'min_wi': np.float64(0.0033300205141579753)}


Unnamed: 0,policy_rewards,ipw,reg_dm,conv_dm,conv_dr,conv_sndr,action_diff_to_real,action_delta,context_diff_to_real,context_delta
0,0.08622184,0.0884,0.08844068,0.09405586,0.08962762,0.08962762,0.92210476,0.0,0.83772226,0.0
15000,0.08710871,0.07826932,0.08687713,0.09178597,0.07893424,0.07893299,0.92695644,0.09435642,0.8433789,0.04041727


In [15]:
dataset_params = dict(
                    n_actions= 500,
                    n_users = 500,
                    emb_dim = 16,
                    # sigma = 0.1,
                    eps = 0.6, # this is the epsilon for the noise in the ground truth policy representation
                    ctr = 0.1
                    )

train_dataset = generate_dataset(dataset_params, seed=50000)

Random Item CTR: 0.0705882181025533
Optimal greedy CTR: 0.09999934164533562
Optimal Stochastic CTR: 0.09995498601895662
Our Initial CTR: 0.08647501952799874


In [16]:
# Run the optimization
df9, best_hyperparams_by_size = trainer_trial(num_runs, num_neighbors, num_rounds_list, train_dataset, batch_size, val_size=10000, n_trials=n_trials_for_optuna, prev_best_params=best_params_to_use)

# Show the performance metrics
df9[['policy_rewards', 'ipw', 'reg_dm', 'conv_dm', 'conv_dr', 'conv_sndr', 'action_diff_to_real', 'action_delta', 'context_diff_to_real', 'context_delta']]

{'gini': np.float64(0.48346307169006664), 'ess': np.float64(4282.643856377782), 'max_wi': np.float64(21.56110254536132), 'min_wi': np.float64(0.010593760814322993)}


[I 2025-10-28 00:35:08,682] A new study created in memory with name: no-name-a60aae64-b943-4cb3-b524-e468919fbb37
Best trial: 0. Best value: 0.0757378:   5%|▌         | 1/20 [00:41<13:10, 41.60s/it]

Train wi info: {'gini': np.float64(0.9992959879929818), 'ess': np.float64(11.584828777106043), 'max_wi': np.float64(4139.250556300044), 'min_wi': np.float64(0.0)}
actual reward: [0.07880265]
{'gini': np.float64(0.9994971649188609), 'ess': np.float64(5.433174965276248), 'max_wi': np.float64(3439.698557222427), 'min_wi': np.float64(0.0)}
Cross-validated error: 0.006809658573854237
[I 2025-10-28 00:35:50,283] Trial 0 finished with value: 0.07573778353316493 and parameters: {'lr': 0.096, 'num_epochs': 5, 'batch_size': 64, 'num_neighbors': 8, 'lr_decay': 0.85}. Best is trial 0 with value: 0.07573778353316493.


Best trial: 0. Best value: 0.0757378:  10%|█         | 2/20 [01:20<12:03, 40.18s/it]

Train wi info: {'gini': np.float64(0.3669867948930161), 'ess': np.float64(8058.104836130282), 'max_wi': np.float64(9.36129950207573), 'min_wi': np.float64(0.08691758279234292)}
actual reward: [0.08721718]
{'gini': np.float64(0.3639490799022447), 'ess': np.float64(5581.674117391761), 'max_wi': np.float64(9.36129950207573), 'min_wi': np.float64(0.09693829087066094)}
Cross-validated error: 0.008202672137022157
[I 2025-10-28 00:36:29,465] Trial 1 finished with value: 0.07416106894605753 and parameters: {'lr': 0.004284887885089372, 'num_epochs': 1, 'batch_size': 256, 'num_neighbors': 6, 'lr_decay': 0.9384921659739343}. Best is trial 0 with value: 0.07573778353316493.


Best trial: 0. Best value: 0.0757378:  15%|█▌        | 3/20 [02:01<11:30, 40.60s/it]

Train wi info: {'gini': np.float64(0.057881124062737474), 'ess': np.float64(14763.631190089971), 'max_wi': np.float64(1.9711921683646618), 'min_wi': np.float64(0.7417091487274633)}
actual reward: [0.08657274]
{'gini': np.float64(0.05844398560961041), 'ess': np.float64(9843.971697741981), 'max_wi': np.float64(1.9018827615476683), 'min_wi': np.float64(0.7417091487274633)}
Cross-validated error: 0.008093481852633796
[I 2025-10-28 00:37:10,565] Trial 2 finished with value: 0.07376526370687941 and parameters: {'lr': 0.00015468216511673133, 'num_epochs': 5, 'batch_size': 128, 'num_neighbors': 11, 'lr_decay': 0.94864987895247}. Best is trial 0 with value: 0.07573778353316493.


Best trial: 0. Best value: 0.0757378:  20%|██        | 4/20 [02:42<10:50, 40.65s/it]

Train wi info: {'gini': np.float64(0.9994895198901006), 'ess': np.float64(8.376442247795127), 'max_wi': np.float64(3480.522595434665), 'min_wi': np.float64(0.0)}
actual reward: [0.07412193]
{'gini': np.float64(0.9997005211033401), 'ess': np.float64(3.1099192199035244), 'max_wi': np.float64(3439.6977371343846), 'min_wi': np.float64(0.0)}
Cross-validated error: 0.006322969723965644
[I 2025-10-28 00:37:51,299] Trial 3 finished with value: 0.06861101935526338 and parameters: {'lr': 0.03971788848717153, 'num_epochs': 4, 'batch_size': 64, 'num_neighbors': 6, 'lr_decay': 0.805593444007512}. Best is trial 0 with value: 0.07573778353316493.


Best trial: 0. Best value: 0.0757378:  25%|██▌       | 5/20 [03:24<10:16, 41.10s/it]

Train wi info: {'gini': np.float64(0.10573997633388886), 'ess': np.float64(14251.983099400013), 'max_wi': np.float64(2.6012870440838833), 'min_wi': np.float64(0.5622386912207865)}
actual reward: [0.08664865]
{'gini': np.float64(0.10585734718197871), 'ess': np.float64(9517.591631325635), 'max_wi': np.float64(2.6012870440838833), 'min_wi': np.float64(0.5778050270708418)}
Cross-validated error: 0.00812590739215284
[I 2025-10-28 00:38:33,189] Trial 4 finished with value: 0.0738777767354435 and parameters: {'lr': 0.00039095653296688463, 'num_epochs': 8, 'batch_size': 256, 'num_neighbors': 12, 'lr_decay': 0.8365917663487674}. Best is trial 0 with value: 0.07573778353316493.


Best trial: 5. Best value: 0.0790723:  30%|███       | 6/20 [04:06<09:41, 41.52s/it]

Train wi info: {'gini': np.float64(0.9895552070364592), 'ess': np.float64(47.09923023722296), 'max_wi': np.float64(1289.9957176441565), 'min_wi': np.float64(1.5945526645987732e-17)}
actual reward: [0.08897865]
{'gini': np.float64(0.988338884633349), 'ess': np.float64(38.050781098725096), 'max_wi': np.float64(1023.3010885036053), 'min_wi': np.float64(2.7855656115473464e-17)}
Cross-validated error: 0.009704631177033213
[I 2025-10-28 00:39:15,519] Trial 5 finished with value: 0.07907226475134486 and parameters: {'lr': 0.004526409646968052, 'num_epochs': 9, 'batch_size': 128, 'num_neighbors': 8, 'lr_decay': 0.8102790164031104}. Best is trial 5 with value: 0.07907226475134486.


Best trial: 5. Best value: 0.0790723:  35%|███▌      | 7/20 [04:49<09:04, 41.87s/it]

Train wi info: {'gini': np.float64(0.935811066738094), 'ess': np.float64(949.6046218781102), 'max_wi': np.float64(33.648859482285026), 'min_wi': np.float64(8.164581593534547e-07)}
actual reward: [0.09068954]
{'gini': np.float64(0.9278559014294333), 'ess': np.float64(699.0844592709083), 'max_wi': np.float64(59.59821515859241), 'min_wi': np.float64(8.164581593534547e-07)}
Cross-validated error: 0.008930432181168325
[I 2025-10-28 00:39:58,122] Trial 6 finished with value: 0.0765934892876214 and parameters: {'lr': 0.0036973037939826347, 'num_epochs': 10, 'batch_size': 512, 'num_neighbors': 11, 'lr_decay': 0.8677139052428515}. Best is trial 5 with value: 0.07907226475134486.


Best trial: 5. Best value: 0.0790723:  40%|████      | 8/20 [05:30<08:19, 41.66s/it]

Train wi info: {'gini': np.float64(0.9993539018865217), 'ess': np.float64(10.612403596666821), 'max_wi': np.float64(2698.3528126238584), 'min_wi': np.float64(3.354567770514372e-28)}
actual reward: [0.08373366]
{'gini': np.float64(0.9995823986199438), 'ess': np.float64(3.4833226404689377), 'max_wi': np.float64(1767.748740270187), 'min_wi': np.float64(3.354567770514372e-28)}
Cross-validated error: 0.009561908718234478
[I 2025-10-28 00:40:39,327] Trial 7 finished with value: 0.07858893069166609 and parameters: {'lr': 0.011334737794862614, 'num_epochs': 7, 'batch_size': 128, 'num_neighbors': 11, 'lr_decay': 0.882639487750813}. Best is trial 5 with value: 0.07907226475134486.


Best trial: 8. Best value: 0.125542:  45%|████▌     | 9/20 [06:12<07:40, 41.87s/it] 

Train wi info: {'gini': np.float64(0.9998240040222142), 'ess': np.float64(2.7476043938280785), 'max_wi': np.float64(3338.670601015153), 'min_wi': np.float64(0.0)}
actual reward: [0.07601398]
{'gini': np.float64(0.9996624740951431), 'ess': np.float64(3.6425613596776856), 'max_wi': np.float64(3393.0099983302457), 'min_wi': np.float64(0.0)}
Cross-validated error: 0.05935819838078455
[I 2025-10-28 00:41:21,667] Trial 8 finished with value: 0.1255423126235649 and parameters: {'lr': 0.019962792339869546, 'num_epochs': 9, 'batch_size': 128, 'num_neighbors': 3, 'lr_decay': 0.9981727019298849}. Best is trial 8 with value: 0.1255423126235649.


Best trial: 8. Best value: 0.125542:  50%|█████     | 10/20 [06:51<06:48, 40.85s/it]

Train wi info: {'gini': np.float64(0.02988263337376717), 'ess': np.float64(14941.057466891712), 'max_wi': np.float64(1.4577717524603682), 'min_wi': np.float64(0.8587533823921971)}
actual reward: [0.08652329]
{'gini': np.float64(0.030067640892557992), 'ess': np.float64(9961.213175590576), 'max_wi': np.float64(1.4286360676885788), 'min_wi': np.float64(0.8587533823921971)}
Cross-validated error: 0.008084587308026445
[I 2025-10-28 00:42:00,232] Trial 9 finished with value: 0.07375384107847219 and parameters: {'lr': 0.00037235474486625486, 'num_epochs': 1, 'batch_size': 128, 'num_neighbors': 10, 'lr_decay': 0.9682559971128273}. Best is trial 8 with value: 0.1255423126235649.


Best trial: 8. Best value: 0.125542:  55%|█████▌    | 11/20 [07:32<06:08, 40.90s/it]

Train wi info: {'gini': np.float64(0.9995617200162231), 'ess': np.float64(6.403061557932268), 'max_wi': np.float64(2396.325856090774), 'min_wi': np.float64(2.405178316465982e-38)}
actual reward: [0.08493112]
{'gini': np.float64(0.9994374661681499), 'ess': np.float64(4.715046508963767), 'max_wi': np.float64(7125.848422911994), 'min_wi': np.float64(5.003719230384735e-38)}
Cross-validated error: 0.010271556761021773
[I 2025-10-28 00:42:41,240] Trial 10 finished with value: 0.08068411101527984 and parameters: {'lr': 0.021005390578878062, 'num_epochs': 7, 'batch_size': 512, 'num_neighbors': 3, 'lr_decay': 0.9968529762863917}. Best is trial 8 with value: 0.1255423126235649.


Best trial: 8. Best value: 0.125542:  60%|██████    | 12/20 [08:13<05:28, 41.03s/it]

Train wi info: {'gini': np.float64(0.9996278965868041), 'ess': np.float64(5.65595690791545), 'max_wi': np.float64(7166.459401093727), 'min_wi': np.float64(3.527147255712749e-33)}
actual reward: [0.08257459]
{'gini': np.float64(0.9997599261800706), 'ess': np.float64(1.890561178816901), 'max_wi': np.float64(26254.15425403589), 'min_wi': np.float64(3.0261921154734124e-33)}
Cross-validated error: 0.008167051883942806
[I 2025-10-28 00:43:22,576] Trial 11 finished with value: 0.07398861438573537 and parameters: {'lr': 0.02191653562937588, 'num_epochs': 7, 'batch_size': 512, 'num_neighbors': 3, 'lr_decay': 0.9990584858557914}. Best is trial 8 with value: 0.1255423126235649.


Best trial: 8. Best value: 0.125542:  65%|██████▌   | 13/20 [08:56<04:50, 41.55s/it]

Train wi info: {'gini': np.float64(0.9995344823266296), 'ess': np.float64(7.203840540587447), 'max_wi': np.float64(1581.3075527287274), 'min_wi': np.float64(3.9509116487048906e-30)}
actual reward: [0.08359005]
{'gini': np.float64(0.9997271950309167), 'ess': np.float64(1.8207047727499222), 'max_wi': np.float64(28385.96721130266), 'min_wi': np.float64(3.9509116487048906e-30)}
Cross-validated error: 0.01028351426959734
[I 2025-10-28 00:44:05,320] Trial 12 finished with value: 0.08083330968944513 and parameters: {'lr': 0.013917018370178315, 'num_epochs': 10, 'batch_size': 512, 'num_neighbors': 15, 'lr_decay': 0.9975359665569817}. Best is trial 8 with value: 0.1255423126235649.


Best trial: 8. Best value: 0.125542:  70%|███████   | 14/20 [09:39<04:11, 41.87s/it]

Train wi info: {'gini': np.float64(0.49929606451964686), 'ess': np.float64(5801.504438443597), 'max_wi': np.float64(9.64844894062166), 'min_wi': np.float64(0.041974754994951666)}
actual reward: [0.08770099]
{'gini': np.float64(0.4900429169442718), 'ess': np.float64(4169.638044191465), 'max_wi': np.float64(9.64844894062166), 'min_wi': np.float64(0.03926833703715242)}
Cross-validated error: 0.008239301097658516
[I 2025-10-28 00:44:47,931] Trial 13 finished with value: 0.07430261473708547 and parameters: {'lr': 0.0012812972637397516, 'num_epochs': 10, 'batch_size': 512, 'num_neighbors': 15, 'lr_decay': 0.9210931645523677}. Best is trial 8 with value: 0.1255423126235649.


Best trial: 8. Best value: 0.125542:  75%|███████▌  | 15/20 [10:22<03:31, 42.27s/it]

Train wi info: {'gini': np.float64(0.9997898129019025), 'ess': np.float64(3.48573325798536), 'max_wi': np.float64(3081.0691696338904), 'min_wi': np.float64(0.0)}
actual reward: [0.07524929]
{'gini': np.float64(0.9996885889850775), 'ess': np.float64(3.4125666071396568), 'max_wi': np.float64(3393.020110333836), 'min_wi': np.float64(0.0)}
Cross-validated error: 0.019874582589839363
[I 2025-10-28 00:45:31,141] Trial 14 finished with value: 0.10085845445598315 and parameters: {'lr': 0.0765954515628084, 'num_epochs': 9, 'batch_size': 128, 'num_neighbors': 15, 'lr_decay': 0.9763743475305927}. Best is trial 8 with value: 0.1255423126235649.


Best trial: 8. Best value: 0.125542:  80%|████████  | 16/20 [11:02<02:46, 41.66s/it]

Train wi info: {'gini': np.float64(0.9995937143512388), 'ess': np.float64(6.031358045066096), 'max_wi': np.float64(2844.888008629834), 'min_wi': np.float64(0.0)}
actual reward: [0.07563382]
{'gini': np.float64(0.9992795578351352), 'ess': np.float64(7.519167376924642), 'max_wi': np.float64(3666.937366922634), 'min_wi': np.float64(0.0)}
Cross-validated error: 0.007355601521696173
[I 2025-10-28 00:46:11,360] Trial 15 finished with value: 0.07123999835295777 and parameters: {'lr': 0.08475770585365226, 'num_epochs': 3, 'batch_size': 128, 'num_neighbors': 5, 'lr_decay': 0.9655251804166723}. Best is trial 8 with value: 0.1255423126235649.


Best trial: 8. Best value: 0.125542:  85%|████████▌ | 17/20 [11:45<02:05, 41.90s/it]

Train wi info: {'gini': np.float64(0.9998088045591834), 'ess': np.float64(2.930080001942311), 'max_wi': np.float64(1931.8148279826214), 'min_wi': np.float64(0.0)}
actual reward: [0.07073557]
{'gini': np.float64(0.9997656039320935), 'ess': np.float64(2.448777875403432), 'max_wi': np.float64(5095.286080560839), 'min_wi': np.float64(0.0)}
Cross-validated error: 0.005197713136560884
[I 2025-10-28 00:46:53,836] Trial 16 finished with value: 0.06406419520558879 and parameters: {'lr': 0.05678279053202348, 'num_epochs': 8, 'batch_size': 128, 'num_neighbors': 13, 'lr_decay': 0.9078387175818443}. Best is trial 8 with value: 0.1255423126235649.


Best trial: 8. Best value: 0.125542:  90%|█████████ | 18/20 [12:27<01:24, 42.04s/it]

Train wi info: {'gini': np.float64(0.9994553612710136), 'ess': np.float64(8.894481725586088), 'max_wi': np.float64(4139.239700683446), 'min_wi': np.float64(0.0)}
actual reward: [0.07771075]
{'gini': np.float64(0.9996214922185778), 'ess': np.float64(4.172393906397507), 'max_wi': np.float64(3301.042376450184), 'min_wi': np.float64(0.0)}
Cross-validated error: 0.007150004642948726
[I 2025-10-28 00:47:36,195] Trial 17 finished with value: 0.06842290634096171 and parameters: {'lr': 0.035334531196929496, 'num_epochs': 8, 'batch_size': 128, 'num_neighbors': 13, 'lr_decay': 0.9727386232170139}. Best is trial 8 with value: 0.1255423126235649.


Best trial: 8. Best value: 0.125542:  95%|█████████▌| 19/20 [13:08<00:41, 41.78s/it]

Train wi info: {'gini': np.float64(0.9995665905650811), 'ess': np.float64(7.2599986968445105), 'max_wi': np.float64(1953.9606705833335), 'min_wi': np.float64(6.915787852514309e-29)}
actual reward: [0.08130133]
{'gini': np.float64(0.9995333412293267), 'ess': np.float64(5.0904239773481175), 'max_wi': np.float64(2921.146967368231), 'min_wi': np.float64(9.237573833884467e-29)}
Cross-validated error: 0.005944900535702796
[I 2025-10-28 00:48:17,383] Trial 18 finished with value: 0.06520743906960992 and parameters: {'lr': 0.010237900630086916, 'num_epochs': 6, 'batch_size': 128, 'num_neighbors': 8, 'lr_decay': 0.9389372936027914}. Best is trial 8 with value: 0.1255423126235649.


Best trial: 8. Best value: 0.125542: 100%|██████████| 20/20 [13:50<00:00, 41.50s/it]

Train wi info: {'gini': np.float64(0.9555948736661486), 'ess': np.float64(679.4494520854925), 'max_wi': np.float64(66.74480736233829), 'min_wi': np.float64(6.305202236406696e-08)}
actual reward: [0.08929322]
{'gini': np.float64(0.9508456492529058), 'ess': np.float64(484.1616772890529), 'max_wi': np.float64(83.04955803682469), 'min_wi': np.float64(6.305202236406696e-08)}
Cross-validated error: 0.00720877371123147
[I 2025-10-28 00:48:58,761] Trial 19 finished with value: 0.07055214562694888 and parameters: {'lr': 0.00158959421333582, 'num_epochs': 9, 'batch_size': 256, 'num_neighbors': 5, 'lr_decay': 0.9816251741348216}. Best is trial 8 with value: 0.1255423126235649.





{'gini': np.float64(0.9996274160207251), 'ess': np.float64(4.076337164453286), 'max_wi': np.float64(2820.2154654927654), 'min_wi': np.float64(3.673717188363759e-35)}


Unnamed: 0,policy_rewards,ipw,reg_dm,conv_dm,conv_dr,conv_sndr,action_diff_to_real,action_delta,context_diff_to_real,context_delta
0,0.08647502,0.0853,0.08531822,0.08859722,0.08369196,0.08369196,0.80232812,0.0,0.84032376,0.0
15000,0.07913758,0.5044776,0.09883096,0.08630665,0.54980022,0.50045389,1.12744744,0.7589929,1.15723686,0.53516479


In [17]:
df4[['policy_rewards', 'ipw', 'reg_dm', 'conv_dm', 'conv_dr', 'conv_sndr', 'action_diff_to_real', 'action_delta', 'context_diff_to_real', 'context_delta']]

Unnamed: 0,policy_rewards,ipw,reg_dm,conv_dm,conv_dr,conv_sndr,action_diff_to_real,action_delta,context_diff_to_real,context_delta
0,0.08610747,0.0879,0.08802531,0.09188019,0.08858942,0.08858942,0.7569287,0.0,0.87627132,0.0
15000,0.08516441,0.08017263,0.08906158,0.09463238,0.0795643,0.07993675,0.82602002,0.24147136,0.93809648,0.1280313


### Poicy Via argmax(r_hat - error_hat) through cross validation

In [18]:
df4[['policy_rewards', 'ipw', 'reg_dm', 'conv_dm', 'conv_dr', 'conv_sndr', 'action_diff_to_real', 'action_delta', 'context_diff_to_real', 'context_delta']]

Unnamed: 0,policy_rewards,ipw,reg_dm,conv_dm,conv_dr,conv_sndr,action_diff_to_real,action_delta,context_diff_to_real,context_delta
0,0.08610747,0.0879,0.08802531,0.09188019,0.08858942,0.08858942,0.7569287,0.0,0.87627132,0.0
15000,0.08516441,0.08017263,0.08906158,0.09463238,0.0795643,0.07993675,0.82602002,0.24147136,0.93809648,0.1280313


### Policy Via using actual policy value

In [19]:
# Show the performance metrics
df4[['policy_rewards', 'ipw', 'reg_dm', 'conv_dm', 'conv_dr', 'conv_sndr', 'action_diff_to_real', 'action_delta', 'context_diff_to_real', 'context_delta']]


Unnamed: 0,policy_rewards,ipw,reg_dm,conv_dm,conv_dr,conv_sndr,action_diff_to_real,action_delta,context_diff_to_real,context_delta
0,0.08610747,0.0879,0.08802531,0.09188019,0.08858942,0.08858942,0.7569287,0.0,0.87627132,0.0
15000,0.08516441,0.08017263,0.08906158,0.09463238,0.0795643,0.07993675,0.82602002,0.24147136,0.93809648,0.1280313
