In [None]:
from typing import Callable

import hvplot.pandas
import jax.numpy as jnp
import pandas as pd

from jax_toolkit.losses.classification import log_loss, sigmoid_focal_crossentropy, squared_hinge

# functions_0to1

In [None]:
functions_0to1 = {
    "log_loss": log_loss,
    "sigmoid_focal_crossentropy": sigmoid_focal_crossentropy,
}

In [None]:
# y_preds = jnp.linspace(start=-1, stop=2, num=1000)
y_preds = list(jnp.linspace(start=0, stop=1, num=1000))

y_true0 = {}
y_true1 = {}
for function_name, function in functions_0to1.items():
    y_true0[function_name] = [float(function(y_true=jnp.array([0.]), y_pred=jnp.array([y_pred]))) for y_pred in y_preds]
    y_true1[function_name] = [float(function(y_true=jnp.array([1.]), y_pred=jnp.array([y_pred]))) for y_pred in y_preds]

In [None]:
df_y_true0 = pd.DataFrame({
    "y_pred": y_preds,
    "log_loss": y_true0["log_loss"],
    "sigmoid_focal_crossentropy":  y_true0["sigmoid_focal_crossentropy"],
})

df_y_true0.hvplot.line(x="y_pred", y=["log_loss", "sigmoid_focal_crossentropy"],
                       ylabel="loss value").opts(title="Comparing loss functions when y_true=0")

In [None]:
df_y_true1 = pd.DataFrame({
    "y_pred": y_preds,
    "log_loss": y_true1["log_loss"],
    "sigmoid_focal_crossentropy":  y_true1["sigmoid_focal_crossentropy"],
})

df_y_true1.hvplot.line(x="y_pred", y=["log_loss", "sigmoid_focal_crossentropy"],
                       ylabel="loss value").opts(title="Comparing loss functions when y_true=1")

# Single loss functions

## loss_log

In [None]:
from jax_toolkit.losses.classification import log_loss

y_preds = list(jnp.linspace(start=0, stop=1, num=1000))

y_true0 = [float(log_loss(y_true=jnp.array([0.]), y_pred=jnp.array([y_pred]))) for y_pred in y_preds]
y_true1 = [float(log_loss(y_true=jnp.array([1.]), y_pred=jnp.array([y_pred]))) for y_pred in y_preds]
    
df = pd.DataFrame({
    "y_pred": y_preds,
    "y_true=0": y_true0,
    "y_true=1": y_true1,
})

df.hvplot.line(x="y_pred", y=["y_true=0", "y_true=1"],
               ylabel="loss value").opts(title="log_loss")

## sigmoid_focal_crossentropy

In [None]:
from jax_toolkit.losses.classification import sigmoid_focal_crossentropy

# y_preds = jnp.linspace(start=-1, stop=2, num=1000)
y_preds = list(jnp.linspace(start=0, stop=1, num=1000))

y_true0 = [float(sigmoid_focal_crossentropy(y_true=jnp.array([0.]), y_pred=jnp.array([y_pred]))) for y_pred in y_preds]
y_true1 = [float(sigmoid_focal_crossentropy(y_true=jnp.array([1.]), y_pred=jnp.array([y_pred]))) for y_pred in y_preds]
    
df = pd.DataFrame({
    "y_pred": y_preds,
    "y_true=0": y_true0,
    "y_true=1": y_true1,
})

df.hvplot.line(x="y_pred", y=["y_true=0", "y_true=1"],
               ylabel="loss value").opts(title="sigmoid_focal_crossentropy")

## squared_hinge

In [None]:
from jax_toolkit.losses.classification import squared_hinge

# y_preds = jnp.linspace(start=-1, stop=2, num=1000)
y_preds = list(jnp.linspace(start=-2, stop=2, num=1000))

y_true0 = [float(squared_hinge(y_true=jnp.array([-1.]), y_pred=jnp.array([y_pred]))) for y_pred in y_preds]
y_true1 = [float(squared_hinge(y_true=jnp.array([1.]), y_pred=jnp.array([y_pred]))) for y_pred in y_preds]
    
df = pd.DataFrame({
    "y_pred": y_preds,
    "y_true=-1": y_true0,
    "y_true=1": y_true1,
})

df.hvplot.line(x="y_pred", y=["y_true=-1", "y_true=1"],
               ylabel="loss value").opts(title="squared_hinge")