In [None]:
import numpy as np

def iterative_math(condition=1, starting_param_list=None, model=2, gens=15, print_gens=None):
    """
    Calculates the expected variance and covariance iteratively based on a given condition,
    printing each matrix at each generation.

    Args:
        condition (int): The condition number (1-5) to use for parameter selection.
        starting_param_list (dict): A dictionary of starting parameters. If None, uses default values.
        model (int): The model number to use (only model 2 is implemented here).
        gens (int): The number of generations to simulate.
        print_gens (list): List of generation numbers to print. If None, prints all generations.
    """
    if starting_param_list is None:
        starting_param_list = {
            'vg1': .75,
            'vg2': 1,
            'am11': 0,
            'am12': 0,
            'am21': 0,
            'am22': 0.43,
            'f11': 0,
            'f12': 0.0,
            'f21': 0.0,
            'f22': 0,
            'Nfam': 1.6e5,
            'rg': 0.5,
            're': 0,
            'prop.h2.latent1': .56/0.75,
            'prop.h2.latent2': 0.65/0.8
        }

    # --- Parameter setup for the given condition ---
    idx = condition - 1
    vg1 = starting_param_list['vg1'][idx] if isinstance(starting_param_list['vg1'], np.ndarray) else starting_param_list['vg1']
    vg2 = starting_param_list['vg2'][idx] if isinstance(starting_param_list['vg2'], np.ndarray) else starting_param_list['vg2']
    rg = starting_param_list['rg'][idx] if isinstance(starting_param_list['rg'], np.ndarray) else starting_param_list['rg']
    prop_h2_latent1 = starting_param_list['prop.h2.latent1'][idx] if isinstance(starting_param_list['prop.h2.latent1'], np.ndarray) else starting_param_list['prop.h2.latent1']
    prop_h2_latent2 = starting_param_list['prop.h2.latent2'][idx] if isinstance(starting_param_list['prop.h2.latent2'], np.ndarray) else starting_param_list['prop.h2.latent2']
    am11 = starting_param_list['am11'][idx] if isinstance(starting_param_list['am11'], np.ndarray) else starting_param_list['am11']
    am12 = starting_param_list['am12'][idx] if isinstance(starting_param_list['am12'], np.ndarray) else starting_param_list['am12']
    am21 = starting_param_list['am21'][idx] if isinstance(starting_param_list['am21'], np.ndarray) else starting_param_list['am21']
    am22 = starting_param_list['am22'][idx] if isinstance(starting_param_list['am22'], np.ndarray) else starting_param_list['am22']
    f11 = starting_param_list['f11'][idx] if isinstance(starting_param_list['f11'], np.ndarray) else starting_param_list['f11']
    f12 = starting_param_list['f12'][idx] if isinstance(starting_param_list['f12'], np.ndarray) else starting_param_list['f12']
    f21 = starting_param_list['f21'][idx] if isinstance(starting_param_list['f21'], np.ndarray) else starting_param_list['f21']
    f22 = starting_param_list['f22'][idx] if isinstance(starting_param_list['f22'], np.ndarray) else starting_param_list['f22']
    re = starting_param_list['re'][idx] if isinstance(starting_param_list['re'], np.ndarray) else starting_param_list['re']

    # --- Implied variables (t0) ---
    k2_matrix = np.array([[1, rg], [rg, 1]])

    vg_obs1 = vg1 * (1 - prop_h2_latent1)
    vg_obs2 = vg2 * (1 - prop_h2_latent2)
    d11 = np.sqrt(vg_obs1)
    d21 = 0
    d22 = np.sqrt(vg_obs2 - d21**2)
    delta_mat = np.array([[d11, 0], [d21, d22]])

    vg_lat1 = vg1 * prop_h2_latent1
    vg_lat2 = vg2 * prop_h2_latent2
    a11 = np.sqrt(vg_lat1)
    a21 = 0
    a22 = np.sqrt(vg_lat2 - a21**2)
    a_mat = np.array([[a11, 0], [a21, a22]])

    covg_mat = (delta_mat @ k2_matrix @ delta_mat.T) + (a_mat @ k2_matrix @ a_mat.T)

    ve1 = 1 - vg1
    ve2 = 1 - vg2
    cove = re * np.sqrt(ve1 * ve2)
    cove_mat = np.array([[ve1, cove], [cove, ve2]])

    COVY = covg_mat + cove_mat

    mate_cor_mat = np.array([[am11, am12], [am21, am22]])
    f_mat = np.array([[f11, f12], [f21, f22]])
    covf_mat = 2 * (f_mat @ COVY @ f_mat.T)

    # --- Initialize parameters for iteration ---
    a_t0 = a_mat
    delta_t0 = delta_mat
    j_t0 = k2_matrix * 0.5
    k_t0 = k2_matrix * 0.5
    f_t0 = f_mat
    rmate_t0 = mate_cor_mat
    covE_t0 = cove_mat

    # Lists to store matrices at each generation
    exp_VY, exp_VF, exp_mu, mate_cov = ([None] * gens for _ in range(4))
    exp_gc, exp_hc, exp_ic, exp_w, exp_v = ([None] * gens for _ in range(5))
    exp_Omega, exp_Gamma, exp_itlo, exp_itol = ([None] * gens for _ in range(4))
    exp_VGO, exp_heritability, exp_cor_matpgs, exp_VGL, exp_COVLO = ([None] * gens for _ in range(5))

    # Initialize matrices at t=0 (index 0)
    exp_gc[0] = exp_hc[0] = exp_ic[0] = exp_w[0] = exp_v[0] = np.zeros((2, 2))
    exp_VY[0] = COVY
    exp_VF[0] = covf_mat
    exp_mu[0] = np.linalg.inv(COVY) @ rmate_t0 @ np.linalg.inv(COVY.T)
    mate_cov[0] = COVY @ exp_mu[0] @ COVY

    print(f"=========================================")
    print(f"STARTING SIMULATION")
    print(f"=========================================\n")
    
    if model == 2:
        for it in range(1, gens):
            it_prev = it - 1
            
            # Determine if we should print this generation
            should_print = print_gens is None or it in print_gens
            
            if should_print & False:
                print(f"\n---------- Generation {it} ----------")

            exp_Omega[it] = (2 * delta_t0 @ exp_gc[it_prev] + delta_t0 @ k_t0 +
                             0.5 * exp_w[it_prev] + 2 * a_t0 @ exp_ic[it_prev])
            if should_print & False:
                print(f"Omega ({it}):\n{exp_Omega[it]}")

            exp_Gamma[it] = (2 * a_t0 @ exp_hc[it_prev] + 2 * delta_t0 @ exp_ic[it_prev].T +
                             a_t0 @ j_t0 + 0.5 * exp_v[it_prev])
            if should_print & False:
                print(f"Gamma ({it}):\n{exp_Gamma[it]}")

            exp_VY[it] = (2 * delta_t0 @ exp_Omega[it].T + 2 * a_t0 @ exp_Gamma[it].T +
                          exp_w[it_prev] @ delta_t0.T + exp_v[it_prev] @ a_t0.T +
                          exp_VF[it_prev] + covE_t0)
            if should_print & False:
                print(f"VY ({it}):\n{exp_VY[it]}")
            
            vy_sqrt_diag = np.sqrt(np.diag(np.diag(exp_VY[it])))
            mate_cov[it] = vy_sqrt_diag @ rmate_t0 @ vy_sqrt_diag
            exp_mu[it] = np.linalg.inv(exp_VY[it]) @ mate_cov[it] @ np.linalg.inv(exp_VY[it].T)
            if should_print & False:
                print(f"mu ({it}):\n{exp_mu[it]}")
                print(f"mate_cov ({it}):\n{mate_cov[it]}")
            exp_gt = exp_Omega[it].T @ exp_mu[it] @ exp_Omega[it]
            if should_print & False:
                print(f"gt ({it}):\n{exp_gt}")
            
            exp_gc[it] = 0.5 * (exp_gt + exp_gt.T)
            if should_print & False:
                print(f"gc ({it}):\n{exp_gc[it]}")

            exp_ht = exp_Gamma[it].T @ exp_mu[it] @ exp_Gamma[it]
            if should_print & False:
                print(f"ht ({it}):\n{exp_ht}")
            
            exp_hc[it] = 0.5 * (exp_ht + exp_ht.T)
            if should_print & False:
                print(f"hc ({it}):\n{exp_hc[it]}")

            exp_w[it] = (2 * f_t0 @ exp_Omega[it] +
                         f_t0 @ exp_VY[it] @ exp_mu[it] @ exp_Omega[it] +
                         f_t0 @ exp_VY[it] @ exp_mu[it].T @ exp_Omega[it])
            if should_print & False:
                print(f"w ({it}):\n{exp_w[it]}")

            exp_v[it] = (2 * f_t0 @ exp_Gamma[it] +
                         f_t0 @ exp_VY[it] @ exp_mu[it] @ exp_Gamma[it] +
                         f_t0 @ exp_VY[it] @ exp_mu[it].T @ exp_Gamma[it])
            if should_print & False:
                print(f"v ({it}):\n{exp_v[it]}")

            exp_VF[it] = (2 * f_t0 @ exp_VY[it] @ f_t0.T +
                          f_t0 @ exp_VY[it] @ exp_mu[it] @ exp_VY[it] @ f_t0.T +
                          f_t0 @ exp_VY[it] @ exp_mu[it].T @ exp_VY[it] @ f_t0.T)
            if should_print:
                print(f"VF ({it}):\n{exp_VF[it]}")

            exp_itlo[it] = exp_Gamma[it].T @ exp_mu[it] @ exp_Omega[it]
            exp_itol[it] = exp_Omega[it].T @ exp_mu[it] @ exp_Gamma[it]
            exp_ic[it] = 0.25 * (exp_itol[it] + exp_itol[it].T + exp_itlo[it] + exp_itlo[it].T)
            if should_print:
                print(f"ic ({it}):\n{exp_ic[it]}")

            exp_VGO[it] = (2 * delta_t0 @ k_t0 @ delta_t0.T + 4 * delta_t0 @ exp_gc[it] @ delta_t0.T)
            exp_VGL[it] = (2 * a_t0 @ j_t0 @ a_t0.T + 4 * a_t0 @ exp_hc[it] @ a_t0.T)
            exp_heritability[it] = (exp_VGL[it] + exp_VGO[it] + 8* a_t0 @ exp_ic[it]@delta_t0) / exp_VY[it]
            if should_print:
                print(f"Heritability ({it}):\n{exp_heritability[it]}")
            
            exp_COVLO[it] = (4 * delta_t0 @ exp_ic[it].T @ a_t0.T + 4 * a_t0 @ exp_ic[it] @ delta_t0.T)
            
            exp_cor_matpgs[it] = (exp_Omega[it].T @ exp_mu[it] @ exp_Omega[it] *2 + exp_Omega[it].T @ exp_mu[it].T@ exp_Omega[it]*2)/ (2*k_t0 + 4*exp_gc[it])
            if should_print:
                print(f"Mate Correlation PGS ({it}):\n{exp_cor_matpgs[it]}")
            
            r2pgs_final = exp_VGO[it]/exp_VY[it]
            if should_print:
                print(f"R2 PGS ({it}):\n{r2pgs_final}")
    else:
        print(f"Model {model} is not implemented in this script.")
        return

    # # --- Print final results ---
    # print(f"\n--- Final Results after {gens-1} generations ---")
    # print("True a:\n", a_mat)
    # print("\nTrue delta:\n", delta_mat)
    # print(f"\nr2pgs1: {delta_mat[0, 0]**2:.4f}")
    # print(f"r2pgs2: {delta_mat[1, 1]**2:.4f}")
    # print("-" * 45, "\n")


In [130]:
# --- Run the simulation for each condition ---
if __name__ == '__main__':
    # Run with default parameters, print only generations 1, 5, 10, and 19
    iterative_math(model=2, print_gens=[14])


STARTING SIMULATION

VF (14):
[[0.14406452 0.        ]
 [0.         0.77277366]]
ic (14):
[[0.06424085 0.        ]
 [0.         0.06638797]]
Heritability (14):
[[0.53293747        nan]
 [       nan 0.38868285]]
Mate Correlation PGS (14):
[[0.13993344        nan]
 [       nan 0.12035086]]
R2 PGS (14):
[[0.09757102        nan]
 [       nan 0.04890321]]


  exp_heritability[it] = (exp_VGL[it] + exp_VGO[it] + 8* a_t0 @ exp_ic[it]@delta_t0) / exp_VY[it]
  exp_cor_matpgs[it] = exp_Omega[it].T @ exp_mu[it] @ exp_Omega[it] *2 + exp_Omega[it].T @ exp_mu[it].T@ exp_Omega[it]*2/ (2*k_t0 + 4*exp_gc[it])
  r2pgs_final = exp_VGO[it]/exp_VY[it]


In [136]:
def find_optimal_parameters(target_mate_cor=0.135, target_heritability=0.5, target_r2pgs=0.12, 
                           f11_range=None, prop_h2_latent1_range=None, vg1_range=None,
                           base_params=None, model=2, gens=15):
    """
    Finds optimal values of f11, prop.h2.latent1, and vg1 to match target values.
    
    Args:
        target_mate_cor (float): Target mate correlation PGS value
        target_heritability (float): Target heritability value
        target_r2pgs (float): Target R2 PGS value
        f11_range (tuple): (min, max, step) for f11 search
        prop_h2_latent1_range (tuple): (min, max, step) for prop.h2.latent1 search
        vg1_range (tuple): (min, max, step) for vg1 search
        base_params (dict): Base parameters to use (will modify f11, prop.h2.latent1, and vg1)
        model (int): Model number
        gens (int): Number of generations to simulate
    
    Returns:
        dict: Best parameters and their resulting values
    """
    if f11_range is None:
        f11_range = (0.1, 0.3, 0.025)
    
    if prop_h2_latent1_range is None:
        prop_h2_latent1_range = (0.5, 0.9, 0.05)
    
    if vg1_range is None:
        vg1_range = (0.6, 0.95, 0.01)
    
    if base_params is None:
        base_params = {
            'vg1': 0.75,  # Will be overridden
            'vg2': 1,
            'am11': 0.43,
            'am12': 0,
            'am21': 0,
            'am22': 0.43,
            'f11': 0.15,  # Will be overridden
            'f12': 0.0,
            'f21': 0.0,
            'f22': 0.25,
            'Nfam': 1.6e5,
            'rg': 0,
            're': 0,
            'prop.h2.latent1': 0.56/0.75,  # Will be overridden
            'prop.h2.latent2': 0.65/0.8
        }
    
    # Generate search grid
    f11_values = np.arange(f11_range[0], f11_range[1] + f11_range[2], f11_range[2])
    prop_h2_latent1_values = np.arange(prop_h2_latent1_range[0], prop_h2_latent1_range[1] + prop_h2_latent1_range[2], prop_h2_latent1_range[2])
    vg1_values = np.arange(vg1_range[0], vg1_range[1] + vg1_range[2], vg1_range[2])
    
    best_score = float('inf')
    best_params = None
    best_results = None
    
    results_list = []
    
    total_combinations = len(f11_values) * len(prop_h2_latent1_values) * len(vg1_values)
    print(f"Searching {len(f11_values)} x {len(prop_h2_latent1_values)} x {len(vg1_values)} = {total_combinations} combinations...")
    print(f"Target: Mate Cor PGS={target_mate_cor:.3f}, Heritability={target_heritability:.3f}, R2 PGS={target_r2pgs:.3f}\n")
    
    for i, f11_val in enumerate(f11_values):
        for j, prop_h2_latent1_val in enumerate(prop_h2_latent1_values):
            for k, vg1_val in enumerate(vg1_values):
                # Create test parameters
                test_params = base_params.copy()
                test_params['f11'] = f11_val
                test_params['prop.h2.latent1'] = prop_h2_latent1_val
                test_params['vg1'] = vg1_val
            
            # Run simulation silently
            try:
                # Temporarily suppress print statements
                import sys
                import io
                old_stdout = sys.stdout
                sys.stdout = io.StringIO()
                
                # Extract final generation values by running the simulation
                condition = 1
                idx = condition - 1
                vg1 = test_params['vg1']
                vg2 = test_params['vg2']
                rg = test_params['rg']
                prop_h2_latent1 = test_params['prop.h2.latent1']
                prop_h2_latent2 = test_params['prop.h2.latent2']
                am11 = test_params['am11']
                am12 = test_params['am12']
                am21 = test_params['am21']
                am22 = test_params['am22']
                f11 = test_params['f11']
                f12 = test_params['f12']
                f21 = test_params['f21']
                f22 = test_params['f22']
                re = test_params['re']
                
                # Implied variables (t0)
                k2_matrix = np.array([[1, rg], [rg, 1]])
                
                vg_obs1 = vg1 * (1 - prop_h2_latent1)
                vg_obs2 = vg2 * (1 - prop_h2_latent2)
                d11 = np.sqrt(vg_obs1)
                d21 = 0
                d22 = np.sqrt(vg_obs2 - d21**2)
                delta_mat = np.array([[d11, 0], [d21, d22]])
                
                vg_lat1 = vg1 * prop_h2_latent1
                vg_lat2 = vg2 * prop_h2_latent2
                a11 = np.sqrt(vg_lat1)
                a21 = 0
                a22 = np.sqrt(vg_lat2 - a21**2)
                a_mat = np.array([[a11, 0], [a21, a22]])
                
                covg_mat = (delta_mat @ k2_matrix @ delta_mat.T) + (a_mat @ k2_matrix @ a_mat.T)
                
                ve1 = 1 - vg1
                ve2 = 1 - vg2
                cove = re * np.sqrt(ve1 * ve2)
                cove_mat = np.array([[ve1, cove], [cove, ve2]])
                
                COVY = covg_mat + cove_mat
                
                mate_cor_mat = np.array([[am11, am12], [am21, am22]])
                f_mat = np.array([[f11, f12], [f21, f22]])
                covf_mat = 2 * (f_mat @ COVY @ f_mat.T)
                
                # Initialize parameters for iteration
                a_t0 = a_mat
                delta_t0 = delta_mat
                j_t0 = k2_matrix * 0.5
                k_t0 = k2_matrix * 0.5
                f_t0 = f_mat
                rmate_t0 = mate_cor_mat
                covE_t0 = cove_mat
                
                # Lists to store matrices at each generation
                exp_VY, exp_VF, exp_mu, mate_cov = ([None] * gens for _ in range(4))
                exp_gc, exp_hc, exp_ic, exp_w, exp_v = ([None] * gens for _ in range(5))
                exp_Omega, exp_Gamma, exp_itlo, exp_itol = ([None] * gens for _ in range(4))
                exp_VGO, exp_heritability, exp_cor_matpgs, exp_VGL, exp_COVLO = ([None] * gens for _ in range(5))
                
                # Initialize matrices at t=0
                exp_gc[0] = exp_hc[0] = exp_ic[0] = exp_w[0] = exp_v[0] = np.zeros((2, 2))
                exp_VY[0] = COVY
                exp_VF[0] = covf_mat
                exp_mu[0] = np.linalg.inv(COVY) @ rmate_t0 @ np.linalg.inv(COVY.T)
                mate_cov[0] = COVY @ exp_mu[0] @ COVY
                
                # Run iterations
                for it in range(1, gens):
                    it_prev = it - 1
                    
                    exp_Omega[it] = (2 * delta_t0 @ exp_gc[it_prev] + delta_t0 @ k_t0 +
                                     0.5 * exp_w[it_prev] + 2 * a_t0 @ exp_ic[it_prev])
                    
                    exp_Gamma[it] = (2 * a_t0 @ exp_hc[it_prev] + 2 * delta_t0 @ exp_ic[it_prev].T +
                                     a_t0 @ j_t0 + 0.5 * exp_v[it_prev])
                    
                    exp_VY[it] = (2 * delta_t0 @ exp_Omega[it].T + 2 * a_t0 @ exp_Gamma[it].T +
                                  exp_w[it_prev] @ delta_t0.T + exp_v[it_prev] @ a_t0.T +
                                  exp_VF[it_prev] + covE_t0)
                    
                    vy_sqrt_diag = np.sqrt(np.diag(np.diag(exp_VY[it])))
                    mate_cov[it] = vy_sqrt_diag @ rmate_t0 @ vy_sqrt_diag
                    exp_mu[it] = np.linalg.inv(exp_VY[it]) @ mate_cov[it] @ np.linalg.inv(exp_VY[it].T)
                    
                    exp_gt = exp_Omega[it].T @ exp_mu[it] @ exp_Omega[it]
                    exp_gc[it] = 0.5 * (exp_gt + exp_gt.T)
                    
                    exp_ht = exp_Gamma[it].T @ exp_mu[it] @ exp_Gamma[it]
                    exp_hc[it] = 0.5 * (exp_ht + exp_ht.T)
                    
                    exp_w[it] = (2 * f_t0 @ exp_Omega[it] +
                                 f_t0 @ exp_VY[it] @ exp_mu[it] @ exp_Omega[it] +
                                 f_t0 @ exp_VY[it] @ exp_mu[it].T @ exp_Omega[it])
                    
                    exp_v[it] = (2 * f_t0 @ exp_Gamma[it] +
                                 f_t0 @ exp_VY[it] @ exp_mu[it] @ exp_Gamma[it] +
                                 f_t0 @ exp_VY[it] @ exp_mu[it].T @ exp_Gamma[it])
                    
                    exp_VF[it] = (2 * f_t0 @ exp_VY[it] @ f_t0.T +
                                  f_t0 @ exp_VY[it] @ exp_mu[it] @ exp_VY[it] @ f_t0.T +
                                  f_t0 @ exp_VY[it] @ exp_mu[it].T @ exp_VY[it] @ f_t0.T)
                    
                    exp_itlo[it] = exp_Gamma[it].T @ exp_mu[it] @ exp_Omega[it]
                    exp_itol[it] = exp_Omega[it].T @ exp_mu[it] @ exp_Gamma[it]
                    exp_ic[it] = 0.25 * (exp_itol[it] + exp_itol[it].T + exp_itlo[it] + exp_itlo[it].T)
                    
                    exp_VGO[it] = (2 * delta_t0 @ k_t0 @ delta_t0.T + 4 * delta_t0 @ exp_gc[it] @ delta_t0.T)
                    exp_VGL[it] = (2 * a_t0 @ j_t0 @ a_t0.T + 4 * a_t0 @ exp_hc[it] @ a_t0.T)
                    exp_heritability[it] = (exp_VGL[it] + exp_VGO[it] + 8* a_t0 @ exp_ic[it]@delta_t0) / exp_VY[it]
                    
                    exp_COVLO[it] = (4 * delta_t0 @ exp_ic[it].T @ a_t0.T + 4 * a_t0 @ exp_ic[it] @ delta_t0.T)
                    
                    exp_cor_matpgs[it] = (exp_Omega[it].T @ exp_mu[it] @ exp_Omega[it] *2 + exp_Omega[it].T @ exp_mu[it].T@ exp_Omega[it]*2)/ (2*k_t0 + 4*exp_gc[it])
                
                # Restore stdout
                sys.stdout = old_stdout
                
                # Extract final generation values (trait 1 only, diagonal elements)
                final_gen = gens - 1
                mate_cor_pgs = exp_cor_matpgs[final_gen][0, 0]
                heritability = exp_heritability[final_gen][0, 0]
                r2_pgs = (exp_VGO[final_gen] / exp_VY[final_gen])[0, 0]
                
                # Calculate weighted error score
                mate_cor_error = 1.5*abs(mate_cor_pgs - target_mate_cor) / target_mate_cor
                heritability_error = 1*abs(heritability - target_heritability) / target_heritability
                r2pgs_error = 1.5*abs(r2_pgs - target_r2pgs) / target_r2pgs
                
                # Combined score (equal weights)
                score = mate_cor_error + heritability_error + r2pgs_error
                
                results_list.append({
                    'f11': f11_val,
                    'prop.h2.latent1': prop_h2_latent1_val,
                    'vg1': vg1_val,
                    'mate_cor_pgs': mate_cor_pgs,
                    'heritability': heritability,
                    'r2_pgs': r2_pgs,
                    'score': score
                })
                
                if score < best_score:
                    best_score = score
                    best_params = {'f11': f11_val, 'prop.h2.latent1': prop_h2_latent1_val, 'vg1': vg1_val}
                    best_results = {
                        'mate_cor_pgs': mate_cor_pgs,
                        'heritability': heritability,
                        'r2_pgs': r2_pgs,
                        'score': score
                    }
                    
            except Exception as e:
                sys.stdout = old_stdout
                print(f"Error with f11={f11_val:.3f}, prop.h2.latent1={prop_h2_latent1_val:.3f}, vg1={vg1_val:.3f}: {e}")
                continue
    
    # Print results
    print("\n" + "="*80)
    print("OPTIMIZATION RESULTS")
    print("="*80)
    print(f"\nBest Parameters:")
    print(f"  f11 = {best_params['f11']:.4f}")
    print(f"  prop.h2.latent1 = {best_params['prop.h2.latent1']:.4f}")
    print(f"  vg1 = {best_params['vg1']:.4f}")
    print(f"\nResulting Values (Trait 1):")
    print(f"  Mate Correlation PGS = {best_results['mate_cor_pgs']:.4f} (target: {target_mate_cor:.4f})")
    print(f"  Heritability = {best_results['heritability']:.4f} (target: {target_heritability:.4f})")
    print(f"  R2 PGS = {best_results['r2_pgs']:.4f} (target: {target_r2pgs:.4f})")
    print(f"\nCombined Error Score = {best_results['score']:.4f}")
    print("="*80)
    
    # Show top 5 candidates
    print("\nTop 5 Candidates:")
    sorted_results = sorted(results_list, key=lambda x: x['score'])[:5]
    for i, result in enumerate(sorted_results, 1):
        print(f"\n{i}. f11={result['f11']:.4f}, prop.h2.latent1={result['prop.h2.latent1']:.4f}, vg1={result['vg1']:.4f}")
        print(f"   Mate Cor PGS={result['mate_cor_pgs']:.4f}, H2={result['heritability']:.4f}, R2 PGS={result['r2_pgs']:.4f}, Score={result['score']:.4f}")
    
    return {
        'best_params': best_params,
        'best_results': best_results,
        'all_results': results_list
    }

# Run the optimization
optimization_results = find_optimal_parameters(
    target_mate_cor=0.135,
    target_heritability=0.45,
    target_r2pgs=0.11,
    f11_range=(0.10, 0.5, 0.01),
    prop_h2_latent1_range=(0.5, 0.9, 0.025),
    vg1_range=(0.6, 0.8, 0.01),
    gens=15
)


Searching 41 x 17 x 22 = 15334 combinations...
Target: Mate Cor PGS=0.135, Heritability=0.450, R2 PGS=0.110



  exp_heritability[it] = (exp_VGL[it] + exp_VGO[it] + 8* a_t0 @ exp_ic[it]@delta_t0) / exp_VY[it]
  exp_cor_matpgs[it] = (exp_Omega[it].T @ exp_mu[it] @ exp_Omega[it] *2 + exp_Omega[it].T @ exp_mu[it].T@ exp_Omega[it]*2)/ (2*k_t0 + 4*exp_gc[it])
  r2_pgs = (exp_VGO[final_gen] / exp_VY[final_gen])[0, 0]



OPTIMIZATION RESULTS

Best Parameters:
  f11 = 0.1500
  prop.h2.latent1 = 0.7250
  vg1 = 0.8100

Resulting Values (Trait 1):
  Mate Correlation PGS = 0.1473 (target: 0.1350)
  Heritability = 0.5536 (target: 0.4500)
  R2 PGS = 0.1097 (target: 0.1100)

Combined Error Score = 0.3714

Top 5 Candidates:

1. f11=0.1500, prop.h2.latent1=0.7250, vg1=0.8100
   Mate Cor PGS=0.1473, H2=0.5536, R2 PGS=0.1097, Score=0.3714

2. f11=0.1800, prop.h2.latent1=0.7000, vg1=0.8100
   Mate Cor PGS=0.1581, H2=0.4938, R2 PGS=0.1082, Score=0.3787

3. f11=0.1600, prop.h2.latent1=0.7250, vg1=0.8100
   Mate Cor PGS=0.1472, H2=0.5334, R2 PGS=0.1057, Score=0.3800

4. f11=0.1200, prop.h2.latent1=0.7500, vg1=0.8100
   Mate Cor PGS=0.1356, H2=0.6153, R2 PGS=0.1093, Score=0.3826

5. f11=0.1700, prop.h2.latent1=0.7250, vg1=0.8100
   Mate Cor PGS=0.1471, H2=0.5135, R2 PGS=0.1018, Score=0.3878


In [None]:
def find_optimal_parameters(target_mate_cor=0.135, target_heritability=0.5, target_r2pgs=0.12, 
                           f11_range=None, prop_h2_latent1_range=None, vg1_range=None,
                           vg2_range=None, f22_range=None, am22_range=None, rg_range=None,
                           base_params=None, model=2, gens=20):
    """
    Finds optimal values of multiple parameters to match target values.
    
    Args:
        target_mate_cor (float): Target mate correlation PGS value
        target_heritability (float): Target heritability value
        target_r2pgs (float): Target R2 PGS value
        f11_range (tuple): (min, max, step) for f11 search
        prop_h2_latent1_range (tuple): (min, max, step) for prop.h2.latent1 search
        vg1_range (tuple): (min, max, step) for vg1 search
        vg2_range (tuple): (min, max, step) for vg2 search
        f22_range (tuple): (min, max, step) for f22 search
        am22_range (tuple): (min, max, step) for am22 search
        rg_range (tuple): (min, max, step) for rg search
        base_params (dict): Base parameters to use
        model (int): Model number
        gens (int): Number of generations to simulate
    
    Returns:
        dict: Best parameters and their resulting values
    """
    if f11_range is None:
        f11_range = (0.1, 0.3, 0.05)
    
    if prop_h2_latent1_range is None:
        prop_h2_latent1_range = (0.5, 0.9, 0.1)
    
    if vg1_range is None:
        vg1_range = (0.6, 0.95, 0.05)
    
    if vg2_range is None:
        vg2_range = (0.5, 0.8, 0.25)
    
    if f22_range is None:
        f22_range = (0.1, 0.3, 0.05)
    
    if am22_range is None:
        am22_range = (0.40, 0.80, 0.05)
    
    if rg_range is None:
        rg_range = (0.5, 0.9, 0.05)
    
    if base_params is None:
        base_params = {
            'vg1': 0.75,  # Will be overridden
            'vg2': .65,  # Will be overridden
            'am11': 0,
            'am12': 0,
            'am21': 0,
            'am22': 0.65,  # Will be overridden
            'f11': 0.15,  # Will be overridden
            'f12': 0.0,
            'f21': 0.0,
            'f22': 0.15,  # Will be overridden
            'Nfam': 1.6e5,
            'rg': 0.65,  # Will be overridden
            're': 0,
            'prop.h2.latent1': 0.56/0.75,  # Will be overridden
            'prop.h2.latent2': .8/0.8
        }
    
    # Generate search grid
    f11_values = np.arange(f11_range[0], f11_range[1] + f11_range[2], f11_range[2])
    prop_h2_latent1_values = np.arange(prop_h2_latent1_range[0], prop_h2_latent1_range[1] + prop_h2_latent1_range[2], prop_h2_latent1_range[2])
    vg1_values = np.arange(vg1_range[0], vg1_range[1] + vg1_range[2], vg1_range[2])
    vg2_values = np.arange(vg2_range[0], vg2_range[1] + vg2_range[2], vg2_range[2])
    f22_values = np.arange(f22_range[0], f22_range[1] + f22_range[2], f22_range[2])
    am22_values = np.arange(am22_range[0], am22_range[1] + am22_range[2], am22_range[2])
    rg_values = np.arange(rg_range[0], rg_range[1] + rg_range[2], rg_range[2])
    
    best_score = float('inf')
    best_params = None
    best_results = None
    
    results_list = []
    
    total_combinations = (len(f11_values) * len(prop_h2_latent1_values) * len(vg1_values) * 
                         len(vg2_values) * len(f22_values) * len(am22_values) * len(rg_values))
    print(f"Searching {len(f11_values)} x {len(prop_h2_latent1_values)} x {len(vg1_values)} x {len(vg2_values)} x {len(f22_values)} x {len(am22_values)} x {len(rg_values)} = {total_combinations} combinations...")
    print(f"Target: Mate Cor PGS={target_mate_cor:.3f}, Heritability={target_heritability:.3f}, R2 PGS={target_r2pgs:.3f}\n")
    
    for f11_val in f11_values:
        for prop_h2_latent1_val in prop_h2_latent1_values:
            for vg1_val in vg1_values:
                for vg2_val in vg2_values:
                    for f22_val in f22_values:
                        for am22_val in am22_values:
                            for rg_val in rg_values:
                                # Create test parameters
                                test_params = base_params.copy()
                                test_params['f11'] = f11_val
                                test_params['prop.h2.latent1'] = prop_h2_latent1_val
                                test_params['vg1'] = vg1_val
                                test_params['vg2'] = vg2_val
                                test_params['f22'] = f22_val
                                test_params['am22'] = am22_val
                                test_params['rg'] = rg_val
            
                                # Run simulation silently
                                try:
                                    # Temporarily suppress print statements
                                    import sys
                                    import io
                                    old_stdout = sys.stdout
                                    sys.stdout = io.StringIO()
                                    
                                    # Extract final generation values by running the simulation
                                    condition = 1
                                    idx = condition - 1
                                    vg1 = test_params['vg1']
                                    vg2 = test_params['vg2']
                                    rg = test_params['rg']
                                    prop_h2_latent1 = test_params['prop.h2.latent1']
                                    prop_h2_latent2 = test_params['prop.h2.latent2']
                                    am11 = test_params['am11']
                                    am12 = test_params['am12']
                                    am21 = test_params['am21']
                                    am22 = test_params['am22']
                                    f11 = test_params['f11']
                                    f12 = test_params['f12']
                                    f21 = test_params['f21']
                                    f22 = test_params['f22']
                                    re = test_params['re']
                                    
                                    # Implied variables (t0)
                                    k2_matrix = np.array([[1, rg], [rg, 1]])
                                    
                                    vg_obs1 = vg1 * (1 - prop_h2_latent1)
                                    vg_obs2 = vg2 * (1 - prop_h2_latent2)
                                    d11 = np.sqrt(vg_obs1)
                                    d21 = 0
                                    d22 = np.sqrt(vg_obs2 - d21**2)
                                    delta_mat = np.array([[d11, 0], [d21, d22]])
                                    
                                    vg_lat1 = vg1 * prop_h2_latent1
                                    vg_lat2 = vg2 * prop_h2_latent2
                                    a11 = np.sqrt(vg_lat1)
                                    a21 = 0
                                    a22 = np.sqrt(vg_lat2 - a21**2)
                                    a_mat = np.array([[a11, 0], [a21, a22]])
                                    
                                    covg_mat = (delta_mat @ k2_matrix @ delta_mat.T) + (a_mat @ k2_matrix @ a_mat.T)
                                    
                                    ve1 = 1 - vg1
                                    ve2 = 1 - vg2
                                    cove = re * np.sqrt(ve1 * ve2)
                                    cove_mat = np.array([[ve1, cove], [cove, ve2]])
                                    
                                    COVY = covg_mat + cove_mat
                                    
                                    mate_cor_mat = np.array([[am11, am12], [am21, am22]])
                                    f_mat = np.array([[f11, f12], [f21, f22]])
                                    covf_mat = 2 * (f_mat @ COVY @ f_mat.T)
                                    
                                    # Initialize parameters for iteration
                                    a_t0 = a_mat
                                    delta_t0 = delta_mat
                                    j_t0 = k2_matrix * 0.5
                                    k_t0 = k2_matrix * 0.5
                                    f_t0 = f_mat
                                    rmate_t0 = mate_cor_mat
                                    covE_t0 = cove_mat
                                    
                                    # Lists to store matrices at each generation
                                    exp_VY, exp_VF, exp_mu, mate_cov = ([None] * gens for _ in range(4))
                                    exp_gc, exp_hc, exp_ic, exp_w, exp_v = ([None] * gens for _ in range(5))
                                    exp_Omega, exp_Gamma, exp_itlo, exp_itol = ([None] * gens for _ in range(4))
                                    exp_VGO, exp_heritability, exp_cor_matpgs, exp_VGL, exp_COVLO = ([None] * gens for _ in range(5))
                                    
                                    # Initialize matrices at t=0
                                    exp_gc[0] = exp_hc[0] = exp_ic[0] = exp_w[0] = exp_v[0] = np.zeros((2, 2))
                                    exp_VY[0] = COVY
                                    exp_VF[0] = covf_mat
                                    exp_mu[0] = np.linalg.inv(COVY) @ rmate_t0 @ np.linalg.inv(COVY.T)
                                    mate_cov[0] = COVY @ exp_mu[0] @ COVY
                                    
                                    # Run iterations
                                    for it in range(1, gens):
                                        it_prev = it - 1
                                        
                                        exp_Omega[it] = (2 * delta_t0 @ exp_gc[it_prev] + delta_t0 @ k_t0 +
                                                         0.5 * exp_w[it_prev] + 2 * a_t0 @ exp_ic[it_prev])
                                        
                                        exp_Gamma[it] = (2 * a_t0 @ exp_hc[it_prev] + 2 * delta_t0 @ exp_ic[it_prev].T +
                                                         a_t0 @ j_t0 + 0.5 * exp_v[it_prev])
                                        
                                        exp_VY[it] = (2 * delta_t0 @ exp_Omega[it].T + 2 * a_t0 @ exp_Gamma[it].T +
                                                      exp_w[it_prev] @ delta_t0.T + exp_v[it_prev] @ a_t0.T +
                                                      exp_VF[it_prev] + covE_t0)
                                        
                                        vy_sqrt_diag = np.sqrt(np.diag(np.diag(exp_VY[it])))
                                        mate_cov[it] = vy_sqrt_diag @ rmate_t0 @ vy_sqrt_diag
                                        exp_mu[it] = np.linalg.inv(exp_VY[it]) @ mate_cov[it] @ np.linalg.inv(exp_VY[it].T)
                                        
                                        exp_gt = exp_Omega[it].T @ exp_mu[it] @ exp_Omega[it]
                                        exp_gc[it] = 0.5 * (exp_gt + exp_gt.T)
                                        
                                        exp_ht = exp_Gamma[it].T @ exp_mu[it] @ exp_Gamma[it]
                                        exp_hc[it] = 0.5 * (exp_ht + exp_ht.T)
                                        
                                        exp_w[it] = (2 * f_t0 @ exp_Omega[it] +
                                                     f_t0 @ exp_VY[it] @ exp_mu[it] @ exp_Omega[it] +
                                                     f_t0 @ exp_VY[it] @ exp_mu[it].T @ exp_Omega[it])
                                        
                                        exp_v[it] = (2 * f_t0 @ exp_Gamma[it] +
                                                     f_t0 @ exp_VY[it] @ exp_mu[it] @ exp_Gamma[it] +
                                                     f_t0 @ exp_VY[it] @ exp_mu[it].T @ exp_Gamma[it])
                                        
                                        exp_VF[it] = (2 * f_t0 @ exp_VY[it] @ f_t0.T +
                                                      f_t0 @ exp_VY[it] @ exp_mu[it] @ exp_VY[it] @ f_t0.T +
                                                      f_t0 @ exp_VY[it] @ exp_mu[it].T @ exp_VY[it] @ f_t0.T)
                                        
                                        exp_itlo[it] = exp_Gamma[it].T @ exp_mu[it] @ exp_Omega[it]
                                        exp_itol[it] = exp_Omega[it].T @ exp_mu[it] @ exp_Gamma[it]
                                        exp_ic[it] = 0.25 * (exp_itol[it] + exp_itol[it].T + exp_itlo[it] + exp_itlo[it].T)
                                        
                                        exp_VGO[it] = (2 * delta_t0 @ k_t0 @ delta_t0.T + 4 * delta_t0 @ exp_gc[it] @ delta_t0.T)
                                        exp_VGL[it] = (2 * a_t0 @ j_t0 @ a_t0.T + 4 * a_t0 @ exp_hc[it] @ a_t0.T)
                                        exp_heritability[it] = (exp_VGL[it] + exp_VGO[it] + 8* a_t0 @ exp_ic[it]@delta_t0) / exp_VY[it]
                                        
                                        exp_COVLO[it] = (4 * delta_t0 @ exp_ic[it].T @ a_t0.T + 4 * a_t0 @ exp_ic[it] @ delta_t0.T)
                                        
                                        exp_cor_matpgs[it] = (exp_Omega[it].T @ exp_mu[it] @ exp_Omega[it] *2 + exp_Omega[it].T @ exp_mu[it].T@ exp_Omega[it]*2)/ (2*k_t0 + 4*exp_gc[it])
                                    
                                    # Restore stdout
                                    sys.stdout = old_stdout
                                    
                                    # Extract final generation values (trait 1 only, diagonal elements)
                                    final_gen = gens - 1
                                    mate_cor_pgs = exp_cor_matpgs[final_gen][0, 0]
                                    heritability = exp_heritability[final_gen][0, 0]
                                    r2_pgs = (exp_VGO[final_gen] / exp_VY[final_gen])[0, 0]
                                    
                                    # Calculate weighted error score
                                    mate_cor_error = 1*abs(mate_cor_pgs - target_mate_cor) / target_mate_cor
                                    heritability_error = 1*abs(heritability - target_heritability) / target_heritability
                                    r2pgs_error = 1*abs(r2_pgs - target_r2pgs) / target_r2pgs
                                    
                                    # Combined score (equal weights)
                                    score = mate_cor_error + heritability_error + r2pgs_error
                                    
                                    results_list.append({
                                        'f11': f11_val,
                                        'prop.h2.latent1': prop_h2_latent1_val,
                                        'vg1': vg1_val,
                                        'vg2': vg2_val,
                                        'f22': f22_val,
                                        'am22': am22_val,
                                        'rg': rg_val,
                                        'mate_cor_pgs': mate_cor_pgs,
                                        'heritability': heritability,
                                        'r2_pgs': r2_pgs,
                                        'score': score
                                    })
                                    
                                    if score < best_score:
                                        best_score = score
                                        best_params = {
                                            'f11': f11_val, 
                                            'prop.h2.latent1': prop_h2_latent1_val, 
                                            'vg1': vg1_val,
                                            'vg2': vg2_val,
                                            'f22': f22_val,
                                            'am22': am22_val,
                                            'rg': rg_val
                                        }
                                        best_results = {
                                            'mate_cor_pgs': mate_cor_pgs,
                                            'heritability': heritability,
                                            'r2_pgs': r2_pgs,
                                            'score': score
                                        }
                                        
                                except Exception as e:
                                    sys.stdout = old_stdout
                                    print(f"Error: {e}")
                                    continue
    
    # Print results
    print("\n" + "="*80)
    print("OPTIMIZATION RESULTS")
    print("="*80)
    print(f"\nBest Parameters:")
    print(f"  f11 = {best_params['f11']:.4f}")
    print(f"  prop.h2.latent1 = {best_params['prop.h2.latent1']:.4f}")
    print(f"  vg1 = {best_params['vg1']:.4f}")
    print(f"  vg2 = {best_params['vg2']:.4f}")
    print(f"  f22 = {best_params['f22']:.4f}")
    print(f"  am22 = {best_params['am22']:.4f}")
    print(f"  rg = {best_params['rg']:.4f}")
    print(f"\nResulting Values (Trait 1):")
    print(f"  Mate Correlation PGS = {best_results['mate_cor_pgs']:.4f} (target: {target_mate_cor:.4f})")
    print(f"  Heritability = {best_results['heritability']:.4f} (target: {target_heritability:.4f})")
    print(f"  R2 PGS = {best_results['r2_pgs']:.4f} (target: {target_r2pgs:.4f})")
    print(f"\nCombined Error Score = {best_results['score']:.4f}")
    print("="*80)
    
    # Show top 5 candidates
    print("\nTop 5 Candidates:")
    sorted_results = sorted(results_list, key=lambda x: x['score'])[:5]
    for i, result in enumerate(sorted_results, 1):
        print(f"\n{i}. f11={result['f11']:.4f}, prop.h2.latent1={result['prop.h2.latent1']:.4f}, vg1={result['vg1']:.4f}, vg2={result['vg2']:.4f}, f22={result['f22']:.4f}, am22={result['am22']:.4f}, rg={result['rg']:.4f}")
        print(f"   Mate Cor PGS={result['mate_cor_pgs']:.4f}, H2={result['heritability']:.4f}, R2 PGS={result['r2_pgs']:.4f}, Score={result['score']:.4f}")
    
    return {
        'best_params': best_params,
        'best_results': best_results,
        'all_results': results_list
    }

# Run the optimization
optimization_results = find_optimal_parameters(
    target_mate_cor=0.135,
    target_heritability=0.45,
    target_r2pgs=0.11,
    f11_range=(0.10, 0.5, 0.05),
    prop_h2_latent1_range=(0.5, 0.9, 0.1),
    vg1_range=(0.6, 0.8, 0.05),
    vg2_range=(0.5, 0.8, 0.25),
    f22_range=(0.1, 0.3, 0.05),
    am22_range=(0.40, 0.80, 0.05),
    rg_range=(0.5, 0.9, 0.05),
    gens=20
)


Searching 9 x 5 x 6 x 3 x 5 x 9 x 9 = 328050 combinations...
Target: Mate Cor PGS=0.135, Heritability=0.450, R2 PGS=0.110


OPTIMIZATION RESULTS

Best Parameters:
  f11 = 0.1000
  prop.h2.latent1 = 0.8000
  vg1 = 0.6000
  vg2 = 1.0000
  f22 = 0.2000
  am22 = 0.6500
  rg = 0.7500

Resulting Values (Trait 1):
  Mate Correlation PGS = 0.1352 (target: 0.1350)
  Heritability = 0.4530 (target: 0.4500)
  R2 PGS = 0.1048 (target: 0.1100)

Combined Error Score = 0.0556

Top 5 Candidates:

1. f11=0.1000, prop.h2.latent1=0.8000, vg1=0.6000, vg2=1.0000, f22=0.2000, am22=0.6500, rg=0.7500
   Mate Cor PGS=0.1352, H2=0.4530, R2 PGS=0.1048, Score=0.0556

2. f11=0.1500, prop.h2.latent1=0.8000, vg1=0.6500, vg2=0.7500, f22=0.1500, am22=0.7000, rg=0.7500
   Mate Cor PGS=0.1348, H2=0.4499, R2 PGS=0.1040, Score=0.0565

3. f11=0.1000, prop.h2.latent1=0.8000, vg1=0.6000, vg2=0.7500, f22=0.1500, am22=0.8000, rg=0.7500
   Mate Cor PGS=0.1346, H2=0.4612, R2 PGS=0.1066, Score=0.0588

4. f11=0.1500, prop.h2.latent

In [138]:
# show the top 20 results

print("\nTop 20 Candidates Overall:")

sorted_results = sorted(optimization_results['all_results'], key=lambda x: x['score'])[:20]
for i, result in enumerate(sorted_results, 1):
    print(f"\n{i}. f11={result['f11']:.4f}, prop.h2.latent1={result['prop.h2.latent1']:.4f}, vg1={result['vg1']:.4f}, vg2={result['vg2']:.4f}, f22={result['f22']:.4f}, am22={result['am22']:.4f}, rg={result['rg']:.4f}")
    print(f"   Mate Cor PGS={result['mate_cor_pgs']:.4f}, H2={result['heritability']:.4f}, R2 PGS={result['r2_pgs']:.4f}, Score={result['score']:.4f}")
    
    


Top 20 Candidates Overall:

1. f11=0.1000, prop.h2.latent1=0.8000, vg1=0.6000, vg2=1.0000, f22=0.2000, am22=0.6500, rg=0.7500
   Mate Cor PGS=0.1352, H2=0.4530, R2 PGS=0.1048, Score=0.0556

2. f11=0.1000, prop.h2.latent1=0.8000, vg1=0.6000, vg2=0.7500, f22=0.1500, am22=0.7500, rg=0.7500
   Mate Cor PGS=0.1356, H2=0.4628, R2 PGS=0.1071, Score=0.0592

3. f11=0.1000, prop.h2.latent1=0.8000, vg1=0.6000, vg2=1.0000, f22=0.3000, am22=0.6000, rg=0.7500
   Mate Cor PGS=0.1351, H2=0.4494, R2 PGS=0.1037, Score=0.0594

4. f11=0.2000, prop.h2.latent1=0.8000, vg1=0.8000, vg2=0.5000, f22=0.3000, am22=0.4500, rg=0.8500
   Mate Cor PGS=0.1353, H2=0.4712, R2 PGS=0.1088, Score=0.0605

5. f11=0.1500, prop.h2.latent1=0.8000, vg1=0.6500, vg2=0.5000, f22=0.3000, am22=0.6000, rg=0.8000
   Mate Cor PGS=0.1360, H2=0.4524, R2 PGS=0.1047, Score=0.0607

6. f11=0.1500, prop.h2.latent1=0.8000, vg1=0.6500, vg2=0.5000, f22=0.2500, am22=0.5500, rg=0.8500
   Mate Cor PGS=0.1346, H2=0.4531, R2 PGS=0.1044, Score=0.0608
