In [12]:
from dcase_evaluator import DCASEEvaluator
from models.audiosep import AudioSep
from models.one_peace_encoder import ONE_PEACE_Encoder

import argparse
import os
from utils import parse_yaml, load_ss_model
import torch

import librosa
import pandas as pd
import soundfile as sf
import scipy.io.wavfile as wf

In [4]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
config_yaml = 'config/audiosep_onepeace.yaml'
encoder_checkpoint_path = '/fs/nexus-scratch/vla/finetune_al_retrieval.pt'
ssnet_checkpoint_path = '/fs/nexus-scratch/vla/checkpoints/train/audiosep_onepeace,devices=1/step=140000.ckpt'
sampling_rate=1600
configs = parse_yaml(config_yaml)

# ONE_PEACE modelhub expects some paths to be relative to this dir
os.chdir('ONE-PEACE/')
# TODO:path in shared scratch dir for now..., move to class project dir whenever we get that
query_encoder = ONE_PEACE_Encoder(pretrained_path=encoder_checkpoint_path)
os.chdir('..')

# put ONE-PEACE model in eval model (probably unecessary)
query_encoder.model.model.eval()

pl_model = load_ss_model(
    configs=configs,
    checkpoint_path=ssnet_checkpoint_path,
    query_encoder=query_encoder
).to(device)

  state = torch.load(f, map_location=torch.device("cpu"))


In [5]:
eval_csv = 'lass_real_evaluation.csv'
dict_eval = pd.read_csv(eval_csv).set_index('file_name').to_dict()['query']

In [19]:
output_dir = 'lass_evaluation_real_output'

In [None]:
audio_dir = 'lass_evaluation_real'

test_samples = [i for i in range(200)]
filenames = [os.path.join(f'test-real-case-{s}.wav') for s in test_samples]

for filename in filenames:

    print(filename)

    source, fs = librosa.load(os.path.join(audio_dir, filename), sr=sampling_rate, mono=True)

    conditions = pl_model.query_encoder.get_query_embed(
                    modality='text',
                    text=[dict_eval[filename]],
                    device=device 
    )

    input_dict = {
                    "mixture": torch.Tensor(source)[None, None, :].to(device),
                    "condition": conditions,
                 } 

    sep_segment = pl_model.ss_model(input_dict)["waveform"]

    # TODO: compute ONE-PEACE embedding on sep_segment and dot w/ conditions for comparison in embedding space
    # sep_segment_embd = pl_model.query_encoder.model
    # sep_segment: (batch_size=1, channels_num=1, segment_samples)
    sep_segment = sep_segment.squeeze(0).squeeze(0).data.cpu().numpy()


    wf.write(str(os.path.join(output_dir, filename)), sampling_rate, sep_segment)

test-real-case-0.wav
test-real-case-1.wav
test-real-case-2.wav
test-real-case-3.wav
test-real-case-4.wav
test-real-case-5.wav
test-real-case-6.wav
test-real-case-7.wav
test-real-case-8.wav
test-real-case-9.wav
test-real-case-10.wav
test-real-case-11.wav
test-real-case-12.wav
test-real-case-13.wav
test-real-case-14.wav
test-real-case-15.wav
test-real-case-16.wav
test-real-case-17.wav
test-real-case-18.wav
test-real-case-19.wav
test-real-case-20.wav
test-real-case-21.wav
test-real-case-22.wav
test-real-case-23.wav
test-real-case-24.wav
test-real-case-25.wav
test-real-case-26.wav
test-real-case-27.wav
test-real-case-28.wav
test-real-case-29.wav
test-real-case-30.wav
test-real-case-31.wav
test-real-case-32.wav
test-real-case-33.wav
test-real-case-34.wav
test-real-case-35.wav
test-real-case-36.wav
test-real-case-37.wav
test-real-case-38.wav
test-real-case-39.wav
test-real-case-40.wav
test-real-case-41.wav
test-real-case-42.wav
test-real-case-43.wav
test-real-case-44.wav
test-real-case-45.wa