In [None]:
%matplotlib notebook
import numpy as np
from scipy.interpolate import interp1d
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from matplotlib import cm

popμ = 178.49
popσ = 7.544
popμc = 105
popσc = 10

adultsGlobal = False

def takeSample (n, adults=False) :
    global adultsGlobal
    adultsGlobal = adults
    if adults :
        return popμ + popσ * np.random.randn(n)
    return np.concatenate([
        popμ + popσ * np.random.randn(2*n//5),
        popμc + popσc * np.random.randn(3*n//5),
        np.array([6])
    ])

def plotSample (s) :
    global adultsGlobal
    fig, ax = plt.subplots()
    if (s.size > 100) :
        x = np.array([s,s,s]).T.reshape(3*s.size)
        y = np.array([-s,s,-s]).T.reshape(3*s.size)
        ax.plot(x, y, "o-", linewidth=0.5)
    else :
        for x in np.nditer(s):
            ax.plot([x,x], [-x,x], "o-")
    
    ax.set_title('A sample of heights of individuals')
    if adultsGlobal :
        ax.set_xlim(popμ - 7 * popσ, popμ + 5 * popσ)
    else :
        ax.set_xlim(0, popμ + 5 * popσ)
    ax.set_ylim(0, popμ + 5 * popσ)
    ax.set_xlabel("Height / cm")
    ax.set_ylabel("Height / cm")

def plotHist (histogram) :
    global adultsGlobal
    hist, bins = histogram
    fig, ax = plt.subplots()
    ax.set_title('Histogram of that height sample')
    width = bins[1] - bins[0]
    center = (bins[:-1] + bins[1:]) / 2
    if adultsGlobal :
        ax.set_xlim(popμ - 7 * popσ, popμ + 5 * popσ)
    else :
        ax.set_xlim(0, popμ + 5 * popσ)
    ax.set_ylim(0,0.1)
    ax.set_xlabel("Height / cm")
    ax.set_ylabel("Frequency / $\mathrm{cm}^{-1}$")
    ax.bar(center, hist, align='center', width=width, edgecolor="#222222")
    return fig, ax, center

def plotHistNormal (histogram, μ, σ, spline=False) :
    hist, bins = histogram
    fig, ax, center = plotHist(histogram)
    gs = normal(μ, σ, center)
    width = bins[1] - bins[0]

    xs = np.arange(150,210,0.2)
    if spline:
        f = interp1d(center, hist, kind='cubic')
        fs = f(xs)
        gs = hist
    else :
        fs = normal(μ,σ, xs)
    ax.bar(center, gs, align='center', width=width, edgecolor="#222222", color="#ff888888")
    ax.plot(xs, fs, color="#ff8844");
    
def normal(μ, σ, x) :
    return 1/np.sqrt(2*np.pi)/σ * np.exp(-(x-μ)**2/(2*σ**2))

def plotParameterSpace(histogram) :
    hist, bins = histogram
    width = bins[1] - bins[0]
    fig, axtable = plt.subplots(5,5)
    xs = np.arange(popμ - 6 * popσ, popμ + 5 * popσ,0.2);
    fig.suptitle("Parameter space of models")

    σ = 19
    for axrow in axtable :
        μ = 145
        for ax in axrow :
            ax.set_xlim(popμ - 6 * popσ, popμ + 5 * popσ)
            ax.set_ylim(0,0.1)
            ax.set_yticklabels([])
            ax.set_xticklabels([])
            fs = normal(μ,σ, xs)
            ax.bar(binCentres, hist, align='center', width=width)
            ax.fill_between(xs, 0, fs, color="#ff888888")
            ax.plot(xs, fs, color="#ff8844");
            μ += 15
        σ -= 4

def genR2grid(histogram) :
    global adultsGlobal
    μs = np.arange(150, 210,1) if adultsGlobal else np.arange(50, 210,1)
    σs = np.arange(2, 50,0.1) if adultsGlobal else np.arange(2, 100,0.1)

    def r2(μ, σ, histogram) :
        r = histogram[0] - normal(μ, σ, binCentres)
        return r @ r

    r2vec = np.vectorize(r2, excluded=[2])

    μgrid, σgrid = np.meshgrid(μs, σs)
    return μgrid, σgrid, r2vec(μgrid,σgrid, histogram)

def drawSurface (histogram) :
    global adultsGlobal
    vmax = 0.03 if adultsGlobal else 0.01
    μgrid, σgrid, r2grid = genR2grid(histogram)
    r2grid = np.fmin(r2grid, vmax)
    fig =  plt.figure()
    ax = fig.gca(projection='3d')
    ax.set_aspect('equal')
    ax.view_init(30, 60)
    ax.set_xlabel("Mean: μ / cm")
    ax.set_ylabel("Standard Deviation: σ / cm")
    ax.set_zlabel("Sum of squared residuals")
    ax.plot_surface(μgrid, σgrid, r2grid, vmax=vmax,
        rstride=1, cstride=1, cmap=cm.coolwarm, linewidth=0, antialiased=False);

def drawContour (histogram):
    global adultsGlobal
    vmax = 0.03 if adultsGlobal else 0.01
    μgrid, σgrid, r2grid = genR2grid(histogram)
    r2grid = np.fmin(r2grid, vmax)
    fig, ax = plt.subplots()
    ax.set_aspect('equal')
    ax.set_xlabel("Mean: μ / cm")
    ax.set_ylabel("Standard Deviation: σ / cm")
    ax.pcolormesh(μgrid, σgrid, r2grid, vmax=vmax, cmap=cm.coolwarm,shading='gouraud');
    ax.contour(μgrid, σgrid, r2grid, vmax=vmax, colors="black", levels=np.linspace(np.min(r2grid),vmax,25), linewidths=0.8);
