In [1]:
from ml4gw.transforms import ChannelWiseScaler
from utils.filt import BandpassFilter
from train.data import DeepCleanDataset

from train.architectures import Autoencoder
from train.metrics import PsdRatio, OfflinePsdRatio
from train.model import DeepClean

### DeepCleanDataset

In [2]:
# data args
fname = "/home/chiajui.chou/deepclean/data/CDC_test-120Hz/deepclean-1378402219-3072.hdf5"
channels = [
    "H1:PEM-CS_ACC_PSL_TABLE1_Z_DQ",
    "H1:HPI-HAM4_BLND_L4C_Y_IN1_DQ",
    "H1:IMC-WFS_A_Q_PIT_OUT_DQ",
    "H1:IMC-WFS_B_Q_PIT_OUT_DQ",
    "H1:IMC-DOF_1_Y_IN1_DQ",
    "H1:PEM-CS_ACC_PSL_PERISCOPE_X_DQ",
    "H1:IMC-DOF_4_Y_IN1_DQ",
    "H1:HPI-HAM6_BLND_L4C_RX_IN1_DQ",
    "H1:IMC-WFS_A_DC_YAW_OUT_DQ",
    "H1:IMC-WFS_B_I_PIT_OUT_DQ",
    "H1:IMC-WFS_B_Q_YAW_OUT_DQ",
    "H1:IMC-WFS_A_Q_YAW_OUT_DQ",
    "H1:PSL-PMC_HV_MON_OUT_DQ",
    "H1:LSC-REFL_SERVO_ERR_OUT_DQ",
    "H1:IMC-WFS_B_I_YAW_OUT_DQ",
    "H1:PEM-CS_ACC_HAM2_PRM_Z_DQ",
    "H1:PEM-CS_ACC_PSL_TABLE1_X_DQ",
    "H1:LSC-MCL_IN1_DQ",
    "H1:IMC-WFS_A_DC_PIT_OUT_DQ",
    "H1:IMC-F_OUT_DQ",
    "H1:LSC-REFL_SERVO_CTRL_OUT_DQ",
    "H1:PSL-PMC_MIXER_OUT_DQ",
    "H1:IMC-WFS_B_DC_YAW_OUT_DQ",
    "H1:PEM-CS_ACC_PSL_PERISCOPE_Y_DQ",
    "H1:IMC-DOF_2_Y_IN1_DQ",
    "H1:LSC-MCL_OUT_DQ",
    "H1:HPI-HAM1_BLND_L4C_RX_IN1_DQ",
    "H1:HPI-HAM6_BLND_L4C_VP_IN1_DQ",
    "H1:IMC-DOF_2_P_IN1_DQ",
    "H1:IMC-L_OUT_DQ",
    "H1:PEM-CS_ACC_LVEAFLOOR_BS_Z_DQ",
    "H1:PEM-CS_ACC_HAM2_PRM_Y_DQ",
    "H1:PEM-CS_ACC_LVEAFLOOR_HAM1_X_DQ",
    "H1:HPI-HAM2_BLND_L4C_RY_IN1_DQ",
    "H1:HPI-HAM1_BLND_L4C_VP_IN1_DQ",
]
freq_low = [110]
freq_high = [130]
train_duration = 1024
test_duration = 2048
valid_frac = 0.1
train_stride = 0.25
inference_sampling_rate = 2
start_offset = 0
clean_kernel_length = 8
clean_stride = 4
batch_size = 32
kernel_length = 8
filt_order = 8

dc_dataset = DeepCleanDataset(
    fname=fname,
    channels=channels,
    kernel_length=kernel_length,
    freq_low=freq_low,
    freq_high=freq_high,
    batch_size=batch_size,
    train_duration=train_duration,
    test_duration=test_duration,
    valid_frac=valid_frac,
    train_stride=train_stride,
    inference_sampling_rate=inference_sampling_rate,
    clean_kernel_length=clean_kernel_length,
    clean_stride=clean_stride,
    start_offset=start_offset,
    filt_order=filt_order,
)

In [3]:
dc_dataset.hparams

"batch_size":              32
"channels":                ['H1:PEM-CS_ACC_PSL_TABLE1_Z_DQ', 'H1:HPI-HAM4_BLND_L4C_Y_IN1_DQ', 'H1:IMC-WFS_A_Q_PIT_OUT_DQ', 'H1:IMC-WFS_B_Q_PIT_OUT_DQ', 'H1:IMC-DOF_1_Y_IN1_DQ', 'H1:PEM-CS_ACC_PSL_PERISCOPE_X_DQ', 'H1:IMC-DOF_4_Y_IN1_DQ', 'H1:HPI-HAM6_BLND_L4C_RX_IN1_DQ', 'H1:IMC-WFS_A_DC_YAW_OUT_DQ', 'H1:IMC-WFS_B_I_PIT_OUT_DQ', 'H1:IMC-WFS_B_Q_YAW_OUT_DQ', 'H1:IMC-WFS_A_Q_YAW_OUT_DQ', 'H1:PSL-PMC_HV_MON_OUT_DQ', 'H1:LSC-REFL_SERVO_ERR_OUT_DQ', 'H1:IMC-WFS_B_I_YAW_OUT_DQ', 'H1:PEM-CS_ACC_HAM2_PRM_Z_DQ', 'H1:PEM-CS_ACC_PSL_TABLE1_X_DQ', 'H1:LSC-MCL_IN1_DQ', 'H1:IMC-WFS_A_DC_PIT_OUT_DQ', 'H1:IMC-F_OUT_DQ', 'H1:LSC-REFL_SERVO_CTRL_OUT_DQ', 'H1:PSL-PMC_MIXER_OUT_DQ', 'H1:IMC-WFS_B_DC_YAW_OUT_DQ', 'H1:PEM-CS_ACC_PSL_PERISCOPE_Y_DQ', 'H1:IMC-DOF_2_Y_IN1_DQ', 'H1:LSC-MCL_OUT_DQ', 'H1:HPI-HAM1_BLND_L4C_RX_IN1_DQ', 'H1:HPI-HAM6_BLND_L4C_VP_IN1_DQ', 'H1:IMC-DOF_2_P_IN1_DQ', 'H1:IMC-L_OUT_DQ', 'H1:PEM-CS_ACC_LVEAFLOOR_BS_Z_DQ', 'H1:PEM-CS_ACC_HAM2_PRM_Y_DQ', 'H1:PEM-

In [None]:
print(getattr(dc_dataset, 'X_scaler'))
print(getattr(dc_dataset, 'y_scaler'))

ChannelWiseScaler()
ChannelWiseScaler()


In [None]:
dc_dataset.setup()

ChannelWiseScaler()
ChannelWiseScaler()


In [17]:
# print(getattr(dc_dataset, 'X_scaler'))
y_scaler = getattr(dc_dataset, 'y_scaler')
print(y_scaler.state_dict().keys())
print(y_scaler.state_dict()['mean'])
print(y_scaler.state_dict()['std'])
X_scaler = getattr(dc_dataset, 'X_scaler')
print(X_scaler.state_dict()['mean'].shape)
print(X_scaler.state_dict()['std'].shape)

odict_keys(['mean', 'std'])
tensor([-4.1139e-05])
tensor([30.9883])
torch.Size([34, 1])
torch.Size([34, 1])


### DeepClean Model

In [4]:
# arch
num_witnesses = dc_dataset.num_witnesses
hidden_channels = [8, 16, 32, 64]
arch = Autoencoder(
    num_witnesses=num_witnesses,
    hidden_channels=hidden_channels,
)

# loss
sample_rate = dc_dataset.sample_rate
fftlength = 2
freq_low = freq_low
freq_high = freq_high
asd = False
loss = PsdRatio(
    sample_rate=sample_rate,
    fftlength=fftlength,
    freq_low=freq_low,
    freq_high=freq_high,
    asd=asd,
)

# metric
sample_rate = dc_dataset.sample_rate
clean_kernel_length = dc_dataset.hparams.clean_kernel_length
clean_stride = dc_dataset.hparams.clean_stride
window = "hann"
bandpass = dc_dataset.bandpass
y_scaler = dc_dataset.y_scaler
metric = OfflinePsdRatio(
    sample_rate=sample_rate,
    clean_kernel_length=clean_kernel_length,
    clean_stride=clean_stride,
    window=window,
    bandpass=bandpass,
    y_scaler=y_scaler,
)

# patience
patience = 20

# model
dc_model = DeepClean(
    arch=arch,
    loss=loss,
    metric=metric,
    patience=patience,
    save_top_k_models=10,
)

### Validation loop

In [5]:
dc_dataset.setup(stage='fit')
val_dataloader = dc_dataset.val_dataloader()
for batch in iter(val_dataloader):
    print(batch[0].shape)
    print(batch[1].shape)

torch.Size([24, 34, 32768])
torch.Size([102, 1, 4096])


In [6]:
dc_model.metric.reset()
for batch in iter(val_dataloader):
    dc_model.validation_step(batch, None)

In [7]:
predicted_noise, raw_strain = dc_model.metric.clean()

In [8]:
print(predicted_noise.shape)
print(raw_strain.shape)

torch.Size([1, 409600])
torch.Size([1, 409600])


### Test loop

In [9]:
dc_dataset.setup(stage='test')
test_dataloader = dc_dataset.test_dataloader()
for batch in iter(test_dataloader):
    if batch[0] is not None:
        print(batch[0].shape)
    if batch[1] is not None:
        print(batch[1].shape)
    else:
        print("None")

torch.Size([128, 34, 32768])
torch.Size([128, 1, 4096])
torch.Size([128, 34, 32768])
torch.Size([128, 1, 4096])
torch.Size([128, 34, 32768])
torch.Size([128, 1, 4096])
torch.Size([127, 34, 32768])
torch.Size([128, 1, 4096])
torch.Size([128, 1, 4096])
torch.Size([128, 1, 4096])
torch.Size([128, 1, 4096])
torch.Size([128, 1, 4096])
torch.Size([128, 1, 4096])
torch.Size([128, 1, 4096])
torch.Size([128, 1, 4096])
torch.Size([128, 1, 4096])
torch.Size([128, 1, 4096])
torch.Size([128, 1, 4096])
torch.Size([128, 1, 4096])
torch.Size([128, 1, 4096])


In [10]:
dc_model.metric.reset()
for batch in iter(test_dataloader):
    dc_model.test_step(batch, None)

In [11]:
predict, raw = dc_model.metric.clean()

In [12]:
print(predict.shape)
print(raw.shape)

torch.Size([1, 8388608])
torch.Size([1, 8388608])
