## This code must be run on ACCRE. Otherwise it will take a super long time. Also it will take up a lot of storage space.

In [None]:
import os, threading, librosa
import numpy as np
import pandas as pd
from PIL import Image
from io import BytesIO
from tqdm import tqdm
from pathlib import Path
from multiprocessing import Pool, Manager

from gwpy.timeseries import TimeSeries

import matplotlib.pyplot as plt

In [None]:
df = pd.read_csv("trainingset_v1d1_metadata.csv") # SD: dataset that I got from Gravity Spy. Check the paper for the link to the original source.

In [None]:
def process_event(event, q_values):
    ifo = event['ifo']
    gps = event['event_time']
    glitch_type = event['label']
    sample_type = event['sample_type']

    start = gps - 2
    end = gps + 2

    try:
        strain = TimeSeries.fetch_open_data(ifo, start-14, end+14)
        if q_values is None: # SD: this is for when I wanted to generate all images with the "optimal" Q-value so I could run clustering on optimal and compare with other Q-values. Ended up not being that informative and had to be scrapped due to the page limit.
            q_transform = strain.q_transform(outseg=(start, end)) # SD: no qrange!
            spec = np.array(q_transform.data).T # SD: some pre-processing to make data look better
            spec = np.clip(spec, 0, None)
            spec = np.log1p(spec)
            spec -= spec.min()
            spec /= spec.max() + 1e-12

            # SD: I didn't want to generate the image, save it, open it up again and resize and save it again after that. All of the IO operations were slowing it down a lot. So, storing it an in-memory buffer first, loading from that to resize and finally saving the image to disk.
            buf = BytesIO()
            fig, ax = plt.subplots(figsize=(4, 4), dpi=100)
            ax.imshow(spec, aspect='auto', origin='lower', cmap='viridis')
            ax.axis('off')
            plt.tight_layout(pad=0)
            fig.savefig(buf, format='png', bbox_inches='tight', pad_inches=0)
            plt.close(fig)

            # SD: resize and save image to disk
            buf.seek(0)
            img = Image.open(buf).convert("RGB").resize((224, 224), resample=Image.Resampling.BICUBIC)

            out_dir = Path(f"optimal/{sample_type}/{glitch_type}")
            out_dir.mkdir(parents=True, exist_ok=True)
            out_file = out_dir / f"{ifo}_{int(gps)}.png"
            img.save(out_file)
        else:
            for q in q_values: # SD: same as above but this time loops through the provided q_values
                q_transform = strain.q_transform(qrange=(q, q), outseg=(start, end)) # SD: providing a qrange this time! min and max both set to the value of q_value.
                spec = np.array(q_transform.data).T
                spec = np.clip(spec, 0, None)
                spec = np.log1p(spec)
                spec -= spec.min()
                spec /= spec.max() + 1e-12
    
                # SD: plot to buffer
                buf = BytesIO()
                fig, ax = plt.subplots(figsize=(4, 4), dpi=100)
                ax.imshow(spec, aspect='auto', origin='lower', cmap='viridis')
                ax.axis('off')
                plt.tight_layout(pad=0)
                fig.savefig(buf, format='png', bbox_inches='tight', pad_inches=0)
                plt.close(fig)
    
                # SD: resize and save image to disk
                buf.seek(0)
                img = Image.open(buf).convert("RGB").resize((224, 224), resample=Image.Resampling.BICUBIC)
    
                out_dir = Path(f"q{q}/{sample_type}/{glitch_type}")
                out_dir.mkdir(parents=True, exist_ok=True)
                out_file = out_dir / f"{ifo}_{int(gps)}.png"
                img.save(out_file)
    except Exception:
        pass

def process_event_with_queue(event, q_values, queue): # SD: helper function simply so that I can have a tqdm progress bar while it runs.
    process_event(event, q_values)
    queue.put(1) # SD: updates queue when an event is done processing

def parallel_generate(events_df, q_values, num_workers):
    events_list = [events_df.iloc[i] for i in range(len(events_df))]
    
    # SD: multi-threading! The part that makes the tqdm progress bar work was written with help from ChatGPT.
    manager = Manager()
    queue = manager.Queue()

    with Pool(num_workers) as pool:
        # SD: start the progress bar thread
        pbar = tqdm(total=len(events_list), desc="Generating Q-transforms")

        def listener(q):
            for _ in range(len(events_list)):
                q.get()
                pbar.update(1)
            pbar.close()

        t = threading.Thread(target=listener, args=(queue,))
        t.start()

        args = [(event, q_values, queue) for event in events_list]
        pool.starmap(process_event_with_queue, args)
        t.join()

In [6]:
# parallel_generate(events_df=df, q_values=None, num_workers=32)

Generating Q-transforms: 100%|██████████| 7966/7966 [1:12:52<00:00,  1.82it/s]


In [None]:
def process_event_logmel(event): # SD: this is literally the same code but instead of generating q-transforms, it does log-mel spectrograms
    ifo = event['ifo']
    gps = event['event_time']
    glitch_type = event['label']
    sample_type = event['sample_type']

    start = gps - 2
    end = gps + 2

    try:
        # SD: fetch strain data
        strain = TimeSeries.fetch_open_data(ifo, start-14, end+14)

        # SD: whiten and crop to the event window! Forgot about whitening for a while because q-transforms do it automatically. About 4 hours of debugging the figure gone right there :)
        y = strain.whiten().crop(start, end).value
        sr = strain.sample_rate.value  # SD: sampling rate

        # SD: just a copy of how librosa tells me to make log-mel spectrograms...
        D = librosa.stft(y, n_fft=2048, hop_length=512)

        # SD: compute log-mel spectrogram
        mel_spec = librosa.feature.melspectrogram(S=np.abs(D), sr=sr, n_mels=128)
        log_mel_spec = librosa.power_to_db(mel_spec, ref=np.max)

        # SD: normalize to [0, 1]
        log_mel_spec -= log_mel_spec.min()
        log_mel_spec /= log_mel_spec.max() + 1e-12

        # SD: plot to buffer
        buf = BytesIO()
        fig, ax = plt.subplots(figsize=(4, 4), dpi=100)
        ax.imshow(log_mel_spec, aspect='auto', origin='lower', cmap='viridis')
        ax.axis('off')
        plt.tight_layout(pad=0)
        fig.savefig(buf, format='png', bbox_inches='tight', pad_inches=0)
        plt.close(fig)

        # SD: resize and save image to disk
        buf.seek(0)
        img = Image.open(buf).convert("RGB").resize((224, 224), resample=Image.Resampling.BICUBIC)

        out_dir = Path(f"mel/{sample_type}/{glitch_type}")
        out_dir.mkdir(parents=True, exist_ok=True)
        out_file = out_dir / f"{ifo}_{int(gps)}.png"
        img.save(out_file)
    except Exception:
        pass

# SD: same code as above, just without q_values
def process_event_with_queue(event, queue):
    process_event_logmel(event)
    queue.put(1)

def parallel_generate(events_df, num_workers=6):
    events_list = [events_df.iloc[i] for i in range(len(events_df))]
    
    manager = Manager()
    queue = manager.Queue()

    with Pool(num_workers) as pool:
        pbar = tqdm(total=len(events_list), desc="Generating Log-Mel Spectrograms")

        def listener(q):
            for _ in range(len(events_list)):
                q.get()
                pbar.update(1)
            pbar.close()

        t = threading.Thread(target=listener, args=(queue,))
        t.start()

        args = [(event, queue) for event in events_list]
        pool.starmap(process_event_with_queue, args)
        t.join()

In [None]:
# parallel_generate(events_df=df, num_workers=32)

Generating Log-Mel Spectrograms: 100%|██████████| 7966/7966 [1:12:44<00:00,  1.83it/s]
