In [52]:
import numpy as np
from scipy.special import logit, expit
from scipy.special import digamma, polygamma
from scipy.optimize import fsolve

from scipy import stats

def find_logbeta_params(mean_logprob, var_logprob):
    def equations(vars):
        alpha, beta = vars
        # Mean equation
        eq1 = digamma(alpha) - digamma(alpha + beta) - mean_logprob
        # Variance equation
        eq2 = polygamma(1, alpha) - polygamma(1, alpha + beta) - var_logprob
        return [eq1, eq2]
    
    # Initial guess for α,β
    alpha0, beta0 = 1.0, 1.0
    alpha, beta = fsolve(equations, [alpha0, beta0])
    
    return alpha, beta

def logbeta_logprobs(mean_logprob, var_logprob, n_samples=10000):
    alpha, beta = find_logbeta_params(mean_logprob, var_logprob)
    probs = stats.beta.rvs(alpha, beta, size=n_samples)
    logprobs = np.log(probs)
    
    print(f"Fitted alpha, beta: {alpha:.3f}, {beta:.3f}")
    print(f"Target mean, var: {mean_logprob:.3f}, {var_logprob:.3f}")
    print(f"Achieved mean, var: {np.mean(logprobs):.3f}, {np.var(logprobs):.3f}")
    
    return logprobs
def compute_p2(p1, var_logprob, n, num_samples=10000):
    # Validate p1 is between 0 and 1
    if not 0 <= p1 <= 1:
        raise ValueError("p1 must be between 0 and 1")
    

    p11 = logbeta_logprobs(p1, var_logprob, num_samples)

    
    # Compute p11 and p12
    p11 = expit(logit(p1) + epsilon1)
    p12 = expit(logit(p1) + epsilon2)
    
    # Compute p2
    p2 = p11 * p12 + (1 - p11) / n
    
    return p2, p11, p12

In [2]:
def old_loss_two_hop(p1, n=100):
    return p1**2 + ((1-p1)**2) / (n-1)


def better_loss_two_hop(p1, n=100):
    return p1**2 + (1-p1) / n




In [55]:
p_1 = 0.0625

old_loss_two_hop(p_1, 10000), better_loss_two_hop(p_1, 10000), compute_p2(p_1, np.abs(logit(p_1))/4, 10000)[0].mean(), compute_p2(p_1, np.abs(logit(p_1))/4, 10000)[1].mean(), compute_p2(p_1, np.abs(logit(p_1))/4, 10000)[2].mean()


(0.003994149414941494,
 0.004,
 0.005664439517671612,
 0.0748176954313564,
 0.07419483546284834)

In [38]:
from scipy import optimize

def find_p1(target_p2, sigma, n, num_samples=10000):
    def objective(p1):
        np.random.seed(42)  # Fix seed for consistent optimization
        p2_samples = compute_p2(p1[0], sigma, n, num_samples)
        return np.abs(np.mean(p2_samples) - target_p2)
    
    # Try multiple starting points
    best_result = None
    best_score = float('inf')
    
    for start in [0.1, 0.3, 0.5, 0.7, 0.9]:
        result = optimize.minimize(
            objective, 
            x0=[start],
            bounds=[(0, 1)],
            method='L-BFGS-B',
            options={'ftol': 1e-8}
        )
        
        if result.fun < best_score:
            best_score = result.fun
            best_result = result
    
    return best_result.x[0]

# Example usage
target_p2 = 0.15
sigma = 1.0
n = 10000
estimated_p1 = find_p1(target_p2, sigma, n)
print(f"Estimated p1: {estimated_p1:.4f}, sqrt p2: {np.sqrt(target_p2):.4f}")

Estimated p1: 0.3640, sqrt p2: 0.3873


In [41]:
logit(0.25)

-1.0986122886681098

In [58]:
def explore_logbeta_bounds(mean_logprob):
    # Try a range of alpha values
    alphas = np.logspace(-2, 4, 100)
    vars = []
    for alpha in alphas:
        # For each alpha, find beta that gives our target mean
        def mean_eq(beta):
            return digamma(alpha) - digamma(alpha + beta) - mean_logprob
        
        try:
            beta = fsolve(mean_eq, [1.0])[0]
            if beta > 0:  # Valid beta only
                var = polygamma(1, alpha) - polygamma(1, alpha + beta)
                vars.append(var)
        except:
            continue
    
    if vars:
        print(f"For mean_logprob = {mean_logprob:.3f}")
        print(f"Min variance: {min(vars):.3f}")
        print(f"Max variance: {max(vars):.3f}")
        return min(vars), max(vars)
    return None

# Test some means
for mean in [-0.01,-0.5, -1.0, -2.0, -3.0, -4.0]:
    explore_logbeta_bounds(mean)

For mean_logprob = -0.010
Min variance: 0.000
Max variance: 1.739
For mean_logprob = -0.500
Min variance: 0.000
Max variance: 86.706
For mean_logprob = -1.000
Min variance: 0.000
Max variance: 172.913
For mean_logprob = -2.000
Min variance: 0.000
Max variance: 343.826
For mean_logprob = -3.000
Min variance: 0.000
Max variance: 334.176
For mean_logprob = -4.000
Min variance: 0.000
Max variance: 441.569
