In [13]:
import os
import torch
import numpy as np
import pandas as pd
import IPython.display as ipd
from IPython.display import Audio, HTML
from sklearn.metrics.pairwise import cosine_similarity
from argparse import ArgumentParser, Namespace, ArgumentTypeError
from speech_to_music.preprocessing.audio_utils import load_audio
from speech_to_music.metric_learning.infer import load_audio_backbone, load_music_backbone, projection_model

In [2]:
parser = ArgumentParser()
parser.add_argument("--root", default="../", type=str)
parser.add_argument("--inference_type", default="speech_extractor", type=str)
parser.add_argument("--branch_type", default="3branch", type=str)
parser.add_argument("--fusion_type", default="audio", type=str)
parser.add_argument("--word_model", default="glove", type=str)
parser.add_argument("--freeze_type", default="feature", type=str)
parser.add_argument("--is_augmentation", default=False, type=bool)
parser.add_argument("--gpus", default=[0], type=list)
parser.add_argument("--reproduce", default=True, type=bool)
args = parser.parse_args([])

## Load model

In [14]:
audio_backbone = load_audio_backbone(args)
music_backbone = load_music_backbone(args)
joint_backbone = projection_model(args)

In [4]:
df_speech = pd.read_csv("../dataset/split/IEMOCAP/test.csv", index_col=0)
df_music = pd.read_csv("../dataset/split/Audioset/test.csv", index_col=0)

# Music Embedding Database

In [60]:
DEVICE = "cuda:0"
music_samples = {}
music_embs = []
for fname in df_music.index:
    music_wav = np.load(os.path.join("../dataset/feature/Audioset/npy", fname + ".npy"))
    audio = torch.from_numpy(music_wav)
    with torch.no_grad():
        audio_emb = music_backbone.model.extractor(audio.unsqueeze(0).to(DEVICE))
        audio_emb = joint_backbone.model.music_mlp(audio_emb)
    music_embs.append(audio_emb.squeeze(0).detach().cpu().numpy())
    music_samples[fname] = {
        "fname": fname,
        "wav": music_wav,
        "label": df_music.loc[fname].idxmax()
    }

# Speech Embedding

In [101]:
speech_samples = {}
speech_embs = []
for label in df_speech.columns:
    fname = df_speech[df_speech[label] == 1].sample(1).index[0]
    speech_wav = np.load(os.path.join("../dataset/feature/IEMOCAP/npy", fname + ".npy"))
    audio = torch.from_numpy(speech_wav)
    with torch.no_grad():
        embs = audio_backbone.model.pooling_extractor(audio.to(DEVICE))
        embs = joint_backbone.model.speech_audio_mlp(embs)
    speech_embs.append(embs.squeeze(0).detach().cpu().numpy())
    speech_samples[fname] = {
        "fname": fname,
        "wav": speech_wav,
        "label": label
    }

In [102]:
speech_embs = np.stack(speech_embs)
music_embs = np.stack(music_embs)

In [103]:
sim_matrix = cosine_similarity(speech_embs, music_embs)
df_sim = pd.DataFrame(sim_matrix, index=speech_samples.keys(), columns=music_samples.keys())

In [109]:
df_sim.T.head() # Audioset(target) x IEMOCAP(query)

Unnamed: 0,Ses05F_impro05_F023,Ses05M_impro03_M010,Ses05M_script01_2_M001,Ses05F_script01_3_F006
WO_Y7djT2k4,0.713043,-0.490283,-0.354062,-0.427027
cCyfADwHiWs,0.752776,-0.646087,-0.358029,-0.300436
OaV-ZyjNDFE,-0.270472,-0.118214,-0.145704,0.446434
EaGhKzpkNso,0.200053,0.388821,-0.098282,-0.685633
S_Z7o4OmU30,-0.461797,0.51214,-0.020668,0.041813


# Nearest Neighbor Search

In [104]:
def demo(df_sim, speech_samples, music_samples, audio_viz=False):
    html_items = []
    for speech_fname in df_sim.index:
        instance = {}
        item = df_sim.loc[speech_fname]
        top3_music = item.sort_values(ascending=False).head(3).index
        audio_obj = ipd.Audio(speech_samples[speech_fname]['wav'], rate=16000)
        instance['speech'] = speech_fname 
        instance['speech_emotion'] = speech_samples[speech_fname]['label']
        speech_src = audio_obj.src_attr()
        if audio_viz:
            instance['speech_wav'] = f"""<audio controls><source src="{speech_src}" type="audio/wav"></audio></td>"""            
        for idx, music_fname in enumerate(top3_music):
            music_obj = ipd.Audio(music_samples[music_fname]['wav'], rate=22050)
            music_src = music_obj.src_attr()
            if audio_viz:
                instance[f'top{idx+1} music'] = f"""<audio controls><source src="{music_src}" type="audio/wav"></audio></td>"""
            else:
                instance[f'top{idx+1} music'] = music_samples[music_fname]['label']
        html_items.append(instance)
    df = pd.DataFrame(html_items).set_index("speech")
    html = df.to_html(escape=False)
    ipd.display(HTML(html))

In [105]:
demo(df_sim, speech_samples, music_samples, audio_viz=False)

Unnamed: 0_level_0,speech_emotion,top1 music,top2 music,top3 music
speech,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
Ses05F_impro05_F023,angry,scary,angry,angry
Ses05M_impro03_M010,happy,happy,happy,happy
Ses05M_script01_2_M001,neutral,noise,noise,noise
Ses05F_script01_3_F006,sad,sad,sad,sad


In [107]:
# demo(df_sim, speech_samples, music_samples, audio_viz=True)