In [1]:
%load_ext autoreload
%autoreload 2

In [8]:
from training_subset_analysis import TrainingSubsetAnalysis
from models.audiosep import AudioSep
import argparse
import os
from utils import parse_yaml, load_ss_model
from scipy.signal import spectrogram


def eval(evaluator,
         encoder_checkpoint_path = None, 
         ssnet_checkpoint_path = None, 
         config_yaml=None, 
         device = "cuda",
         encoder_type = None):

    
    assert encoder_type is not None, 'define encoder type'
    
    configs = parse_yaml(config_yaml)
    
    if encoder_type == 'ONE-PEACE':

        from models.one_peace_encoder import ONE_PEACE_Encoder
        # ONE_PEACE modelhub expects some paths to be relative to this dir
        os.chdir('ONE-PEACE/')
        # TODO:path in shared scratch dir for now..., move to class project dir whenever we get that
        query_encoder = ONE_PEACE_Encoder(pretrained_path=encoder_checkpoint_path)
        os.chdir('..')

        # put ONE-PEACE model in eval model (probably unecessary)
        query_encoder.model.model.eval()

    elif encoder_type == 'CLAP':
        from models.clap_encoder import CLAP_Encoder
        query_encoder = CLAP_Encoder(pretrained_path=encoder_checkpoint_path).eval()

    pl_model = load_ss_model(
        configs=configs,
        checkpoint_path=ssnet_checkpoint_path,
        query_encoder=query_encoder
    ).to(device)

    return evaluator, pl_model

    print(f'-------  Start Evaluation  -------')
    df_results = evaluator(pl_model)
    df_results.to_csv(f'{encoder_type}_training_subset.csv', index = None)
    print('-------------------------  Done  ---------------------------')
    # evaluation 
    

    


In [9]:

parser = argparse.ArgumentParser()
parser.add_argument(
    "--config_yaml",
    type=str,
    required=True,
    help="Path of config file for AudioSep model",
)

parser.add_argument(
    "--encoder_checkpoint_path",
    type=str,
    required=True,
    help="Path of pretrained checkpoint for QueryEncoder (ONE-PEACE/CLAP)",
)

parser.add_argument(
    '--ssnet_checkpoint_path',
    type=str,
    required=True,
    help = "Path of pretrained checkpoint for Seperation Network (ResUNet)"
)

parser.add_argument(
    '--encoder_type',
    type=str,
    required=True,
    help= 'type of Query Encoder'
)

# cli = '--config_yaml config/audiosep_onepeace.yaml --encoder_checkpoint_path /fs/nexus-scratch/vla/finetune_al_retrieval.pt --ssnet_checkpoint_path /fs/nexus-scratch/vla/checkpoints/train/audiosep_onepeace,devices=1/step=140000.ckpt --encoder_type ONE-PEACE'
cli = '--config_yaml config/audiosep_base.yaml --encoder_checkpoint_path ./checkpoint/music_speech_audioset_epoch_15_esc_89.98.pt  --ssnet_checkpoint_path checkpoint/audiosep_baseline.ckpt --encoder_type CLAP'


args = parser.parse_args(cli.split())
print(args)


Namespace(config_yaml='config/audiosep_base.yaml', encoder_checkpoint_path='./checkpoint/music_speech_audioset_epoch_15_esc_89.98.pt', ssnet_checkpoint_path='checkpoint/audiosep_baseline.ckpt', encoder_type='CLAP')


In [10]:

# Run evaluation on training subset + pull out per-sample metrics and similarity scores
dcase_evaluator = TrainingSubsetAnalysis(
    sampling_rate=16000,
    eval_indexes='lass_training_subset.csv',
    audio_dir= '',        # use absolute paths in eval_indexes csv file
    output_dir = None,    # set to none to avoid making audio .wav files
    encoder_type=args.encoder_type,
    config_yaml = args.config_yaml
)

evaluator, pl_model = eval(dcase_evaluator,
                            encoder_checkpoint_path = args.encoder_checkpoint_path,
                            ssnet_checkpoint_path = args.ssnet_checkpoint_path,
                            config_yaml = args.config_yaml,
                            device = "cuda",
                            encoder_type=args.encoder_type)



Some weights of RobertaModel were not initialized from the model checkpoint at roberta-base and are newly initialized: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
  checkpoint = torch.load(checkpoint_path, map_location=map_location)
/fs/nexus-scratch/vla/micromamba/envs/LASS/lib/python3.9/site-packages/lightning/pytorch/core/saving.py:195: Found keys that are not in the model state dict but in the checkpoint: ['query_encoder.model.text_branch.embeddings.position_ids']


In [11]:
df_results = evaluator(pl_model)

Evaluating on lass_training_subset.csv


  2%|▏         | 999/50000 [00:50<41:30, 19.68it/s]


In [12]:
df_results[:5]

Unnamed: 0,caption,source_path,noise_path,sisdr,sdri,sdr
0,"A trumpet sounds with bright, bold notes.",/fs/nexus-scratch/vla/FSD50K/FSD50K.eval_audio...,/fs/nexus-scratch/vla/FSD50K/FSD50K.dev_audio/...,-3.777549,-5.525764,1.474236
1,A printer operates with mechanical sounds.,/fs/nexus-scratch/vla/FSD50K/FSD50K.dev_audio/...,/fs/nexus-scratch/vla/Clotho/development/Glass...,22.576742,30.263355,22.263356
2,Breathing sounds are audible.,/fs/nexus-scratch/vla/FSD50K/FSD50K.dev_audio/...,/fs/nexus-scratch/vla/FSD50K/FSD50K.eval_audio...,14.866004,6.976233,14.976232
3,A single bell rings.,/fs/nexus-scratch/vla/FSD50K/FSD50K.dev_audio/...,/fs/nexus-scratch/vla/Clotho/development/20080...,22.267787,19.24417,22.24417
4,A man's singing voice resonates.,/fs/nexus-scratch/vla/FSD50K/FSD50K.dev_audio/...,/fs/nexus-scratch/vla/FSD50K/FSD50K.dev_audio/...,22.752757,12.771955,22.771955


## CLAP BELOW

In [13]:
df_results.loc[:, 'sisdr'].mean()

9.119818571870418

In [14]:
df_results.loc[:, 'sdri'].mean()

10.837878851696653

In [15]:
df_results.loc[:, 'sdr'].mean()

10.498848509130836

## ONE-PEACE below

In [7]:
df_results.loc[:, 'sisdr'].mean()

9.728438988414236

In [8]:
df_results.loc[:, 'sdri'].mean()

10.888474036001798

In [9]:
df_results.loc[:, 'sdr'].mean()

10.54944369343598