In [1]:
from transformers import Wav2Vec2FeatureExtractor
from transformers import AutoModel
import torch
import torch.nn as nn
import torch.nn.functional as F

import torchaudio
import torchaudio.transforms as T

import librosa

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# loading our model weights
model = AutoModel.from_pretrained("m-a-p/MERT-v1-95M", trust_remote_code=True)
model = model.to("cuda")

# loading the corresponding preprocessor config
processor = Wav2Vec2FeatureExtractor.from_pretrained("m-a-p/MERT-v1-95M",trust_remote_code=True)



Some weights of the model checkpoint at m-a-p/MERT-v1-95M were not used when initializing MERTModel: ['encoder.pos_conv_embed.conv.weight_g', 'encoder.pos_conv_embed.conv.weight_v']
- This IS expected if you are initializing MERTModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing MERTModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of MERTModel were not initialized from the model checkpoint at m-a-p/MERT-v1-95M and are newly initialized: ['encoder.pos_conv_embed.conv.parametrizations.weight.original0', 'encoder.pos_conv_embed.conv.parametrizations.weight.original1']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
So

In [3]:
# load audio files
audio_files = ['rockin_around.mp3', 'test_song.mp3']
input_audio = []

for audio_file in audio_files:
    # audio, sampling_rate = librosa.load(audio_file, sr=None)
    audio, sampling_rate = torchaudio.load(audio_file, backend='ffmpeg')
    audio = audio[0] # only one channel
    
    first_nonzero_index = next((index for index, value in enumerate(audio) if value != 0), None)
    print(first_nonzero_index)

    resample_rate = processor.sampling_rate
    # make sure the sample_rate aligned
    if resample_rate != sampling_rate:
        print(f'setting rate from {sampling_rate} to {resample_rate}')
        resampler = T.Resample(sampling_rate, resample_rate)
    else:
        resampler = None

    # audio file is decoded on the fly
    if resampler is not None:
        audio = resampler(audio)
    
    input_audio.append(audio.numpy())
        

8641
setting rate from 44100 to 24000
6960
setting rate from 44100 to 24000


In [4]:
input_audio[0].shape

(3021845,)

In [25]:
inputs = processor(input_audio, sampling_rate=resample_rate, return_tensors="pt", padding=True)
inputs = inputs.to("cuda")

with torch.no_grad():
    outputs = model(**inputs, output_hidden_states=True)

In [26]:
inputs.keys()

dict_keys(['input_values', 'attention_mask'])

In [27]:
# take a look at the output shape, there are 13 layers of representation
# each layer performs differently in different downstream tasks, you should choose empirically
all_layer_hidden_states = torch.stack(outputs.hidden_states).squeeze()
print(all_layer_hidden_states.shape) # [13 layer, Time steps, 768 feature_dim]

# for utterance level classification tasks, you can simply reduce the representation in time
time_reduced_hidden_states = all_layer_hidden_states.mean(-2)
print(time_reduced_hidden_states.shape) # [13, 768]


torch.Size([13, 9022, 768])
torch.Size([13, 768])


In [22]:
trunc = time_reduced_hidden_states.permute(1,0,2)[1]

In [28]:
non_trunc = time_reduced_hidden_states.cpu().numpy()

In [32]:
trunc

tensor([[-0.1839,  0.0560,  0.3645,  ..., -0.1801, -0.2498,  0.0766],
        [-0.1308, -0.1409,  0.5574,  ..., -0.2917, -0.0786,  0.0187],
        [ 0.0600, -0.0027,  0.3828,  ..., -0.0912, -0.0442, -0.0882],
        ...,
        [-0.0450, -0.1034,  0.0135,  ..., -0.1094, -0.0755, -0.0220],
        [ 0.0359, -0.0826, -0.0972,  ..., -0.1007, -0.0177,  0.0244],
        [-0.0394, -0.0164, -0.0542,  ...,  0.1010, -0.1002,  0.0485]],
       device='cuda:0')

In [33]:
non_trunc

array([[-0.18856402,  0.05180721,  0.4061588 , ..., -0.1926618 ,
        -0.25924715,  0.07684468],
       [-0.13712476, -0.13524362,  0.60284626, ..., -0.30100065,
        -0.07869767,  0.0167464 ],
       [ 0.05161031,  0.00725114,  0.42961738, ..., -0.09042618,
        -0.03980322, -0.09716687],
       ...,
       [-0.0262141 , -0.09254253,  0.02330886, ..., -0.12008666,
        -0.06740732, -0.01618211],
       [ 0.0374712 , -0.0803475 , -0.08498251, ..., -0.1117202 ,
        -0.01279834,  0.03500533],
       [-0.04584039, -0.01758612, -0.03769425, ...,  0.09866826,
        -0.10444974,  0.05512224]], dtype=float32)

In [35]:
F.mse_loss(trunc, non_trunc)

TypeError: 'int' object is not callable