In [23]:
import requests
from tqdm import tqdm
import os
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
from pycbc.waveform import get_td_waveform
from gwpy.timeseries import TimeSeries
from scipy.signal import spectrogram, welch, butter, filtfilt, stft
from scipy.fftpack import fft, ifft
from scipy.signal.windows import tukey
import torch.nn.functional as F
from tqdm import tqdm
from astropy.cosmology import z_at_value, Planck18 as cosmo
from astropy import units as u


def download_ET_noise_files(n_downloads=5, file_list='./MDC1_v2_noise_E1.txt', outdir="data"):
    '''
    Download file from url to outdir with a progress bar

    for Downloading ET noise data files 
    '''
    

    with open(file_list) as file:
        filenames = [line.rstrip() for line in file][300:]
    print(f'Number of downloadable files: {len(filenames)}')
    print(f'Starting download of {n_downloads} files...')

    os.makedirs(outdir, exist_ok=True)

    local_filenames = []
    for filename in filenames[:n_downloads]:
        local_filename = os.path.join(outdir, filename.split('/')[-1])
        
        # Skip if file already exists
        if os.path.exists(local_filename):
            local_filenames.append(local_filename)
            print(f"Skipping {local_filename}, already exists.")
            continue
        
        url = f'http://et-origin.cism.ucl.ac.be/{filename}'

        # Stream download so we don’t load the whole file into memory
        with requests.get(url, stream=True) as r:
            r.raise_for_status()
            total_size = int(r.headers.get("content-length", 0))
            block_size = 1024  # 1 KB
            progress = tqdm(total=total_size, unit="iB", unit_scale=True, desc=local_filename)
            
            with open(local_filename, "wb") as f:
                for chunk in r.iter_content(block_size):
                    f.write(chunk)
                    progress.update(len(chunk))
            progress.close()
        local_filenames.append(local_filename)
        
    return local_filenames

### DOWNLOAD ET NOISE FILES ###
raw_data_dir = './ET_noise/raw_waveforms/'
ET_noise_file_locations = download_ET_noise_files(n_downloads=10, outdir=raw_data_dir) # you will need to download more than 10, this is just to demonstrate


Number of downloadable files: 1000
Starting download of 10 files...
Skipping ./ET_noise/raw_waveforms/E-E1_STRAIN_NOISE-1000614400-2048.gwf, already exists.
Skipping ./ET_noise/raw_waveforms/E-E1_STRAIN_NOISE-1000616448-2048.gwf, already exists.
Skipping ./ET_noise/raw_waveforms/E-E1_STRAIN_NOISE-1000618496-2048.gwf, already exists.
Skipping ./ET_noise/raw_waveforms/E-E1_STRAIN_NOISE-1000620544-2048.gwf, already exists.
Skipping ./ET_noise/raw_waveforms/E-E1_STRAIN_NOISE-1000622592-2048.gwf, already exists.
Skipping ./ET_noise/raw_waveforms/E-E1_STRAIN_NOISE-1000624640-2048.gwf, already exists.
Skipping ./ET_noise/raw_waveforms/E-E1_STRAIN_NOISE-1000626688-2048.gwf, already exists.
Skipping ./ET_noise/raw_waveforms/E-E1_STRAIN_NOISE-1000628736-2048.gwf, already exists.
Skipping ./ET_noise/raw_waveforms/E-E1_STRAIN_NOISE-1000630784-2048.gwf, already exists.
Skipping ./ET_noise/raw_waveforms/E-E1_STRAIN_NOISE-1000632832-2048.gwf, already exists.


In [8]:
ET_noise_file_locations

['./ET_noise/raw_waveforms/E-E1_STRAIN_NOISE-1000614400-2048.gwf',
 './ET_noise/raw_waveforms/E-E1_STRAIN_NOISE-1000616448-2048.gwf',
 './ET_noise/raw_waveforms/E-E1_STRAIN_NOISE-1000618496-2048.gwf',
 './ET_noise/raw_waveforms/E-E1_STRAIN_NOISE-1000620544-2048.gwf',
 './ET_noise/raw_waveforms/E-E1_STRAIN_NOISE-1000622592-2048.gwf',
 './ET_noise/raw_waveforms/E-E1_STRAIN_NOISE-1000624640-2048.gwf',
 './ET_noise/raw_waveforms/E-E1_STRAIN_NOISE-1000626688-2048.gwf',
 './ET_noise/raw_waveforms/E-E1_STRAIN_NOISE-1000628736-2048.gwf',
 './ET_noise/raw_waveforms/E-E1_STRAIN_NOISE-1000630784-2048.gwf',
 './ET_noise/raw_waveforms/E-E1_STRAIN_NOISE-1000632832-2048.gwf']

In [12]:

# whitening
def whiten(data, f_psd, psd, fs):
    psd = np.maximum(psd, 1e-20)
    freqs = np.fft.fftfreq(len(data), d=1/fs)
    psd_interp = np.interp(freqs, f_psd, psd)
    white_fft = fft(data) / np.sqrt(psd_interp)
    return np.real(ifft(white_fft))


def process_file(data_file, out_dir, train_test='train'):

    '''Process a single ET noise data file: read, cut into 2s segments, whiten the noise, and save as tensors.'''

    print(f'Processing: {data_file}')
    try:
        strain = TimeSeries.read(data_file, channel=channel)
        f_psd, psd = welch(strain.value, fs=sample_rate, nperseg=sample_rate*4)

        num_segments = int((strain.times.value[-1] - strain.times.value[0]) // segment_duration)
        segments = []

        os.makedirs(f'{out_dir}/{train_test}/', exist_ok=True)
        #for i in range(num_segments):
        for i in tqdm(range(num_segments), desc=f'Processing {data_file}', unit='seg'):

            start = strain.times.value[0] + i * segment_duration
            end = start + segment_duration
            if end > strain.times.value[-1]:
                break

            segment = strain.crop(start, end).value
            segment = whiten(segment, f_psd, psd, sample_rate)
            tensor = torch.tensor(segment, dtype=torch.float32).unsqueeze(0)
            segments.append(tensor)

        batch_tensor = torch.cat(segments, dim=0)
        torch.save(batch_tensor, f'{out_dir}/{train_test}/' + data_file.split('/')[-1][:-4] + '.pt')

    except Exception as e:
        print(f"Failed to process {data_file}: {e}")
        return None

segmented_data_dir = './ET_noise/segmented_data'
channel = 'E1:STRAIN'
sample_rate = 8192
segment_duration = 2
segment_length = sample_rate * segment_duration

for file_loc in ET_noise_file_locations[:8]:
    process_file(file_loc, out_dir=segmented_data_dir, train_test='train')

for file_loc in ET_noise_file_locations[8:]:
    process_file(file_loc, out_dir=segmented_data_dir, train_test='test')

Processing: ./ET_noise/raw_waveforms/E-E1_STRAIN_NOISE-1000614400-2048.gwf


Processing ./ET_noise/raw_waveforms/E-E1_STRAIN_NOISE-1000614400-2048.gwf: 100%|██████████| 1023/1023 [00:07<00:00, 132.23seg/s]


Processing: ./ET_noise/raw_waveforms/E-E1_STRAIN_NOISE-1000616448-2048.gwf


Processing ./ET_noise/raw_waveforms/E-E1_STRAIN_NOISE-1000616448-2048.gwf: 100%|██████████| 1023/1023 [00:07<00:00, 139.90seg/s]


Processing: ./ET_noise/raw_waveforms/E-E1_STRAIN_NOISE-1000618496-2048.gwf


Processing ./ET_noise/raw_waveforms/E-E1_STRAIN_NOISE-1000618496-2048.gwf: 100%|██████████| 1023/1023 [00:07<00:00, 140.74seg/s]


Processing: ./ET_noise/raw_waveforms/E-E1_STRAIN_NOISE-1000620544-2048.gwf


Processing ./ET_noise/raw_waveforms/E-E1_STRAIN_NOISE-1000620544-2048.gwf: 100%|██████████| 1023/1023 [00:07<00:00, 140.83seg/s]


Processing: ./ET_noise/raw_waveforms/E-E1_STRAIN_NOISE-1000622592-2048.gwf


Processing ./ET_noise/raw_waveforms/E-E1_STRAIN_NOISE-1000622592-2048.gwf: 100%|██████████| 1023/1023 [00:07<00:00, 140.39seg/s]


Processing: ./ET_noise/raw_waveforms/E-E1_STRAIN_NOISE-1000624640-2048.gwf


Processing ./ET_noise/raw_waveforms/E-E1_STRAIN_NOISE-1000624640-2048.gwf: 100%|██████████| 1023/1023 [00:07<00:00, 137.39seg/s]


Processing: ./ET_noise/raw_waveforms/E-E1_STRAIN_NOISE-1000626688-2048.gwf


Processing ./ET_noise/raw_waveforms/E-E1_STRAIN_NOISE-1000626688-2048.gwf: 100%|██████████| 1023/1023 [00:07<00:00, 141.09seg/s]


Processing: ./ET_noise/raw_waveforms/E-E1_STRAIN_NOISE-1000628736-2048.gwf


Processing ./ET_noise/raw_waveforms/E-E1_STRAIN_NOISE-1000628736-2048.gwf: 100%|██████████| 1023/1023 [00:07<00:00, 141.80seg/s]


Processing: ./ET_noise/raw_waveforms/E-E1_STRAIN_NOISE-1000630784-2048.gwf


Processing ./ET_noise/raw_waveforms/E-E1_STRAIN_NOISE-1000630784-2048.gwf: 100%|██████████| 1023/1023 [00:07<00:00, 141.14seg/s]


Processing: ./ET_noise/raw_waveforms/E-E1_STRAIN_NOISE-1000632832-2048.gwf


Processing ./ET_noise/raw_waveforms/E-E1_STRAIN_NOISE-1000632832-2048.gwf: 100%|██████████| 1023/1023 [00:07<00:00, 141.42seg/s]


In [29]:
def inject_signal_into_noise(data_file, mass_range, distance_range, out_dir, train_test='train', channel = 'E1:STRAIN'):
    ''' Process raw waveforms into spectrograms with injected signals '''

    print(f'Processing: {data_file}')
    try:
        strain = TimeSeries.read(data_file, channel=channel)
        f_psd, psd = welch(strain.value, fs=sample_rate, nperseg=sample_rate*4)

        num_segments = int((strain.times.value[-1] - strain.times.value[0]) // segment_duration)
        segments = []
        injected_segments = []

        #for i in range(num_segments):
        for i in tqdm(range(num_segments), desc=f'Processing {data_file}', unit='seg'):

            start = strain.times.value[0] + i * segment_duration
            end = start + segment_duration
            if end > strain.times.value[-1]:
                break

            m1 = np.random.uniform(mass_range[0], mass_range[1])
            m2 = np.random.uniform(mass_range[0], m1)
            distance = np.random.uniform(distance_range[0], distance_range[1])

            lumi_distance = distance * u.Gpc
            z = z_at_value(cosmo.luminosity_distance, lumi_distance)
            m1_det = m1 * (1 + z)
            m2_det = m2 * (1 + z)

            segment = strain.crop(start, end).value
            injected_segment = segment.copy()
            hp, _ = get_td_waveform(
                    approximant="IMRPhenomHM",
                    mass1=m1,#m1_det, 
                    mass2=m2,#m2_det,
                    delta_t=1 / sample_rate,
                    f_lower=3,
                    distance=lumi_distance.to(u.Mpc).value,
                    spin1z=0.0,
                    spin2z=0.0,
                    eccentricity=0.0,
                    inclination=0.0
                )

            gw_tensor = torch.tensor(hp.numpy(), dtype=torch.float32)
            peak = torch.argmax(gw_tensor).item()
            new_peak = np.random.randint(int(0.2 * segment_length), int(0.8 * segment_length))
            shift = new_peak - peak
            gw_tensor = torch.roll(gw_tensor, shifts=shift)

            start_idx = max(0, (new_peak - segment_length) // 2)
            end_idx = start_idx + segment_length
            gw_tensor = gw_tensor[start_idx:end_idx]

            if len(gw_tensor) < segment_length:
                gw_tensor = F.pad(gw_tensor, (0, segment_length - len(gw_tensor)))

            #taper = tukey(segment_length, alpha=0.2)
            gw_tapered = gw_tensor.numpy() #* taper

            injected_segment += gw_tapered

            injected_segment = whiten(injected_segment, f_psd, psd, sample_rate)
            injected_tensor = torch.tensor(injected_segment, dtype=torch.float32).unsqueeze(0)

            injected_segments.append(injected_tensor)

        injected_batch_tensor = torch.cat(injected_segments, dim=0)
        os.makedirs(out_dir + f'/{train_test}/', exist_ok=True)
        torch.save(injected_batch_tensor, out_dir + f'/{train_test}/' + data_file.split('/')[-1][:-4] + '.pt')


    except Exception as e:
        print(f"Failed to process {data_file}: {e}")
        return None
    

channel = 'E1:STRAIN'
sample_rate = 8192
segment_duration = 2
segment_length = sample_rate * segment_duration
mass_range = [100, 200] # masses in solar mass
distance_range = [5, 20] # distance in Gpc
injected_data_dir = f'./ET_noise/injected_data/mass_{mass_range[0]}_{mass_range[1]}__distance_{distance_range[0]}_{distance_range[1]}Gpc'

for file_loc in ET_noise_file_locations[:8]:
    inject_signal_into_noise(file_loc, mass_range=mass_range, distance_range=distance_range, out_dir=injected_data_dir, train_test='train')

for file_loc in ET_noise_file_locations[8:]:
    inject_signal_into_noise(file_loc, mass_range=mass_range, distance_range=distance_range, out_dir=injected_data_dir, train_test='test')

Processing: ./ET_noise/raw_waveforms/E-E1_STRAIN_NOISE-1000614400-2048.gwf


Processing: ./ET_noise/raw_waveforms/E-E1_STRAIN_NOISE-1000614400-2048.gwf


Processing ./ET_noise/raw_waveforms/E-E1_STRAIN_NOISE-1000614400-2048.gwf: 100%|██████████| 1023/1023 [03:55<00:00,  4.34seg/s]



Processing: ./ET_noise/raw_waveforms/E-E1_STRAIN_NOISE-1000614400-2048.gwf


Processing ./ET_noise/raw_waveforms/E-E1_STRAIN_NOISE-1000614400-2048.gwf: 100%|██████████| 1023/1023 [03:55<00:00,  4.34seg/s]



Processing: ./ET_noise/raw_waveforms/E-E1_STRAIN_NOISE-1000616448-2048.gwf


Processing: ./ET_noise/raw_waveforms/E-E1_STRAIN_NOISE-1000614400-2048.gwf


Processing ./ET_noise/raw_waveforms/E-E1_STRAIN_NOISE-1000614400-2048.gwf: 100%|██████████| 1023/1023 [03:55<00:00,  4.34seg/s]



Processing: ./ET_noise/raw_waveforms/E-E1_STRAIN_NOISE-1000616448-2048.gwf


Processing ./ET_noise/raw_waveforms/E-E1_STRAIN_NOISE-1000616448-2048.gwf: 100%|██████████| 1023/1023 [03:59<00:00,  4.28seg/s]


Processing: ./ET_noise/raw_waveforms/E-E1_STRAIN_NOISE-1000618496-2048.gwf


Processing ./ET_noise/raw_waveforms/E-E1_STRAIN_NOISE-1000618496-2048.gwf: 100%|██████████| 1023/1023 [03:53<00:00,  4.37seg/s]


Processing: ./ET_noise/raw_waveforms/E-E1_STRAIN_NOISE-1000620544-2048.gwf


Processing ./ET_noise/raw_waveforms/E-E1_STRAIN_NOISE-1000620544-2048.gwf: 100%|██████████| 1023/1023 [03:53<00:00,  4.38seg/s]


Processing: ./ET_noise/raw_waveforms/E-E1_STRAIN_NOISE-1000622592-2048.gwf


Processing ./ET_noise/raw_waveforms/E-E1_STRAIN_NOISE-1000622592-2048.gwf: 100%|██████████| 1023/1023 [03:57<00:00,  4.31seg/s]


Processing: ./ET_noise/raw_waveforms/E-E1_STRAIN_NOISE-1000624640-2048.gwf


Processing ./ET_noise/raw_waveforms/E-E1_STRAIN_NOISE-1000624640-2048.gwf: 100%|██████████| 1023/1023 [03:52<00:00,  4.40seg/s]


Processing: ./ET_noise/raw_waveforms/E-E1_STRAIN_NOISE-1000626688-2048.gwf


Processing ./ET_noise/raw_waveforms/E-E1_STRAIN_NOISE-1000626688-2048.gwf: 100%|██████████| 1023/1023 [03:55<00:00,  4.35seg/s]


Processing: ./ET_noise/raw_waveforms/E-E1_STRAIN_NOISE-1000628736-2048.gwf


Processing ./ET_noise/raw_waveforms/E-E1_STRAIN_NOISE-1000628736-2048.gwf: 100%|██████████| 1023/1023 [03:55<00:00,  4.35seg/s]


Processing: ./ET_noise/raw_waveforms/E-E1_STRAIN_NOISE-1000630784-2048.gwf


Processing ./ET_noise/raw_waveforms/E-E1_STRAIN_NOISE-1000630784-2048.gwf: 100%|██████████| 1023/1023 [03:52<00:00,  4.39seg/s]


Processing: ./ET_noise/raw_waveforms/E-E1_STRAIN_NOISE-1000632832-2048.gwf


Processing ./ET_noise/raw_waveforms/E-E1_STRAIN_NOISE-1000632832-2048.gwf: 100%|██████████| 1023/1023 [03:55<00:00,  4.35seg/s]


### Load in processed data and train model ###

In [30]:
def compute_spectrogram(waveform, nperseg=512, noverlap=256, sample_rate=8192):
    f, t, Zxx = stft(waveform, nperseg=nperseg, noverlap=noverlap, fs=sample_rate)
    spec = np.abs(Zxx)
    if spec.shape[1] > 4:
        spec = spec[:, 1:-1]
    return spec

def preprocess_waveforms_to_spectrograms(waveforms, sample_rate=8192, nperseg=512, noverlap=256, norm=True):
    processed = []
    for waveform in waveforms:
        spectrogram = compute_spectrogram(waveform.numpy(), sample_rate=sample_rate)
        if norm:
            min_val, max_val = np.min(spectrogram), np.max(spectrogram)
            spectrogram = (spectrogram - min_val) / (max_val - min_val)

        processed.append(spectrogram)
    return torch.tensor(np.array(processed), dtype=torch.float32)

def load_all_from_folder(folder_path):
    all_waveforms = []
    for filename in os.listdir(folder_path):
        if filename.endswith('.pt'):
            data = torch.load(os.path.join(folder_path, filename))
            waveforms = data.get('waveforms', data) if isinstance(data, dict) else data
            all_waveforms.append(waveforms)
    return torch.cat(all_waveforms)

class DeeperAutoencoderWithBottleneck(nn.Module):
    def __init__(self, d=0.08):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, padding=1)
        self.bn1 = nn.BatchNorm2d(32)
        self.drop1 = nn.Dropout(d)
        self.pool1 = nn.MaxPool2d(2, stride=2, return_indices=True)

        self.conv2 = nn.Conv2d(32, 64, 3, padding=1)
        self.bn2 = nn.BatchNorm2d(64)
        self.drop2 = nn.Dropout(d)
        self.pool2 = nn.MaxPool2d(2, stride=2, return_indices=True)

        self.conv3 = nn.Conv2d(64, 128, 3, padding=1)
        self.bn3 = nn.BatchNorm2d(128)
        self.drop3 = nn.Dropout(d)
        self.pool3 = nn.MaxPool2d(2, stride=2, return_indices=True)
        
        self.flatten = nn.Flatten()
        self.fc_enc = nn.Linear(128 * 32 * 3, 1024)
        self.fc_dec = nn.Linear(1024, 128 * 32 * 3)
        self.unflatten= nn.Unflatten(1, (128, 32, 3))
        
        self.unpool1 = nn.MaxUnpool2d(2, stride=2)
        self.deconv1 = nn.ConvTranspose2d(128, 64, 3, padding=1)
        self.bn4 = nn.BatchNorm2d(64)
        self.drop4 = nn.Dropout(d)

        self.unpool2 = nn.MaxUnpool2d(2, stride=2)
        self.deconv2 = nn.ConvTranspose2d(64, 32, 3, padding=1)
        self.bn5 = nn.BatchNorm2d(32)
        self.drop5 = nn.Dropout(d)

        self.unpool3 = nn.MaxUnpool2d(2, stride=2)
        self.deconv3 = nn.ConvTranspose2d(32, 1, 3, padding=1)

    def forward(self, x):
        orig = x.size()
        x = self.conv1(x); x = self.bn1(x); x = self.drop1(x); x, i1 = self.pool1(x); s1 = x.size()
        x = self.conv2(x); x = self.bn2(x); x = self.drop2(x); x, i2 = self.pool2(x); s2 = x.size()
        x = self.conv3(x); x = self.bn3(x); x = self.drop3(x); x, i3 = self.pool3(x); s3 = x.size()
        x = self.unpool1(x, i3, output_size=s2); x = self.deconv1(x); x = self.bn4(x); x = self.drop4(x)
        x = self.unpool2(x, i2, output_size=s1); x = self.deconv2(x); x = self.bn5(x); x = self.drop5(x)
        x = self.unpool3(x, i1, output_size=orig); x = self.deconv3(x)
        return x


In [20]:
# --------------------------
# Load and Prepare Data
# --------------------------
noise_folder = "./ET_noise/segmented_data/train/" # path to processed ET noise spectrograms
raw_waveforms = torch.cat([torch.load(os.path.join(noise_folder, filename)) for filename in os.listdir(noise_folder)])
spectrograms = preprocess_waveforms_to_spectrograms(raw_waveforms, sample_rate=8192, nperseg=512, noverlap=256)
spectrograms = torch.nn.functional.interpolate(spectrograms.unsqueeze(1), size=(256, 31))

train_data, val_data = train_test_split(spectrograms, test_size=0.3, random_state=42)
print(f"Training with {train_data.shape[0]} training samples and {val_data.shape[0]} validation samples ")
train_loader = DataLoader(TensorDataset(train_data), batch_size=32, shuffle=True)
val_loader = DataLoader(TensorDataset(val_data), batch_size=32, shuffle=True)


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = DeeperAutoencoderWithBottleneck(d=0.08).to(device)
criterion = nn.MSELoss()
optimizer = optim.Adam(
    model.parameters(),
    lr=0.005
)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, mode='min', factor=0.5, patience=5, threshold=0.005
)


Training with 5728 training samples and 2456 validation samples 


In [21]:
# --------------------------
# Training Loop
# --------------------------

num_epochs = 5
train_losses, val_losses = [], []
train_errors_all, val_errors_all = [], []

for epoch in range(num_epochs):
    model.train()
    epoch_train_loss, epoch_train_errors = 0.0, []

    for batch in train_loader:
        x = batch[0].to(device)
        optimizer.zero_grad()
        y = model(x)
        loss = criterion(y, x)
        loss.backward()
        optimizer.step()
        epoch_train_loss += loss.item()
        epoch_train_errors.extend(((y - x)**2).mean(dim=[1, 2, 3]).detach().cpu().numpy())

    train_losses.append(epoch_train_loss / len(train_loader))
    train_errors_all.extend(epoch_train_errors)

    model.eval()
    epoch_val_loss, epoch_val_errors = 0.0, []

    with torch.no_grad():
        for batch in val_loader:
            x = batch[0].to(device)
            y = model(x)
            loss = criterion(y, x)
            epoch_val_loss += loss.item()
            epoch_val_errors.extend(((y - x)**2).mean(dim=[1, 2, 3]).cpu().numpy())

    val_losses.append(epoch_val_loss / len(val_loader))
    val_errors_all.extend(epoch_val_errors)
    scheduler.step(epoch_val_loss)

    print(f"[Epoch {epoch+1}/{num_epochs}] Train Loss: {train_losses[-1]:.6f}, Val Loss: {val_losses[-1]:.6f}")


# Save Model
model_path = "./models/v3.0__unsupervised__ET_model.pt" #path to where the model will be saved
torch.save(model.state_dict(), model_path)
print(f"Model saved to {model_path}")

# --------------------------
# Plot Losses
# --------------------------
plt.figure(figsize=(10, 5))
plt.plot(train_losses, label="Train Loss")
plt.plot(val_losses, label="Validation Loss")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.title("Training and Validation Loss")
plt.legend()
plt.grid(True)
plt.tight_layout()
plt.show()

KeyboardInterrupt: 

In [33]:
# --------------------------
# Load and Prepare Data
# --------------------------
noise_folder = "./ET_noise/segmented_data/train/" # path to processed ET noise spectrograms
raw_waveforms = torch.cat([torch.load(os.path.join(noise_folder, filename)) for filename in os.listdir(noise_folder)])
spectrograms = preprocess_waveforms_to_spectrograms(raw_waveforms)

noise_train_data, noise_val_data = train_test_split(spectrograms, test_size=0.5, random_state=42)

noise_train_spectrograms = torch.nn.functional.interpolate(noise_train_data.unsqueeze(1), size=(256, 31))
noise_train_loader = DataLoader(TensorDataset(noise_train_spectrograms), batch_size=64, shuffle=True)

noise_val_spectrograms = torch.nn.functional.interpolate(noise_val_data.unsqueeze(1), size=(256, 31))
noise_val_loader = DataLoader(TensorDataset(noise_val_spectrograms), batch_size=64, shuffle=False)

anomaly_folder = "./ET_noise/injected_data/mass_100_200__distance_5_20Gpc/train/" #noFrameConversion__m_50_100__5_20Gpc/" # path to processed ET noise spectrograms
anomaly_waveforms = torch.cat([torch.load(os.path.join(anomaly_folder, filename)) for filename in os.listdir(anomaly_folder)])
anomaly_spectrograms = preprocess_waveforms_to_spectrograms(anomaly_waveforms)

anomaly_train_data, anomaly_val_data = train_test_split(anomaly_spectrograms, test_size=0.5, random_state=42)
print(f"Training with {anomaly_train_data.shape[0]} training samples and {anomaly_val_data.shape[0]} validation samples ")

anomaly_train_spectrograms = torch.nn.functional.interpolate(anomaly_train_data.unsqueeze(1), size=(256, 31))
anomaly_train_loader = DataLoader(TensorDataset(anomaly_train_spectrograms), batch_size=64, shuffle=True)

anomaly_val_spectrograms = torch.nn.functional.interpolate(anomaly_val_data.unsqueeze(1), size=(256, 31))
anomaly_val_loader = DataLoader(TensorDataset(anomaly_val_spectrograms), batch_size=64, shuffle=False)



device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = DeeperAutoencoderWithBottleneck(d=0.08).to(device)
criterion = nn.MSELoss()
optimizer = optim.Adam(
    model.parameters(),
    lr=0.005
)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, mode='min', factor=0.5, patience=5, threshold=0.005
)


def weakly_supervised_loss(model, x_norm, x_anom, margin=0.05):
    recon_norm = model(x_norm)
    recon_anom = model(x_anom)

    L_norm = mse_loss(recon_norm, x_norm)
    L_anom = mse_loss(recon_anom, x_anom)

    separation = torch.relu(margin - (L_anom - L_norm))
    total_loss = L_norm + separation

    return total_loss, L_norm.item(), L_anom.item()



mse_loss = nn.MSELoss()
num_epochs = 10
margin = 0.05

epoch_losses = []
epoch_anomaly_losses = []
epoch_noise_losses = []

val_epoch_losses = []
val_epoch_anomaly_losses = []
val_epoch_noise_losses = []

for epoch in range(num_epochs):
    model.train()
    losses = []
    noise_losses = []
    anomaly_losses = []
    for (batch_noise, batch_anomaly) in zip(noise_train_loader, anomaly_train_loader):
        x = batch_noise[0].to(device)             # unsupervised input
        anomaly_x = batch_anomaly[0].to(device)          # weakly supervised inputs
        optimizer.zero_grad()
        loss, noise_loss, anomaly_loss = weakly_supervised_loss(model, x, anomaly_x, margin=margin)
        loss.backward()
        optimizer.step()

        losses.append(loss.item())
        noise_losses.append(noise_loss)
        anomaly_losses.append(anomaly_loss)

    epoch_losses.append(np.mean(losses))
    epoch_anomaly_losses.append(np.mean(anomaly_losses))
    epoch_noise_losses.append(np.mean(noise_losses))

    model.eval()
    val_losses = []
    val_noise_losses = []
    val_anomaly_losses = []
    for (val_batch_noise, val_batch_anomaly) in zip(noise_val_loader, anomaly_val_loader):
        val_x = val_batch_noise[0].to(device)             # unsupervised input
        val_anomaly_x = val_batch_anomaly[0].to(device)          # weakly supervised inputs

        with torch.no_grad():
            val_loss, val_noise_loss, val_anomaly_loss = weakly_supervised_loss(model, val_x, val_anomaly_x, margin=margin)

        val_losses.append(val_loss.item())
        val_noise_losses.append(val_noise_loss)
        val_anomaly_losses.append(val_anomaly_loss)

    val_epoch_losses.append(np.mean(val_losses))
    val_epoch_anomaly_losses.append(np.mean(val_anomaly_losses))
    val_epoch_noise_losses.append(np.mean(val_noise_losses))

    scheduler.step(np.mean(val_losses))
    print(f"[Epoch {epoch+1}/{num_epochs}] Train Loss: {loss:.6f}, Val Loss: {val_loss:.6f}, noise train/val: {noise_loss:.6f}/{val_noise_loss:.6f}, anomaly train/val: {anomaly_loss:.6f}/{val_anomaly_loss:.6f}")


# Save Model
model_path = "./models/v3.0__weakly_supervised__ET_model.pt" #path to where the model will be saved
torch.save(model.state_dict(), model_path)
print(f"Model saved to {model_path}")

Training with 4092 training samples and 4092 validation samples 


KeyboardInterrupt: 

!pip i