In [None]:
import os
import sys
import pandas as pd
import numpy as np
import torch
import matplotlib.pyplot as plt
import seaborn as sns

root_dir = "../../"
script_dir = os.path.join(root_dir, "scripts", "crosses")
sys.path.append(root_dir)
sys.path.append(script_dir)

from dataset import Crosses
from utils import batch_size, r, s, d, alpha

data_dir = "data"
figure_dir = os.path.join(root_dir, "figures", "crosses")
power_figure_dir = os.path.join(figure_dir, "power")
example_figure_dir = os.path.join(figure_dir, "examples")
os.makedirs(figure_dir, exist_ok=True)
os.makedirs(power_figure_dir, exist_ok=True)
os.makedirs(example_figure_dir, exist_ok=True)

sns.set_theme()
sns.set_context("paper", font_scale=1.8)

In [None]:
# code to reproduce Fig. E.1.a
dataset = Crosses(batch_size, r, s, d, 1 / d**2)

_, ax = plt.subplots()
im = ax.imshow(dataset._signal.reshape(d, d), cmap="gray")
ax.xaxis.set_visible(False)
ax.yaxis.set_visible(False)
plt.savefig(os.path.join(figure_dir, "signal.pdf"), bbox_inches="tight")
plt.show()

In [None]:
# code to reproduce Fig. E.1.b through E.1.d
for sigma in [1 / d**2, 1 / d, 1 / np.sqrt(d)]:
    sigma_fig_dir = os.path.join(example_figure_dir, f"{sigma:.2f}")
    os.makedirs(sigma_fig_dir, exist_ok=True)

    dataset = Crosses(batch_size, r, s, d, sigma)

    X = dataset.data
    Y = dataset.labels

    P_idx = Y.nonzero().squeeze()[:2]
    N_idx = (1 - Y).nonzero().squeeze()[:2]
    idx = torch.cat((P_idx, N_idx))

    for i, k in enumerate(idx):
        x = X[k].squeeze()
        y = Y[k].item()

        fig, ax = plt.subplots()
        im = ax.imshow(x, cmap="gray", vmin=-3, vmax=3)
        ax.set_title(r"$Y = %d$" % y)
        ax.xaxis.set_visible(False)
        ax.yaxis.set_visible(False)
        fig.colorbar(im, ax=ax)
        plt.savefig(
            os.path.join(sigma_fig_dir, f"{int(y)}_{i % 2}.pdf"), bbox_inches="tight"
        )
        plt.close()

In [None]:
# code to reproduce Fig. 2
power_df = pd.read_csv(os.path.join(data_dir, "power_m.csv"))
sigma_df = pd.read_csv(os.path.join(data_dir, "power_sigma.csv"))

_, axes = plt.subplots(1, 2, figsize=(16, 9 / 2))

ax = axes[0]
sns.lineplot(
    data=power_df, x="m", y="beta", hue="model_name", ax=ax, ci=95, legend=False
)
ax.set_xlabel(r"$m$")
ax.set_ylabel(r"Power at $\alpha = %s$" % alpha)
ax.set_xscale("log")
ax.set_ylim([-0.05, 1.05])

ax = axes[1]
sns.lineplot(data=sigma_df, x="sigma", y="beta", hue="model_name", ax=ax, ci=95)
ax.set_xlabel(r"$\sigma$")
ax.set_ylabel(r"Power at $\alpha = %s$" % alpha)
ax.set_xscale("log")
ax.set_xticks([1 / d**2, 1 / d, 1 / np.sqrt(d), 2 / np.sqrt(d)])
ax.set_xticklabels([r"$1/d^2$", r"$1/d$", r"$1/d^{-1/2}$", r"$2/d^{-1/2}$"])
ax.set_ylim([-0.05, 1.05])
ax.legend(title="Model", loc="upper left", bbox_to_anchor=(1, 1))

plt.savefig(os.path.join(power_figure_dir, "power.pdf"), bbox_inches="tight")
plt.show()