In [3]:
import os
import subprocess
import sys
from queue import Queue
from threading import Thread

from deepclean.couplings import subtraction_problems

### Environment variables and config input

In [16]:
HOME = "/home/chiajui.chou"
DEEPCLEAN_IFO = "H1"
DEEPCLEAN_PROBLEM = "180Hz"
DATA_DIR = f"{HOME}/deepclean/data/CDC_test-{DEEPCLEAN_PROBLEM}"
RESULTS_DIR = f"{HOME}/deepclean/results"

# timeseries data
data_fname = f"{DATA_DIR}/deepclean-1378402219-3072.hdf5"

# subtraction problems
ifo = DEEPCLEAN_IFO
strain_channel = f"{ifo}:GDS-CALIB_STRAIN"
problem = [DEEPCLEAN_PROBLEM]
couplings = [subtraction_problems[i][ifo] for i in problem]
witnesses = [j for i in couplings for j in i.channels]
channels = [strain_channel] + witnesses
freq_low = [i.freq_low for i in couplings]
freq_high = [i.freq_high for i in couplings]

# Training
config = f"{HOME}/deepcleanv2/projects/train/config.yaml"
output_dir = f"{RESULTS_DIR}/O4-CDC_{DEEPCLEAN_PROBLEM}_test"
GPU_INDEX = 0
version = 0

### stream_command

In [11]:
def read_stream(stream, process, q):
    stream = getattr(process, stream)
    try:
        it = iter(stream.readline, b"")
        while True:
            try:
                line = next(it)
            except StopIteration:
                break
            q.put(line.decode())
    finally:
        q.put(None)

def stream_process(process):
    q = Queue()
    args = (process, q)
    streams = ["stdout", "stderr"]
    threads = [Thread(target=read_stream, args=(i,) + args) for i in streams]
    for t in threads:
        t.start()

    for _ in range(2):
        for line in iter(q.get, None):
            sys.stdout.write(line)

def stream_command(command: list[str]):
    process = subprocess.Popen(
        command, stdout=subprocess.PIPE,
        stderr=subprocess.PIPE, env=os.environ
    )
    stream_process(process)

### Run training

In [12]:
os.environ["CUDA_VISIBLE_DEVICES"] = str(GPU_INDEX)
command = [
    "python",
    "-m",
    "train",
    "--config",
    config,
    "--data.fname",
    data_fname,
    "--data.channels",
    "[" + ",".join(channels) + "]",
    "--data.freq_low",
    str(freq_low),
    "--data.freq_high",
    str(freq_high),
]
command.append(f"--trainer.logger.save_dir={output_dir}")
print(command)

if not os.path.exists(output_dir):
    os.mkdir(output_dir)

stream_command(command)

['python', '-m', 'train', '--config', '/home/chiajui.chou/deepcleanv2/projects/train/config.yaml', '--data.fname', '/home/chiajui.chou/deepclean/data/CDC_test-180Hz/deepclean-1378402219-3072.hdf5', '--data.channels', '[H1:GDS-CALIB_STRAIN,H1:PEM-CS_MAG_LVEA_OUTPUTOPTICS_Z_DQ,H1:LSC-REFL_A_LF_OUT_DQ,H1:IMC-F_OUT_DQ,H1:IMC-WFS_B_Q_YAW_OUT_DQ,H1:IMC-WFS_A_Q_YAW_OUT_DQ,H1:PEM-EX_MAG_VEA_FLOOR_X_DQ,H1:PEM-CS_ACC_PSL_PERISCOPE_Y_DQ,H1:ISI-ITMY_ST2_BLND_RZ_GS13_CUR_IN1_DQ,H1:IMC-WFS_A_DC_YAW_OUT_DQ,H1:IMC-WFS_B_DC_YAW_OUT_DQ,H1:ISI-HAM6_BLND_GS13Z_IN1_DQ,H1:PEM-EY_VMON_ETMY_ESDPOWERMINUS18_DQ,H1:ISI-HAM2_BLND_GS13RZ_IN1_DQ,H1:IMC-L_OUT_DQ,H1:PEM-CS_ACC_PSL_PERISCOPE_X_DQ,H1:PEM-CS_MAG_EBAY_LSCRACK_Z_DQ,H1:IMC-DOF_4_Y_IN1_DQ,H1:IMC-WFS_B_I_YAW_OUT_DQ]', '--data.freq_low', '[176]', '--data.freq_high', '[184]', '--trainer.logger.save_dir=/home/chiajui.chou/deepclean/results/O4-CDC_180Hz_test']
Seed set to 101588
You are using the plain ModelCheckpoint callback. Consider using LitModelCheckpoint 

### Clean

In [27]:
import glob
import yaml
import numpy as np
import scipy.signal as sig
from gwpy.timeseries import TimeSeriesDict

import torch
from ml4gw.transforms import ChannelWiseScaler
from ml4gw.dataloading import InMemoryDataset
from utils.filt import BandpassFilter

### Config input

In [33]:
# Model loading
sample_rate = 4096
version = version
train_dir = f"{output_dir}/lightning_logs/version_{version}"
with open(f"{train_dir}/config.yaml", 'r') as file:
    train_config = yaml.safe_load(file)

data_dir = DATA_DIR
data_source = glob.glob(f"{data_dir}/*.hdf5")
device = GPU_INDEX

# Data loading and preprocessing
test_start = 1024
test_duration = 2048
clean_kernel_size = 8 * sample_rate
clean_stride = 4 * sample_rate
clean_batch_size = 128

# Inference
window = "hann"

# Clean and write
channel = f"{strain_channel}_DC"
format = "hdf5"
out_dir = f"/home/chiajui.chou/deepclean/dcprod-test"
out_label = f"CDC_test-{DEEPCLEAN_PROBLEM}_dcprod_v{version}"

### Load model

In [22]:
# Load trained model and scalers
model_path = f"{train_dir}/model.pt"
model = torch.jit.load(model_path).to(device)

strain_channel = train_config['data']['channels'][0]
witnesses = sorted(train_config['data']['channels'][1:])
num_witnesses = len(witnesses)
X_scaler = ChannelWiseScaler(num_channels=num_witnesses)
X_scaler_path = f"{train_dir}/X_scaler.pt"
X_scaler.load_state_dict(torch.load(X_scaler_path))

y_scaler = ChannelWiseScaler()
y_scaler_path = f"{train_dir}/y_scaler.pt"
y_scaler.load_state_dict(torch.load(y_scaler_path))

# bandpass filter
freq_low = train_config['data']['freq_low']
freq_high = train_config['data']['freq_high']
filt_order = int(train_config['data']['filt_order'])
bandpass = BandpassFilter(freq_low, freq_high, sample_rate, filt_order)

### Load data

In [26]:
# Load data
source = data_source
data = TimeSeriesDict.read(source)

# Preprocess
infer_start = int(test_start * sample_rate)
infer_size = int(test_duration * sample_rate)
idx = slice(infer_start, infer_start + infer_size)

infer_y = torch.Tensor(data[strain_channel][idx])
infer_X = torch.zeros((num_witnesses, infer_size))
for i, channel in enumerate(witnesses):
    infer_X[i] = torch.Tensor(data[channel][idx])

infer_X = X_scaler(infer_X)

# Inference dataloader
X_inference = InMemoryDataset(
    infer_X,
    kernel_size=clean_kernel_size,
    stride=clean_stride,
    batch_size=clean_batch_size,
    coincident=True,
    shuffle=False,
    device=device,
)
y_inference = InMemoryDataset(
    infer_y,
    kernel_size=clean_kernel_size,
    stride=clean_stride,
    batch_size=clean_batch_size,
    coincident=True,
    shuffle=False,
    device=device,
)

### Inference

In [29]:
# Inference
prediction = []
for X, y in zip(iter(X_inference), iter(y_inference)):
    pred = model(X)
    prediction.append(pred.cpu().double().detach().numpy())

prediction = np.concatenate(prediction)

# Aggregating predictions (deepclean-prod)
N = prediction.shape[0]

# hanning window function
window_fn = sig.get_window(window, clean_kernel_size) * clean_stride / clean_kernel_size 

# Concatenate timeseries
nsamp = int((N - 1) * clean_stride + clean_kernel_size)
y_pred = np.zeros(nsamp)
for i in range(N):
    idx = slice(i*clean_stride, i*clean_stride + clean_kernel_size)
    y_pred[idx] += prediction[i]*window_fn

# Postprocess: reversed normalize, bandpass
noise = torch.Tensor(y_pred, device="cpu")
noise = y_scaler(noise.double(), reverse=True)
noise = bandpass(noise.detach().numpy())

In [30]:
# Aggregating predictions (deepclean-prod)
N = prediction.shape[0]

# hanning window function
window_fn = sig.get_window(window, clean_kernel_size) * clean_stride / clean_kernel_size 

# Concatenate timeseries
nsamp = int((N - 1) * clean_stride + clean_kernel_size)
y_pred = np.zeros(nsamp)
for i in range(N):
    idx = slice(i*clean_stride, i*clean_stride + clean_kernel_size)
    y_pred[idx] += prediction[i]*window_fn

# Postprocess: reversed normalize, bandpass
noise = torch.Tensor(y_pred, device="cpu")
noise = y_scaler(noise.double(), reverse=True)
noise = bandpass(noise.detach().numpy())

### Clean and write cleaned data

In [34]:
# Clean
from gwpy.timeseries import TimeSeries
raw = infer_y.cpu().double().detach().numpy()
cleaned = raw - noise

# Save cleaned strain to gwf files
t0 = data[strain_channel].t0.value
duration = data[strain_channel].duration.value

cleaned_ts = TimeSeries(
    cleaned,
    t0=t0,
    sample_rate=sample_rate,
    channel=channel,
)
raw = data[strain_channel].crop(t0, t0+duration)
raw_ts = TimeSeries(
    raw,
    t0=t0,
    sample_rate=sample_rate,
    channel=strain_channel,
)

output_file = f"{out_dir}/{out_label}-{int(t0)}-{int(duration)}.{format}"
ts_dict = TimeSeriesDict()
ts_dict[channel] = cleaned_ts
ts_dict[strain_channel] = raw_ts
ts_dict.write(output_file, format=format)