In [1]:
from transformers import Wav2Vec2FeatureExtractor
from transformers import AutoModel
import torch
import torch.nn as nn
import torch.nn.functional as F

import torchaudio
import torchaudio.transforms as T

import librosa

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# loading our model weights
model = AutoModel.from_pretrained("m-a-p/MERT-v1-95M", trust_remote_code=True)
model = model.to("cuda")

# loading the corresponding preprocessor config
processor = Wav2Vec2FeatureExtractor.from_pretrained("m-a-p/MERT-v1-95M",trust_remote_code=True)



Some weights of the model checkpoint at m-a-p/MERT-v1-95M were not used when initializing MERTModel: ['encoder.pos_conv_embed.conv.weight_v', 'encoder.pos_conv_embed.conv.weight_g']
- This IS expected if you are initializing MERTModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing MERTModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of MERTModel were not initialized from the model checkpoint at m-a-p/MERT-v1-95M and are newly initialized: ['encoder.pos_conv_embed.conv.parametrizations.weight.original0', 'encoder.pos_conv_embed.conv.parametrizations.weight.original1']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [6]:
# load audio files
audio_files = ['rockin_around.mp3', 'test_song.mp3']
input_audio = []

for audio_file in audio_files:
    # audio, sampling_rate = librosa.load(audio_file, sr=None)
    audio, sampling_rate = torchaudio.load(audio_file, backend='ffmpeg')
    audio = audio[0] # only one channel
    
    first_nonzero_index = next((index for index, value in enumerate(audio) if value != 0), None)
    print(first_nonzero_index)

    resample_rate = processor.sampling_rate
    # make sure the sample_rate aligned
    if resample_rate != sampling_rate:
        print(f'setting rate from {sampling_rate} to {resample_rate}')
        resampler = T.Resample(sampling_rate, resample_rate)
    else:
        resampler = None

    # audio file is decoded on the fly
    if resampler is not None:
        audio = resampler(audio)
    
    input_audio.append(audio.numpy())
        

8641
setting rate from 44100 to 24000
6960
setting rate from 44100 to 24000


In [7]:
input_audio[0].shape

(3021845,)

In [8]:
inputs = processor(input_audio, sampling_rate=resample_rate, return_tensors="pt", padding=True)
inputs = inputs.to("cuda")

with torch.no_grad():
    outputs = model(**inputs, output_hidden_states=True)

In [8]:
inputs.keys()

dict_keys(['input_values', 'attention_mask'])

In [9]:
# take a look at the output shape, there are 13 layers of representation
# each layer performs differently in different downstream tasks, you should choose empirically
all_layer_hidden_states = torch.stack(outputs.hidden_states).squeeze()
print(all_layer_hidden_states.shape) # [13 layer, Time steps, 768 feature_dim]

# for utterance level classification tasks, you can simply reduce the representation in time
time_reduced_hidden_states = all_layer_hidden_states.mean(-2)
print(time_reduced_hidden_states.shape) # [13, 768]


torch.Size([13, 2, 9443, 768])
torch.Size([13, 2, 768])


In [11]:
audio_embeds = all_layer_hidden_states.permute(1, 0, 2, 3)
audio_embeds_avg_pool = F.avg_pool1d(audio_embeds.permute(0, 2, 1), kernel_size=4, padding=1).permute(0, 2, 1)
audio_embeds_max_pool = F.max_pool1d(audio_embeds.permute(0, 2, 1), kernel_size=4, padding=1).permute(0, 2, 1)

RuntimeError: permute(sparse_coo): number of dimensions in the tensor input does not match the length of the desired ordering of dimensions i.e. input.dim() = 4 is not equal to len(dims) = 3

In [22]:
trunc = time_reduced_hidden_states.permute(1,0,2)[1]

In [28]:
non_trunc = time_reduced_hidden_states.cpu().numpy()

In [35]:
F.mse_loss(trunc, non_trunc)

TypeError: 'int' object is not callable