In [12]:
import plotly.express as px
import plotly.graph_objects as go
import numpy as np
from tqdm.auto import tqdm

In [2]:
def sigmoid(x):
  return 1 / (1 + np.exp(-x))

def derivate_sigmoid(x):
  return sigmoid(x) * (1 - sigmoid(x))

In [3]:
def SPU(x):
    if type(x) is np.ndarray:
        return np.array([sigmoid(-x_i) - 1 if x_i <= 0 else x_i ** 2 - 0.5 for x_i in x])
    else:
        return sigmoid(-x) - 1 if x <= 0 else x ** 2 - 0.5

def derivate_SPU(x):
    if type(x) is np.ndarray:
        return np.array([derivate_sigmoid(-x_i) if x_i <= 0 else 2 * x_i for x_i in x])
    else:
        return derivate_sigmoid(-x) if x <= 0 else 2 * x


In [37]:
def compute_linear_bounds(l, u):
    if l > u:
        l, u = u, l
    
    # All the points are in the negative half-plane
    if u <= 0:
        # w_u = -derivate_sigmoid(-l)
        # b_u = SPU(l) + l * derivate_sigmoid(-l) + 1e-3
        w_u = 0
        b_u = SPU(l)

        w_l = (SPU(u) - SPU(l)) / (u - l)
        b_l = SPU(l) - w_l * l

    elif l >= 0:
        w_u = u + l
        b_u = -u * l - SPU(0)

        # Optimal point
        a = (u + l) / 2
        
        w_l = derivate_SPU(a)
        b_l = SPU(a) - a * derivate_SPU(a)

    else:
        SPU_equals_to_zero = np.sqrt(abs(SPU(0)))

        if u >= SPU_equals_to_zero:
            w_u = (SPU(u) - SPU(l)) / (u - l)
            b_u = SPU(l) - w_u * l

            # Optimal point
            a = max((u + l) / 2, SPU_equals_to_zero)

            w_l = derivate_SPU(a)
            b_l = SPU(a) - a * derivate_SPU(a)
        else:
            if SPU(l) > SPU(u):
                w_u = 0
                b_u = max(SPU(l), SPU(u))
            
            else:
                w_u = (SPU(u) - SPU(l)) / (u - l)
                b_u = SPU(l) - w_u * l

            w_l = 0.0
            b_l = SPU(0)
        
    return w_l, b_l, w_u, b_u


In [38]:
x_error = -0.24408222313370653
l = -1.4229721107058906
u = 0.9324498846633332

max_ = max(abs(l), abs(u))
k = max_ + 1 if max_ > 2 else 2

x = np.linspace(-k, k, 5000)
y = SPU(x)

fig = go.Figure()

w_l, b_l, w_u, b_u = compute_linear_bounds(l, u)


y_l = w_l * x + b_l
y_u = w_u * x + b_u


fig.add_trace(
    go.Scatter(
        x = x,
        y = y_u,
        name='Upper bound'
    )
)

fig.add_trace(
    go.Scatter(
        x=x,
        y=y,
        name='y = SPU(x)'
    ),
)

fig.add_trace(
    go.Scatter(
        x=x,
        y=y_l,
        name="Lower bound"
    )
)


fig.add_vline(
    x = l,
    line_dash = 'dash',
    name=f'x = {l}',
    line=dict(color='orange')
)


fig.add_vline(
    x = u,
    line_dash = 'dash',
    name=f'x = {u}',
    line=dict(color='orange'),
)

fig.add_vline(
    x=x_error,
    line_dash = 'dash',
    name=f'x = {u}',
    line=dict(color='red'),
)

fig.update_layout(title='Sigmoid-Parabola Unit Activation Function', height=700, hovermode='x unified')
fig.update_xaxes(zerolinecolor='black', showline=False)
fig.update_yaxes(zerolinecolor='black', showline=False)

fig.show()


In [40]:

sound = True

# Test soundness of transformer
for _ in tqdm(range(10000)):
    l, u = np.random.randn(), np.random.randn()

    if l > u:
        l, u = u, l

    w_l, b_l, w_u, b_u = compute_linear_bounds(l, u)

    # Test lower bounds
    for x in np.linspace(l, u, 1000):
        if not SPU(x) > w_l * x + b_l - 1e-5:
            print(f"Not lower-bound sound for \nx_error = {x}\nl = {l}\nu = {u}")
            print(f"SPU(x) = {SPU(x)}")
            print(f"w_l * x + b_l = {w_l * x + b_l - 1e-5}\n")

            sound = False

            assert SPU(x) > w_l * x + b_l - 1e-5


    
    # Test upper bounds
    for x in np.linspace(l, u, 1000):
        if not SPU(x) < w_u * x + b_u + 1e-5:
            print(f"Not upper-bound sound for \nx_error = {x}\nl = {l}\nu = {u}")
            print(f"SPU(x) = {SPU(x)}")
            print(f"w_u * x + b_u = {w_l * x + b_l - 1e-5}\n")

            sound = False

            assert SPU(x) < w_u * x + b_u + 1e-5

if sound:
    print("The transformer is sound!")
else:
    print("Oops. The transformer is not sound")


100%|██████████| 10000/10000 [00:48<00:00, 208.23it/s]

The transformer is sound!



