Skip to content
This repository has been archived by the owner on Apr 21, 2024. It is now read-only.

seungheondoh/music-text-representation

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

14 Commits
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Toward Universal Text-to-Music Retrieval

This is a PyTorch implementation of Toward Universal Text-to-Music Retrieval for multi-modal music representation learning. Check our demo

Toward Universal Text-to-Music Retrieval
SeungHeon Doh, Minz Won, Keunwoo Choi, Juhan Nam
To appear ICASSP 2023

TL;DR

  • We introduced effective design choices for universal text-to-music retrieval. Recent text-music representation learning frameworks are assessed by using a carefully designed dataset and downstream tasks
  • Our proposed stochastic text representation achieved robust performance in tag-level, caption-level, and zero-shot query retrieval cases
  • Contrastive models achieve better performance than triplet models in both retrieval and downstream tasks.
  • Reproducible code pre-trained models, MSD-ECALS music-caption dataset and the downstream benchmark are available online for future research.

Main Results

The following results are based on MSD-ECAL dataset pre-training. Pre-trained models and configs can be found at Zenodo-Pretrained.

Tag based Retrieval Language based Retrieval
Model Type Text Enc. Text Rep. 50 Tag 1054 Tag 1000 Music-Caption Pair
ROC/PR ROC/PR R@1 R@5 R@10 mAP MedR
Classification Binary Tag 90.2/39.5 86.4/8.8 4.0 13.8 20.1 8.3 86
Triplet GloVe Tag 89.2/36.0 82.6/6.1 2.8 11.2 18.6 6.6 51.5
Triplet GloVe Caption 88.6/37.1 76.8/5.3 5.4 22.1 35.0 13.0 17.0
Triplet GloVe Stochastic 89.2/37.6 81.6/6.2 6.4 21.8 32.7 12.8 19.5
Triplet BERT Tag 86.9/30.2 81.7/5.1 1.6 6.2 12.0 3.9 68.0
Triplet BERT Caption 87.7/35.0 78.9/5.4 6.7 23.6 36.6 14.1 16.0
Triplet BERT Stochastic 88.4/35.0 83.6/6.3 6.6 25.1 39.4 14.6 16.0
Contrastive BERT Tag 90.6/40.2 86.4/8.8 2.5 13.7 22.5 7.4 47.0
Contrastive BERT Caption 87.0/32.5 77.6/5.1 6.8 25.4 38.4 15.3 17.0
Contrastive BERT Stochastic 89.8/38.0 84.8/7.7 10.2 29.8 42.8 18.7 13.0

Note:

  • See our paper for more results on different benchmarks, including MTAT, MTG-Jamendo, FMA, GTZAN, Emotify, KVT.

Requirements

  1. Install python and PyTorch:

    • python==3.8
    • torch==1.12.1 (Please install it according to your CUDA version.)
  2. Other requirements:

    • pip install -e .
conda create -n YOUR_ENV_NAME python=3.8
conda activate YOUR_ENV_NAME
pip install -e .

Using Pretrained Model & Inference

wget https://zenodo.org/record/7322135/files/mtr.tar.gz
tar -zxvf mtr.tar.gz 

Please refer to notebook/demo.ipynb for MSD-testset tag, sentence, unseen query retrieval. Below is the audio and text embedding extraction code.

from mtr.utils.demo_utils import get_model
from mtr.utils.audio_utils import load_audio, STR_CH_FIRST

framework='contrastive' 
text_type='bert'
text_rep="stochastic"
# load model
model, tokenizer, config = get_model(framework=framework, text_type=text_type, text_rep=text_rep)

def text_infer(query, model, tokenizer):
    text_input = tokenizer(query, return_tensors="pt")['input_ids']
    with torch.no_grad():
        text_embs = model.encode_bert_text(text_input, None)
    return text_embs

def audio_infer(audio_path, model, sr=16000, duration=9.91):
    audio, _ = load_audio(
            path=audio_path,
            ch_format= STR_CH_FIRST,
            sample_rate= sr,
            downmix_to_mono= True
    )
    input_size = int(duration * sr)
    hop = int(len(audio) // input_size)
    audio = np.stack([np.array(audio[i * input_size : (i + 1) * input_size]) for i in range(hop)]).astype('float32')
    audio_tensor = torch.from_numpy(audio)
    with torch.no_grad():
        z_audio = model.encode_audio(audio_tensor)
    audio_embs = z_audio.mean(0).detach().cpu()
    return audio_embs

query = "fusion jazz with synth, bass, drums, saxophone"
audio_path = "your_audio"
text_embs = text_infer(query, model, tokenizer)
audio_embs = audio_infer(audio_path, model)

Text Representation

From our empirical study, we find that there is a strong association between text representation (train stage) and text query types (test stage). We propose a stochastic text representation. During the training stage, we select K words from L length text caption. At this time, K is uniformly randomly sampled among integer numbers from 1 (tag length) to L (caption length). Unlike the dropout method, which determines the length by probability value, stochastic sampling has a dynamic input length.

def text_load(self, tag_list):
    """
    input:  tag_list = list of tag
    output: text = string of text
    """
    if self.text_rep == "caption":
        if self.split == "TRAIN":
            random.shuffle(tag_list)
        text = ", ".join(tag_list)
    elif self.text_rep == "tag":
        text = [random.choice(tag_list)]
    elif self.text_rep == "stochastic":
        k = random.choice(range(1, len(tag_list)+1)) 
        sampled_tag_list = random.sample(tag_list, k)
        text = ", ".join(sampled_tag_list)
    return text

1.Text-Music Pre-training (Quick start: mtr/contrastive/main.sh)

Download ECALS(Extended Cleaned tag and Artist-Level Stratified split) dataset & MSD audio Link

cd mtr/{triplet or contrastive}
# train pretrained model
python train.py --text_type {bert,glove} --text_rep {tag,caption,stochastic} --data_dir {msd-subsets} --multiprocessing-distributed

# evaluation on ECALS dataset (single, multi query)
python test.py --text_type {bert,glove} --text_rep {tag,caption,stochastic} --data_dir {msd-subsets}

Following MoCo V3 Repo, This repo only multi-gpu, DistributedDataParallel training is supported; single-gpu or DataParallel training is not supported. This code is improved to better suit the multi-node setting.

other pretrining settings are:

parser.add_argument("--duration", default=9.91, type=int)
parser.add_argument("--sr", default=16000, type=int)
parser.add_argument("--mel_dim", default=128, type=int)
parser.add_argument("--n_fft", default=1024, type=int)
parser.add_argument("--win_length", default=1024, type=int)
parser.add_argument("--frontend", default="cnn", type=str)
parser.add_argument("--mix_type", default="cf", type=str)
parser.add_argument("--audio_rep", default="mel", type=str)
parser.add_argument("--cos", default=True, type=bool)
parser.add_argument("--attention_nlayers", default=4, type=int)
parser.add_argument("--attention_ndim", default=256, type=int)
parser.add_argument("--temperature", default=0.2, type=float)
parser.add_argument("--mlp_dim", default=128, type=int) -> joint embedding dim

2. Zeroshot Transfer, and Probing (Quick start: mtr/transfer/main.sh)

Download downstream dataset and preprocessing code github, and we release datasplit and metadata annotation in zenodo.

Downstream dataset consists MTAT, FMA, MTG-JAMENDO, GTZAN, KVT, Emotify

cd mtr/transfer
# extract embedding
python extractor.py --framework {classification, triplet, contrastive} --text_type {binary, glove, bert} --text_rep {tag,caption,stocahstic} --eval_dataset $DATASET

# eval zero-shot transfer
python eval_zs.py --framework {triplet, contrastive} --text_type {binary, glove, bert} --text_rep {tag,caption,stocahstic} --eval_dataset $DATASET

# train shallow classifier
python train_probing.py --probe_type {linear, mlp} --framework {classification, triplet, contrastive} --text_type {binary, glove, bert} --text_rep {tag,caption,stocahstic} --eval_dataset $DATASET

# eval shallow classifier
python eval_probing.py --probe_type {linear, mlp} --framework {classification, triplet, contrastive} --text_type {binary, glove, bert} --text_rep {tag,caption,stocahstic} --eval_dataset $DATASET

License

This project is under the CC-BY-NC 4.0 license. See LICENSE for details.

Acknowledgement

We would like to thank the MoCoV3 for its training code and jukemir-CodifiedLM for its evaluation protocal.

Citation

Please consider citing our paper in your publications if the project helps your research. BibTeX reference is as follow.

@inproceedings{doh2023toward,
  title={Toward Universal Text-to-Music Retrieval},
  author={Doh, SeungHeon and Won, Minz and Choi, Keunwoo and Nam, Juhan},
  booktitle={ICASSP 2023-2023 IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP)},
  year={2023}
}

About

Toward Universal Text-to-Music-Retrieval (TTMR) [ICASSP23]

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published