In [3]:
from transformers import Wav2Vec2FeatureExtractor, WavLMForXVector
from datasets import load_dataset
import torch

In [4]:
dataset = load_dataset("hf-internal-testing/librispeech_asr_demo", "clean", split="validation")

feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained('microsoft/wavlm-base-plus-sv')
model = WavLMForXVector.from_pretrained('microsoft/wavlm-base-plus-sv')

# audio files are decoded on the fly
audio = [x["array"] for x in dataset[:2]["audio"]]
inputs = feature_extractor(audio, padding=True, return_tensors="pt")
embeddings = model(**inputs).embeddings
embeddings = torch.nn.functional.normalize(embeddings, dim=-1).cpu()

# the resulting embeddings can be used for cosine similarity-based retrieval
cosine_sim = torch.nn.CosineSimilarity(dim=-1)
similarity = cosine_sim(embeddings[0], embeddings[1])
threshold = 0.86  # the optimal threshold is dataset-dependent
if similarity < threshold:
    print("Speakers are not the same!")

Downloading builder script:   0%|          | 0.00/5.17k [00:00<?, ?B/s]

Downloading and preparing dataset librispeech_asr_demo/clean to /home/yangwenhao/.cache/huggingface/datasets/hf-internal-testing___librispeech_asr_demo/clean/2.1.0/d3bc4c2bc2078fcde3ad0f0f635862e4c0fef78ba94c4a34c4c250a097af240b...


Downloading data files:   0%|          | 0/1 [00:00<?, ?it/s]

Downloading data:   0%|          | 0.00/9.08M [00:00<?, ?B/s]

Extracting data files:   0%|          | 0/1 [00:00<?, ?it/s]

Generating validation split: 0 examples [00:00, ? examples/s]

Dataset librispeech_asr_demo downloaded and prepared to /home/yangwenhao/.cache/huggingface/datasets/hf-internal-testing___librispeech_asr_demo/clean/2.1.0/d3bc4c2bc2078fcde3ad0f0f635862e4c0fef78ba94c4a34c4c250a097af240b. Subsequent calls will reuse this data.


It is strongly recommended to pass the ``sampling_rate`` argument to this function. Failing to do so can result in silent errors that might be hard to debug.


In [7]:
inputs['input_values'].shape

torch.Size([2, 93680])

In [10]:
audio[0].shape

(93680,)

In [17]:
outputs = model.wavlm(**inputs,
                      output_attentions=None,
                      output_hidden_states=True,
                      return_dict=True,)

In [45]:
model.wavlm

WavLMModel(
  (feature_extractor): WavLMFeatureEncoder(
    (conv_layers): ModuleList(
      (0): WavLMGroupNormConvLayer(
        (conv): Conv1d(1, 512, kernel_size=(10,), stride=(5,), bias=False)
        (activation): GELUActivation()
        (layer_norm): GroupNorm(512, 512, eps=1e-05, affine=True)
      )
      (1): WavLMNoLayerNormConvLayer(
        (conv): Conv1d(512, 512, kernel_size=(3,), stride=(2,), bias=False)
        (activation): GELUActivation()
      )
      (2): WavLMNoLayerNormConvLayer(
        (conv): Conv1d(512, 512, kernel_size=(3,), stride=(2,), bias=False)
        (activation): GELUActivation()
      )
      (3): WavLMNoLayerNormConvLayer(
        (conv): Conv1d(512, 512, kernel_size=(3,), stride=(2,), bias=False)
        (activation): GELUActivation()
      )
      (4): WavLMNoLayerNormConvLayer(
        (conv): Conv1d(512, 512, kernel_size=(3,), stride=(2,), bias=False)
        (activation): GELUActivation()
      )
      (5): WavLMNoLayerNormConvLayer(
       

In [30]:
hidden_states = torch.stack(outputs[2], dim=1) #.shape

In [14]:
model.config.use_return_dict

True

In [15]:
model.config.use_weighted_layer_sum

True

In [28]:
norm_weights = torch.nn.functional.softmax(model.layer_weights, dim=-1)
print(norm_weights.shape)

torch.Size([13])


In [32]:
hidden_states = (hidden_states * norm_weights.view(-1, 1, 1)).sum(dim=1)

In [34]:
print(hidden_states.shape)

torch.Size([2, 292, 768])


In [37]:
model.projector

Linear(in_features=768, out_features=512, bias=True)

In [35]:
hidden_states = model.projector(hidden_states)
print(hidden_states.shape)

torch.Size([2, 292, 512])


In [44]:
model.tdnn[0]

TDNNLayer(
  (kernel): Linear(in_features=2560, out_features=512, bias=True)
  (activation): ReLU()
)

In [43]:
print(hidden_states.shape)
model.tdnn[0](hidden_states).shape

torch.Size([2, 292, 512])


torch.Size([2, 288, 512])