In [None]:
import os
import sys
import torch
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm

root_dir = "../../"
sys.path.append(root_dir)
from krcps.utils import get_loss, get_bound

sns.set_theme(style="white")
sns.set_context("paper", font_scale=1.5)

experiment_dir = "../"
fig_dir = os.path.join(experiment_dir, "figures", "motivation")
os.makedirs(fig_dir, exist_ok=True)

In [None]:
def crc_bound_fn(n, delta, loss, B=1):
    return n / (n + 1) * loss + B / (n + 1)


rcps_loss_fn = get_loss("01")
hb_bound_fn = get_bound("hoeffding_bentkus")

n = 128
mu = torch.tensor([-1, 1])
d = mu.size(0)

x = torch.randn(n, 1, d)
x = mu + x

_l, _u = -1, 1
l, u = _l * torch.ones_like(x), _u * torch.ones_like(x)

epsilon = delta = 0.1
lambda_max = 4
stepsize = 1e-03

In [None]:
_, ax = plt.subplots(figsize=(16 / 2, 9 / 2))


def _rcps(bound_fn, bound_name):
    _lambda, _loss, _ucb = [], [], []

    vector_lambda = lambda_max * torch.ones_like(x)
    loss = rcps_loss_fn(x, l - vector_lambda, u + vector_lambda)
    ucb = bound_fn(n, delta, loss)

    _lambda.append(vector_lambda.unique().item())
    _loss.append(loss.item())
    _ucb.append(ucb)

    while ucb <= epsilon:
        vector_lambda = vector_lambda - stepsize
        loss = rcps_loss_fn(x, l - vector_lambda, u + vector_lambda)
        ucb = bound_fn(n, delta, loss)

        _lambda.append(vector_lambda.unique().item())
        _loss.append(loss.item())
        _ucb.append(ucb)

    x_test = torch.randn(100, 1, d)
    x_test = mu + x_test

    l_test, u_test = _l * torch.ones_like(x_test), _u * torch.ones_like(x_test)
    loss_test = rcps_loss_fn(
        x_test, l_test - vector_lambda.unique(), u_test + vector_lambda.unique()
    )

    i = u - l + 2 * vector_lambda

    ax.plot(_lambda, _ucb, label=f"{bound_name}")
    print(
        f"Bound: {bound_name}, lambda hat: {vector_lambda.unique().item()}, test loss: {loss_test.item()} mean interval length: {torch.mean(i)}"
    )


_rcps(hb_bound_fn, "RCPS")
_rcps(crc_bound_fn, "CRC")

ax.set_xlabel(r"$\lambda$")
ax.set_ylabel(r"UCB$(\lambda)$")
ax.set_xscale("log")
ax.legend()
plt.show()

In [None]:
vector_lambda = torch.zeros_like(x)

_, ax = plt.subplots(figsize=(5, 5))

m = 50
ll = torch.linspace(lambda_max, 0, m)
min_lambda = lambda_max
for l1 in tqdm(ll):
    for l2 in ll:
        vector_lambda[:, :, 0] = l1
        vector_lambda[:, :, 1] = l2

        loss = rcps_loss_fn(x, l - vector_lambda, u + vector_lambda)
        ucb = hb_bound_fn(n, delta, loss)

        controlled = ucb <= epsilon
        if not controlled:
            ax.scatter(l1, l2, marker="x", color="#c44e52", alpha=0.20)
        else:
            ax.scatter(l1, l2, marker="x", color="#55a868", alpha=0.80)
            if l1 == l2 and l1 < min_lambda:
                min_lambda = l1
ax.set_xlabel(r"$\lambda_1$")
ax.set_ylabel(r"$\lambda_2$")
ax.plot(
    [min_lambda, lambda_max],
    [min_lambda, lambda_max],
    color="#1f77b4",
    linestyle="--",
    linewidth=2,
)
ax.scatter(min_lambda, min_lambda, marker="*", color="#4c72b0", s=150)
ax.set_xlim(0, lambda_max)
ax.set_ylim(0, lambda_max)
ax.set_xticks([0, 2, 4])
ax.set_yticks([0, 2, 4])
plt.savefig(os.path.join(fig_dir, "balanced.jpg"), bbox_inches="tight")
plt.savefig(os.path.join(fig_dir, "balanced.pdf"), bbox_inches="tight")
plt.show()

In [None]:
n = 128
mu = torch.tensor([-2, 0.75])
d = mu.size(0)

x = torch.randn(n, 1, d)
x = mu + x

_l, _u = -1, 1
l, u = _l * torch.ones_like(x), _u * torch.ones_like(x)

epsilon = delta = 0.1
lambda_max = 4

vector_lambda = torch.zeros_like(x)

_, ax = plt.subplots(figsize=(5, 5))

m = 50
ll = torch.linspace(lambda_max, 0, m)
min_lambda = lambda_max
min_l1 = min_l2 = lambda_max
min_sum = d * lambda_max
for l1 in tqdm(ll):
    for l2 in ll:
        vector_lambda[:, :, 0] = l1
        vector_lambda[:, :, 1] = l2

        loss = rcps_loss_fn(x, l - vector_lambda, u + vector_lambda)
        ucb = hb_bound_fn(n, delta, loss)

        controlled = ucb <= epsilon
        if not controlled:
            ax.scatter(l1, l2, marker="x", color="#c44e52", alpha=0.20)
        else:
            ax.scatter(l1, l2, marker="x", color="#55a868", alpha=0.80)
            if l1 == l2 and l1 < min_lambda:
                min_lambda = l1
            if l1 + l2 < min_sum:
                min_sum = l1 + l2
                min_l1 = l1
                min_l2 = l2

ax.set_xlabel(r"$\lambda_1$")
ax.set_ylabel(r"$\lambda_2$")
ax.plot(
    [min_lambda, lambda_max],
    [min_lambda, lambda_max],
    color="#1f77b4",
    linestyle="--",
    linewidth=2,
)
ax.plot(
    [min_l1, min_l1 + lambda_max],
    [min_l2, min_l2 + lambda_max],
    color="#ff7f0e",
    linestyle="--",
    linewidth=2,
)
ax.scatter(min_lambda, min_lambda, marker="*", color="#4c72b0", s=150)
ax.scatter(min_l1, min_l2, marker="*", color="#dd8452", s=150)
ax.set_xlim(0, lambda_max)
ax.set_ylim(0, lambda_max)
ax.set_xticks([0, 2, 4])
ax.set_yticks([0, 2, 4])
plt.savefig(os.path.join(fig_dir, "unbalanced.jpg"), bbox_inches="tight")
plt.savefig(os.path.join(fig_dir, "unbalanced.pdf"), bbox_inches="tight")
plt.show()