In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import torch
import optuna
import pickle
torch.cuda.get_device_name(0)

In [None]:
from environment import ImageBanditEnv, Experiment
from algorithms import LinUCB, CNN_UCB, ViT_UCB
from pathlib import Path

In [None]:
f = "anime_env.pkl"
if Path(f).is_file():
    environment = ImageBanditEnv.load(f)
else:
    environment = ImageBanditEnv("data.jsonl", 3, 1)
    environment.save(f)

In [None]:
def vit_objective(trial):
    # Suggest hyperparameters to tune
    alpha = trial.suggest_float("alpha",1, 100.0)
    lambda_reg = trial.suggest_float("lambda_reg", 0.1, 10.0)
    lora_r = trial.suggest_int("lora_r", 4, 20)
    lora_alpha = trial.suggest_float("lora_alpha", 3, 32)
    
    vitucb = ViT_UCB(environment,
                    alpha=alpha,
                    lambda_reg=lambda_reg,
                    lora_r=lora_r,
                    lora_alpha=lora_alpha)
    
    # Run it and get regret
    vitucb.run(1500)
    final_regret = vitucb.regret
    return final_regret  # Optuna will try to minimize this
def lin_objective(trial):
    # Suggest hyperparameters to tune
    alpha = trial.suggest_float("alpha",1, 100.0)
    
    linucb = LinUCB(environment,
                    alpha=alpha)
    
    # Run it and get regret
    linucb.run(1500)
    final_regret = linucb.regret
    return final_regret  # Optuna will try to minimize this
def cnn_objective(trial):
    # Suggest hyperparameters to tune
    alpha = trial.suggest_float("alpha",1, 100.0)
    lambda_reg = trial.suggest_float("lambda_reg", 0.1, 10.0)
    
    cnnucb = CNN_UCB(environment,
                    alpha=alpha,
                    lambda_reg=lambda_reg)
    
    # Run it and get regret
    cnnucb.run(1500)
    final_regret = cnnucb.regret
    return final_regret  # Optuna will try to minimize this


In [None]:
for name, objective in [("vit", vit_objective), ("cnn", cnn_objective), ("lin", lin_objective)]:
    study = optuna.create_study(
        study_name=f"{name}ucb_tuning",
        storage="sqlite:///db.sqlite3",
        load_if_exists=True,
        direction="minimize"
    )
    study.optimize(objective, n_trials=50)
    print(f"{name}:", study.best_params)

In [None]:
study.best_params