In [None]:
import sys
import os
import torch
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from itertools import combinations
from tqdm import tqdm

root_dir = "../../"
sys.path.append(root_dir)
from dataset import SimpleSigmoidDataset, SigmoidDataset
from shaplit import SHAPLIT

figure_dir = os.path.join(root_dir, "figures", "sigmoid")
os.makedirs(figure_dir, exist_ok=True)

torch.set_grad_enabled(False)

sns.set_theme(style="white")
sns.set_context("paper", font_scale=1.5)

In [None]:
m = int(1e03)

dataset = SimpleSigmoidDataset(m)
model = lambda x: torch.sigmoid(torch.sum(x, dim=1))

X = dataset.data
Y = model(X)

_, axes = plt.subplots(1, 2, figsize=(16 / 2, 9 / 2))
ax = axes[0]
ax.scatter(X[:, 0], X[:, 1])
ax.axis("equal")

ax = axes[1]
ax.hist(Y)
plt.show()
plt.close()

In [None]:
N = set(range(X.size(1)))
K = 50
M = 1

P = []
G = []
for x in tqdm(X):
    for j, xj in enumerate(x):
        S = N - {j}

        for c in range(len(S) + 1):
            CC = combinations(S, c)
            for C in CC:
                C = set(C)

                C.add(j)
                f = model(dataset.cond(x, C, K))

                C.remove(j)
                f_null = model(dataset.cond(x, C, K))

                g = torch.mean(f - f_null).item()

                p = [
                    SHAPLIT(model, dataset.cond, x, j, C, K=K, M=M).item()
                    for _ in range(10)
                ]
                p = np.mean(p)

                P.append(p)
                G.append(g)

_, ax = plt.subplots(figsize=(16 / 2, 9 / 2))
ax.scatter(G, P, alpha=0.5, marker=".")
ax.plot([-1, -0.5], [1, 0.5], "k--")
ax.plot([-1, 0], [1, 1], "k--")
ax.plot([0, 1], [1, 0], "k--")
ax.plot([0, 1], [0, 0], "k--")
gg = np.linspace(-0.5, 0)
ax.plot(gg, gg**2 / (0.5**2 + gg**2), "k--")
ax.set_xticks([-1, -0.5, 0, 0.5, 1])
ax.set_yticks([0, 0.5, 1])
ax.axis("equal")
ax.set_xlabel(r"$\hat{\gamma}_{j,C}$")
ax.set_ylabel(r"$\hat{p}_{j,C}$")
plt.show()

In [None]:
m = int(1e03)

dataset = SigmoidDataset(m)
# lambda model = x: torch.sigmoid(torch.sum(x, dim=1))


def model(x):
    x1x2, x3x4 = x[:, :2], x[:, 2:]
    s_x1x2, s_x3x4 = torch.sum(x1x2, dim=1), torch.sum(x3x4, dim=1)
    # m = torch.where(torch.sum(torch.abs(x1x2), dim=1) <= 2, 1.0, 0.0)
    alpha = 0.95
    return torch.sigmoid(alpha * s_x1x2 + (1 - alpha) * s_x3x4)


X = dataset.data
Y = model(X)

_, axes = plt.subplots(1, 2, figsize=(16 / 2, 9 / 2))
ax = axes[0]
ax.scatter(X[:, 0], X[:, 1])
ax.scatter(X[:, 2], X[:, 3])
ax.axis("equal")

ax = axes[1]
ax.hist(Y)
plt.show()
plt.close()

In [None]:
N = set(range(X.size(1)))
K = 50
M = 1

P = []
G = []
for x in tqdm(X):
    for j, xj in enumerate(x):
        S = N - {j}

        for c in range(len(S) + 1):
            CC = combinations(S, c)
            for C in CC:
                C = set(C)

                C.add(j)
                f = model(dataset.cond(x, C, K))

                C.remove(j)
                f_null = model(dataset.cond(x, C, K))

                g = torch.mean(f - f_null).item()

                p = [
                    SHAPLIT(model, dataset.cond, x, j, C, K=K, M=M).item()
                    for _ in range(10)
                ]
                p = np.mean(p)

                P.append(p)
                G.append(g)

_, ax = plt.subplots(figsize=(16 / 2, 9 / 2))
ax.scatter(G, P, alpha=0.5, marker=".")
ax.plot([-1, -0.5], [1, 0.5], "k--")
ax.plot([-1, 0], [1, 1], "k--")
ax.plot([0, 1], [1, 0], "k--")
ax.plot([0, 1], [0, 0], "k--")
gg = np.linspace(-0.5, 0)
ax.plot(gg, gg**2 / (0.5**2 + gg**2), "k--")
ax.set_xticks([-1, -0.5, 0, 0.5, 1])
ax.set_yticks([0, 0.5, 1])
ax.axis("equal")
plt.show()

In [None]:
def model(x):
    x1x2, x3x4 = x[:, :2], x[:, 2:]
    s_x1x2, s_x3x4 = torch.sum(x1x2, dim=1), torch.sum(x3x4, dim=1)
    # m = torch.where(torch.sum(torch.abs(x1x2), dim=1) <= 2, 1.0, 0.0)
    alpha = 0.99
    return torch.sigmoid(alpha * s_x1x2 + (1 - alpha) * s_x3x4)


for x in tqdm(X):
    x1x2 = x[:2]
    p = [
        SHAPLIT(model, dataset.cond, x, 2, {0, 1}, K=K, M=M).item() for _ in range(512)
    ]
    break

print(p)
_, ax = plt.subplots(figsize=(5, 5))
ax.set_xticks([0, 0.5, 1])
ax.set_yticks([0, 0.5, 1])
ax.plot([0, 1], [0, 1], "k--")
sns.ecdfplot(p, ax=ax)