In [None]:
import numpy as np
import time
import pandas as pd
from typing import Union
import matplotlib.pyplot as plt
from scipy.spatial import Delaunay

In [None]:
from ambientfisher.interpolators import AmbientFisherInterpolator

In [None]:
def lognormal_family(alpha):
    mu = alpha[0]
    sigma = alpha[1] + 1.0
    def pdf(x):
        return np.exp(-(np.log(x) - mu)**2 / (2.0 * sigma**2)) / (np.sqrt(2*np.pi) * x * sigma)
    return pdf

xarray = np.linspace(1e-9, 7.0, num=10000)

In [None]:
def plot_anchor_interp(
    alpha_target=np.array([0.05, -0.25]),
    alphas=np.array([[0.0, 0.0], [0.1, -0.5], [0.0, -0.3]]),
    compare_predictions_with_truth = False
):
    
    anchor_alphas = np.array(alphas)
    anchor_pdfs = [lognormal_family(alpha) for alpha in anchor_alphas]

    interp = AmbientFisherInterpolator(anchor_alphas, anchor_pdfs, xarray)

    p_hat_kyle = interp.predict_extrinsic(alpha_target, followKyle=True)
    print(f"prediction = {p_hat_kyle}")
    p_hat_intrinsic = interp.predict_intrinsic(alpha_target)
    print(f"prediction 2 = {p_hat_intrinsic}")
    
    p_true = lognormal_family(alpha_target)(xarray)

    p0, p1, p2 = anchor_pdfs

    plt.figure(figsize=(7.2, 4.6))

    if not compare_predictions_with_truth:
        plt.plot(xarray, p0(xarray), c="black", label='anchor 1')
        plt.plot(xarray, p1(xarray), c="r", label='anchor 2')
        plt.plot(xarray, p2(xarray), c="g", label='anchor 3')
        plt.plot(xarray, p_hat_intrinsic, c="b", ls="dashed", label='target pred')
        plt.plot(xarray, p_true, c="b", ls="dotted", label='target truth')

    else:
        plt.plot(xarray, p_hat_kyle, label='Method Kyle')
        plt.plot(xarray, p_hat_intrinsic, label='Method Intrinsic')
        plt.plot(xarray, p_true, c="b", ls="dotted", label='target truth')
        
    plt.xlabel(r"$f(x \mid \alpha)$", fontsize=16)
    plt.ylabel("density", fontsize=14)

    plt.axis(xmin = 0.0, xmax = 5.0, ymin = 0.0, ymax = 0.9)

    plt.tight_layout()
    plt.legend()
    plt.show()

    return None
