In [1]:
import warnings
warnings.filterwarnings("ignore")
from copy import deepcopy
from datetime import datetime
from pathlib import Path

import numpy as np
import pandas as pd
import sys

sys.path.append("/code")

from tqdm import tqdm
import torch
# device = torch.device('cpu')
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
# import gym
# import recogym

import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
import torch.optim as optim

torch.backends.cudnn.benchmark = torch.cuda.is_available()
if torch.cuda.is_available():
    torch.set_float32_matmul_precision("high")  # TF32 = big speedup on Ada


from sklearn.utils import check_random_state

# implementing OPE of the IPWLearner using synthetic bandit data
from sklearn.linear_model import LogisticRegression

import matplotlib.pyplot as plt

from scipy.special import softmax
import optuna
# from memory_profiler import profile


from estimators import (
    DirectMethod as DM
)

from simulation_utils import (
    eval_policy,
    generate_dataset,
    create_simulation_data_from_pi,
    get_train_data,
    get_opl_results_dict,
    CustomCFDataset,
    calc_reward,
    get_weights_info
)

from models import (    
    CFModel,
    NeighborhoodModel,
    BPRModel, 
    RegressionModel
)

from training_utils import (
    train,
    validation_loop, 
    cv_score_model
 )

from custom_losses import (
    SNDRPolicyLoss
    )

random_state=12345
random_ = check_random_state(random_state)

pd.options.display.float_format = '{:,.8f}'.format

Using device: cpu
Using device: cpu
Using device: cpu


In [2]:
def get_trial_results(
    our_x, 
    our_a, 
    emb_x, 
    emb_a, 
    original_x, 
    original_a, 
    dataset, 
    val_data, 
    original_policy_prob, 
    neighberhoodmodel, 
    regression_model, 
    dm
):
    policy = np.expand_dims(softmax(our_x @ our_a.T, axis=1), -1)
    policy_reward = calc_reward(dataset, policy)
    eval_metrics = eval_policy(neighberhoodmodel, val_data, original_policy_prob, policy)
    action_diff_to_real = np.sqrt(np.mean((emb_a - our_a) ** 2))
    action_delta = np.sqrt(np.mean((original_a - our_a) ** 2))
    context_diff_to_real = np.sqrt(np.mean((emb_x - our_x) ** 2))
    context_delta = np.sqrt(np.mean((original_x - our_x) ** 2))

    row = np.concatenate([
        np.atleast_1d(policy_reward),
        np.atleast_1d(eval_metrics),
        np.atleast_1d(action_diff_to_real),
        np.atleast_1d(action_delta),
        np.atleast_1d(context_diff_to_real),
        np.atleast_1d(context_delta)
    ])
    reg_dm = dm.estimate_policy_value(policy[val_data['x_idx']], regression_model.predict(val_data['x']))
    reg_results = np.array([reg_dm])
    conv_results = np.array([row])
    return get_opl_results_dict(reg_results, conv_results)

## `trainer_trial` Function

This function runs policy learning experiments using offline bandit data and evaluates various estimators.

### Parameters
- **num_runs** (int): Number of experimental runs per training size
- **num_neighbors** (int): Number of neighbors to consider in the neighborhood model
- **num_rounds_list** (list): List of training set sizes to evaluate
- **dataset** (dict): Contains dataset information including embeddings, action probabilities, and reward probabilities
- **batch_size** (int): Batch size for training the policy model
- **num_epochs** (int): Number of training epochs for each experiment
- **lr** (float, default=0.001): Learning rate for the optimizer

### Process Flow
1. Initializes result structures and retrieval models
2. For each training size in `num_rounds_list`:
   - Creates a uniform logging policy and simulates data
   - Generates training data for offline learning
   - Fits regression and neighborhood models for reward estimation
   - Initializes and trains a counterfactual policy model
   - Evaluates policy performance using various estimators
   - Collects metrics on policy reward and embedding quality

### Returns
- **DataFrame**: Results table with rows indexed by training size and columns for various metrics:
  - `policy_rewards`: True expected reward of the learned policy
  - Various estimator errors (`ipw`, `reg_dm`, `conv_dm`, `conv_dr`, `conv_sndr`)
  - Variance metrics for each estimator
  - Embedding quality metrics comparing learned representations to ground truth

### Implementation Notes
- Uses uniform random logging policy for collecting offline data
- Employs Self-Normalized Doubly Robust (SNDR) policy learning
- Measures embedding quality via RMSE to original/ground truth embeddings

In [None]:
def trainer_trial(
    num_runs,
    num_neighbors,
    num_rounds_list,
    dataset,
    batch_size,
    val_size=2000,
    n_trials=10,    
    prev_best_params=None
):

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    torch.backends.cudnn.benchmark = torch.cuda.is_available()
    if torch.cuda.is_available():
        torch.set_float32_matmul_precision("high")

    dm = DM()
    results = {}

    our_x, our_a = dataset["our_x"], dataset["our_a"]
    emb_x, emb_a = dataset["emb_x"], dataset["emb_a"]

    original_x, original_a = dataset["original_x"], dataset["original_a"]
    n_users, n_actions, emb_dim = dataset["n_users"], dataset["n_actions"], dataset["emb_dim"]

    all_user_indices = np.arange(n_users, dtype=np.int64)

    def T(x):
        return torch.as_tensor(x, device=device, dtype=torch.float32)

    def _mean_dict(dicts):
        """
        Robust mean over a list of dicts with numeric/scalar/1D-array values.
        Returns a single dict with elementwise means.
        """
        if not dicts:
            return {}
        keys = dicts[0].keys()
        out = {}
        for k in keys:
            vals = [d[k] for d in dicts if k in d]
            # try to convert each to np.array and average
            arrs = [np.asarray(v) for v in vals]
            # broadcast to same shape if scalars/1D
            stacked = np.stack(arrs, axis=0)
            out[k] = np.mean(stacked, axis=0)
        return out

    # ===== unpack dataset (keep originals safe) =====
    our_x_orig, our_a_orig = our_x, our_a
    emb_x, emb_a = emb_x, emb_a
    original_x, original_a = original_x, original_a
    n_users, n_actions, emb_dim = n_users, n_actions, emb_dim
    all_user_indices = np.arange(n_users, dtype=np.int64)

    dm = DM()
    results = {}
    best_hyperparams_by_size = {}
    last_best_params = prev_best_params if prev_best_params is not None else None

    # ===== baseline (sample size = 0) using get_trial_results =====
    pi_0 = softmax(our_x_orig @ our_a_orig.T, axis=1)
    original_policy_prob = np.expand_dims(pi_0, -1)

    simulation_data = create_simulation_data_from_pi(
        dataset, pi_0, val_size, random_state=0
    )

    # use same data for train/val just to generate the baseline row
    train_data = get_train_data(n_actions, val_size, simulation_data, np.arange(val_size), our_x_orig)
    val_data   = get_train_data(n_actions, val_size, simulation_data, np.arange(val_size), our_x_orig)

    regression_model = RegressionModel(
        n_actions=n_actions, action_context=our_x_orig,
        base_model=LogisticRegression(random_state=12345)
    )

    regression_model.fit(train_data['x'], train_data['a'], train_data['r'])

    neighberhoodmodel = NeighborhoodModel(
        train_data['x_idx'], train_data['a'],
        our_a_orig, our_x_orig, train_data['r'],
        num_neighbors=num_neighbors
    )

    # baseline row produced via get_trial_results
    results[0] = get_trial_results(
        our_x_orig, our_a_orig, emb_x, emb_a, original_x, original_a,
        dataset, val_data, original_policy_prob,
        neighberhoodmodel, regression_model, dm
    )

    # ===== main loop over training sizes =====
    for train_size in num_rounds_list:

        # we’ll collect per-run trial dicts generated by get_trial_results
        trial_dicts_this_size = []
        best_hyperparams_by_size[train_size] = {}

        # --- prepare a resampling for Optuna’s objective (shared loaders built per-run inside objective) ---
        # We’ll do Optuna per-run (fresh resample + search), then final fit with best params, then get_trial_results.

        for run in range(num_runs):

            # --- resample for this run ---
            pi_0 = softmax(our_x_orig @ our_a_orig.T, axis=1)
            original_policy_prob = np.expand_dims(pi_0, -1)

            simulation_data = create_simulation_data_from_pi(
                dataset, pi_0, train_size + val_size,
                random_state=(run + 1) * (train_size + 17)
            )

            idx_train = np.arange(train_size)
            train_data = get_train_data(n_actions, train_size, simulation_data, idx_train, our_x_orig)
            val_idx   = np.arange(val_size) + train_size
            val_data  = get_train_data(n_actions, val_size, simulation_data, val_idx, our_x_orig)

            num_workers = 4 if torch.cuda.is_available() else 0

            cf_dataset = CustomCFDataset(
                train_data['x_idx'], train_data['a'], train_data['r'], original_policy_prob
            )

            val_dataset = CustomCFDataset(
                val_data['x_idx'], val_data['a'], val_data['r'], original_policy_prob
            )

            # val_loader = DataLoader(
            #     val_dataset, batch_size=val_size, shuffle=False,
            #     pin_memory=torch.cuda.is_available(),
            #     num_workers=num_workers, persistent_workers=bool(num_workers)
            # )


            # --- Optuna objective bound to this run's data ---
            def objective(trial):
                lr = trial.suggest_float("lr", 1e-4, 1e-2, log=True)
                epochs = trial.suggest_int("num_epochs", 1, 10)
                trial_batch_size = trial.suggest_categorical("batch_size", [64, 128, 256, 512])
                trial_num_neighbors = trial.suggest_int("num_neighbors", 3, 15)
                lr_decay = trial.suggest_float("lr_decay", 0.8, 1.0)

                trial_neigh_model = NeighborhoodModel(
                    train_data['x_idx'], train_data['a'],
                    our_a_orig, our_x_orig, train_data['r'],
                    num_neighbors=trial_num_neighbors
                )

                trial_scores_all = torch.as_tensor(
                    trial_neigh_model.predict(all_user_indices),
                    device=device, dtype=torch.float32
                )

                trial_model = CFModel(
                    n_users, n_actions, emb_dim,
                    initial_user_embeddings=T(our_x_orig),
                    initial_actions_embeddings=T(our_a_orig)
                ).to(device)

                assert (not torch.cuda.is_available()) or next(trial_model.parameters()).is_cuda

                final_train_loader = DataLoader(
                    cf_dataset, batch_size=trial_batch_size, shuffle=True,
                    pin_memory=torch.cuda.is_available(),
                    num_workers=num_workers, persistent_workers=bool(num_workers)
                )

                current_lr = lr
                for epoch in range(epochs):
                    if epoch > 0:
                        current_lr *= lr_decay
                        
                    train(
                        trial_model, final_train_loader, trial_scores_all,
                        criterion=SNDRPolicyLoss(), num_epochs=1, lr=current_lr, device=str(device)
                    )

                trial_x, trial_a = trial_model.get_params()
                trial_x = trial_x.detach().cpu().numpy()
                trial_a = trial_a.detach().cpu().numpy()

                pi_i = softmax(trial_x @ trial_a.T, axis=1)

                print(get_weights_info(pi_i, original_policy_prob))
                # validation reward for selection
                return cv_score_model(val_data, trial_scores_all, pi_i)


            # --- run Optuna for this run ---
            study = optuna.create_study(direction="maximize")
            
            if last_best_params is not None:
                study.enqueue_trial(last_best_params)

            study.optimize(objective, n_trials=n_trials, show_progress_bar=True)

            best_params = study.best_params
            last_best_params = best_params  # optional warm-start to next run
            best_hyperparams_by_size[train_size][run] = {
                "params": best_params,
                "reward": study.best_value
            }


            # --- final training with best params on this run’s data ---
            regression_model = RegressionModel(
                n_actions=n_actions, action_context=our_x_orig,
                base_model=LogisticRegression(random_state=12345)
            )
            regression_model.fit(
                train_data['x'], train_data['a'], train_data['r'],
                original_policy_prob[train_data['x_idx'], train_data['a']].squeeze()
            )

            neighberhoodmodel = NeighborhoodModel(
                train_data['x_idx'], train_data['a'],
                our_a_orig, our_x_orig, train_data['r'],
                num_neighbors=best_params['num_neighbors']
            )
            scores_all = torch.as_tensor(
                neighberhoodmodel.predict(all_user_indices),
                device=device, dtype=torch.float32
            )

            model = CFModel(
                n_users, n_actions, emb_dim,
                initial_user_embeddings=T(our_x_orig),
                initial_actions_embeddings=T(our_a_orig)
            ).to(device)
            assert (not torch.cuda.is_available()) or next(model.parameters()).is_cuda

            train_loader = DataLoader(
                cf_dataset, batch_size=batch_size, shuffle=True,
                pin_memory=torch.cuda.is_available(),
                num_workers=num_workers, persistent_workers=bool(num_workers)
            )

            current_lr = best_params['lr']
            for epoch in range(best_params['num_epochs']):
                if epoch > 0:
                    current_lr *= best_params['lr_decay']
                train(
                    model, train_loader, scores_all,
                    criterion=SNDRPolicyLoss(), num_epochs=1, lr=current_lr, device=str(device)
                )

            # learned embeddings (do NOT overwrite originals)
            learned_x_t, learned_a_t = model.get_params()
            learned_x = learned_x_t.detach().cpu().numpy()
            learned_a = learned_a_t.detach().cpu().numpy()

            # --- produce the per-run result via get_trial_results ---
            trial_res = get_trial_results(
                learned_x, learned_a,          # learned (policy) embeddings
                emb_x, emb_a,                  # ground-truth embedding refs
                original_x, original_a,        # original clean refs
                dataset,
                val_data,                      # use this run's val split
                original_policy_prob,
                neighberhoodmodel,
                regression_model,
                dm
            )

            trial_dicts_this_size.append(trial_res)

            # memory hygiene
            torch.cuda.empty_cache()

        # === aggregate per-run results (mean) and store under this train_size ===
        results[train_size] = _mean_dict(trial_dicts_this_size)

    return pd.DataFrame.from_dict(results, orient='index'), best_hyperparams_by_size

## Learning

We will run several simulations on a generated dataset, the dataset is generated like this:
$$ \text{We have users U and actions A } u_i \sim N(0, I_{emb_dim}) \ a_i \sim N(0, I_{emb_dim})$$
$$ p_{ij} = 1 / (5 + e^{-(u_i.T a_j)}) $$
$$r_{ij} \sim Bin(p_{ij})$$

We have a policy $\pi$
and it's ground truth reward is calculated by
$$R_{gt} = \sum_{i}{\sum_{j}{\pi_{ij} * p_{ij}}} $$

Our parameters for the dataset will be
$$EmbDim = 5$$
$$NumActions= 150$$
$$NumUsers = 150$$
$$NeighborhoodSize = 6$$

to learn a new policy from $\pi$ we will sample from:
$$\pi_{start} = (1-\epsilon)*\pi + \epsilon * \pi_{random}$$

In [4]:
dataset_params = dict(
                    n_actions= 500,
                    n_users = 500,
                    emb_dim = 16,
                    # sigma = 0.1,
                    eps = 0.6, # this is the epsilon for the noise in the ground truth policy representation
                    ctr = 0.1
                    )

train_dataset = generate_dataset(dataset_params)

Random Item CTR: 0.07066414727263938
Optimal greedy CTR: 0.09999926940951757
Optimal Stochastic CTR: 0.09995326955796031
Our Initial CTR: 0.08610747363354625


In [5]:
num_runs = 1
batch_size = 200
num_neighbors = 6
n_trials_for_optuna = 50
num_rounds_list = [500, 1000, 2000, 10000, 20000]
# num_rounds_list = [20000]


# Manually define your best parameters
best_params_to_use = {
    "lr": 0.0095,  # Learning rate
    "num_epochs": 5,  # Number of training epochs
    "batch_size": 64,  # Batch size for training
    "num_neighbors": 8,  # Number of neighbors for neighborhood model
    "lr_decay": 0.85  # Learning rate decay factor
}

### 1

$$emb = 0.7 * gt + 0.3 * noise$$
$$lr = 0.005$$
$$n_{epochs} = 1$$
$$BatchSize=50$$

In [6]:
print("Value of num_rounds_list:", num_rounds_list)

# Run the optimization
df4, best_hyperparams_by_size = trainer_trial(num_runs, num_neighbors, num_rounds_list, train_dataset, batch_size, val_size=10000, n_trials=n_trials_for_optuna, prev_best_params=best_params_to_use)

# Print best hyperparameters for each training size
print("\n=== BEST HYPERPARAMETERS BY TRAINING SIZE ===")
for train_size, params in best_hyperparams_by_size.items():
    print(f"\nTraining Size: {train_size}")
    # print(f"Best Reward: {params['reward']:.6f}")
    print("Parameters:")
    for param_name, value in params['params'].items():
        print(f"  {param_name}: {value}")
print("===========================\n")

# Show the performance metrics
df4[['policy_rewards', 'ipw', 'reg_dm', 'conv_dm', 'conv_dr', 'conv_sndr', 'action_diff_to_real', 'action_delta', 'context_diff_to_real', 'context_delta']]

Value of num_rounds_list: [500, 1000, 2000, 10000, 20000]


[I 2025-10-01 15:30:21,228] A new study created in memory with name: no-name-b1c182ce-86bb-4c4a-a476-5a162cf51afe
Best trial: 0. Best value: -0.00865103:   2%|▏         | 1/50 [00:02<01:52,  2.29s/it]

{'gini': np.float64(0.40408135351240027), 'ess': np.float64(165822.8754548929), 'max_wi': np.float64(6.6850931907224105), 'min_wi': np.float64(0.012997558145085754)}
[I 2025-10-01 15:30:23,516] Trial 0 finished with value: -0.008651025705584873 and parameters: {'lr': 0.0095, 'num_epochs': 5, 'batch_size': 64, 'num_neighbors': 8, 'lr_decay': 0.85}. Best is trial 0 with value: -0.008651025705584873.


Best trial: 1. Best value: -0.00830631:   4%|▍         | 2/50 [00:03<01:30,  1.88s/it]

{'gini': np.float64(0.009550868274245028), 'ess': np.float64(249928.04691040178), 'max_wi': np.float64(1.0463413835121627), 'min_wi': np.float64(0.9432234036870707)}
[I 2025-10-01 15:30:25,107] Trial 1 finished with value: -0.008306305473049015 and parameters: {'lr': 0.0003492101076594037, 'num_epochs': 3, 'batch_size': 128, 'num_neighbors': 13, 'lr_decay': 0.9921847299109916}. Best is trial 1 with value: -0.008306305473049015.


Best trial: 1. Best value: -0.00830631:   6%|▌         | 3/50 [00:05<01:22,  1.75s/it]

{'gini': np.float64(0.2810461598458946), 'ess': np.float64(201431.56187775044), 'max_wi': np.float64(3.9347998970584532), 'min_wi': np.float64(0.10695843642185762)}
[I 2025-10-01 15:30:26,701] Trial 2 finished with value: -0.008778997760828673 and parameters: {'lr': 0.0047237651051862724, 'num_epochs': 5, 'batch_size': 64, 'num_neighbors': 6, 'lr_decay': 0.9371375875369643}. Best is trial 1 with value: -0.008306305473049015.


Best trial: 1. Best value: -0.00830631:   8%|▊         | 4/50 [00:07<01:18,  1.71s/it]

{'gini': np.float64(0.06190253595218667), 'ess': np.float64(247069.20919380258), 'max_wi': np.float64(1.370503699137448), 'min_wi': np.float64(0.6860749619442146)}
[I 2025-10-01 15:30:28,365] Trial 3 finished with value: -0.008335498572771723 and parameters: {'lr': 0.0014021547187254787, 'num_epochs': 3, 'batch_size': 64, 'num_neighbors': 13, 'lr_decay': 0.941410494998372}. Best is trial 1 with value: -0.008306305473049015.


Best trial: 1. Best value: -0.00830631:  10%|█         | 5/50 [00:08<01:15,  1.68s/it]

{'gini': np.float64(0.0015992662699963812), 'ess': np.float64(249997.90653245358), 'max_wi': np.float64(1.0065715454666753), 'min_wi': np.float64(0.9908295655828836)}
[I 2025-10-01 15:30:29,976] Trial 4 finished with value: -0.008518659426660673 and parameters: {'lr': 0.00016373991494485015, 'num_epochs': 3, 'batch_size': 512, 'num_neighbors': 5, 'lr_decay': 0.9193414861989612}. Best is trial 1 with value: -0.008306305473049015.


Best trial: 5. Best value: -0.00829023:  12%|█▏        | 6/50 [00:10<01:13,  1.67s/it]

{'gini': np.float64(0.005977617056358519), 'ess': np.float64(249970.78120838874), 'max_wi': np.float64(1.0247484937912767), 'min_wi': np.float64(0.9660362403397448)}
[I 2025-10-01 15:30:31,624] Trial 5 finished with value: -0.008290231227185519 and parameters: {'lr': 0.0003331166809194788, 'num_epochs': 8, 'batch_size': 512, 'num_neighbors': 12, 'lr_decay': 0.8649479272041829}. Best is trial 5 with value: -0.008290231227185519.


Best trial: 5. Best value: -0.00829023:  14%|█▍        | 7/50 [00:12<01:11,  1.65s/it]

{'gini': np.float64(0.02475003206153017), 'ess': np.float64(249509.93225722923), 'max_wi': np.float64(1.127181546895956), 'min_wi': np.float64(0.8596019955465549)}
[I 2025-10-01 15:30:33,247] Trial 6 finished with value: -0.008558195255526941 and parameters: {'lr': 0.0022681893135742498, 'num_epochs': 2, 'batch_size': 256, 'num_neighbors': 9, 'lr_decay': 0.9334097204313898}. Best is trial 5 with value: -0.008290231227185519.


Best trial: 5. Best value: -0.00829023:  16%|█▌        | 8/50 [00:13<01:08,  1.64s/it]

{'gini': np.float64(0.055277421086979674), 'ess': np.float64(247548.1942747799), 'max_wi': np.float64(1.2523942165803776), 'min_wi': np.float64(0.715810741324646)}
[I 2025-10-01 15:30:34,854] Trial 7 finished with value: -0.008856590173805664 and parameters: {'lr': 0.00448839124426945, 'num_epochs': 4, 'batch_size': 512, 'num_neighbors': 4, 'lr_decay': 0.9234970435509054}. Best is trial 5 with value: -0.008290231227185519.


Best trial: 5. Best value: -0.00829023:  18%|█▊        | 9/50 [00:15<01:07,  1.66s/it]

{'gini': np.float64(0.0037538981634174717), 'ess': np.float64(249988.81822557014), 'max_wi': np.float64(1.0174321268677016), 'min_wi': np.float64(0.9762503654451352)}
[I 2025-10-01 15:30:36,553] Trial 8 finished with value: -0.008636533272981608 and parameters: {'lr': 0.00018959871817723908, 'num_epochs': 5, 'batch_size': 256, 'num_neighbors': 7, 'lr_decay': 0.8317327716404882}. Best is trial 5 with value: -0.008290231227185519.


Best trial: 9. Best value: -0.008286:  20%|██        | 10/50 [00:16<01:05,  1.64s/it] 

{'gini': np.float64(0.0010629286237055131), 'ess': np.float64(249999.07505826798), 'max_wi': np.float64(1.0043615961686618), 'min_wi': np.float64(0.9939022972126248)}
[I 2025-10-01 15:30:38,150] Trial 9 finished with value: -0.00828599628365239 and parameters: {'lr': 0.0001668903770927532, 'num_epochs': 2, 'batch_size': 512, 'num_neighbors': 12, 'lr_decay': 0.8012721736260111}. Best is trial 9 with value: -0.00828599628365239.


Best trial: 10. Best value: -0.00816584:  22%|██▏       | 11/50 [00:18<01:03,  1.63s/it]

{'gini': np.float64(0.005594980143018311), 'ess': np.float64(249974.36297044568), 'max_wi': np.float64(1.0277463457779874), 'min_wi': np.float64(0.9672636486212093)}
[I 2025-10-01 15:30:39,747] Trial 10 finished with value: -0.00816584277780242 and parameters: {'lr': 0.0005792398201366767, 'num_epochs': 1, 'batch_size': 128, 'num_neighbors': 15, 'lr_decay': 0.8039639225036387}. Best is trial 10 with value: -0.00816584277780242.


Best trial: 10. Best value: -0.00816584:  24%|██▍       | 12/50 [00:20<01:02,  1.63s/it]

{'gini': np.float64(0.006202463437346538), 'ess': np.float64(249968.2306582709), 'max_wi': np.float64(1.0285165801840424), 'min_wi': np.float64(0.962929551042903)}
[I 2025-10-01 15:30:41,393] Trial 11 finished with value: -0.00817004123053909 and parameters: {'lr': 0.0006398264323558589, 'num_epochs': 1, 'batch_size': 128, 'num_neighbors': 15, 'lr_decay': 0.800664103939745}. Best is trial 10 with value: -0.00816584277780242.


Best trial: 10. Best value: -0.00816584:  26%|██▌       | 13/50 [00:21<00:59,  1.62s/it]

{'gini': np.float64(0.0061298501510206075), 'ess': np.float64(249969.34097515236), 'max_wi': np.float64(1.0331520406663735), 'min_wi': np.float64(0.9602585664106824)}
[I 2025-10-01 15:30:42,982] Trial 12 finished with value: -0.00816618809482683 and parameters: {'lr': 0.0006458520315022833, 'num_epochs': 1, 'batch_size': 128, 'num_neighbors': 15, 'lr_decay': 0.8056473111442846}. Best is trial 10 with value: -0.00816584277780242.


Best trial: 10. Best value: -0.00816584:  28%|██▊       | 14/50 [00:23<00:58,  1.64s/it]

{'gini': np.float64(0.034623805462395546), 'ess': np.float64(249070.78106060752), 'max_wi': np.float64(1.1763716248400016), 'min_wi': np.float64(0.8229149838753685)}
[I 2025-10-01 15:30:44,657] Trial 13 finished with value: -0.00819454581182358 and parameters: {'lr': 0.0006509515823268328, 'num_epochs': 10, 'batch_size': 128, 'num_neighbors': 15, 'lr_decay': 0.8749159012945495}. Best is trial 10 with value: -0.00816584277780242.


Best trial: 10. Best value: -0.00816584:  30%|███       | 15/50 [00:24<00:56,  1.61s/it]

{'gini': np.float64(0.009996888647664928), 'ess': np.float64(249917.77778081907), 'max_wi': np.float64(1.0458305289580163), 'min_wi': np.float64(0.9378057287179167)}
[I 2025-10-01 15:30:46,213] Trial 14 finished with value: -0.00836972894483524 and parameters: {'lr': 0.0010408089986116318, 'num_epochs': 1, 'batch_size': 128, 'num_neighbors': 11, 'lr_decay': 0.8267444366429995}. Best is trial 10 with value: -0.00816584277780242.


Best trial: 10. Best value: -0.00816584:  32%|███▏      | 16/50 [00:26<00:55,  1.64s/it]

{'gini': np.float64(0.01969213105903264), 'ess': np.float64(249698.02769229948), 'max_wi': np.float64(1.1003876076385366), 'min_wi': np.float64(0.8992017796381105)}
[I 2025-10-01 15:30:47,927] Trial 15 finished with value: -0.00818003116536382 and parameters: {'lr': 0.0004354452209675935, 'num_epochs': 7, 'batch_size': 128, 'num_neighbors': 15, 'lr_decay': 0.8880745349794974}. Best is trial 10 with value: -0.00816584277780242.


Best trial: 10. Best value: -0.00816584:  34%|███▍      | 17/50 [00:28<00:53,  1.63s/it]

{'gini': np.float64(0.014209827792922783), 'ess': np.float64(249835.07531066993), 'max_wi': np.float64(1.0735016890895293), 'min_wi': np.float64(0.920294210978682)}
[I 2025-10-01 15:30:49,512] Trial 16 finished with value: -0.008464224458048966 and parameters: {'lr': 0.0014754028668387085, 'num_epochs': 1, 'batch_size': 128, 'num_neighbors': 10, 'lr_decay': 0.8246778163412122}. Best is trial 10 with value: -0.00816584277780242.


Best trial: 10. Best value: -0.00816584:  36%|███▌      | 18/50 [00:29<00:52,  1.64s/it]

{'gini': np.float64(0.02785700382289883), 'ess': np.float64(249399.61994466686), 'max_wi': np.float64(1.1450001988215035), 'min_wi': np.float64(0.8462493485625164)}
[I 2025-10-01 15:30:51,189] Trial 17 finished with value: -0.00823103715380283 and parameters: {'lr': 0.0006896163959523733, 'num_epochs': 7, 'batch_size': 128, 'num_neighbors': 14, 'lr_decay': 0.8475497736685316}. Best is trial 10 with value: -0.00816584277780242.


Best trial: 10. Best value: -0.00816584:  38%|███▊      | 19/50 [00:31<00:50,  1.64s/it]

{'gini': np.float64(0.042131658188601165), 'ess': np.float64(248614.10853933543), 'max_wi': np.float64(1.2266531050411134), 'min_wi': np.float64(0.7711738017109608)}
[I 2025-10-01 15:30:52,810] Trial 18 finished with value: -0.008316508234472414 and parameters: {'lr': 0.002270528661896805, 'num_epochs': 2, 'batch_size': 128, 'num_neighbors': 13, 'lr_decay': 0.9855257156930795}. Best is trial 10 with value: -0.00816584277780242.


Best trial: 10. Best value: -0.00816584:  40%|████      | 20/50 [00:33<00:49,  1.66s/it]

{'gini': np.float64(0.010061339707137181), 'ess': np.float64(249920.32168922917), 'max_wi': np.float64(1.0495823821374328), 'min_wi': np.float64(0.9430310570230466)}
[I 2025-10-01 15:30:54,513] Trial 19 finished with value: -0.008779924165966669 and parameters: {'lr': 0.0003859003043609369, 'num_epochs': 10, 'batch_size': 256, 'num_neighbors': 3, 'lr_decay': 0.814825254802979}. Best is trial 10 with value: -0.00816584277780242.


Best trial: 10. Best value: -0.00816584:  42%|████▏     | 21/50 [00:34<00:47,  1.64s/it]

{'gini': np.float64(0.0029339975805340755), 'ess': np.float64(249993.20783749), 'max_wi': np.float64(1.0155152554581597), 'min_wi': np.float64(0.9817160064843171)}
[I 2025-10-01 15:30:56,119] Trial 20 finished with value: -0.008440352171747259 and parameters: {'lr': 0.0001016703297951095, 'num_epochs': 4, 'batch_size': 128, 'num_neighbors': 10, 'lr_decay': 0.8478713124950682}. Best is trial 10 with value: -0.00816584277780242.


Best trial: 21. Best value: -0.00816457:  44%|████▍     | 22/50 [00:36<00:45,  1.64s/it]

{'gini': np.float64(0.005869595849389141), 'ess': np.float64(249971.71425841484), 'max_wi': np.float64(1.0292353981244264), 'min_wi': np.float64(0.9634218891862482)}
[I 2025-10-01 15:30:57,759] Trial 21 finished with value: -0.008164571851198362 and parameters: {'lr': 0.0006143330982197509, 'num_epochs': 1, 'batch_size': 128, 'num_neighbors': 15, 'lr_decay': 0.8030689230667892}. Best is trial 21 with value: -0.008164571851198362.


Best trial: 21. Best value: -0.00816457:  46%|████▌     | 23/50 [00:38<00:43,  1.63s/it]

{'gini': np.float64(0.0076331538424940695), 'ess': np.float64(249952.07294241), 'max_wi': np.float64(1.0382223155633168), 'min_wi': np.float64(0.9518975833180364)}
[I 2025-10-01 15:30:59,362] Trial 22 finished with value: -0.008225712331771134 and parameters: {'lr': 0.0007998474454784035, 'num_epochs': 1, 'batch_size': 128, 'num_neighbors': 14, 'lr_decay': 0.8124345323974064}. Best is trial 21 with value: -0.008164571851198362.


Best trial: 21. Best value: -0.00816457:  48%|████▊     | 24/50 [00:39<00:42,  1.63s/it]

{'gini': np.float64(0.007631147802518705), 'ess': np.float64(249953.50428509247), 'max_wi': np.float64(1.0404475123247392), 'min_wi': np.float64(0.9526614472634778)}
[I 2025-10-01 15:31:01,008] Trial 23 finished with value: -0.008227319384124512 and parameters: {'lr': 0.00044898394497364986, 'num_epochs': 2, 'batch_size': 128, 'num_neighbors': 14, 'lr_decay': 0.8329684620953282}. Best is trial 21 with value: -0.008164571851198362.


Best trial: 24. Best value: -0.00816434:  50%|█████     | 25/50 [00:41<00:40,  1.62s/it]

{'gini': np.float64(0.0026736786360868608), 'ess': np.float64(249994.23400475175), 'max_wi': np.float64(1.0139387873247805), 'min_wi': np.float64(0.9849337146697912)}
[I 2025-10-01 15:31:02,602] Trial 24 finished with value: -0.008164337950792359 and parameters: {'lr': 0.00027709918970579615, 'num_epochs': 1, 'batch_size': 128, 'num_neighbors': 15, 'lr_decay': 0.8000161005043485}. Best is trial 24 with value: -0.008164337950792359.


Best trial: 24. Best value: -0.00816434:  52%|█████▏    | 26/50 [00:42<00:38,  1.61s/it]

{'gini': np.float64(0.005995031036930361), 'ess': np.float64(249971.680473601), 'max_wi': np.float64(1.0289311970078028), 'min_wi': np.float64(0.9661432746520132)}
[I 2025-10-01 15:31:04,199] Trial 25 finished with value: -0.008285634426757697 and parameters: {'lr': 0.00025021842995115597, 'num_epochs': 3, 'batch_size': 128, 'num_neighbors': 12, 'lr_decay': 0.8646110743415786}. Best is trial 24 with value: -0.008164337950792359.


Best trial: 24. Best value: -0.00816434:  54%|█████▍    | 27/50 [00:44<00:37,  1.61s/it]

{'gini': np.float64(0.004095442599868318), 'ess': np.float64(249986.66026737317), 'max_wi': np.float64(1.021754177528384), 'min_wi': np.float64(0.9742975468588876)}
[I 2025-10-01 15:31:05,814] Trial 26 finished with value: -0.008228850404085527 and parameters: {'lr': 0.00023512036921091421, 'num_epochs': 2, 'batch_size': 128, 'num_neighbors': 14, 'lr_decay': 0.9002112685279765}. Best is trial 24 with value: -0.008164337950792359.


Best trial: 24. Best value: -0.00816434:  56%|█████▌    | 28/50 [00:46<00:35,  1.62s/it]

{'gini': np.float64(0.00506138035858204), 'ess': np.float64(249980.00000681012), 'max_wi': np.float64(1.0255970158530718), 'min_wi': np.float64(0.9696330478783677)}
[I 2025-10-01 15:31:07,438] Trial 27 finished with value: -0.008366375977233408 and parameters: {'lr': 0.00010339796327409263, 'num_epochs': 4, 'batch_size': 64, 'num_neighbors': 11, 'lr_decay': 0.840351743294758}. Best is trial 24 with value: -0.008164337950792359.


Best trial: 24. Best value: -0.00816434:  58%|█████▊    | 29/50 [00:47<00:34,  1.63s/it]

{'gini': np.float64(0.005823242935283665), 'ess': np.float64(249972.27847087276), 'max_wi': np.float64(1.0276399355600159), 'min_wi': np.float64(0.9668908506645371)}
[I 2025-10-01 15:31:09,088] Trial 28 finished with value: -0.008302946075710119 and parameters: {'lr': 0.0010059942742010294, 'num_epochs': 1, 'batch_size': 256, 'num_neighbors': 13, 'lr_decay': 0.8178045777900285}. Best is trial 24 with value: -0.008164337950792359.


Best trial: 24. Best value: -0.00816434:  60%|██████    | 30/50 [00:49<00:32,  1.64s/it]

{'gini': np.float64(0.03175385389776177), 'ess': np.float64(249222.2945567292), 'max_wi': np.float64(1.1654360104586208), 'min_wi': np.float64(0.8168240775681119)}
[I 2025-10-01 15:31:10,746] Trial 29 finished with value: -0.008555820056530803 and parameters: {'lr': 0.00048016511727003684, 'num_epochs': 6, 'batch_size': 64, 'num_neighbors': 8, 'lr_decay': 0.8630264535942166}. Best is trial 24 with value: -0.008164337950792359.


Best trial: 24. Best value: -0.00816434:  62%|██████▏   | 31/50 [00:51<00:31,  1.63s/it]

{'gini': np.float64(0.004640100663551571), 'ess': np.float64(249982.80068978865), 'max_wi': np.float64(1.023946698873982), 'min_wi': np.float64(0.9730650147282662)}
[I 2025-10-01 15:31:12,377] Trial 30 finished with value: -0.008165207588103417 and parameters: {'lr': 0.00027019017136917565, 'num_epochs': 2, 'batch_size': 128, 'num_neighbors': 15, 'lr_decay': 0.8179847682596127}. Best is trial 24 with value: -0.008164337950792359.


Best trial: 24. Best value: -0.00816434:  64%|██████▍   | 32/50 [00:52<00:29,  1.63s/it]

{'gini': np.float64(0.004646215412601606), 'ess': np.float64(249982.83753827924), 'max_wi': np.float64(1.0227435032454135), 'min_wi': np.float64(0.9730633276843679)}
[I 2025-10-01 15:31:13,990] Trial 31 finished with value: -0.008164397599625737 and parameters: {'lr': 0.0002729116364371551, 'num_epochs': 2, 'batch_size': 128, 'num_neighbors': 15, 'lr_decay': 0.8191022692001436}. Best is trial 24 with value: -0.008164337950792359.


Best trial: 24. Best value: -0.00816434:  66%|██████▌   | 33/50 [00:54<00:27,  1.63s/it]

{'gini': np.float64(0.00445836720483108), 'ess': np.float64(249984.18877291176), 'max_wi': np.float64(1.0228772696167119), 'min_wi': np.float64(0.9743489739323542)}
[I 2025-10-01 15:31:15,625] Trial 32 finished with value: -0.008229560711521975 and parameters: {'lr': 0.0002622980812989751, 'num_epochs': 2, 'batch_size': 128, 'num_neighbors': 14, 'lr_decay': 0.8193397687231945}. Best is trial 24 with value: -0.008164337950792359.


Best trial: 24. Best value: -0.00816434:  68%|██████▊   | 34/50 [00:56<00:26,  1.64s/it]

{'gini': np.float64(0.008293241016751248), 'ess': np.float64(249945.527428984), 'max_wi': np.float64(1.0427682391708508), 'min_wi': np.float64(0.9501718170407719)}
[I 2025-10-01 15:31:17,284] Trial 33 finished with value: -0.008300224324200882 and parameters: {'lr': 0.00030821251112823786, 'num_epochs': 3, 'batch_size': 128, 'num_neighbors': 13, 'lr_decay': 0.971158620384831}. Best is trial 24 with value: -0.008164337950792359.


Best trial: 24. Best value: -0.00816434:  70%|███████   | 35/50 [00:57<00:24,  1.64s/it]

{'gini': np.float64(0.004618786409293234), 'ess': np.float64(249983.11565123423), 'max_wi': np.float64(1.0222937560986158), 'min_wi': np.float64(0.9737249161903242)}
[I 2025-10-01 15:31:18,918] Trial 34 finished with value: -0.008167010450457937 and parameters: {'lr': 0.00019931699876548723, 'num_epochs': 3, 'batch_size': 128, 'num_neighbors': 15, 'lr_decay': 0.8412907102556448}. Best is trial 24 with value: -0.008164337950792359.


Best trial: 24. Best value: -0.00816434:  72%|███████▏  | 36/50 [00:59<00:23,  1.64s/it]

{'gini': np.float64(0.0038775105954834414), 'ess': np.float64(249988.0744289153), 'max_wi': np.float64(1.0194210342106358), 'min_wi': np.float64(0.9773922389060659)}
[I 2025-10-01 15:31:20,577] Trial 35 finished with value: -0.00823050934806256 and parameters: {'lr': 0.00013355775878324557, 'num_epochs': 2, 'batch_size': 64, 'num_neighbors': 14, 'lr_decay': 0.8138164046421575}. Best is trial 24 with value: -0.008164337950792359.


Best trial: 24. Best value: -0.00816434:  74%|███████▍  | 37/50 [01:01<00:21,  1.65s/it]

{'gini': np.float64(0.008727242102061742), 'ess': np.float64(249940.2626282969), 'max_wi': np.float64(1.0429399921249338), 'min_wi': np.float64(0.9492184010948006)}
[I 2025-10-01 15:31:22,228] Trial 36 finished with value: -0.008299910232517855 and parameters: {'lr': 0.00030658606559628185, 'num_epochs': 4, 'batch_size': 128, 'num_neighbors': 13, 'lr_decay': 0.8339830040065375}. Best is trial 24 with value: -0.008164337950792359.


Best trial: 24. Best value: -0.00816434:  76%|███████▌  | 38/50 [01:02<00:19,  1.65s/it]

{'gini': np.float64(0.08023364050295664), 'ess': np.float64(244911.13023472382), 'max_wi': np.float64(1.3865715383690926), 'min_wi': np.float64(0.6073666732857843)}
[I 2025-10-01 15:31:23,874] Trial 37 finished with value: -0.008370164099193193 and parameters: {'lr': 0.009076145704900198, 'num_epochs': 3, 'batch_size': 512, 'num_neighbors': 12, 'lr_decay': 0.85756540247065}. Best is trial 24 with value: -0.008164337950792359.


Best trial: 24. Best value: -0.00816434:  78%|███████▊  | 39/50 [01:04<00:18,  1.65s/it]

{'gini': np.float64(0.0027008881840663848), 'ess': np.float64(249994.21956188002), 'max_wi': np.float64(1.0130921557168173), 'min_wi': np.float64(0.985262770063087)}
[I 2025-10-01 15:31:25,524] Trial 38 finished with value: -0.008166052897754227 and parameters: {'lr': 0.00014618622526002724, 'num_epochs': 2, 'batch_size': 128, 'num_neighbors': 15, 'lr_decay': 0.9576408583014726}. Best is trial 24 with value: -0.008164337950792359.


Best trial: 24. Best value: -0.00816434:  80%|████████  | 40/50 [01:05<00:16,  1.64s/it]

{'gini': np.float64(0.03040182655326042), 'ess': np.float64(249272.11783920226), 'max_wi': np.float64(1.1505856378075412), 'min_wi': np.float64(0.8206456137149871)}
[I 2025-10-01 15:31:27,146] Trial 39 finished with value: -0.008642840647382135 and parameters: {'lr': 0.00135098348679697, 'num_epochs': 5, 'batch_size': 256, 'num_neighbors': 6, 'lr_decay': 0.8994321823300094}. Best is trial 24 with value: -0.008164337950792359.


Best trial: 24. Best value: -0.00816434:  82%|████████▏ | 41/50 [01:07<00:14,  1.66s/it]

{'gini': np.float64(0.006567427199942143), 'ess': np.float64(249964.73463439016), 'max_wi': np.float64(1.02722051756333), 'min_wi': np.float64(0.9627232588438793)}
[I 2025-10-01 15:31:28,866] Trial 40 finished with value: -0.008233932056540052 and parameters: {'lr': 0.0003277334949651814, 'num_epochs': 9, 'batch_size': 512, 'num_neighbors': 14, 'lr_decay': 0.8792413267788508}. Best is trial 24 with value: -0.008164337950792359.


Best trial: 24. Best value: -0.00816434:  84%|████████▍ | 42/50 [01:09<00:13,  1.66s/it]

{'gini': np.float64(0.005535740141543364), 'ess': np.float64(249974.85991731213), 'max_wi': np.float64(1.0268330096594478), 'min_wi': np.float64(0.9692088494055054)}
[I 2025-10-01 15:31:30,512] Trial 41 finished with value: -0.008172226027694516 and parameters: {'lr': 0.0005690189828640535, 'num_epochs': 1, 'batch_size': 128, 'num_neighbors': 15, 'lr_decay': 0.8062001692152938}. Best is trial 24 with value: -0.008164337950792359.


Best trial: 24. Best value: -0.00816434:  86%|████████▌ | 43/50 [01:10<00:11,  1.66s/it]

{'gini': np.float64(0.001995700410191967), 'ess': np.float64(249996.740425026), 'max_wi': np.float64(1.0108344781883014), 'min_wi': np.float64(0.9872229957114854)}
[I 2025-10-01 15:31:32,163] Trial 42 finished with value: -0.008164696771501783 and parameters: {'lr': 0.00020693132353647456, 'num_epochs': 1, 'batch_size': 128, 'num_neighbors': 15, 'lr_decay': 0.8001221842123989}. Best is trial 24 with value: -0.008164337950792359.


Best trial: 24. Best value: -0.00816434:  88%|████████▊ | 44/50 [01:12<00:09,  1.64s/it]

{'gini': np.float64(0.001980209615031799), 'ess': np.float64(249996.79382723724), 'max_wi': np.float64(1.009484406209847), 'min_wi': np.float64(0.9870406558060784)}
[I 2025-10-01 15:31:33,781] Trial 43 finished with value: -0.008227849238207867 and parameters: {'lr': 0.00020358971340340314, 'num_epochs': 1, 'batch_size': 128, 'num_neighbors': 14, 'lr_decay': 0.8232719978410347}. Best is trial 24 with value: -0.008164337950792359.


Best trial: 44. Best value: -0.00816396:  90%|█████████ | 45/50 [01:14<00:08,  1.64s/it]

{'gini': np.float64(0.002328613613675946), 'ess': np.float64(249995.68045209345), 'max_wi': np.float64(1.0117779719856053), 'min_wi': np.float64(0.9859739767974958)}
[I 2025-10-01 15:31:35,407] Trial 44 finished with value: -0.008163962755856928 and parameters: {'lr': 0.0001381078986803915, 'num_epochs': 2, 'batch_size': 128, 'num_neighbors': 15, 'lr_decay': 0.8000865035210912}. Best is trial 44 with value: -0.008163962755856928.


Best trial: 44. Best value: -0.00816396:  92%|█████████▏| 46/50 [01:15<00:06,  1.64s/it]

{'gini': np.float64(0.004032656968885356), 'ess': np.float64(249987.04820875294), 'max_wi': np.float64(1.0190700791345406), 'min_wi': np.float64(0.9753098315012226)}
[I 2025-10-01 15:31:37,053] Trial 45 finished with value: -0.008302004997511166 and parameters: {'lr': 0.00018037092336658017, 'num_epochs': 3, 'batch_size': 128, 'num_neighbors': 13, 'lr_decay': 0.8026220729888982}. Best is trial 44 with value: -0.008163962755856928.


Best trial: 44. Best value: -0.00816396:  94%|█████████▍| 47/50 [01:17<00:04,  1.64s/it]

{'gini': np.float64(0.0021053881208348275), 'ess': np.float64(249996.45147268358), 'max_wi': np.float64(1.0106291712764162), 'min_wi': np.float64(0.9885180431864301)}
[I 2025-10-01 15:31:38,703] Trial 46 finished with value: -0.008165444549768947 and parameters: {'lr': 0.00012818206344151014, 'num_epochs': 1, 'batch_size': 64, 'num_neighbors': 15, 'lr_decay': 0.8001527909082651}. Best is trial 44 with value: -0.008163962755856928.


Best trial: 44. Best value: -0.00816396:  96%|█████████▌| 48/50 [01:19<00:03,  1.64s/it]

{'gini': np.float64(0.0036982407184482547), 'ess': np.float64(249989.0734211782), 'max_wi': np.float64(1.0170454290007809), 'min_wi': np.float64(0.9756599704758526)}
[I 2025-10-01 15:31:40,324] Trial 47 finished with value: -0.008305092562343807 and parameters: {'lr': 0.0002186617412429803, 'num_epochs': 2, 'batch_size': 128, 'num_neighbors': 13, 'lr_decay': 0.8121484791968973}. Best is trial 44 with value: -0.008163962755856928.


Best trial: 44. Best value: -0.00816396:  98%|█████████▊| 49/50 [01:20<00:01,  1.63s/it]

{'gini': np.float64(0.0005511801316785044), 'ess': np.float64(249999.75125646926), 'max_wi': np.float64(1.0022596087464426), 'min_wi': np.float64(0.9968354385001351)}
[I 2025-10-01 15:31:41,948] Trial 48 finished with value: -0.008228942765870665 and parameters: {'lr': 0.00015585418754553656, 'num_epochs': 1, 'batch_size': 512, 'num_neighbors': 14, 'lr_decay': 0.8285923176095564}. Best is trial 44 with value: -0.008163962755856928.


Best trial: 44. Best value: -0.00816396: 100%|██████████| 50/50 [01:22<00:00,  1.65s/it]

{'gini': np.float64(0.0033157545906658324), 'ess': np.float64(249991.35517540027), 'max_wi': np.float64(1.0166187774335813), 'min_wi': np.float64(0.9808881408306281)}
[I 2025-10-01 15:31:43,595] Trial 49 finished with value: -0.008286964862933495 and parameters: {'lr': 0.0001203393482579109, 'num_epochs': 4, 'batch_size': 128, 'num_neighbors': 12, 'lr_decay': 0.8110332356751904}. Best is trial 44 with value: -0.008163962755856928.





TypeError: train() got multiple values for argument 'criterion'

In [None]:
df4[['policy_rewards', 'ipw', 'reg_dm', 'conv_dm', 'conv_dr', 'conv_sndr', 'action_diff_to_real', 'action_delta', 'context_diff_to_real', 'context_delta']]

Unnamed: 0,policy_rewards,ipw,reg_dm,conv_dm,conv_dr,conv_sndr,action_diff_to_real,action_delta,context_diff_to_real,context_delta
0,0.08610747,0.09380742,0.08732337,0.09204356,0.09202231,0.09194136,0.7569287,0.0,0.87627132,0.0
500,0.08611565,0.08938694,0.0955977,0.09864954,0.09419736,0.07762452,0.75680212,0.00321359,0.87629079,0.00154774
1000,0.08702833,0.09197439,0.09029889,0.10181247,0.09961428,0.09519203,0.77645672,0.18381873,0.88400949,0.07933387
2000,0.08611525,0.09211356,0.08769463,0.09043088,0.08906903,0.08397306,0.75688584,0.00360047,0.87628394,0.00160451
10000,0.08642838,0.0767594,0.08338645,0.08754185,0.08388183,0.07629422,0.76114452,0.08074157,0.87917341,0.03694905
20000,0.086647,0.08919219,0.08729591,0.09040812,0.08901549,0.08726893,0.76196226,0.08492833,0.87692124,0.03392471


In [None]:
# Show the performance metrics
df4[['policy_rewards', 'ipw', 'reg_dm', 'conv_dm', 'conv_dr', 'conv_sndr', 'action_diff_to_real', 'action_delta', 'context_diff_to_real', 'context_delta']]


Unnamed: 0,policy_rewards,ipw,reg_dm,conv_dm,conv_dr,conv_sndr,action_diff_to_real,action_delta,context_diff_to_real,context_delta
0,0.08610747,0.1069709,0.09051612,0.09112201,0.09452505,0.10672373,0.7569287,0.0,0.87627132,0.0
500,0.08705604,0.09221834,,0.08299331,0.08198609,0.07599146,0.79170973,0.24615559,0.88427728,0.08661758
1000,0.08939145,0.11301958,,0.08679147,0.09052395,0.10668506,1.01853061,0.76340735,0.91464524,0.19321758
2000,0.09251861,0.10603409,,0.09028676,0.15397776,0.10628439,1.73862067,1.70789298,0.99652312,0.34170287
10000,0.09268524,0.09704712,,0.09829317,0.09539229,0.09257621,2.18938809,2.22344507,1.03555944,0.40151858
20000,0.09264639,0.09493701,,0.09216787,0.09134166,0.09056984,2.21242505,2.24879912,1.03520993,0.40065441
