In [3]:
from dcase_evaluator import DCASEEvaluator
from models.audiosep import AudioSep
# from models.one_peace_encoder import ONE_PEACE_Encoder

import argparse
import os
from utils import parse_yaml, load_ss_model
import torch
import numpy as np

import librosa
import pandas as pd
import soundfile as sf
import scipy.io.wavfile as wf
from scipy.io import wavfile
from scipy.signal import spectrogram
from tqdm import tqdm
import matplotlib.pyplot as plt

In [4]:
df = pd.read_csv('lass_synthetic_validation.csv')
df.head()

Unnamed: 0,source,noise,snr,caption
0,692211_12333864-hq,701692_6014995-hq,10,Someone is playing a kind of musical instrumen...
1,692211_12333864-hq,708399_14710576-hq,1,"Someone is playing the musical instrument, whi..."
2,692211_12333864-hq,708360_14781196-hq,-13,The musical instrument is producing buzzing so...
3,699830_15175263-hq,692865_2250422-hq,-6,There are some motorcycles and cars on the roa...
4,699830_15175263-hq,695808_5812781-hq,10,A race car engine roars as it accelerates and ...


In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
config_yaml = 'config/audiosep_onepeace.yaml'
encoder_checkpoint_path = '/fs/nexus-scratch/vla/finetune_al_retrieval.pt'

# NOTE: best checkpoint on validation set
ssnet_checkpoint_path = '/fs/nexus-scratch/vla/checkpoints/train/audiosep_onepeace,devices=1/step=140000.ckpt'
sampling_rate=1600
configs = parse_yaml(config_yaml)

# ONE_PEACE modelhub expects some paths to be relative to this dir
os.chdir('ONE-PEACE/')
# TODO:path in shared scratch dir for now..., move to class project dir whenever we get that
query_encoder = ONE_PEACE_Encoder(pretrained_path=encoder_checkpoint_path)
os.chdir('..')

# put ONE-PEACE model in eval model (probably unecessary)
query_encoder.model.model.eval()

pl_model = load_ss_model(
    configs=configs,
    checkpoint_path=ssnet_checkpoint_path,
    query_encoder=query_encoder
).to(device)

In [5]:
df.head()

Unnamed: 0,source,noise,snr,caption
0,692211_12333864-hq,701692_6014995-hq,10,Someone is playing a kind of musical instrumen...
1,692211_12333864-hq,708399_14710576-hq,1,"Someone is playing the musical instrument, whi..."
2,692211_12333864-hq,708360_14781196-hq,-13,The musical instrument is producing buzzing so...
3,699830_15175263-hq,692865_2250422-hq,-6,There are some motorcycles and cars on the roa...
4,699830_15175263-hq,695808_5812781-hq,10,A race car engine roars as it accelerates and ...


In [None]:
validation_csv = 'lass_synthetic_validation.csv'
dict_eval = pd.read_csv(validation_csv).set_index('file_name').to_dict()['caption']
output_dir = 'lass_evaluation_real_output'
audio_dir = 'lass_validation'

In [None]:

gather = []
with torch.no_grad():
    # iterate through dataframe
    for index, row in df.iterrows():
        # source file name (.wav file)
        filename = row['source']
        file_path = os.path.join(audio_dir, f'{filename}.wav')
        
        # load audio from test set
        input_path = os.path.join(audio_dir, filename)
        source, fs = librosa.load(input_path, sr=sampling_rate, mono=True)

        # compute text embedding with ONE-PEACE query encoder
        conditions = pl_model.query_encoder.get_query_embed(
                        modality='text',
                        text=[dict_eval[filename]],
                        device=device 
        )

        input_dict = {
                        "mixture": torch.Tensor(source)[None, None, :].to(device),
                        "condition": conditions,
                    } 

        # output audio
        sep_segment = pl_model.ss_model(input_dict)["waveform"]

        # TODO: compute ONE-PEACE embedding on sep_segment and dot w/ conditions for comparison in embedding space
        # sep_segment_embd = pl_model.query_encoder.model
        # sep_segment: (batch_size=1, channels_num=1, segment_samples)
        sep_segment = sep_segment.squeeze(0).squeeze(0).data.cpu().numpy()


        # write out .wav file
        output_path = os.path.join(output_dir, filename)
        # wf.write(output_path, sampling_rate, sep_segment)

        similarities = dict(
            filename = filename
        )

        # COMPUTE SIMILARITIES
        src_audios, audio_padding_masks = pl_model.query_encoder.model.process_audio([input_path])
        audio_features = pl_model.query_encoder.model.extract_audio_features(src_audios, audio_padding_masks)
        input_similarity = conditions @ audio_features.T
        similarities['input_similarity'] = input_similarity.squeeze(0).cpu().numpy()[0]
        # print(f'Text Prompt - Mixed Audio Input Similarity: {input_similarity}')

        src_audios, audio_padding_masks = pl_model.query_encoder.model.process_audio([output_path])
        audio_features = pl_model.query_encoder.model.extract_audio_features(src_audios, audio_padding_masks)
        output_similarity = conditions @ audio_features.T
        similarities['output_similarity'] = output_similarity.squeeze(0).cpu().numpy()[0]
        # print(f'Text Prompt - Seperated Audio Output Similarity: {output_similarity}')

        gather.append(similarities)
