In [None]:
import torch
from matplotlib import pyplot as plt

z = torch.stack([
    torch.linspace(-8.0, 8.0, 201),
    torch.zeros(201)
])

p = torch.softmax(z, dim=0)

plt.plot(z[0], p[0])
plt.show()
p.shape

In [None]:
import math

kl = 0.5 * -torch.log(p).sum(0) + math.log(0.5)
plt.plot(z[0], kl)
mse = ((p - torch.tensor([0.5, 0.5])[:, None])**2).sum(0)
plt.plot(z[0], mse)

In [None]:
from baukit import PlotWidget, Range, Checkbox, show

xmin, xmax = -6.0, 6.0
z = torch.stack([
    torch.zeros(201),
    torch.linspace(xmin, xmax, 201),
])
p = torch.softmax(z, dim=0)

def compare_loss(fig, y1=0.5, dokl=True, domse=True, doce=True, dol1=True):
    [ax1] = fig.axes
    y0 = 1.0 - y1
    kl = y0 * (math.log(y0) - torch.log(p[0])) + y1 * (math.log(y1) - torch.log(p[1]))
    ce = y0 * ( - torch.log(p[0])) + y1 * ( - torch.log(p[1]))
    mse = ((p - torch.tensor([y0, y1])[:, None])**2).sum(0)
    # sampled_mse = (y0 * ((1-p[0])**2 + p[1]**2)) + (y1 * ((1-p[1])**2 + p[0]**2))
    sampled_l1 = (2*y0*p[1] + 2*y1*p[0])
    ax1.clear()
    ax1.set_ylim(0, 3.0)
    ax1.set_xlim(xmin, xmax)
    ax1.set_ylabel('Loss')
    ax1.set_xlabel('Difference between logits $z_1 - z_0$')
    ax1.set_title(f'Loss curve on softmax when target $y_1={y1:.3f}$')

    if dokl: ax1.plot(z[1], kl, label='KL', color='b')
    if domse: ax1.plot(z[1], mse, label='MSE', color='r')
    if doce: ax1.plot(z[1], ce, label='CE', color='g', linestyle='dashed', alpha=0.6)
    if dol1: ax1.plot(z[1], sampled_l1, label='L1', color='orange', linestyle='dotted', alpha=0.7)
    if dokl or domse or doce or dol1: ax1.legend()

def compare_grad(fig, y1=0.5, dokl=True, domse=True):
    [ax1] = fig.axes
    y0 = 1.0 - y1
    kl = p[1] - y1
    mse = 4 * (p[1] - y1) * p[1] * p[0]
    ax1.clear()
    ax1.set_ylim(-0.7, 0.7)
    ax1.set_xlim(xmin, xmax)
    ax1.set_xlabel('Difference between logits $z_1 - z_0$')
    ax1.set_title(f'Gradient of loss with repect to $z_1$ when $y_1={y1:.3f}$')

    if dokl:
        ax1.plot(z[1], kl, color='b', label=r'$\frac{\partial \mathrm{KL}}{\partial z_1}$' +
            r'=$\frac{\partial \mathrm{CE}}{\partial z_1}$')
    if domse:
        ax1.plot(z[1], mse, color='r', label=r'$\frac{\partial \mathrm{MSE}}{\partial z_1}$')
    ax1.axhline(0, color='gray', linewidth=0.5)
    if dokl or domse:
        ax1.legend(loc='upper left')

rw = Range(min=0.001, max=0.999, step=0.001, value=0.5)
bkl = Checkbox('KL', value=True)
bce = Checkbox('CE', value=False)
bmse = Checkbox('MSE', value=True)
bl1 = Checkbox('L1', value=False)
ploss = PlotWidget(compare_loss, y1=rw.prop('value'),
                   dokl=bkl.prop('value'), domse=bmse.prop('value'), doce=bce.prop('value'), dol1=bl1.prop('value'))
pgrad = PlotWidget(compare_grad, y1=rw.prop('value'),
                   dokl=bkl.prop('value'), domse=bmse.prop('value'))
show(show.WRAP, [[[show.raw_html('<div>y<sub>1</sub>:</div>'), show.style(flex=12), rw, bkl, bce, bmse, bl1],
                  [ploss, pgrad]]])


dmse/dp1  =  2y1 - 2p1
dmse/dp2  =  2y2 - 2p2

dp1/dx1 = p1 - p1^2
dp2/dx1 = -p1p2


dmse/dx1 = 2(y1-p1)*(p1 - p1^2) - 2(y2-p2)(p1p2)
         = 2(y1-p1)*(p1 - p1^2) + 2(y1-p1)(p1-p1^2)
         = 4(y1-p1)*p1*(1-p1)



In [None]:
from baukit import Checkbox

cb = Checkbox('hi')
cb

In [None]:
cb.value = False