In [None]:
import ergo
import seaborn

from ergo import Logistic, LogisticMixture
from ergo.distributions.conditions import HistogramCondition, IntervalCondition, PercentileCondition
from tqdm.autonotebook import tqdm
from matplotlib import pyplot

In [None]:
def normalize(xs):
    z = sum(xs)
    return [x/z for x in xs]

def sample_component():
    return Logistic(loc=ergo.uniform(-1, 2), scale=abs(ergo.lognormal_from_interval(0.2, 3)))

def sample_condition(dist):
    case = ergo.random_choice(["low_open", "bounded", "high_open"])
    if case == "low_open":
        xmin = float("-inf")
        xmax = ergo.uniform(-3, 3)
    elif case == "bounded":
        xmin = ergo.uniform(-3, 0)                
        xmax = xmin + ergo.uniform(0, 3)        
    elif case == "high_open":
        xmin = ergo.uniform(-3, 3)
        xmax = float("+inf")
    p = actual_p(dist, xmin, xmax)
    return IntervalCondition(p, xmin, xmax)

def sample_conditions(dist):
    num_conditions = ergo.random_choice([1, 2, 3, 5, 7])
    conditions = [sample_condition(dist) for _ in range(num_conditions)]
    return conditions

def sample_mixture():
    num_components = ergo.random_choice([1, 2, 3])
    components = [sample_component() for _ in range(num_components)]
    probs = normalize([ergo.uniform(0, 1) for _ in range(num_components)])
    return LogisticMixture(components, probs)
    
def actual_p(dist, xmin, xmax):
    cdf_at_min = dist.cdf(xmin) if not np.isneginf(xmin) else 0
    cdf_at_max = dist.cdf(xmax) if not np.isposinf(xmax) else 1
    return cdf_at_max - cdf_at_min

def plot(dist, ax=None):
    xs = np.linspace(-4, 4, 100)
    ys = [float(mixture.pdf1(x)) for x in xs]
    # pyplot.figure()
    return seaborn.lineplot(xs, ys)
    
def model():
    # 1. Sample a distribution with 1-3 peaks
    true_dist = sample_mixture()

    # 2. Sample 1-7 conditions
    conditions = sample_conditions(true_dist)
    
    # 3. Fit a mixture to those conditions
    fit_dist = LogisticMixture.from_conditions(conditions, num_components=3)
    
    # 4. Check that the conditions are satisfied
    for condition in conditions:
        fit = condition.describe_fit(fit_dist)
        if fit["loss"] > 0.000002:
            print(true_dist)
            print(fit_dist)
            print(fit)     
            for condition in conditions:
                print(conditions)
            ax = plot(true_dist)
            plot(fit_dist, ax=ax)
            raise Exception("Failed to fit")

for i in tqdm(range(1000)):
    model()