In [28]:
import re
import os
import numpy as np
import csv
import random

In [29]:
def extract_snr(filename):
    match = re.search(r"SNR=(\d+\.\d{1,3})", filename)
    if match:
        return float(match.group(1))
    else:
        return 0


# SNR=5.00892423193341_m1=15_m2=11_17464_time=0.999931060125.png


def extract_mass(filename):
    m1 = re.search(r"m1=(.\d{1,2})", filename)
    m2 = re.search(r"m2=(.\d{1,2})", filename)
    if m1:
        return int(m1.group(1)), int(m2.group(1))
    else:
        return 0, 0


def extract_time(filename):
    time = re.search(r"time=(\d+\.\d{1,6})", filename)
    if time:
        return float(time.group(1))
    else:
        return 0


# Function to load labels from filenames
def load_labels_from_directory(directory):
    snr_array = []
    m1_array = []
    m2_array = []
    time_array = []
    filenames = []
    for root, _, files in os.walk(directory):
        for filename in files:
            if filename.endswith(".png"):
                try:
                    snr = extract_snr(filename)
                    m1, m2 = extract_mass(filename)
                    time = extract_time(filename)

                except ValueError:
                    snr = 0  # Label for noise
                    m1 = 0
                    m2 = 0
                    time = 0

                snr_array.append(snr)
                m1_array.append(m1)
                m2_array.append(m2)
                time_array.append(time)

                filenames.append(os.path.join(root, filename))
    return (
        np.array(snr_array),
        np.array(m1_array),
        np.array(m2_array),
        np.array(time_array),
        filenames,
    )


print(
    load_labels_from_directory(
        "/home/arush/GW_Project_1/Data_Generation/Continous_Check/Data"
    )
)

(array([ 5.958, 19.188, 19.242, ..., 19.594, 17.524,  8.25 ]), array([11, 13, 19, ..., 22, 20, 13]), array([18, 12, 26, ..., 10, 26, 25]), array([0.800164, 0.999492, 0.797136, ..., 1.000357, 0.223374, 0.998576]), ['/home/arush/GW_Project_1/Data_Generation/Continous_Check/Data/SNR=5.958721395863849_m1=11_m2=18_798_time=0.8001642333750001.png', '/home/arush/GW_Project_1/Data_Generation/Continous_Check/Data/SNR=19.18818511219659_m1=13_m2=12_22999_time=0.9994921635.png', '/home/arush/GW_Project_1/Data_Generation/Continous_Check/Data/SNR=19.242508177757244_m1=19_m2=26_21683_time=0.797136732125.png', '/home/arush/GW_Project_1/Data_Generation/Continous_Check/Data/SNR=10.76515758354318_m1=28_m2=10_8560_time=0.20934385774999997.png', '/home/arush/GW_Project_1/Data_Generation/Continous_Check/Data/SNR=8.99933228139127_m1=22_m2=21_4167_time=0.999397876625.png', '/home/arush/GW_Project_1/Data_Generation/Continous_Check/Data/SNR=15.40098351942467_m1=21_m2=16_22748_time=0.22239024637500004.png', '/ho

In [30]:
snr_label, m1_label, m2_label, time_label, filenames_data = load_labels_from_directory(
    "/home/arush/GW_Project_1/Data_Generation/Continous_Check/Data"
)

(
    snr_label_noise,
    m1_label_noise,
    m2_label_noise,
    time_label_noise,
    filenames_data_noise,
) = load_labels_from_directory(
    "/home/arush/GW_Project_1/Data_Generation/Continous_Check/noise"
)


SNR_labels = np.concatenate((snr_label, snr_label_noise))
m1_labels = np.concatenate((m1_label, m1_label_noise))
m2_labels = np.concatenate((m2_label, m2_label_noise))
time_labels = np.concatenate((time_label, time_label_noise))
filenames = np.concatenate((filenames_data, filenames_data_noise))
combined = list(zip(SNR_labels, m1_labels, m2_labels, time_labels, filenames))

random.seed(42)
random.shuffle(combined)

SNR_labels, m1_labels, m2_labels, time_labels, filenames = zip(*combined)

SNR_labels = np.array(SNR_labels)
m1_labels = np.array(m1_labels)
m2_labels = np.array(m2_labels)
time_labels = np.array(time_labels)
filenames = np.array(filenames)

print("SNR lengths: ", len(SNR_labels))
print("m1 lengths: ", len(m1_labels))
print("m2 lengths: ", len(m2_labels))
print("time lengths: ", len(time_labels))
print("filenames lengths: ", len(filenames))

SNR lengths:  39100
m1 lengths:  39100
m2 lengths:  39100
time lengths:  39100
filenames lengths:  39100


In [31]:
# function to find the number of files in train, test and validation set
train_len = int(0.8 * len(SNR_labels))
test_len = int(0.1 * len(SNR_labels))
val_len = len(SNR_labels) - train_len - test_len

print("training dataset length: ", train_len)
print("testing dataset length: ", test_len)
print("validation dataset length: ", val_len)

training dataset length:  31280
testing dataset length:  3910
validation dataset length:  3910


In [33]:
i = 0
with open(
    "/home/arush/GW_Project_1/Data_Generation/Continous_Check/cont_data_train.csv",
    "w",
    newline="",
) as file:
    writer = csv.writer(file)
    writer.writerow(["SNR", "M1", "M2", "TIME", "Path"])
    while i < train_len:
        writer.writerow(
            [SNR_labels[i], m1_labels[i], m2_labels[i], time_labels[i], filenames[i]]
        )
        i += 1

with open(
    "/home/arush/GW_Project_1/Data_Generation/Continous_Check/cont_data_test.csv",
    "w",
    newline="",
) as file:
    writer = csv.writer(file)
    writer.writerow(["SNR", "M1", "M2", "TIME", "Path"])
    while i < train_len + test_len:
        writer.writerow(
            [SNR_labels[i], m1_labels[i], m2_labels[i], time_labels[i], filenames[i]]
        )
        i += 1

with open(
    "/home/arush/GW_Project_1/Data_Generation/Continous_Check/cont_data_val.csv",
    "w",
    newline="",
) as file:
    writer = csv.writer(file)
    writer.writerow(["SNR", "M1", "M2", "TIME", "Path"])
    while i < (len(SNR_labels)):
        writer.writerow(
            [SNR_labels[i], m1_labels[i], m2_labels[i], time_labels[i], filenames[i]]
        )
        i += 1