In [2]:
import h5py
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
from tensorflow import keras
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Layer, Dense, Input
from tensorflow.keras.models import Model
from tensorflow.keras.callbacks import (
    EarlyStopping,
    LearningRateScheduler,
    LambdaCallback,
)
from sklearn.preprocessing import StandardScaler  # for scaling input and output data
from sklearn.preprocessing import RobustScaler  # for scaling input and output data
from sklearn.preprocessing import MinMaxScaler
from scipy.interpolate import interp1d, make_interp_spline
import argparse
from sklearn.decomposition import PCA
from tensorflow.keras.models import load_model
from tqdm import tqdm
import pickle
from classy import Class
from train_pybird_emulators.emu_utils import integrated_model
from train_pybird_emulators.emu_utils import emu_utils
from cosmic_toolbox import logger
from train_pybird_emulators.emu_utils.k_arrays import k_emu, k_pybird

2024-11-25 02:11:45.456350: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2024-11-25 02:11:47.589551: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:485] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-11-25 02:11:48.038813: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:8454] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-11-25 02:11:48.278782: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1452] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-11-25 02:11:49.335890: I tensorflow/core/platform/cpu_feature_guar

loading matrices!


In [42]:
piece_name = "IRPsloop"
model_name = "test_ploopl"
ntrain=1000000
mono=True
quad_hex=False
mask_high_k=False
quad_alone=False
hex_alone=False
k_array_length = 77
training_data_file = "/cluster/scratch/areeves/pk_bank_boss_gaussian_cov_fixed_bug2/total_data.h5"

In [43]:
cov = emu_utils.get_default_cov()
print("cov shape", cov.shape)
flattened_rescale_factor = None

cov shape (231, 231)


In [44]:
x_train, y_train = emu_utils.get_training_data_from_hdf5(
    training_data_file,
    piece_name,
    ntrain,
    mono,
    quad_hex,
    quad_alone,
    hex_alone,
    mask_high_k
)

24-11-25 02:21:54 train_pybi INF   total number of available training points: 896400 
24-11-25 02:21:54 train_pybi INF   Available keys in the file: <KeysViewHDF5 ['D', 'IRPs11', 'IRPsct', 'IRPsloop', 'Ploopl', 'bpk_resum_False', 'bpk_resum_True', 'emu_inputs', 'f', 'kk', 'params', 'pk_lin']> 
xtrain shape (896400, 82)
24-11-25 02:23:34 train_pybi INF   Using monopole data for IRPsloop 
where are zeros?
(array([   0,    1,    2,    3,    4,    5,    6,    7,   77,   78,   79,
         80,   81,   82,   83,   84,  154,  155,  156,  157,  158,  159,
        160,  161,  231,  232,  233,  234,  235,  236,  237,  238,  308,
        309,  310,  311,  312,  313,  314,  315,  385,  386,  387,  388,
        389,  390,  391,  392,  462,  463,  464,  465,  466,  467,  468,
        469,  539,  540,  541,  542,  543,  544,  545,  546,  616,  617,
        618,  619,  620,  621,  622,  623,  693,  694,  695,  696,  697,
        698,  699,  700,  770,  771,  772,  773,  774,  775,  776,  777,
        

In [None]:
# print(f"filtering out bad indices for piece {piece_name}")

print("orig shape", x_train.shape)
condition_1 = np.any(x_train[:, :-2] > 0, axis=1)
condition_2 = x_train[:, -1] < 0
condition_3 = x_train[:, -2] < 0
bad_inds = np.where(condition_1 | condition_2 | condition_3)[0]

# # bad_inds = np.where(condition_1 | condition_2 | condition_3 | condition_4)[0]

# #ensure that the gradients in the first 10 knots are not consecutively negative 
# # New condition: Two consecutive negative gradients in the first 10 positions
# # Compute gradients in the first 10 positions
gradients_first_5 = np.diff(x_train[:, :6], axis=1)  # Shape: (num_samples, 10)

# # # Identify negative gradients
negative_gradients = gradients_first_5 < 0  # Shape: (num_samples, 10)
condition_4 = np.any(negative_gradients, axis=1)


bad_inds = np.where(condition_1 | condition_2 | condition_3 | condition_4)[0]


# if piece_name.startswith("I"):
#     print("training IR piece... going to filter out large gradients")
    # Calculate the absolute gradients along each row
gradients = np.abs(np.diff(y_train, axis=1))

gradient_threshold = np.quantile(
    gradients, 0.9995
)  # top 15% of gradients

# spikes typically happen around high k
spike_positions = np.arange(
    k_emu.shape[0] - 1, gradients.shape[1], k_emu.shape[0]
)  # Adjust for 0-index and diff output size

# Condition to identify rows with gradient spikes at specific positions
condition_5= np.any(
    gradients[:, spike_positions] > gradient_threshold, axis=1
)


bad_inds = np.where(
    condition_1 | condition_2 | condition_3  |condition_5
)[0]


    

print(f"removing {len(bad_inds)} bad indices")
x_train = np.delete(x_train, bad_inds, axis=0)
y_train = np.delete(y_train, bad_inds, axis=0)

orig shape (896400, 82)


In [None]:
# print(np.where(condition_4))

In [None]:
x_train.shape 

In [None]:
plt.hist(x_train[:,-2], bins=100)
plt.xlim(-100,300000)

In [None]:
for i in range(650): 
    plt.plot(y_train[i])

In [None]:
np.where(np.abs(y_train)==np.amax(np.abs(y_train)))

In [None]:
for i in range(900): 
    plt.plot(y_train[i])
plt.plot(y_train[np.where(np.abs(y_train)==np.amax(np.abs(y_train)))[0][0]])

In [None]:
y_train.shape 

In [None]:
max_params= x_train[np.where(np.abs(y_train)==np.amax(np.abs(y_train)))[0]]

In [None]:
knots = np.load("/cluster/work/refregier/alexree/local_packages/train_pybird_emulators/src/train_pybird_emulators/data/knots_data/final_knots_80.npy")

In [None]:
print(knots[0:6])

In [None]:
plt.loglog(knots, np.exp(max_params[0, :80]))

In [None]:
print(max_params[0, -2])

In [None]:
if flattened_rescale_factor is not None:
    num_patterns = y_train.shape[1] // k_array_length
    rescaling_factor = emu_utils.generate_repeating_array(
        flattened_rescale_factor, 77, num_patterns // 3
    )
    if mono:
        rescaling_factor = emu_utils.generate_repeating_array(
            flattened_rescale_factor, 77, num_patterns
        )
        rescaling_factor = rescaling_factor[: 35 * 77]
    if quad_hex:
        rescaling_factor = emu_utils.generate_repeating_array(
            flattened_rescale_factor, 77, 35
        )
        rescaling_factor = rescaling_factor[35 * 77 :]
    if not mono and not quad_hex:
        rescaling_factor = rescaling_factor
    rescaling_factor = np.array(rescaling_factor)
else:
    rescaling_factor = None

In [None]:
# plt.plot(flattened_rescale_factor)

In [None]:
# plt.plot(1/rescaling_factor)

In [None]:
# Are there places where all the columns in the data are zero?
zero_columns = np.where(np.sum(np.abs(y_train), axis=0) == 0)[0]

if zero_columns is not None and zero_columns.shape[0] > 0:
    # LOGGER.info(f"removing zero columns for piece {args.piece_name}")
    # remove and save zero columns indices
    np.save(f"zero_coumns_{piece_name}", zero_columns)
    y_train = np.delete(y_train, zero_columns, axis=1)
    if rescaling_factor is not None:
        rescaling_factor = np.delete(rescaling_factor, zero_columns, axis=0)

In [None]:
# print("using PCA preprocessing")
# npca=512
# pca_scaler = StandardScaler().fit(y_train)
# pca = PCA(n_components=npca)
# # Fit PCA to standard scaled data
# normalized_data = pca_scaler.transform(y_train)
# pca.fit(normalized_data)
# y_train = pca.transform(normalized_data)
# rescaling_factor = np.power(
#     pca.explained_variance_, -1
# )  # default for PCA is to use the explained variance to weight the components
# print(f"explained variance: {np.sum(pca.explained_variance_ratio_)}")
# print("using explained variance to weight the components")
# rescaling_factor = np.array(rescaling_factor)

In [None]:
# print("Log prepocessing")
# offset = np.amin(y_train, axis=0)
# offset[offset > 0] = 0
# y_train = np.log(y_train - 2 * offset)

In [None]:
input_scaler = StandardScaler().fit(x_train)
output_scaler = StandardScaler().fit(y_train)

print(f"x_train shape: {x_train.shape}")
print(f"y_train shape: {y_train.shape}")

keras_model = integrated_model.create_model(
    input_dim=x_train.shape[1],
    hidden_layers=[256,256,256,256],
    output_dim=y_train.shape[1],
)

In [None]:
# Initialize model and train
model = integrated_model.IntegratedModel(
    keras_model,
    input_scaler=input_scaler,
    output_scaler=output_scaler,
    offset=None,
    log_preprocess=False,
    temp_file=f"saved_models/{model_name}_temp",
    # pca=pca,
    # pca_scaler=pca_scaler,
    zero_columns=zero_columns,
    rescaling_factor=rescaling_factor,
)
model.train(
    x_train,
    y_train,
    epochs=600,
    batch_size=2048,
    validation_split=0.2,
)

In [None]:
n_train=5000

In [None]:
x_test, y_test = emu_utils.get_training_data_from_hdf5(
    training_data_file,
    piece_name,
    n_train,
    mono,
    quad_hex,
    quad_alone, 
    hex_alone,
    mask_high_k,
    test_data=True
)

In [None]:
x_test

In [None]:
print(f"filtering out bad indices for piece {piece_name}")

condition_1 = np.any(x_test[:, :-2] > 0, axis=1)
condition_2 = x_test[:, -1] < 0
condition_3 = x_test[:, -2] < 0
condition_4 = x_test[:, -2] > 20000
gradients_first_10 = np.diff(x_test[:, :11], axis=1)  # Shape: (num_samples, 10)

# Identify negative gradients
negative_gradients = gradients_first_10 < 0  # Shape: (num_samples, 10)

# Find two consecutive negative gradients
neg_gradients_original = negative_gradients[:, :-1]  # Exclude last element
neg_gradients_shifted = negative_gradients[:, 1:]    # Exclude first element

consecutive_negatives = neg_gradients_original & neg_gradients_shifted  # Shape: (num_samples, 9)

# Condition 6: Samples with two consecutive negative gradients in first 10 positions
condition_5 = np.any(negative_gradients, axis=1)

bad_inds = np.where(condition_1 | condition_2 | condition_3 | condition_4 | condition_5)[0]



if piece_name.startswith("I"):
    print("training IR piece... going to filter out large gradients")
    # Calculate the absolute gradients along each row
    gradients = np.abs(np.diff(y_test, axis=1))
    
    gradient_threshold = np.quantile(
        gradients, 0.80
    )  # top 15% of gradients
    
    # spikes typically happen around high k
    spike_positions = np.arange(
        k_emu.shape[0] - 1, gradients.shape[1], k_emu.shape[0]
    )  # Adjust for 0-index and diff output size
    
    # Condition to identify rows with gradient spikes at specific positions
    condition_5= np.any(
        gradients[:, spike_positions] > gradient_threshold, axis=1
    )
    
    
    bad_inds = np.where(
        condition_1 | condition_2 | condition_3 | condition_5 
    )[0]

print(f"removing {len(bad_inds)} bad indices")
x_test = np.delete(x_test, bad_inds, axis=0)
y_test = np.delete(y_test, bad_inds, axis=0)

In [None]:
model.save(f"test_models/{model_name}")


In [None]:
test_model = integrated_model.IntegratedModel(None,None,None,None)

In [None]:
test_model.restore(f"test_models/{model_name}")

In [None]:
# with open("/cluster/work/refregier/alexree/local_packages/train_pybird_emulators/src/train_pybird_emulators/notebooks/test_models/test_ploopl.pkl", "rb") as f:
#     attributes = pickle.load(f)

In [None]:
test_model.log_preprocess = False 
predicted_testing_spectra = test_model.predict(x_test)
# predicted_testing_spectra = model.predict(x_test)

testing_spectra = y_test


In [None]:
fig, ax = plt.subplots(nrows=5, ncols=5, figsize=(16,20))
for j in range(5):
    for i in range(5):
        pred = predicted_testing_spectra[i+j*3]
        true = testing_spectra[i+j*3]
        ell_range = np.arange(true.shape[0])
        ax[j, i].plot(ell_range, true, 'blue', label = 'Original')
        ax[j, i].plot(ell_range, pred, 'red', label = 'NN reconstructed', linestyle='--')
        ax[j, i].set_xlabel('$\ell$', fontsize='x-large')
        ax[j, i].set_ylabel('$\\frac{[\ell(\ell+1)]^2}{2 \pi} C_\ell$', fontsize='x-large')
        ax[j, i].legend(fontsize=15)


In [None]:
(1-np.median(np.abs((testing_spectra[:, ~zero_columns]-predicted_testing_spectra[:, ~zero_columns])/testing_spectra[:, ~zero_columns])))*100