# Quick Example
This notebook contains the quick examples from the Readme.

In [1]:
import jax
import jax.numpy as jnp
import softjax as sj


jax.config.update("jax_enable_x64", True)
jax.config.update("jax_default_matmul_precision", "high")
jax.config.update("jax_platforms", "cpu")

In [None]:
x = jnp.array([-0.2, -1.0, 0.3, 1.0])
y = jnp.array([0.2, -0.5, 0.5, -1.0])

# Elementwise functions
print("Hard ReLU:", jax.nn.relu(x))
print("Soft ReLU:", sj.relu(x))
print("Hard Clip:", jnp.clip(x, -0.5, 0.5))
print(
    "Soft Clip:",
    sj.clip(
        x,
        -0.5,
        0.5,
    ),
)
print("Hard Absolute:", jnp.abs(x))
print("Soft Absolute:", sj.abs(x))
print("Hard Sign:", jnp.sign(x))
print("Soft Sign:", sj.sign(x))
print("Hard round:", jnp.round(x))
print("Soft round:", sj.round(x))
print("Hard heaviside:", jnp.heaviside(x, 0.5))
print("Soft heaviside:", sj.heaviside(x))

Hard ReLU: [0.  0.  0.3 1. ]
Soft ReLU: [1.26928011e-02 4.53988992e-06 3.04858735e-01 1.00000454e+00]
Hard Clip: [-0.2 -0.5  0.3  0.5]
Soft Clip: [-0.19523241 -0.4993285   0.28734074  0.4993285 ]
Hard Absolute: [0.2 1.  0.3 1. ]
Soft Absolute: [0.15231883 0.9999092  0.27154448 0.9999092 ]
Hard Sign: [-1. -1.  1.  1.]
Soft Sign: [-0.76159416 -0.9999092   0.90514825  0.9999092 ]
Hard round: [-0. -1.  0.  1.]
Soft round: [-0.04651704 -1.          0.1188737   1.        ]


In [3]:
# Functions on arrays
print("Hard max:", jnp.max(x))
print("Soft max:", sj.max(x))
print("Hard min:", jnp.min(x))
print("Soft min:", sj.min(x))
print("Hard median:", jnp.median(x))
print("Soft median:", sj.median(x))
print("Hard top_k:", jax.lax.top_k(x, k=3)[0])
print(
    "Soft top_k:",
    sj.top_k(
        x,
        k=3,
    )[0],
)
print("Hard sort:", jnp.sort(x))
print("Soft sort:", sj.sort(x))
print("Hard ranking:", jnp.argsort(jnp.argsort(x)))
print("Soft ranking:", sj.ranking(x, descending=False))

Hard max: 1.0
Soft max: 0.9993548976691374
Hard min: -1.0
Soft min: -0.9997287789452775


Hard median: 0.04999999999999999
Soft median: 0.05000033589501627
Hard top_k: [ 1.   0.3 -0.2]
Soft top_k: [ 0.9993549   0.29728716 -0.19691387]
Hard sort: [-1.  -0.2  0.3  1. ]
Soft sort: [-0.99972878 -0.19691387  0.29728716  0.9993549 ]
Hard ranking: [1 0 2 3]
Soft ranking: [1.00636968e+00 3.39874686e-04 1.99421369e+00 2.99907667e+00]


In [4]:
# Functions returning indices
print("Hard argmax:", jnp.argmax(x))
print("Soft argmax:", sj.argmax(x))
print("Hard argmin:", jnp.argmin(x))
print("Soft argmin:", sj.argmin(x))
print("Hard argmedian:", "Not implemented in standard JAX")
print("Soft argmedian:", sj.argmedian(x))
print("Hard argtop_k:", jax.lax.top_k(x, k=3)[1])
print("Soft argtop_k:", sj.top_k(x, k=3)[1])
print("Hard argsort:", jnp.argsort(x))
print("Soft argsort:", sj.argsort(x))

Hard argmax: 3
Soft argmax: [6.13857697e-06 2.05926316e-09 9.11045600e-04 9.99082814e-01]
Hard argmin: 1
Soft argmin: [3.35349372e-04 9.99662389e-01 2.25956629e-06 2.06045775e-09]
Hard argmedian: Not implemented in standard JAX
Soft argmedian: [4.99999764e-01 5.62675608e-08 4.99999764e-01 4.15764163e-07]
Hard argtop_k: [3 2 0]
Soft argtop_k: [[6.13857697e-06 2.05926316e-09 9.11045600e-04 9.99082814e-01]
 [6.68677917e-03 2.24316451e-06 9.92406021e-01 9.04957153e-04]
 [9.92970214e-01 3.33104397e-04 6.69058067e-03 6.10101985e-06]]
Hard argsort: [1 0 2 3]
Soft argsort: [[3.35349372e-04 9.99662389e-01 2.25956629e-06 2.06045775e-09]
 [9.92970214e-01 3.33104397e-04 6.69058067e-03 6.10101985e-06]
 [6.68677917e-03 2.24316451e-06 9.92406021e-01 9.04957153e-04]
 [6.13857697e-06 2.05926316e-09 9.11045600e-04 9.99082814e-01]]


In [None]:
## SoftBool generation
print("Hard greater:", x > y)
print("Soft greater:", sj.greater(x, y))
print("Hard greater equal:", x >= y)
print("Soft greater equal:", sj.greater_equal(x, y))
print("Hard less:", x < y)
print("Soft less:", sj.less(x, y))
print("Hard less equal:", x <= y)
print("Soft less equal:", sj.less_equal(x, y))
print("Hard equal:", x == y)
print("Soft equal:", sj.equal(x, y))
print("Hard not equal:", x != y)
print("Soft not equal:", sj.not_equal(x, y))
print("Hard isclose:", jnp.isclose(x, y))
print("Soft isclose:", sj.isclose(x, y))

Hard heaviside: [0. 0. 1. 1.]
Soft heaviside: [1.19202922e-01 4.53978687e-05 9.52574127e-01 9.99954602e-01]
Hard greater: [False False False  True]
Soft greater: [0.01798621 0.00669285 0.11920292 1.        ]
Hard greater equal: [False False False  True]
Soft greater equal: [0.01798621 0.00669285 0.11920292 1.        ]
Hard less: [ True  True  True False]
Soft less: [9.82013790e-01 9.93307149e-01 8.80797078e-01 2.06115369e-09]
Hard less equal: [ True  True  True False]
Soft less equal: [9.82013790e-01 9.93307149e-01 8.80797078e-01 2.06115369e-09]
Hard equal: [False False False False]
Soft equal: [1.79862100e-02 6.69285093e-03 1.19202922e-01 2.06115369e-09]
Hard not equal: [ True  True  True  True]
Soft not equal: [0.98201379 0.99330715 0.88079708 1.        ]
Hard isclose: [False False False False]
Soft isclose: [1.79865650e-02 6.69318401e-03 1.19208182e-01 2.06135997e-09]


In [6]:
## SoftBool manipulation
fuzzy_a = jnp.array([0.1, 0.2, 0.8, 1.0])
fuzzy_b = jnp.array([0.7, 0.3, 0.1, 0.9])
print("Soft AND:", sj.logical_and(fuzzy_a, fuzzy_b))
print("Soft OR:", sj.logical_or(fuzzy_a, fuzzy_b))
print("Soft NOT:", sj.logical_not(fuzzy_a))
print("Soft XOR:", sj.logical_xor(fuzzy_a, fuzzy_b))
print("Soft ALL:", sj.all(fuzzy_a))
print("Soft ANY:", sj.any(fuzzy_a))

## SoftBool selection
print("Where:", sj.where(fuzzy_a, x, y))

Soft AND: [0.26457513 0.24494897 0.28284271 0.9486833 ]
Soft OR: [0.48038476 0.25166852 0.57573593 0.99999684]
Soft NOT: [0.9 0.8 0.2 0. ]
Soft XOR: [0.58702688 0.43498731 0.63937484 0.17309871]
Soft ALL: 0.35565588200778464
Soft ANY: 0.9980519925071494
Where: [ 0.16 -0.6   0.34  1.  ]


In [7]:
# Straight-through estimation: Use hard function on forward and soft on backward
print("Straight-through ReLU:", sj.relu_st(x))
print("Straight-through sort:", sj.sort_st(x))
print("Straight-through argtop_k:", sj.top_k_st(x, k=3)[1])
print("Straight-through greater:", sj.greater_st(x, y))

Straight-through ReLU: [0.  0.  0.3 1. ]


Straight-through sort: [-1.  -0.2  0.3  1. ]
Straight-through argtop_k: [[0. 0. 0. 1.]
 [0. 0. 1. 0.]
 [1. 0. 0. 0.]]
Straight-through greater: [0. 0. 0. 1.]


## "Hard" mode of Softjax

In [8]:
x = jnp.array([-0.2, -1.0, 0.3, 1.0])
y = jnp.array([0.2, -0.5, 0.5, -1.0])

# Elementwise functions
print("Jax ReLU:    ", jax.nn.relu(x))
print("Softjax ReLU:", sj.relu(x, mode="hard"))
print("Jax Clip:    ", jnp.clip(x, -0.5, 0.5))
print(
    "Softjax Clip:",
    sj.clip(
        x,
        -0.5,
        0.5,
        mode="hard",
    ),
)
print("Jax Absolute:    ", jnp.abs(x))
print("Softjax Absolute:", sj.abs(x, mode="hard"))
print("Jax Sign:    ", jnp.sign(x))
print("Softjax Sign:", sj.sign(x, mode="hard"))
print("Jax round:    ", jnp.round(x))
print("Softjax round:", sj.round(x, mode="hard"))

Jax ReLU:     [0.  0.  0.3 1. ]
Softjax ReLU: [0.  0.  0.3 1. ]
Jax Clip:     [-0.2 -0.5  0.3  0.5]
Softjax Clip: [-0.2 -0.5  0.3  0.5]
Jax Absolute:     [0.2 1.  0.3 1. ]
Softjax Absolute: [0.2 1.  0.3 1. ]
Jax Sign:     [-1. -1.  1.  1.]
Softjax Sign: [-1. -1.  1.  1.]
Jax round:     [-0. -1.  0.  1.]
Softjax round: [-0. -1.  0.  1.]


In [9]:
# Functions on arrays
print("Jax max:    ", jnp.max(x))
print("Softjax max:", sj.max(x, mode="hard"))
print("Jax min:    ", jnp.min(x))
print("Softjax min:", sj.min(x, mode="hard"))
print("Jax median:    ", jnp.median(x))
print("Softjax median:", sj.median(x, mode="hard"))
print("Jax top_k:    ", jax.lax.top_k(x, k=3)[0])
print(
    "Softjax top_k:",
    sj.top_k(
        x,
        k=3,
        mode="hard",
    )[0],
)
print("Jax sort:    ", jnp.sort(x))
print("Softjax sort:", sj.sort(x, mode="hard"))
print("Jax ranking:    ", jnp.argsort(jnp.argsort(x)))
print("Softjax ranking:", sj.ranking(x, mode="hard", descending=False))

Jax max:     1.0
Softjax max: 1.0
Jax min:     -1.0
Softjax min: -1.0
Jax median:     0.04999999999999999
Softjax median: 0.04999999999999999
Jax top_k:     [ 1.   0.3 -0.2]
Softjax top_k: [ 1.   0.3 -0.2]
Jax sort:     [-1.  -0.2  0.3  1. ]
Softjax sort: [-1.  -0.2  0.3  1. ]
Jax ranking:     [1 0 2 3]
Softjax ranking: [1. 0. 2. 3.]


In [10]:
# Functions returning indices
print("Jax argmax:    ", jnp.argmax(x))
print("Softjax argmax:", sj.argmax(x, mode="hard"))
print("Jax argmin:    ", jnp.argmin(x))
print("Softjax argmin:", sj.argmin(x, mode="hard"))
print("Jax argmedian:    ", "Not implemented in standard JAX")
print("Softjax argmedian:", sj.argmedian(x, mode="hard"))
print("Jax argtop_k:    ", jax.lax.top_k(x, k=3)[1])
print("Softjax argtop_k:", sj.top_k(x, k=3, mode="hard")[1])
print("Jax argsort:    ", jnp.argsort(x))
print("Softjax argsort:", sj.argsort(x, mode="hard"))

Jax argmax:     3


Softjax argmax: [0. 0. 0. 1.]
Jax argmin:     1
Softjax argmin: [0. 1. 0. 0.]
Jax argmedian:     Not implemented in standard JAX
Softjax argmedian: [0.5 0.  0.5 0. ]
Jax argtop_k:     [3 2 0]
Softjax argtop_k: [[0. 0. 0. 1.]
 [0. 0. 1. 0.]
 [1. 0. 0. 0.]]
Jax argsort:     [1 0 2 3]
Softjax argsort: [[0. 1. 0. 0.]
 [1. 0. 0. 0.]
 [0. 0. 1. 0.]
 [0. 0. 0. 1.]]


In [11]:
## SoftBool generation
print("Jax heaviside:    ", jnp.heaviside(x, 0.5))
print("Softjax heaviside:", sj.heaviside(x, mode="hard"))
print("Jax greater:    ", x > y)
print("Softjax greater:", sj.greater(x, y, mode="hard"))
print("Jax greater equal:    ", x >= y)
print("Softjax greater equal:", sj.greater_equal(x, y, mode="hard"))
print("Jax less:    ", x < y)
print("Softjax less:", sj.less(x, y, mode="hard"))
print("Jax less equal:    ", x <= y)
print("Softjax less equal:", sj.less_equal(x, y, mode="hard"))
print("Jax equal:    ", x == y)
print("Softjax equal:", sj.equal(x, y, mode="hard"))
print("Jax not equal:    ", x != y)
print("Softjax not equal:", sj.not_equal(x, y, mode="hard"))
print("Jax isclose:    ", jnp.isclose(x, y))
print("Softjax isclose:", sj.isclose(x, y, mode="hard"))

Jax heaviside:     [0. 0. 1. 1.]


Softjax heaviside: [0. 0. 1. 1.]
Jax greater:     [False False False  True]
Softjax greater: [0. 0. 0. 1.]
Jax greater equal:     [False False False  True]
Softjax greater equal: [0. 0. 0. 1.]
Jax less:     [ True  True  True False]
Softjax less: [1. 1. 1. 0.]
Jax less equal:     [ True  True  True False]
Softjax less equal: [1. 1. 1. 0.]
Jax equal:     [False False False False]
Softjax equal: [0. 0. 0. 0.]
Jax not equal:     [ True  True  True  True]
Softjax not equal: [1. 1. 1. 1.]
Jax isclose:     [False False False False]
Softjax isclose: [0. 0. 0. 0.]
