In [1]:
import torch
from transformers import Wav2Vec2ForSequenceClassification, AutoFeatureExtractor
from datasets import load_dataset

In [2]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device

device(type='cuda', index=0)

In [3]:
version = "facebook/mms-lid-126"

# dataset

https://github.com/facebookresearch/ImageBind#usage

For windows users, you might need to install librosa and soundfile for reading/writing audio files. (Thanks @congyue1977)

`pip install soundfile librosa`

In [4]:
dataset = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
dataset = dataset.sort("id")
dataset

Dataset({
    features: ['file', 'audio', 'text', 'speaker_id', 'chapter_id', 'id'],
    num_rows: 73
})

In [5]:
sampling_rate = dataset.features["audio"].sampling_rate
sampling_rate

16000

In [6]:
dataset[0]

{'file': 'C:/Users/Administrator/.cache/huggingface/datasets/downloads/extracted/b49df5cb4e26d70a35c542fbe0eadc8bfee0f971809886d2131859668faeba1c/dev_clean/1272/128104\\1272-128104-0000.flac',
 'audio': {'path': 'C:/Users/Administrator/.cache/huggingface/datasets/downloads/extracted/b49df5cb4e26d70a35c542fbe0eadc8bfee0f971809886d2131859668faeba1c/dev_clean/1272/128104\\1272-128104-0000.flac',
  'array': array([0.00238037, 0.0020752 , 0.00198364, ..., 0.00042725, 0.00057983,
         0.0010376 ]),
  'sampling_rate': 16000},
 'text': 'MISTER QUILTER IS THE APOSTLE OF THE MIDDLE CLASSES AND WE ARE GLAD TO WELCOME HIS GOSPEL',
 'speaker_id': 1272,
 'chapter_id': 128104,
 'id': '1272-128104-0000'}

In [7]:
# get multi array
[d["array"] for d in dataset[:2]["audio"]]

[array([0.00238037, 0.0020752 , 0.00198364, ..., 0.00042725, 0.00057983,
        0.0010376 ]),
 array([-1.52587891e-04, -9.15527344e-05, -1.83105469e-04, ...,
         9.76562500e-04,  9.46044922e-04, -4.88281250e-04])]

In [8]:
# get multi text
[d for d in dataset[:2]["text"]]

['MISTER QUILTER IS THE APOSTLE OF THE MIDDLE CLASSES AND WE ARE GLAD TO WELCOME HIS GOSPEL',
 "NOR IS MISTER QUILTER'S MANNER LESS INTERESTING THAN HIS MATTER"]

# AutoFeatureExtractor

In [9]:
feature_extractor: AutoFeatureExtractor = AutoFeatureExtractor.from_pretrained(version)
feature_extractor

Downloading (…)rocessor_config.json:   0%|          | 0.00/212 [00:00<?, ?B/s]

Wav2Vec2FeatureExtractor {
  "do_normalize": true,
  "feature_extractor_type": "Wav2Vec2FeatureExtractor",
  "feature_size": 1,
  "padding_side": "right",
  "padding_value": 0,
  "return_attention_mask": true,
  "sampling_rate": 16000
}

## processor

In [10]:
inputs = feature_extractor(
    [d["array"] for d in dataset[:2]["audio"]],
    sampling_rate=sampling_rate,
    padding=True,
    return_tensors="pt"
).to(device, torch.float16)
inputs

{'input_values': tensor([[ 0.0386,  0.0337,  0.0322,  ...,  0.0070,  0.0095,  0.0169],
        [-0.0015, -0.0008, -0.0019,  ...,  0.0000,  0.0000,  0.0000]],
       device='cuda:0'), 'attention_mask': tensor([[1, 1, 1,  ..., 1, 1, 1],
        [1, 1, 1,  ..., 0, 0, 0]], device='cuda:0', dtype=torch.int32)}

In [11]:
inputs["input_values"].shape

torch.Size([2, 93680])

# Wav2Vec2ForSequenceClassification(分辨语言种类)

Different LID models are available based on the number of languages they can recognize

In [14]:
model: Wav2Vec2ForSequenceClassification = Wav2Vec2ForSequenceClassification.from_pretrained(version, torch_dtype=torch.float16).to(device)
model

Downloading model.safetensors:   0%|          | 0.00/3.86G [00:00<?, ?B/s]

Wav2Vec2ForSequenceClassification(
  (wav2vec2): Wav2Vec2Model(
    (feature_extractor): Wav2Vec2FeatureEncoder(
      (conv_layers): ModuleList(
        (0): Wav2Vec2LayerNormConvLayer(
          (conv): Conv1d(1, 512, kernel_size=(10,), stride=(5,))
          (layer_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
          (activation): GELUActivation()
        )
        (1-4): 4 x Wav2Vec2LayerNormConvLayer(
          (conv): Conv1d(512, 512, kernel_size=(3,), stride=(2,))
          (layer_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
          (activation): GELUActivation()
        )
        (5-6): 2 x Wav2Vec2LayerNormConvLayer(
          (conv): Conv1d(512, 512, kernel_size=(2,), stride=(2,))
          (layer_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
          (activation): GELUActivation()
        )
      )
    )
    (feature_projection): Wav2Vec2FeatureProjection(
      (layer_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=

In [15]:
model.eval()
with torch.inference_mode():
    outputs = model(**inputs)
outputs

SequenceClassifierOutput(loss=None, logits=tensor([[ 6.3095e-01,  5.6210e-01,  1.1455e+01,  4.6605e+00,  2.0735e+00,
         -1.7886e+00,  2.5402e-01,  1.3036e+00, -1.4132e+00, -1.0144e+00,
         -2.1347e+00,  8.1125e-01, -1.2697e+00,  4.8693e-01,  4.3894e-02,
         -9.0640e-01, -1.4628e-01, -2.1953e+00,  8.0087e-01, -1.2776e+00,
          1.2712e+00, -2.1888e-01,  2.4856e-01, -2.3980e+00, -3.3244e-01,
         -5.9368e-01, -1.3560e+00,  3.4993e-01, -1.4715e+00, -1.1132e+00,
         -2.2107e+00, -8.6330e-02, -2.7320e+00, -4.6908e-01,  3.4111e-01,
          8.3798e-01, -2.7987e+00, -1.6982e+00,  2.0136e-02, -6.3702e-01,
          1.7377e+00, -1.9903e-01,  1.2074e-01,  7.2755e-01, -6.1422e-01,
         -4.7519e-01, -2.1832e+00, -1.5885e+00,  1.6448e+00,  1.3297e+00,
          1.3466e+00,  3.1740e+00, -9.3507e-01,  3.1203e-01,  1.3561e+00,
          2.4167e+00, -1.8261e-01, -1.1498e+00,  7.0182e-01, -9.3038e-01,
         -1.8628e+00,  4.6539e+00, -9.2118e-01,  1.4500e+00, -2.4243e

In [16]:
logits = outputs.logits
print(logits.shape)

torch.Size([2, 126])


In [17]:
predicted_class_ids = logits.argmax(dim=-1)
predicted_class_ids

tensor([2, 2], device='cuda:0')

In [18]:
model.config.id2label

{0: 'ara',
 1: 'cmn',
 2: 'eng',
 3: 'spa',
 4: 'fra',
 5: 'mlg',
 6: 'swe',
 7: 'por',
 8: 'vie',
 9: 'ful',
 10: 'sun',
 11: 'asm',
 12: 'ben',
 13: 'zlm',
 14: 'kor',
 15: 'ind',
 16: 'hin',
 17: 'tuk',
 18: 'urd',
 19: 'aze',
 20: 'slv',
 21: 'mon',
 22: 'hau',
 23: 'tel',
 24: 'swh',
 25: 'bod',
 26: 'rus',
 27: 'tur',
 28: 'heb',
 29: 'mar',
 30: 'som',
 31: 'tgl',
 32: 'tat',
 33: 'tha',
 34: 'cat',
 35: 'ron',
 36: 'mal',
 37: 'bel',
 38: 'pol',
 39: 'yor',
 40: 'nld',
 41: 'bul',
 42: 'hat',
 43: 'afr',
 44: 'isl',
 45: 'amh',
 46: 'tam',
 47: 'hun',
 48: 'hrv',
 49: 'lit',
 50: 'cym',
 51: 'fas',
 52: 'mkd',
 53: 'ell',
 54: 'bos',
 55: 'deu',
 56: 'sqi',
 57: 'jav',
 58: 'nob',
 59: 'uzb',
 60: 'snd',
 61: 'lat',
 62: 'nya',
 63: 'grn',
 64: 'mya',
 65: 'orm',
 66: 'lin',
 67: 'hye',
 68: 'yue',
 69: 'pan',
 70: 'jpn',
 71: 'kaz',
 72: 'npi',
 73: 'kat',
 74: 'guj',
 75: 'kan',
 76: 'tgk',
 77: 'ukr',
 78: 'ces',
 79: 'lav',
 80: 'bak',
 81: 'khm',
 82: 'fao',
 83: 'glg',
 8

In [19]:
print(model.config.id2label[predicted_class_ids[0].item()])
print(model.config.id2label[predicted_class_ids[1].item()])

eng
eng
