# Symbolic calculation

In [None]:
import sympy
import sympy.stats

In [None]:
x = sympy.Symbol("x", real=True)
t = sympy.Symbol("t", real=True)
m1 = sympy.Symbol("m1", real=True)
m2 = sympy.Symbol("m2", real=True)
s1 = sympy.Symbol("s1", positive=True)
s2 = sympy.Symbol("s2", positive=True)

pdf = sympy.stats.Normal("pdf", m1, s1)
cdf = sympy.stats.Normal("cdf", m2, s2)
norm = sympy.stats.Normal("norm", sympy.stats.E(cdf), sympy.stats.std(pdf + cdf))

p = sympy.simplify(sympy.stats.density(pdf)(x) * sympy.stats.cdf(cdf)(x) / sympy.stats.cdf(norm)(m1))

d_log_p = sympy.simplify(sympy.diff(sympy.log(p), x))
dd_log_p = sympy.simplify(sympy.diff(d_log_p, x))
mean_update = sympy.simplify(x - d_log_p / dd_log_p) # One step Newton's method
approximate_variance = sympy.simplify((-dd_log_p)**-0.5);

lambdas = dict()
lambdas["p"] = sympy.lambdify([x, m1, s1, m2, s2], p, "math")
lambdas["mean_update"] = sympy.lambdify([x, m1, s1, m2, s2], mean_update, "math")
lambdas["approximate_variance"] = sympy.lambdify([x, m1, s1, m2, s2], approximate_variance, "math")

# Visualization

In [None]:
import ipywidgets
import numpy
import scipy, scipy.stats
import math
import matplotlib.pyplot as plt
%matplotlib ipympl

In [None]:
sample_points = numpy.linspace(-1, 1, 200)
plt.close()
plt.figure(figsize=(10, 6))
continuous_update = True
@ipywidgets.interact(
    sample_points=ipywidgets.fixed(sample_points),
    lambdas=ipywidgets.fixed(lambdas),
    m1=ipywidgets.FloatSlider(min=-1.0, max=1.0, value=0.0, step=0.01, continuous_update=continuous_update),
    s1=ipywidgets.FloatSlider(min=0.1, max=2, value=1.0, step=0.01, continuous_update=continuous_update), 
    m2=ipywidgets.FloatSlider(min=-1.0, max=1.0, value=0.0, step=0.01, continuous_update=continuous_update), 
    s2=ipywidgets.FloatSlider(min=0.1, max=2, value=1.0, step=0.01, continuous_update=continuous_update))
def interactive_plot(sample_points, lambdas, m1, s1, m2, s2):
    # pdf_ = numpy.array([lambdas["pdf"](sample_point, m1, s1) for sample_point in sample_points])
    # cdf_ = numpy.array([lambdas["cdf"](sample_point, m2, s2) for sample_point in sample_points])
    # p_ = numpy.array([lambdas["p"](sample_point, m1, s1, m2, s2) for sample_point in sample_points])
    pdf_ = scipy.stats.norm(m1, s1).pdf(sample_points)
    cdf_ = scipy.stats.norm(m2, s2).cdf(sample_points)
    p_ = pdf_ * cdf_ / scipy.stats.norm(m2, math.sqrt(s1*s1 + s2*s2)).cdf(m1)
    plt.cla()
    plt.plot(sample_points, pdf_, "r", linewidth=1, label="PDF")
    plt.plot(sample_points, cdf_, "b", linewidth=1, label="CDF")
    plt.plot(sample_points, p_, "k", linewidth=2, label="Product")
    plt.tight_layout()
    plt.gca().set_xlim([-1, 1])
    plt.gca().set_ylim(bottom=0)
    plt.legend(loc="upper left")

    # Mean update
    mean_update_pdf = lambdas["mean_update"](m1, m1, s1, m2, s2)
    mean_update_cdf = lambdas["mean_update"](m2, m1, s1, m2, s2)
    value_update_pdf = lambdas["p"](mean_update_pdf, m1, s1, m2, s2)
    value_update_cdf = lambdas["p"](mean_update_cdf, m1, s1, m2, s2)
    if value_update_pdf > value_update_cdf:
        line_mean = mean_update_pdf
        line_mean_value = value_update_pdf
        line_mean_color = "r"
    else:
        line_mean = mean_update_cdf
        line_mean_value = value_update_cdf
        line_mean_color = "b"
    plt.axvline(x=line_mean, color=line_mean_color, linestyle="--")
    variance = lambdas["approximate_variance"](line_mean, m1, s1, m2, s2)
    plt.text(line_mean, line_mean_value, f"Mean: {line_mean:.4f}\nVariance: {variance:.4f}", color=line_mean_color, fontsize=14, horizontalalignment="left", verticalalignment="bottom")

    plt.show()