In [None]:
import os
import sys
import pickle
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

root_dir = "../"
sys.path.append(root_dir)
from test_sgr import SelectiveRiskMartingale

results_dir = os.path.join(root_dir, "results")
results = pickle.load(open(os.path.join(results_dir, "sgr_results.pkl"), "rb"))

sns.set_style("white")
sns.set_context("paper")

In [None]:
figure_dir = os.path.join(root_dir, "figures")
os.makedirs(figure_dir, exist_ok=True)

s = 0.2
k0 = np.random.choice(results["none"]["tests"], 1)[0]
k1 = np.random.choice(results[s]["tests"], 1)[0]

t_change = 128
shift = ["none", 0.0, 0.2, 0.4, 0.6, 0.8, 1.0]

test_risk, test_coverage = [], []
rejection_rate, rejection_time = [], []
for _shift in shift:
    _rejection_rate = []
    _rejection_time = []

    _results = results[_shift]
    test_risk.append(_results["test_risk"])
    test_coverage.append(_results["test_coverage"])
    for k in _results["tests"]:
        wealth = k.wealth
        rejected = np.where(wealth >= 1 / 0.05)[0]
        if len(rejected) > 0:
            _rejection_rate.append(1)
            _rejection_time.append(rejected[0] - t_change)
        else:
            _rejection_rate.append(0)
            _rejection_time.append(len(wealth))

    rejection_rate.append(np.mean(_rejection_rate))
    rejection_time.append(np.mean(_rejection_time) / 512)

print(test_risk)
print(test_coverage)

_, axes = plt.subplots(1, 2, figsize=(16 / 2, 9 / 4))
ax = axes[0]
ax.plot(k0.wealth, label="No shift")
ax.plot(k1.wealth, label=f"With shift (t = {s})")
ax.set_xlabel("Time")
ax.set_ylabel("Wealth")
ax.axhline(1 / 0.05, color="black", linestyle="--")
ax.axvline(t_change, color="red", linestyle="-", label="Changepoint")
ax.text(0, 2 / 0.05, r"$1/\alpha$")
ax.legend()
ax.set_yscale("log")
ax.set_ylim(None, 1e05)

ax = axes[1]
shift[0] = -0.2
ax.plot(
    shift, rejection_rate, "--*", alpha=0.8, color="#a6cee3", label="Rejection rate"
)
ax.plot(shift, rejection_time, "--*", color="#1f78b4", label="Rejection time")
ax.axhline(0.05, color="black", linestyle="--", label="Significance level")
ax.set_xlabel("Solarization threshold")
ax.set_xticks(shift)
ax.set_xticklabels(["No shift"] + shift[1:])
ax.set_ylim(0, 1)
ax.legend()
plt.savefig(os.path.join(figure_dir, "sgr.pdf"), bbox_inches="tight")
plt.show()