In [15]:
from dcase_evaluator import DCASEEvaluator
from models.audiosep import AudioSep
from models.one_peace_encoder import ONE_PEACE_Encoder

import argparse
import os
from utils import parse_yaml, load_ss_model
import torch

import librosa
import pandas as pd
import soundfile as sf
import scipy.io.wavfile as wf
from tqdm import tqdm

In [2]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
config_yaml = 'config/audiosep_onepeace.yaml'
encoder_checkpoint_path = '/fs/nexus-scratch/vla/finetune_al_retrieval.pt'

# NOTE: best checkpoint on validation set
ssnet_checkpoint_path = '/fs/nexus-scratch/vla/checkpoints/train/audiosep_onepeace,devices=1/step=140000.ckpt'
sampling_rate=1600
configs = parse_yaml(config_yaml)

# ONE_PEACE modelhub expects some paths to be relative to this dir
os.chdir('ONE-PEACE/')
# TODO:path in shared scratch dir for now..., move to class project dir whenever we get that
query_encoder = ONE_PEACE_Encoder(pretrained_path=encoder_checkpoint_path)
os.chdir('..')

# put ONE-PEACE model in eval model (probably unecessary)
query_encoder.model.model.eval()

pl_model = load_ss_model(
    configs=configs,
    checkpoint_path=ssnet_checkpoint_path,
    query_encoder=query_encoder
).to(device)

  state = torch.load(f, map_location=torch.device("cpu"))


In [3]:
eval_csv = 'lass_real_evaluation.csv'
dict_eval = pd.read_csv(eval_csv).set_index('file_name').to_dict()['query']
output_dir = 'lass_evaluation_real_output'
audio_dir = 'lass_evaluation_real'

In [31]:
test_samples = [i for i in range(200)]
filenames = [os.path.join(f'test-real-case-{s}.wav') for s in test_samples]

gather = []
with torch.no_grad():
    for filename in tqdm(filenames):
        
        # load audio from test set
        input_path = os.path.join(audio_dir, filename)
        source, fs = librosa.load(input_path, sr=sampling_rate, mono=True)

        # compute text embedding with ONE-PEACE query encoder
        conditions = pl_model.query_encoder.get_query_embed(
                        modality='text',
                        text=[dict_eval[filename]],
                        device=device 
        )

        input_dict = {
                        "mixture": torch.Tensor(source)[None, None, :].to(device),
                        "condition": conditions,
                    } 

        # output audio
        sep_segment = pl_model.ss_model(input_dict)["waveform"]

        # TODO: compute ONE-PEACE embedding on sep_segment and dot w/ conditions for comparison in embedding space
        # sep_segment_embd = pl_model.query_encoder.model
        # sep_segment: (batch_size=1, channels_num=1, segment_samples)
        sep_segment = sep_segment.squeeze(0).squeeze(0).data.cpu().numpy()

        # write out .wav file
        output_path = os.path.join(output_dir, filename)
        # wf.write(output_path, sampling_rate, sep_segment)

        similarities = dict(
            filename = filename
        )

        # COMPUTE SIMILARITIES
        src_audios, audio_padding_masks = pl_model.query_encoder.model.process_audio([input_path])
        audio_features = pl_model.query_encoder.model.extract_audio_features(src_audios, audio_padding_masks)
        input_similarity = conditions @ audio_features.T
        similarities['input_similarity'] = input_similarity.squeeze(0).cpu().numpy()[0]
        # print(f'Text Prompt - Mixed Audio Input Similarity: {input_similarity}')

        src_audios, audio_padding_masks = pl_model.query_encoder.model.process_audio([output_path])
        audio_features = pl_model.query_encoder.model.extract_audio_features(src_audios, audio_padding_masks)
        output_similarity = conditions @ audio_features.T
        similarities['output_similarity'] = output_similarity.squeeze(0).cpu().numpy()[0]
        # print(f'Text Prompt - Seperated Audio Output Similarity: {output_similarity}')

        gather.append(similarities)

100%|██████████| 200/200 [00:56<00:00,  3.53it/s]


In [37]:
df = pd.DataFrame(gather)
df['delta_similarity'] = df.output_similarity - df.input_similarity
df[:5]

Unnamed: 0,filename,input_similarity,output_similarity,delta_similarity
0,test-real-case-0.wav,0.353482,0.174945,-0.178537
1,test-real-case-1.wav,0.40859,0.104871,-0.303719
2,test-real-case-2.wav,0.29939,0.052828,-0.246562
3,test-real-case-3.wav,0.306069,0.057904,-0.248165
4,test-real-case-4.wav,0.186257,0.071551,-0.114706


In [None]:
# Top 10 increase in similarity
df.sort_values(by = 'delta_similarity', ascending=False)[:10]

Unnamed: 0,filename,input_similarity,output_similarity,delta_similarity
89,test-real-case-89.wav,0.032961,0.252458,0.219498
186,test-real-case-186.wav,0.04055,0.243609,0.203058
142,test-real-case-142.wav,0.096658,0.281703,0.185045
12,test-real-case-12.wav,-0.033666,0.120087,0.153754
68,test-real-case-68.wav,0.089508,0.226485,0.136977
69,test-real-case-69.wav,0.150574,0.28117,0.130596
52,test-real-case-52.wav,0.210967,0.335514,0.124547
118,test-real-case-118.wav,0.189311,0.297989,0.108679
87,test-real-case-87.wav,0.22657,0.32852,0.101951
187,test-real-case-187.wav,-0.061375,0.023337,0.084712
