In [None]:
# pip install cuda-python==12.1.0

In [None]:
# pip install jax

In [None]:
# pip install jaxlib

In [None]:
# pip install -U "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

In [None]:
# pip install pennylane --upgrade

In [None]:
# pip install scipy --upgrade

In [None]:
# For Linux 64, Open MPI is built with CUDA awareness but this support is disabled by default.
# To enable it, please set the environment variable OMPI_MCA_opal_cuda_support=true before
# launching your MPI processes. Equivalently, you can set the MCA parameter in the command line:
# mpiexec --mca opal_cuda_support 1 ...

# **Libs**

In [None]:
import time
import os
from math import sqrt

import pennylane as qml
from pennylane import numpy as pnp
from pennylane.operation import Operation, AnyWires

import jax
from jax import numpy as jnp
from jax import random
import optax

import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl
mpl.rcParams['mathtext.fontset'] = 'cm'
from matplotlib.ticker import (
    AutoLocator, AutoMinorLocator)
import seaborn as sns
# %matplotlib
# import pandas as pd


import sklearn
from sklearn import datasets
from sklearn.utils import shuffle
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import *
from sklearn.metrics import mean_squared_error,r2_score
from sklearn.decomposition import KernelPCA


In [None]:
jax.config.update('jax_platform_name', 'gpu')       #Hard-set to run on GPU
jax.config.update("jax_enable_x64", True)

In [None]:
jax.devices()

# **Data Generation**

In [None]:
#Initial conditions
Nf = 1024
f0 = 1420.4057
N = 10000 #no. of curves
n = 2
f00 = 78.1
T0 = 2039.611
a1 = -2.42096
a2 = -0.08062
a3 = 0.02898
f = np.linspace(27,200,Nf)
redshift = (f0/f)-1

In [None]:
#Data files
"""Input with fs, fx, Nlw as params"""
X_21 = np.loadtxt("Input_dTb_10k.csv", delimiter =',').T
Y_21 = np.loadtxt("Output_Params_10k_diff.csv", delimiter =',')

In [None]:
#Signal+foreground+noise data files
X_cfgn = np.loadtxt("SFGNC_in.csv", delimiter =',')     #constant foreground
Y_cfgn = np.loadtxt("SFGNC_out.csv", delimiter =',')
X_fgn = np.loadtxt("SFGNV_in.csv", delimiter =',')      #varying foreground
Y_fgn = np.loadtxt("SFGNV_out.csv", delimiter =',')

In [None]:
#class to generate foreground foreground
class Foreground:
    def __init__(self, T0, a1, a2, a3, N, f, f00, variation,random=False, constant_space=False):
        """
        Initialize the Foreground object.

        Parameters:
            T0 (float): Parameter T0.
            a1 (float): Parameter a1.
            a2 (float): Parameter a2.
            a3 (float): Parameter a3.
            N (int): Number of samples.
            f (list or array): List or array of frequencies.
            f00 (float): Reference frequency.
            constant_space (bool): Whether to use constant space or not.
        """
        self.T0 = T0
        self.a1 = a1
        self.a2 = a2
        self.a3 = a3
        self.N = N
        self.f = f
        self.f00 = f00
        self.var = variation
        self.rand = random
        self.constant_space = constant_space

    def param_range(self, n=2):
        """
        Generate parameter ranges.

        Returns:
            list of arrays: List of parameter ranges.
        """
        a0_init = np.log10(self.T0)
        var0,var1,var2,var3 = self.var
        a0_range = np.linspace(a0_init*(1-var0),a0_init*(1+var0), n)
        a1_range = np.linspace(self.a1*(1-var1),self.a1*(1+var1), n)
        a2_range = np.linspace(self.a2*(1-var2),self.a2*(1+var2), n)
        a3_range = np.linspace(self.a3*(1-var3),self.a3*(1+var3), n)
        a0_range = np.sort(a0_range)
        a1_range = np.sort(a1_range)
        a2_range = np.sort(a2_range)
        a3_range = np.sort(a3_range)
        return [a0_range, a1_range, a2_range, a3_range]

    def param_space(self):
        """
        Generate parameter space.

        Returns:
            array: Parameter space array.
        """
        param_range = self.param_range()
        space = Space(param_range)
        lhs = Lhs(lhs_type="classic", criterion=None)
        x = lhs.generate(space.dimensions, self.N)
        return np.array(x)

    def random_space(self):
        """
        Generate constant parameter space.

        Returns:
            array: Constant parameter space array.
        """
        np.random.seed(1234)
        a0_init = np.log10(self.T0)
        var0,var1,var2,var3 = self.var
        a0_ = np.random.uniform(a0_init*(1-var0),a0_init*(1+var0),size=self.N)
        a1_ = np.random.uniform(self.a1*(1-var1),self.a1*(1+var1),size=self.N)
        a2_ = np.random.uniform(self.a2*(1-var2),self.a2*(1+var2),size=self.N)
        a3_ = np.random.uniform(self.a3*(1-var3),self.a3*(1+var3),size=self.N)
        return np.array([a0_, a1_, a2_, a3_]).T

    def const_space(self):
        """
        Generate constant parameter space.

        Returns:
            array: Constant parameter space array.
        """
        a0_init = np.log10(self.T0)
        a0_ = a0_init * np.ones(self.N)
        a1_ = self.a1 * np.ones(self.N)
        a2_ = self.a2 * np.ones(self.N)
        a3_ = self.a3 * np.ones(self.N)
        return np.array([a0_, a1_, a2_, a3_]).T

    def get_params(self):
        if self.constant_space == False and self.rand==False:
            arr = self.param_space()
        elif self.rand == False and self.constant_space==True:
            arr = self.const_space()
        elif self.rand == True and self.constant_space==False:
            arr = self.random_space()
        return arr


    def generate_fore(self):
        """
        Generate foreground.

        Returns:
            array: Foreground.
        """
        arr = self.get_params()

        Tf = []
        for i in range(self.N):
            a_arr = arr[i]
            # Tf = []
            for f_val in self.f:
                logf = np.log10(f_val / self.f00)
                tf = 0
                for j in range(4):
                    tf += a_arr[j] * logf ** j
                Tf.append(tf)
        T_fg = 10**np.array(Tf)
        return T_fg.reshape(self.N,len(self.f))


In [None]:
#Thermal Noise generation
def therm_noise(foreground,del_nu,N_t):
    return (foreground)/np.sqrt(del_nu*10**6*3600*N_t)

In [None]:
fg = Foreground(T0,a1,a2,a3,N,f,f00,variation=[0.1,0,0.05,0.08],random=True,constant_space=False)
fg1 = Foreground(T0,a1,a2,a3,N,f,f00,variation=[0.1,0.01,0.05,0.08],random=False,constant_space=True)

In [None]:
X_fg = fg.generate_fore()
a_arr = fg.get_params()
X_fg1 = fg1.generate_fore()
a_arr1 = fg1.get_params()

In [None]:
X_n = therm_noise(X_fg,1,1000)
X_n1 = therm_noise(X_fg1,1,1000)

In [None]:
X_21.shape

In [None]:
X_data_nf = X_21
Y_data_nf = Y_21
X_21_n = X_21 + X_n
X_data_f = X_21+X_fg
X_fgn = (X_21+X_fg*10**3+X_n)
X_cfgn = (X_21+X_fg1*10**3+X_n1)
Y_data_f = np.concatenate((Y_21,a_arr),axis=1)
Y_data_fc = np.concatenate((Y_21,a_arr1),axis=1)

In [None]:
import os
directory_path = '/home/somnath/akash/QNN21cm/Codes/'

# Specify the filename (you can change 'my_array.csv' to your desired filename)
file_name1 = 'SFGNV_in.csv'
file_name2 = 'SFGNC_in.csv'
file_name3 = 'SFGNC_out.csv'
file_name4 = 'SFGNV_out.csv'

# Concatenate the directory path and the filename to get the full path
full_path1 = os.path.join(directory_path, file_name1)
full_path2 = os.path.join(directory_path, file_name2)
full_path3 = os.path.join(directory_path, file_name3)
full_path4 = os.path.join(directory_path, file_name4)

# Save the NumPy array as a CSV file
np.savetxt(full_path1, X_fgn, delimiter=',')
np.savetxt(full_path2, X_cfgn, delimiter=',')
np.savetxt(full_path3, Y_data_fc, delimiter=',')
np.savetxt(full_path4, Y_data_f, delimiter=',')

In [None]:
fig,ax = plt.subplots()
# ax2 = ax.twiny()

for i in range(N):
    ax.plot(f,X_21[i,:])
ax.minorticks_on()
ax.grid()

ax.set_xlabel(r'$\nu$ (MHz)',fontsize=14)
ax.set_ylabel(r'$\delta T_{b}$ (mK)',fontsize=14)
ax.set_title('Global 21cm Signal',fontsize=15)
ax.tick_params(axis='x', labelsize=12)
ax.tick_params(axis='y', labelsize=12)
# plt.savefig('/home/somnath/akash/QNN21cm/Codes/Plots/21cm_signal.pdf',dpi=100,edgecolor='black')


In [None]:
redshift.shape

In [None]:
fig,ax = plt.subplots()
# ax2 = ax.twiny()

for i in range(N):
    ax.plot(f,X_fgn[i,:])
ax.minorticks_on()
ax.grid()
ax.set_xlabel(r'$\nu$ (MHz)',fontsize=14)
ax.set_ylabel(r'$\delta T_{b}=T_{21}+T_{FG} + T_{N}$ (mK)',fontsize=14)
ax.set_title('Contaminated Signal (Varying)',fontsize=16)
ax.set_yscale('log')
ax.tick_params(axis='x', labelsize=12)
ax.tick_params(axis='y', labelsize=12)
plt.savefig('/home/somnath/akash/QNN21cm/Codes/Plots/Fore+Sig_var.pdf')

In [None]:
fig,ax = plt.subplots()
# ax2 = ax.twiny()

for i in range(N):
    ax.plot(f,X_cfgn[i,:])
ax.minorticks_on()
ax.grid()
ax.set_xlabel(r'$\nu$ (MHz)',fontsize=14)
ax.set_ylabel(r'$\delta T_{b}=T_{21}+T_{FG} + T_{N}$ (mK)',fontsize=14)
ax.set_title('Contaminated Signal (Constant)',fontsize=16)
ax.set_yscale('log')
ax.tick_params(axis='x', labelsize=12)
ax.tick_params(axis='y', labelsize=12)
plt.savefig('/home/somnath/akash/QNN21cm/Codes/Plots/Fore+Sig_const.pdf')

In [None]:
fig,ax = plt.subplots(1,2,figsize=(12,5))
title = ['Contaminated Signal (Constant)',
         'Contaminated Signal (Varying)']
# fig.text(0.5, 0.02, r'$\nu$ (MHz)', ha='center',fontsize=14)
fig.text(0.07, 0.5, r'$\delta T_{b}=T_{21}+T_{FG} + T_{N}$ (mK)', va='center', rotation='vertical',fontsize=12)

for i in range(N):
    ax[0].plot(f,X_cfgn[i,:])
    ax[1].plot(f,X_fgn[i,:])

for i in range(2):
    ax[i].minorticks_on()
    ax[i].grid()
    ax[i].set_xlabel(r'$\nu$ (MHz)',fontsize=12)
    ax[i].set_title(title[i],fontsize=14)
    ax[i].set_yscale('log')
    ax[i].tick_params(axis='x', labelsize=12)
    ax[i].tick_params(axis='y', labelsize=12)
# plt.savefig('/home/somnath/akash/QNN21cm/Codes/Plots/Fore_cv.pdf')

# **Ultility Functions**

In [None]:
#quantum circuit class - Check pennylane source code for more details -
#https://docs.pennylane.ai/en/stable/_modules/pennylane/templates/layers/strongly_entangling.html#StronglyEntanglingLayers
class StronglyEntanglingLayers(Operation):
    num_wires = AnyWires
    grad_method = None

    def __init__(self, weights, wires, ranges=None, imprimitive=None, id=None):
        shape = qml.math.shape(weights)[-3:]

        if shape[1] != len(wires):
            raise ValueError(
                f"Weights tensor must have second dimension of length {len(wires)}; got {shape[1]}"
            )

        if shape[2] != 3:
            raise ValueError(
                f"Weights tensor must have third dimension of length 3; got {shape[2]}"
            )

        if ranges is None:
            if len(wires) > 1:
                # tile ranges with iterations of range(1, n_wires)
                ranges = tuple((l % (len(wires) - 1)) + 1 for l in range(shape[0]))
            else:
                ranges = (0,) * shape[0]
        else:
            ranges = tuple(ranges)
            if len(ranges) != shape[0]:
                raise ValueError(f"Range sequence must be of length {shape[0]}; got {len(ranges)}")
            for r in ranges:
                if r % len(wires) == 0:
                    raise ValueError(
                        f"Ranges must not be zero nor divisible by the number of wires; got {r}"
                    )

        self._hyperparameters = {"ranges": ranges, "imprimitive": imprimitive }

        super().__init__(weights, wires=wires, id=id)

    @property
    def num_params(self):
        return 1

    @staticmethod
    def compute_decomposition(
        weights, wires, ranges, imprimitive
    ):  # pylint: disable=arguments-differ

        n_layers = qml.math.shape(weights)[0]
        wires = qml.wires.Wires(wires)
        op_list = []
        shape = qml.math.shape(weights)[-3:]

        for l in range(n_layers):
            for i in range(len(wires)):  # pylint: disable=consider-using-enumerate
                op_list.append(
                    qml.U3(
                        weights[..., l, i, 0],
                        weights[..., l, i, 1],
                        weights[..., l, i, 2],
                        wires=wires[i],
                    )
                )

            if len(wires) > 1:
                for i in range(len(wires)):
                    act_on = wires.subset([i, i + ranges[l]], periodic_boundary=True)
                    op_list.append(imprimitive(wires=act_on))

        return op_list


    @staticmethod
    def shape(n_layers, n_wires):

        return n_layers, n_wires, 3

In [None]:
# Defining a function to split the data into training, validation, and test sets
def data_split(X, Y, test_size):
    """
    Inputs:
    X -> Input
    Y -> Output
    test_size -> Test data split

    Returns:
    JAX Numpy arrays of train, validation and test datasets
    """

    # Splitting the data into training and temporary data (1)
    x_train, x_1, y_train, y_1 = train_test_split(X, Y, test_size=test_size, random_state=2)

    # Further splitting the temporary data (1) into validation and test sets
    x_val, x_test, y_val, y_test = train_test_split(x_1, y_1, test_size=0.5, random_state=42)

    # Shuffling the training data for randomness
    x_train, y_train = shuffle(x_train, y_train, random_state=2)

    # Converting the data arrays into JAX arrays
    x_train = jnp.array(x_train)
    y_train = jnp.array(y_train)
    x_val = jnp.array(x_val)
    y_val = jnp.array(y_val)
    x_test = jnp.array(x_test)
    y_test = jnp.array(y_test)

    # Returning the split datasets
    return x_train, x_val, x_test, y_train, y_val, y_test

In [None]:
# Function to initialize parameters for a quantum neural network
def params_init(n_layer, n_qubits, Y):
    """
    Inputs:
    n_layer -> specifies 1st dimension of weights array
    n_qubits -> specifies the 2nd dimension of weights array
    Y -> Specifes the shape of the bias array

    Returns:
    A Dictionary containing the initialized weights and bias
    """

    # Generating a random PRNG key for initialization
    key = jax.random.PRNGKey(np.random.randint(0, 1e4))

    # Initializing weights randomly using uniform distribution
    # Shape of weights: (number of layers, number of qubits per layer, 3)
    # 3 represents the parameters required for each qubit in the layer
    var_init = jax.random.uniform(key, (n_layer, n_qubits, 3), minval=0, maxval=1)

    # Initializing bias as zeros with shape matching the number of output classes
    bias_init = jnp.zeros(Y.shape[1])

    # Constructing a dictionary to store initialized parameters
    params = {"weights": var_init, "bias": bias_init}

    return params

In [None]:
def KPCA(data, comps, kernel, inverse=True):
    # Initializing KernelPCA with specified parameters
    Kpca = KernelPCA(n_components=comps, kernel=kernel, fit_inverse_transform=inverse)

    # Performing dimensionality reduction
    red = Kpca.fit_transform(data)

    # Reconstructing the original data if inverse transformation is enabled
    if inverse:
        recon = Kpca.inverse_transform(red)
        return red, recon
    else:
        return red, None  # If inverse transformation is disabled, return only reduced data

In [None]:
def pca_fg_filter(data_cube, nfg):
    '''
    This function will estimate the principal components of the data from the freq covariance and remove
    nfg number of modes, which are dominant by FG emission, i.e, with largest eigenvalues of the covariance matrix.
    This is PCA fg filter technique.

    Parameters
    ----------

    data_cube: the image cube. np.ndarray[freq,ra,el]
    nfg : int,
          number of eigenmodes to remove from the data. Default: nfg=5

    Returns
    -------
    residual_cube: The FG filtered clean cube
    A : The design matrix of shape [nfreq,nfg]. This is the FG operator.
    FG_modes : The FG amplitudes of eigenmodes,  of shape [nfg,nra,el].
    eigvals: The eigenvalues of the freq cov matrix (sorted as largest value first)
    eigvecs: The eigenvectors of the freq cov matrix (sorted as largest value first)

    '''
#     shp = data_cube.shape # shape of the data cube [nfreq,nRA,nel]
#     nfreq, nra, ndec = shp[0], shp[1], shp[2]
#     print(f" nfreq :{nfreq}, Nx : {nra}, Ny: {ndec}")
#     data = data_cube.reshape(nfreq, nra * ndec) # shaping the data in [nfreq, npix] shape

    data = data_cube

    Cov = np.cov(data) # estimate the freq covariance
    eigvals, eigvecs = np.linalg.eigh(Cov) # eigendecomposition

    # Sort by eigenvalue
    idxs = np.argsort(eigvals)[::-1] # reverse order (biggest eigenvalue first)
    eigvals = eigvals[idxs]
    eigvecs = eigvecs[:,idxs]

    # Construct the design matrix
    A = eigvecs[:,:nfg] # (Nfreqs, Nmodes)

    # The foreground amplitudes for each line of sight
    FG_modes = np.dot(A.T, data) # (Nmodes, Npix)

    # Reconstruct the FG map

    FG_map = np.dot(A, FG_modes) # Design matrix times FG_modes
#     FG_map = FG_map.reshape(nfreq,nra,ndec)

    # Subtract the FG map from data
    residual_cube = data_cube - FG_map
# A, FG_modes.reshape(nfg,nra,ndec), eigvals, eigvecs, FG_map
    return residual_cube

# **Parameter Estimation**

## **Training**

In [None]:
# Define a function for quantum model training
def Qmodel(X, Y, n_qubits, n_layer, opt, epoch, batch_size, printing=True, Foreground=False):

    """
    Inputs:
    X -> Input data
    Y -> Output data
    n_qubits -> No. of qubits required
    n_layers -> No. of circuit layers required
    opt -> Optimizer used

    epoch -> Total epochs for training
    batch_size -> Batched input shape for training
    Foreground (bool) -> if True, will predict foreground parameters provided that the input data consists the foreground

    Returns:
    params -> Optimised Weights
    train_loss -> Training Loss function
    val_loss -> Validation Loss function
    test_loss -> Loss between predicted and test data
    Y_pre -> Predicted output
    """

    # Define the quantum neural network
    def QNN(params, inputs):
        weights = params["weights"]
        bias = params["bias"]
        dev = qml.device("default.qubit", wires=n_qubits)

        @qml.qnode(dev, interface="jax", diff_method='backprop')
        def circ():
            # Quantum circuit construction
            qml.AmplitudeEmbedding(features=inputs, wires=range(n_qubits), normalize=True, pad_with=0.5)
            # qml.QFT(wires=range(n_qubits))
            StronglyEntanglingLayers(weights, wires=range(n_qubits), imprimitive=qml.ops.CNOT)
            if Foreground == True:
                out = [qml.expval(qml.PauliZ(i)) for i in range(Y.shape[1])[0::1]]      #exp val of each qubit
            else:
                out = [qml.expval(qml.PauliZ(i) + qml.PauliZ(i + 1)) for i in range(n_qubits - 1)[0::2]]  #exp value of pair of nearby qubits - giving out an array of 3 elements
            return out

        # Adding bias
        def qnn():
            circ_out = circ()
            out = []
            for i in range(len(circ_out)):
                out.append(circ_out[i] + bias[i])
            return jnp.array(out)
        return qnn()

    # Define mean squared error loss function
    def mse(observed, predictions):
        loss = jnp.sum((observed - predictions) ** 2 / len(observed))
        return jnp.mean(loss)

    # Define a function for making predictions
    def predict(params, features):
        preds = QNN(params, features)

        preds_np = jnp.asarray(preds).T
        return preds_np

    # Vectorized prediction function
    batched_predict = jax.vmap(predict, in_axes=(None, 0,))

    # Define the cost function - since this is used to make predictions jitting is not necessay
    def cost(params, features, observed):
        preds = batched_predict(params, features)
        cost = mse(observed, preds)
        return cost

    # Define a JIT-compiled version of the cost function - Used durng training and validationg
    @jax.jit
    def jit_cost(params, features, observed):
        preds = batched_predict(params, features)
        cost = mse(observed, preds)
        return cost

    # Update step function - gradient descen optimization
    @jax.jit
    def update_step(params, opt_state, features, observed):
        """
        Optimization
        """

        train_cost, grads = jax.value_and_grad(jit_cost)(params, features, observed)
        updates, opt_state = opt.update(grads, opt_state)
        params = optax.apply_updates(params, updates)
        return params, opt_state, train_cost

    # Function for fitting the model - This is where the training takes place
    def fit(params, opt_state, x_train, y_train, x_val, y_val, epoch, batch_size):
        """
        Model Training

        Returns:
        params -> Optimised Weights and bias
        Train loss -> Training loss function
        Val loss -> Validation loss function
        """
        train_loss = []
        val_loss = []
        num_train = len(x_train)
        num_val = len(x_val)
        key_t = random.PRNGKey(np.random.randint(0, 1e4))

        for i in range(epoch):
            key_t, key_v = random.split(key_t)
            idx_train = random.choice(key_t, num_train, shape=(batch_size,))
            idx_val = random.choice(key_v, num_val, shape=(batch_size,))

            x_train_batch = jnp.asarray(x_train[idx_train])
            y_train_batch = jnp.asarray(y_train[idx_train])
            x_val_batch = jnp.asarray(x_val)
            y_val_batch = jnp.asarray(y_val)

            start_time = time.time()
            params, opt_state, train_cost = update_step(params, opt_state, x_train_batch, y_train_batch)
            end_time = time.time()

            val_cost = jit_cost(params, x_val_batch, y_val_batch)
            epoch_time = end_time - start_time
            if printing ==True:
                print("Epoch: {:5d} | Loss: {:0.7f} | Val_Loss: {:0.7f} | Time: {:0.4f} seconds".format(i + 1, train_cost, val_cost, epoch_time))

            train_loss.append(train_cost)
            val_loss.append(val_cost)

        return params, train_loss, val_loss

    # Splitting data into train, validation, and test sets
    x_train, x_val, x_test, y_train, y_val, y_test = data_split(X, Y, 0.2)

    # Initializing parameters
    params = params_init(n_layer, n_qubits, Y)

    # Initializing optimizer state
    opt_state = opt.init(params)

    # Fitting the model
    st = time.time()
    params, train_loss, val_loss = fit(params, opt_state, x_train, y_train, x_val, y_val, epoch, batch_size)
    tr_t = time.time() - st
    if printing ==True:
        print(f" Total Training time = {tr_t // 60} minutes {round(tr_t % 60)} seconds.")

    # Calculating test loss
    test_loss = cost(params, x_test, y_test)
    # Making predictions on test data
    st = time.time()
    Y_pre = predict(params, x_test)
    te_t = time.time() - st
    if printing ==True:
        print(f" Prediction time = {te_t // 60} minutes {round(tr_t % 60)} seconds.")
    out = params, train_loss, val_loss, test_loss, Y_pre
    return out

### **Signal only**

In [None]:
scaler_x = MinMaxScaler(feature_range=(-np.pi,np.pi),clip=True)
scaler_x.fit(X_data_nf)
X = scaler_x.transform(X_data_nf)

In [None]:
#Normalization
scaler_y = MinMaxScaler()
scaler_y.fit(Y_data_nf)
Y = scaler_y.transform(Y_data_nf)

In [None]:
X_red,X_recon= KPCA(X,comps=128,kernel='rbf',inverse=True)

In [None]:
for i in range(N):
    plt.plot(X_red[i,:])

In [None]:
j = 999
plt.plot(X[j,:],color='blue')
plt.plot(X_recon[j,:],color='orange')

In [None]:
reconstruction_error_original = np.mean(np.square(X - X_recon))
print(f"Reconstruction Error - Original Data: {reconstruction_error_original}")

In [None]:
#initial params for the circuit and data split
n_qubits = int(np.log2(X_red.shape[1]))
n_layer = 40
lr_sched = optax.exponential_decay(0.01,1000,0.96)
opt = optax.adam(learning_rate=0.01)
x_train,x_val,x_test,y_train,y_val,y_test = data_split(X_red,Y,0.2)

In [None]:
#QNN training
params,train_loss,val_loss, test_loss,Y_pre = Qmodel(X_red,Y,n_qubits,n_layer,opt,epoch=800,batch_size=512)

In [None]:
test_loss

In [None]:
fig,ax = plt.subplots()
N_p = 1
ax.plot(np.convolve(train_loss,np.ones(N_p)/N_p), color="red", label= 'Training Loss')
ax.plot(np.convolve(val_loss,np.ones(N_p)/N_p), color="green", label= 'Validation loss')
ax.legend()
ax.set_xlabel("Epoch",fontsize=14)
ax.set_ylabel("MSE Loss",fontsize=14)
# ax2=ax.twinx()params
ax.grid()
# ax2.set_ylabel("Validation Loss",color="green",fontsize=14)
# ax.set_yscale('log')
# ax2.set_yscale('log')
ax.tick_params(axis='x', labelsize=12)
ax.tick_params(axis='y', labelsize=12)
# plt.savefig("/home/somnath/akash/QNN21cm/Codes/Plots/Signal_lossfn_2.pdf")

In [None]:
# Compute the output
Y_pred = scaler_y.inverse_transform(Y_pre)
Y_test = scaler_y.inverse_transform(y_test)


r2score= r2_score(y_test, Y_pre, multioutput='uniform_average')
fs_r2score= r2_score(y_test[:,0], Y_pre[:,0])
print("Star_Formation_efficiency_R2score=", fs_r2score)

fx_r2score= r2_score(y_test[:,1], Y_pre[:,1])
print("X-ray_uncertainty_R2score=", fx_r2score)

Nlw_r2score= r2_score(y_test[:,2], Y_pre[:,2])
print("Ly-photon_num_R2score=", Nlw_r2score)


#RMSE Score
mse = mean_squared_error(y_test, Y_pre, multioutput='raw_values')
fs_rmse= sqrt(mse[0])
Nlw_rmse= sqrt(mse[2])
fx_rmse= sqrt(mse[1])

print("Star_Formation_efficiency_RMSE=",fs_rmse)
print("X-ray_uncertainty_RMSE=",fx_rmse)
print("Ly-photon_num_RMSE=",Nlw_rmse)


In [None]:
r2_score(Y_test,Y_pred)

In [None]:
#Saving weights and biases of the model

# import os
# directory_path = '/home/akash/QNN21cm/Codes'

# # Specify the filename (you can change 'my_array.csv' to your desired filename)
# file_name_w = 'Weight_JAX.csv'
# file_name_b = 'bias_JAX.csv'

# # Concatenate the directory path and the filename to get the full path
# full_path_w = os.path.join(directory_path, file_name_w)
# full_path_b = os.path.join(directory_path, file_name_b)

# # Save the NumPy array as a CSV file
# np.savetxt(full_path_w, params["weights"].reshape(n_layer*n_qubits*3), delimiter=',')
# # Save the NumPy array as a CSV file
# np.savetxt(full_path_b, params["bias"], delimiter=',')

In [None]:
#SAving the prediction output

import os
directory_path = '/home//akash/QNN21cm/Codes'

# Specify the filename (you can change 'my_array.csv' to your desired filename)
file_name1 = 'Prediction_21_JAX.csv'

# Concatenate the directory path and the filename to get the full path
full_path1 = os.path.join(directory_path, file_name1)

# Save the NumPy array as a CSV file
np.savetxt(full_path1, Y_pred, delimiter=',')
# Specify the filename (you can change 'my_array.csv' to your desired filename)
file_name2 = 'Test_21_JAX.csv'

# Concatenate the directory path and the filename to get the full path
full_path2 = os.path.join(directory_path, file_name2)

# Save the NumPy array as a CSV file
np.savetxt(full_path2, Y_test, delimiter=',')

In [None]:
fsp, fxp, Nlwp = Y_pred.T
fst, fxt, Nlwt = Y_test.T
title = ['Star Formation Efficiency, f$_{star}$',
         'X-Ray Efficiency, f$_{X}$',
         'Lyman-'r'$\alpha$ Photon Number, N$_{lw}$ ($ \times 10^{5}$)']
fig,axe = plt.subplots(1,3,figsize= (17,5))
# ticks_1 = np.arange(0,1.2,0.2)
# ticks_3 = np.arange(0,0.7,0.1)
# ticks = [ticks_1,ticks_1,ticks_3]
fig.text(0.5, 0.01, 'Actual Values', ha='center',fontsize=18)
fig.text(0.08, 0.5, 'Predicted Values', va='center', rotation='vertical',fontsize=18)

sns.regplot(x=fst,y=fsp,ax=axe[0],ci=99, marker=".", color=".3", line_kws=dict(color="r"))
sns.regplot(x=fxt,y=fxp,ax=axe[1],ci=99, marker=".", color=".3", line_kws=dict(color="r"))
sns.regplot(x=Nlwt/1e5,y=Nlwp/1e5,ax=axe[2],ci=99, marker=".", color=".3", line_kws=dict(color="r"))

for i in range(Y_test.shape[1]):
    axe[i].set_title(title[i],fontsize=16)
    axe[i].grid()
    axe[i].tick_params(axis='x', labelsize=13)
    axe[i].tick_params(axis='y', labelsize=13)

plt.savefig('/home/akash/QNN21cm/Codes/Plots/preds_sonly.pdf')

### **Signal+Foreground**

#### **Constant Contamination**

In [None]:
# Use X_cfgn or X_fgn according to your need
scaler_x_m = MinMaxScaler(feature_range=(-np.pi,np.pi))
scaler_x_m.fit(X_cfgn)
X_f = scaler_x_m.transform(X_cfgn)


In [None]:
scaler_y = MinMaxScaler()
scaler_y.fit(Y_data_f)
Y_f = scaler_y.transform(Y_data_f)

In [None]:
X_red, X_recon =KPCA(X_f,comps=128,kernel='rbf')

In [None]:
for k in range(10000):
    plt.plot(X_red[k,:])

In [None]:
j=999
plt.plot(X_f[j,:])
plt.plot(X_recon[j,:])

In [None]:
reconstruction_error_original = np.mean(np.square(X_f - X_recon))
print(f"Reconstruction Error - Original Data: {reconstruction_error_original}")

In [None]:
X_red.shape

In [None]:
n_qubits = int(np.log2(X_red.shape[1]))
n_layer = 40
lr_sched = optax.exponential_decay(0.01,1000,0.96)
opt = optax.adam(learning_rate=0.01)
x_train,x_val,x_test,y_train,y_val,y_test = data_split(X_red,Y,0.2)

In [None]:
params,train_loss,val_loss, test_loss,Y_pre = Qmodel(X_red,Y,n_qubits,n_layer,opt,epoch=1000,batch_size=512)

In [None]:
test_loss

In [None]:
fig,ax = plt.subplots()
N_p = 1
ax.plot(np.convolve(train_loss,np.ones(N_p)/N_p), color="red", label= 'Training Loss')
ax.plot(np.convolve(val_loss,np.ones(N_p)/N_p), color="green", label= 'Validation loss')
ax.legend()
ax.set_xlabel("Epoch",fontsize=14)
ax.set_ylabel("MSE Loss",fontsize=14)
# ax2=ax.twinx()
ax.grid()
# ax2.set_ylabel("Validation Loss",color="green",fontsize=14)
ax.set_yscale('log')
# ax2.set_yscale('log')
ax.tick_params(axis='x', labelsize=12)
ax.tick_params(axis='y', labelsize=12)
# plt.savefig("/home/somnath/akash/QNN21cm/Codes/Plots/Sfgn_lossfn_2.pdf")

In [None]:
# Compute the output
Y_pred = scaler_y.inverse_transform(Y_pre)
Y_test = scaler_y.inverse_transform(y_test)


r2score= r2_score(y_test, Y_pre, multioutput='uniform_average')
fs_r2score= r2_score(y_test[:,0], Y_pre[:,0])
print("Star_Formation_efficiency_R2score=", fs_r2score)

fx_r2score= r2_score(y_test[:,1], Y_pre[:,1])
print("X-ray_uncertainty_R2score=", fx_r2score)

Nlw_r2score= r2_score(y_test[:,2], Y_pre[:,2])
print("Ly-photon_num_R2score=", Nlw_r2score)

print("tot_r2=",r2_score(y_test,Y_pre))

# a0_r2score= r2_score(y_test[:,3], Y_pre[:,3])
# print("a0_R2score=",a0_r2score )

# a1_r2score= r2_score(y_test[:,4], Y_pre[:,4])
# print("a1_R2score=",a1_r2score)

# a2_r2score= r2_score(y_test[:,5], Y_pre[:,5])
# print("a2_R2score=",a2_r2score )

# a3_r2score= r2_score(y_test[:,6], Y_pre[:,6])
# print("a3_R2score=",a3_r2score)

from math import sqrt
#RMSE Score
mse = mean_squared_error(y_test, Y_pre, multioutput='raw_values')
fs_rmse= sqrt(mse[0])
Nlw_rmse= sqrt(mse[2])
fx_rmse= sqrt(mse[1])
mean = np.mean([fs_rmse,fx_rmse,Nlw_rmse])
print("Star_Formation_efficiency_RMSE=",fs_rmse)
print("X-ray_uncertainty_RMSE=",fx_rmse)
print("Ly-photon_num_RMSE=",Nlw_rmse)
print("tot_rmse=",mean )

In [None]:
r2_score(y_test,Y_pre)

In [None]:
r2_score(y_test,Y_pre)

In [None]:
# # fsp, Nlwp,Nionp,a0p,a1p,a2p,a3p = Y_pred.T
# # fst, Nlwt,Niont,a0t,a1t,a2t,a3t= Y_test.T
fsp, fxp, Nlwp = Y_pred.T
fst, fxt, Nlwt = Y_test.T
title = ['Star Formation Efficiency, f$_{star}$',
         'X-Ray Efficiency, f$_{X}$',
         'Lyman-'r'$\alpha$ Photon Number, N$_{lw}$ ($ \times 10^{5}$)']
fig,axe = plt.subplots(1,3,figsize= (17,5))
# ticks_1 = np.arange(0,1.2,0.2)
# ticks_3 = np.arange(0,0.7,0.1)
# ticks = [ticks_1,ticks_1,ticks_3]
fig.text(0.5, 0.01, 'Actual Values', ha='center',fontsize=18)
fig.text(0.08, 0.5, 'Predicted Values', va='center', rotation='vertical',fontsize=18)

sns.regplot(x=fst,y=fsp,ax=axe[0],ci=95, marker=".", color=".3", line_kws=dict(color="r"))
sns.regplot(x=fxt,y=fxp,ax=axe[1],ci=95, marker=".", color=".3", line_kws=dict(color="r"))
sns.regplot(x=Nlwt/1e5,y=Nlwp/1e5,ax=axe[2],ci=95, marker=".", color=".3", line_kws=dict(color="r"))

for i in range(Y_test.shape[1]):
    axe[i].set_title(title[i],fontsize=16)
    axe[i].grid()
    axe[i].tick_params(axis='x', labelsize=12)
    axe[i].tick_params(axis='y', labelsize=12)

# plt.savefig('/home/akash/QNN21cm/Codes/Plots/preds_sfgnc.pdf')

In [None]:
import os
directory_path = '/home/akash/QNN21cm/Codes'

# Specify the filename (you can change 'my_array.csv' to your desired filename)
file_name1 = 'Prediction_21cfgn_JAX.csv'

# Concatenate the directory path and the filename to get the full path
full_path1 = os.path.join(directory_path, file_name1)

# Save the NumPy array as a CSV file
np.savetxt(full_path1, Y_pred, delimiter=',')
# Specify the filename (you can change 'my_array.csv' to your desired filename)
file_name2 = 'Test_21cfgn_JAX.csv'

# Concatenate the directory path and the filename to get the full path
full_path2 = os.path.join(directory_path, file_name2)

# Save the NumPy array as a CSV file
np.savetxt(full_path2, Y_test, delimiter=',')

#### **Varying Contamination**

In [None]:
res = pca_fg_filter(X_fgn,nfg=4)

In [None]:
for i in range(N):
    plt.plot(f,res[i,:])

In [None]:
scaler_x_m = MinMaxScaler(feature_range=(-np.pi,np.pi),clip=True)
scaler_x_m.fit(res)
X_f = scaler_x_m.transform(res)

In [None]:
X_red,X_recon = KPCA(X_f,comps=128,kernel='rbf')

In [None]:
n_qubits = int(np.log2(X_red.shape[1]))
n_layer = 90
lr_sched = optax.exponential_decay(0.01,1000,0.96)
opt = optax.adam(learning_rate=0.01)
x_train,x_val,x_test,y_train,y_val,y_test = data_split(X_red,Y,0.2)

In [None]:
params,train_loss,val_loss, test_loss,Y_pre = Qmodel(X_red,Y,n_qubits,n_layer,opt,epoch=1000,batch_size=512)

In [None]:
test_loss

In [None]:
fig,ax = plt.subplots()
N_p = 1
ax.plot(np.convolve(train_loss,np.ones(N_p)/N_p), color="red", label= 'Training Loss')
ax.plot(np.convolve(val_loss,np.ones(N_p)/N_p), color="green", label= 'Validation loss')
ax.legend()
ax.set_xlabel("Epoch",fontsize=14)
ax.set_ylabel("MSE Loss",fontsize=14)
# ax2=ax.twinx()
ax.grid()
# ax2.set_ylabel("Validation Loss",color="green",fontsize=14)
ax.set_yscale('log')
# ax2.set_yscale('log')
ax.tick_params(axis='x', labelsize=12)
ax.tick_params(axis='y', labelsize=12)

In [None]:
# Compute the output
Y_pred = scaler_y.inverse_transform(Y_pre)
Y_test = scaler_y.inverse_transform(y_test)


r2score= r2_score(y_test, Y_pre, multioutput='uniform_average')
fs_r2score= r2_score(y_test[:,0], Y_pre[:,0])
print("Star_Formation_efficiency_R2score=", fs_r2score)

fx_r2score= r2_score(y_test[:,1], Y_pre[:,1])
print("X-ray_uncertainty_R2score=", fx_r2score)

Nlw_r2score= r2_score(y_test[:,2], Y_pre[:,2])
print("Ly-photon_num_R2score=", Nlw_r2score)

print("tot_r2=",r2_score(y_test,Y_pre))


from math import sqrt
#RMSE Score
mse = mean_squared_error(y_test, Y_pre, multioutput='raw_values')
fs_rmse= sqrt(mse[0])
Nlw_rmse= sqrt(mse[2])
fx_rmse= sqrt(mse[1])
mean = np.mean([fs_rmse,fx_rmse,Nlw_rmse])
print("Star_Formation_efficiency_RMSE=",fs_rmse)
print("X-ray_uncertainty_RMSE=",fx_rmse)
print("Ly-photon_num_RMSE=",Nlw_rmse)
print("tot_rmse=",mean )

In [None]:
# # fsp, Nlwp,Nionp,a0p,a1p,a2p,a3p = Y_pred.T
# # fst, Nlwt,Niont,a0t,a1t,a2t,a3t= Y_test.T
fsp, fxp, Nlwp = Y_pred.T
fst, fxt, Nlwt = Y_test.T
title = ['Star Formation Efficiency, f$_{star}$',
         'X-Ray Efficiency, f$_{X}$',
         'Lyman-'r'$\alpha$ Photon Number, N$_{lw}$ ($ \times 10^{5}$)']
fig,axe = plt.subplots(1,3,figsize= (17,5))
# ticks_1 = np.arange(0,1.2,0.2)
# ticks_3 = np.arange(0,0.7,0.1)
# ticks = [ticks_1,ticks_1,ticks_3]
fig.text(0.5, 0.01, 'Actual Values', ha='center',fontsize=18)
fig.text(0.08, 0.5, 'Predicted Values', va='center', rotation='vertical',fontsize=18)

sns.regplot(x=fst,y=fsp,ax=axe[0],ci=95, marker=".", color=".3", line_kws=dict(color="r"))
sns.regplot(x=fxt,y=fxp,ax=axe[1],ci=95, marker=".", color=".3", line_kws=dict(color="r"))
sns.regplot(x=Nlwt/1e5,y=Nlwp/1e5,ax=axe[2],ci=95, marker=".", color=".3", line_kws=dict(color="r"))

for i in range(Y_test.shape[1]):
    axe[i].set_title(title[i],fontsize=16)
    axe[i].grid()
    axe[i].tick_params(axis='x', labelsize=12)
    axe[i].tick_params(axis='y', labelsize=12)

# plt.savefig('/home/akash/QNN21cm/Codes/Plots/preds_sfgnv.pdf')

In [None]:
import os
directory_path = '/home/akash/QNN21cm/Codes'

# Specify the filename (you can change 'my_array.csv' to your desired filename)
file_name1 = 'Prediction_21vfgn_JAX.csv'

# Concatenate the directory path and the filename to get the full path
full_path1 = os.path.join(directory_path, file_name1)

# Save the NumPy array as a CSV file
np.savetxt(full_path1, Y_pred, delimiter=',')
# Specify the filename (you can change 'my_array.csv' to your desired filename)
file_name2 = 'Test_21vfgn_JAX.csv'

# Concatenate the directory path and the filename to get the full path
full_path2 = os.path.join(directory_path, file_name2)

# Save the NumPy array as a CSV file
np.savetxt(full_path2, Y_test, delimiter=',')

### **Generalization with less training set**

In [None]:
n_qubits = int(np.log2(X_red.shape[1]))
n_layer = 40
lr_sched = optax.exponential_decay(0.01,1000,0.96)
opt = optax.adam(learning_rate=0.01)
x_train,x_val,x_test,y_train,y_val,y_test = data_split(X_red,Y,0.2)

In [None]:
def multi_train(num,epoch,batch):
    # np.random.seed(43)
    idx = np.random.randint(0, 10000, (num,))
    x_train, x_val, x_test, y_train, y_val, y_test = data_split(X_red[idx],Y[idx],test_size=0.2)
    params,train_loss,val_loss, test_loss,Y_pre = Qmodel(X_red[idx],Y[idx],n_qubits,n_layer,opt,epoch=epoch,batch_size=batch,printing=False)
    loss= np.array(train_loss)[-20:].mean()
    val_loss = np.array(val_loss)[-20:].mean()
    gen_err = val_loss-loss
    r2score= r2_score(y_test, Y_pre, multioutput='uniform_average')
    mse = mean_squared_error(y_test, Y_pre, multioutput='raw_values')
    fs_rmse= sqrt(mse[0])
    Nlw_rmse= sqrt(mse[2])
    fx_rmse= sqrt(mse[1])
    r_score = r2_score(Y_pre,y_test)
    print(r_score)
    mean = np.mean([fs_rmse,fx_rmse,Nlw_rmse])
    print(mean)
    return r_score, mean, gen_err

In [None]:
r1,m1,gen1 = multi_train(100,200,32)
r2,m2,gen2 = multi_train(500,200,64)
r3,m3,gen3 = multi_train(1000,200,128)
r4,m4,gen4 = multi_train(3000,200,256)
r5,m5,gen5 = multi_train(7000,500,512)
r6,m6,gen6 = multi_train(10000,500,512)

In [None]:
dat_size = np.array([100,500,1000,3000,7000,10000])
r_err = np.array([r1,r2,r3,r4,r5,r6])
m_err = np.array([m1,m2,m3,m4,m5,m6])
gen_err = np.array([gen1,gen2,gen3,gen4,gen5,gen6])

In [None]:
err = np.array([r_err,m_err,gen_err])
label = [r'R$^2$ Score','RMSE', 'Generalization Error']
y_ax = [r'R$^2$ Score','RMSE',None]

In [None]:
fig,ax = plt.subplots(1,3,figsize=(15,5))
for i in range(3):
  ax[i].plot(dat_size,err[i],marker='o',color='green',label=label[i])
  ax[i].grid()
  ax[i].legend()
  ax[i].set_ylabel(y_ax[i])
  ax[i].set_xlabel('Training Set Size')

In [None]:
# Specify the filename (you can change 'my_array.csv' to your desired filename)
directory_path = '/home/akash/QNN21cm/Codes'
file_name1 = 'Gen_Err_3.csv'

# Concatenate the directory path and the filename to get the full path
full_path1 = os.path.join(directory_path, file_name1)

# Save the NumPy array as a CSV file
np.savetxt(full_path1, err, delimiter=',')

# **Benchmarking**

In [None]:
fig, ax = plt.subplots()

fruits = ['A100 80GB', 'A30 24GB', 'GTX 1070', 'CPU']
t_all = np.array([3,4.5,6.75,19])
# bar_labels = ['red', 'blue', '_red', 'orange']
bar_colors = ['tab:green', 'tab:blue', 'tab:red', 'tab:orange']

ax.bar(fruits, t_all, color=bar_colors)

ax.set_ylabel('JQNN Model Training Time (in minutes)')
ax.set_title('Benchmarking')
ax.grid()
# ax.legend(title='Benchmarking')