In [None]:
import os
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from scipy.stats import norm
from scipy.special import expit
from tqdm import tqdm

root_dir = "../../"

figure_dir = os.path.join(root_dir, "figures", "sigmoid")
os.makedirs(figure_dir, exist_ok=True)

sns.set_theme()
sns.set_context("paper", font_scale=1.8)

In [None]:
x1_dist = lambda n: norm.rvs(1, 1, size=n)

pos_dist = lambda n: norm.rvs(3, 0.5, size=n)
neg_dist = lambda n: norm.rvs(-1, 1, size=n)
x2_cond = lambda x1: pos_dist(x1.shape[0]) * (x1 >= 3) + neg_dist(x1.shape[0]) * (
    x1 < 3
)

f = lambda x, theta: expit(x @ theta)


def make_data(N, d):
    X = np.empty((N, 2 * d))
    for i in range(d):
        x1 = x1_dist(N)
        x2 = x2_cond(x1)
        X[:, 2 * i] = x1
        X[:, 2 * i + 1] = x2
    return X

In [None]:
N, d = 5000, 3
# change theta_2 to -2 to reproduce figure 2.b
theta = np.array([[1, -2, 1, 1, 1, 1]], dtype=float).T

X = make_data(N, 3)

In [None]:
M, K = 100, 100

mean_p = np.empty(N)
for i in tqdm(range(N)):
    x = X[[i], :]

    p_hat = np.empty(M)
    for j in range(M):
        x_t = np.copy(x)
        x_t[:, 0] = x1_dist(1)

        t = f(x_t, theta)

        x_null = np.copy(x)
        X_null = np.tile(x_null, (K, 1))

        x1_null = x1_dist(K)
        x2_null = x2_cond(x1_null)

        X_null[:, 0] = x1_null
        X_null[:, 1] = x2_null

        t_null = f(X_null, theta)

        _p_hat = (np.sum(t_null >= t) + 1) / (K + 1)
        p_hat[j] = _p_hat

    mean_p[i] = np.mean(p_hat)

In [None]:
gamma = np.empty(N)
for i in tqdm(range(N)):
    x = X[[i], :]

    x_t = np.copy(x)
    X_t = np.tile(x_t, (K, 1))

    x1_t = x1_dist(K)

    X_t[:, 0] = x1_t

    t = f(X_t, theta)

    x_null = np.copy(x)
    X_null = np.tile(x_null, (K, 1))

    x1_null = x1_dist(K)
    x2_null = x2_cond(x1_null)

    X_null[:, 0] = x1_null
    X_null[:, 1] = x2_null

    t_null = f(X_null, theta)

    gamma[i] = np.mean(t - t_null)

In [None]:
_, ax = plt.subplots(figsize=(16 / 2, 9 / 2))
xx = np.linspace(0, 1, 200)
ub = 1 - (K / (K + 1)) * xx
ax.plot(xx, ub, "--")
ax.plot(gamma[gamma >= 0], mean_p[gamma >= 0], "o")
ax.set_xlabel(r"$\gamma_{j,C}$", fontsize=15)
ax.set_ylabel(r"$p^{S-XRT}_{j,C}$", fontsize=15)
ax.set_xlim([-0.05, 1])
ax.set_ylim([0, 1.1])
ax.legend([r"$p^{S-XRT}_{j,C} = 1 - \frac{K}{K+1} \gamma_{j,C}$"])

plt.savefig(os.path.join(figure_dir, "ub.pdf"), bbox_inches="tight")
plt.show()

In [None]:
_, ax = plt.subplots(figsize=(16 / 2, 9 / 2))
xx = np.linspace(-1, 0, 200)
lb = (1 / (K + 1)) * (1 + K * (xx**2))
ax.plot(xx, lb, "--")
ax.plot(gamma[gamma < 0], mean_p[gamma < 0], "o")
ax.set_xlabel(r"$\gamma_{j,C}$", fontsize=15)
ax.set_ylabel(r"$p^{S-XRT}_{j,C}$", fontsize=15)
ax.set_xlim([-1, 0.05])
ax.set_ylim([0, 1.1])
ax.legend([r"$p^{S-XRT}_{j,C} = \frac {1}{K+1}(1 + K(\gamma_{j,C})^2)$"])

plt.savefig(os.path.join(figure_dir, "lb.pdf"), bbox_inches="tight")
plt.show()