In [27]:
import sys
import os
# Add the parent directory to the system path
sys.path.append(os.path.abspath('../'))

import numpy as np
from matter_multi_fidelity_emu.gpemulator_singlebin import SingleBindGMGP 
from matter_multi_fidelity_emu.data_loader import PowerSpecs
from matter_multi_fidelity_emu.data_loader_dgmgp import interpolate

import glob
import random
import contextlib
import io
import argparse
import time

In [28]:
def generate_data(
    folder_1: str = "../data/narrow/matter_power_297_Box100_Part75_27_Box100_Part300_z0",
    folder_2: str = "../data/narrow/matter_power_297_Box100_Part75_27_Box100_Part300_z0",
    n_fidelities: int = 2,
):
    data_1 = PowerSpecs(
        n_fidelities=n_fidelities,
    )
    data_1.read_from_txt(folder=folder_1)

    data_2 = PowerSpecs(
        n_fidelities=n_fidelities,
    )
    data_2.read_from_txt(folder=folder_2)

    return data_1, data_2

In [29]:
data_1, data_2 = generate_data()

log10_k_target = data_1.kf
log10_k_train = data_2.kf
ind_min = (log10_k_target >= log10_k_train.min()) & (
    log10_k_target <= log10_k_train.max()
)
# interpolate: interp(log10_k, Y_lf)(log10_k[ind_min])
data_2.Y_train_norm[0] = interpolate(
data_2.kf, data_2.Y_train_norm[0], data_1.kf[ind_min]
)
data_2.Y_train[0] = interpolate(data_2.kf, data_2.Y_train[0], data_1.kf[ind_min])
assert data_2.Y_train_norm[0].shape[1] == data_1.kf[ind_min].shape[0]
# HF powerspecs trim to same size as LF
data_1.Y_train_norm[0] = data_1.Y_train_norm[0][:, ind_min]
data_1.Y_train[0] = data_1.Y_train[0][:, ind_min]
data_1.Y_train_norm[1] = data_1.Y_train_norm[1][:, ind_min]
data_1.Y_train[1] = data_1.Y_train[1][:, ind_min]

data_1.Y_test[0] = data_1.Y_test[0][:, ind_min]

kf = data_1.kf[ind_min]

64 64 64
[Info] rebin powerspecs from 64 k bins to 64 k bins.
64 64 64
[Info] rebin powerspecs from 64 k bins to 64 k bins.


In [30]:
kf

array([-0.59976085, -0.57448928, -0.5492177 , -0.52394612, -0.49867454,
       -0.47340297, -0.44813139, -0.42285981, -0.39758823, -0.37231666,
       -0.34704508, -0.3217735 , -0.29650192, -0.27123035, -0.24595877,
       -0.22068719, -0.19541562, -0.17014404, -0.14487246, -0.11960088,
       -0.09432931, -0.06905773, -0.04378615, -0.01851457,  0.006757  ,
        0.03202858,  0.05730016,  0.08257174,  0.10784331,  0.13311489,
        0.15838647,  0.18365804,  0.20892962,  0.2342012 ,  0.25947278,
        0.28474435,  0.31001593,  0.33528751,  0.36055909,  0.38583066,
        0.41110224,  0.43637382,  0.4616454 ,  0.48691697,  0.51218855,
        0.53746013,  0.5627317 ,  0.58800328,  0.61327486,  0.63854644,
        0.66381801,  0.68908959,  0.71436117,  0.73963275,  0.76490432,
        0.7901759 ,  0.81544748,  0.84071906,  0.86599063,  0.89126221,
        0.91653379,  0.94180536,  0.96707694,  0.99234852])

In [31]:
data_1.Y_test[0]

array([[3.10786406, 3.04142428, 2.98888339, 3.03859416, 2.98557624,
        2.95719006, 2.87097404, 2.81589504, 2.82821697, 2.81693388,
        2.78822046, 2.72079326, 2.67418068, 2.68780069, 2.67979582,
        2.63078706, 2.60253339, 2.57982333, 2.55758746, 2.50223063,
        2.51015509, 2.46912788, 2.47467111, 2.42549254, 2.39723264,
        2.36899209, 2.33830431, 2.32462184, 2.31210791, 2.27846698,
        2.2638257 , 2.24295885, 2.19847622, 2.17564189, 2.14611387,
        2.11803319, 2.08869244, 2.05294939, 2.0192375 , 1.98896262,
        1.95225724, 1.91600968, 1.88122148, 1.84635171, 1.80132772,
        1.76084476, 1.713499  , 1.67518847, 1.62858013, 1.58503166,
        1.54714005, 1.50207218, 1.45651089, 1.40976158, 1.36405009,
        1.31767761, 1.27008851, 1.22253181, 1.17376883, 1.12610811,
        1.07610185, 1.02537923, 0.97628943, 0.92629411],
       [3.22438911, 3.15828964, 3.10704758, 3.1540178 , 3.10754589,
        3.07771705, 2.98849459, 2.94259088, 2.96063411, 2.9

In [43]:
import random

def search_LF(data, num_LF, HF_selected_ind, len_slice):

    X_0 = data.X_train[0] # LF cosmologies
    X_1 = data.X_train[1] # HF cosmologies

    n_LF_slc = num_LF//len_slice
    n_HF_slc = len(HF_selected_ind)//len_slice

    LF_selected_ind = []
    ind_slc_LF = []
    for i in range(n_HF_slc):
        if i == n_LF_slc:
            break
        first_ind = HF_selected_ind[i*len_slice]
        ind_LF_HF = np.where(X_1[first_ind] == X_0)[0][0]
        ind_slc_LF.append(int(ind_LF_HF/3))
        for j in range(len_slice):
            LF_selected_ind.append(ind_LF_HF + j)
    return LF_selected_ind, ind_slc_LF

len_slice = 3
n_sample_HF = 27
n_sample_LF = 297

num_HF = 27
num_LF = 297

ind_slc_list = random.sample(list(range(int(n_sample_HF/len_slice))), k=int(num_HF/len_slice))
        # print("n_sample_HF/len_slice:", int(n_sample_HF/len_slice))
        # print("num_HF/len_slice:", int(num_HF/len_slice))
HF_selected_ind = []
for ind_slc in ind_slc_list:
    for i in range(len_slice):
        HF_ind = len_slice*ind_slc + i
        HF_selected_ind.append(HF_ind)

LF_selected_ind, LF_selected_ind_slc = search_LF(data_1, num_LF, HF_selected_ind, len_slice)
print("LF_HF_common_ind:", LF_selected_ind, "\n")

if num_LF > num_HF:
    slc_list_remain = [i for i in range(int(n_sample_LF/len_slice)) if i not in LF_selected_ind_slc]
    ind_slc_list = random.sample(slc_list_remain, k=int((num_LF - num_HF)/len_slice))
    for ind_slc in ind_slc_list:
        for i in range(len_slice):
            LF_ind = len_slice*ind_slc + i
            LF_selected_ind.append(LF_ind)

LF_selected_ind.sort()
HF_selected_ind.sort()

LF_HF_common_ind: [81, 82, 83, 72, 73, 74, 6, 7, 8, 168, 169, 170, 60, 61, 62, 189, 190, 191, 192, 193, 194, 54, 55, 56, 282, 283, 284] 



In [44]:
# LF_selected_ind.append(6)

In [45]:
def assert_no_duplicates(lst):
    assert len(lst) == len(set(lst)), "Array has repeated elements"

assert_no_duplicates(LF_selected_ind)

In [46]:
LF_selected_ind

[0,
 1,
 2,
 3,
 4,
 5,
 6,
 7,
 8,
 9,
 10,
 11,
 12,
 13,
 14,
 15,
 16,
 17,
 18,
 19,
 20,
 21,
 22,
 23,
 24,
 25,
 26,
 27,
 28,
 29,
 30,
 31,
 32,
 33,
 34,
 35,
 36,
 37,
 38,
 39,
 40,
 41,
 42,
 43,
 44,
 45,
 46,
 47,
 48,
 49,
 50,
 51,
 52,
 53,
 54,
 55,
 56,
 57,
 58,
 59,
 60,
 61,
 62,
 63,
 64,
 65,
 66,
 67,
 68,
 69,
 70,
 71,
 72,
 73,
 74,
 75,
 76,
 77,
 78,
 79,
 80,
 81,
 82,
 83,
 84,
 85,
 86,
 87,
 88,
 89,
 90,
 91,
 92,
 93,
 94,
 95,
 96,
 97,
 98,
 99,
 100,
 101,
 102,
 103,
 104,
 105,
 106,
 107,
 108,
 109,
 110,
 111,
 112,
 113,
 114,
 115,
 116,
 117,
 118,
 119,
 120,
 121,
 122,
 123,
 124,
 125,
 126,
 127,
 128,
 129,
 130,
 131,
 132,
 133,
 134,
 135,
 136,
 137,
 138,
 139,
 140,
 141,
 142,
 143,
 144,
 145,
 146,
 147,
 148,
 149,
 150,
 151,
 152,
 153,
 154,
 155,
 156,
 157,
 158,
 159,
 160,
 161,
 162,
 163,
 164,
 165,
 166,
 167,
 168,
 169,
 170,
 171,
 172,
 173,
 174,
 175,
 176,
 177,
 178,
 179,
 180,
 181,
 182,
 183,
 184,


In [47]:
HF_selected_ind

[0,
 1,
 2,
 3,
 4,
 5,
 6,
 7,
 8,
 9,
 10,
 11,
 12,
 13,
 14,
 15,
 16,
 17,
 18,
 19,
 20,
 21,
 22,
 23,
 24,
 25,
 26]

In [48]:
[72, 73, 74, 189, 190, 191, 18, 19, 20, 162, 163, 164, 78, 79, 80, 132, 133, 134, 54, 55, 56, 114, 115, 116, 99, 100, 101, 216, 217, 218, 24, 25, 26, 153, 154, 155, 117, 118, 119, 102, 103, 104, 147, 148, 149, 75, 76, 77, 15, 16, 17, 261, 262, 263, 84, 85, 86, 42, 43, 44, 168, 169, 170, 138, 139, 140, 93, 94, 95, 285, 286, 287, 156, 157, 158, 210, 211, 212, 66, 67, 68, 294, 295, 296, 27, 28, 29, 90, 91, 92, 123, 124, 125, 165, 166, 167, 213, 214, 215, 51, 52, 53, 249, 250, 251, 57, 58, 59, 237, 238, 239, 39, 40, 41, 189, 190, 191, 273, 274, 275, 63, 64, 65, 252, 253, 254, 258, 259, 260, 144, 145, 146, 201, 202, 203, 111, 112, 113, 150, 151, 152, 291, 292, 293, 141, 142, 143, 195, 196, 197, 159, 160, 161, 225, 226, 227, 219, 220, 221, 0, 1, 2, 231, 232, 233, 270, 271, 272, 72, 73, 74, 3, 4, 5, 243, 244, 245, 192, 193, 194, 267, 268, 269, 6, 7, 8, 30, 31, 32, 45, 46, 47, 228, 229, 230, 36, 37, 38, 69, 70, 71, 246, 247, 248, 240, 241, 242, 255, 256, 257, 48, 49, 50, 33, 34, 35, 207, 208, 209, 186, 187, 188, 21, 22, 23, 177, 178, 179, 276, 277, 278, 264, 265, 266, 9, 10, 11, 222, 223, 224, 174, 175, 176, 198, 199, 200, 96, 97, 98, 12, 13, 14, 288, 289, 290, 81, 82, 83, 234, 235, 236, 60, 61, 62, 135, 136, 137, 204, 205, 206, 108, 109, 110, 180, 181, 182, 87, 88, 89, 282, 283, 284, 183, 184, 185, 171, 172, 173, 129, 130, 131]

[72,
 73,
 74,
 189,
 190,
 191,
 18,
 19,
 20,
 162,
 163,
 164,
 78,
 79,
 80,
 132,
 133,
 134,
 54,
 55,
 56,
 114,
 115,
 116,
 99,
 100,
 101,
 216,
 217,
 218,
 24,
 25,
 26,
 153,
 154,
 155,
 117,
 118,
 119,
 102,
 103,
 104,
 147,
 148,
 149,
 75,
 76,
 77,
 15,
 16,
 17,
 261,
 262,
 263,
 84,
 85,
 86,
 42,
 43,
 44,
 168,
 169,
 170,
 138,
 139,
 140,
 93,
 94,
 95,
 285,
 286,
 287,
 156,
 157,
 158,
 210,
 211,
 212,
 66,
 67,
 68,
 294,
 295,
 296,
 27,
 28,
 29,
 90,
 91,
 92,
 123,
 124,
 125,
 165,
 166,
 167,
 213,
 214,
 215,
 51,
 52,
 53,
 249,
 250,
 251,
 57,
 58,
 59,
 237,
 238,
 239,
 39,
 40,
 41,
 189,
 190,
 191,
 273,
 274,
 275,
 63,
 64,
 65,
 252,
 253,
 254,
 258,
 259,
 260,
 144,
 145,
 146,
 201,
 202,
 203,
 111,
 112,
 113,
 150,
 151,
 152,
 291,
 292,
 293,
 141,
 142,
 143,
 195,
 196,
 197,
 159,
 160,
 161,
 225,
 226,
 227,
 219,
 220,
 221,
 0,
 1,
 2,
 231,
 232,
 233,
 270,
 271,
 272,
 72,
 73,
 74,
 3,
 4,
 5,
 243,
 244,
 245,
 192,