In [66]:
import torch
from transformers import WhisperForAudioClassification, WhisperFeatureExtractor, WhisperProcessor
from datasets import load_dataset

In [67]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device

device(type='cuda', index=0)

In [68]:
version = "sanchit-gandhi/whisper-medium-fleurs-lang-id"

# dataset

https://github.com/facebookresearch/ImageBind#usage

For windows users, you might need to install librosa and soundfile for reading/writing audio files. (Thanks @congyue1977)

`pip install soundfile librosa`

In [69]:
dataset = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
dataset = dataset.sort("id")
dataset

Using the latest cached version of the module from C:\Users\Administrator\.cache\huggingface\modules\datasets_modules\datasets\hf-internal-testing--librispeech_asr_dummy\d3bc4c2bc2078fcde3ad0f0f635862e4c0fef78ba94c4a34c4c250a097af240b (last modified on Mon Aug  7 16:20:21 2023) since it couldn't be found locally at hf-internal-testing/librispeech_asr_dummy., or remotely on the Hugging Face Hub.


Dataset({
    features: ['file', 'audio', 'text', 'speaker_id', 'chapter_id', 'id'],
    num_rows: 73
})

In [70]:
sampling_rate = dataset.features["audio"].sampling_rate
sampling_rate

16000

In [71]:
dataset[0]

{'file': 'C:/Users/Administrator/.cache/huggingface/datasets/downloads/extracted/b49df5cb4e26d70a35c542fbe0eadc8bfee0f971809886d2131859668faeba1c/dev_clean/1272/128104\\1272-128104-0000.flac',
 'audio': {'path': 'C:/Users/Administrator/.cache/huggingface/datasets/downloads/extracted/b49df5cb4e26d70a35c542fbe0eadc8bfee0f971809886d2131859668faeba1c/dev_clean/1272/128104\\1272-128104-0000.flac',
  'array': array([0.00238037, 0.0020752 , 0.00198364, ..., 0.00042725, 0.00057983,
         0.0010376 ]),
  'sampling_rate': 16000},
 'text': 'MISTER QUILTER IS THE APOSTLE OF THE MIDDLE CLASSES AND WE ARE GLAD TO WELCOME HIS GOSPEL',
 'speaker_id': 1272,
 'chapter_id': 128104,
 'id': '1272-128104-0000'}

In [72]:
# get multi array
[d["array"] for d in dataset[:2]["audio"]]

[array([0.00238037, 0.0020752 , 0.00198364, ..., 0.00042725, 0.00057983,
        0.0010376 ]),
 array([-1.52587891e-04, -9.15527344e-05, -1.83105469e-04, ...,
         9.76562500e-04,  9.46044922e-04, -4.88281250e-04])]

In [73]:
# get multi text
[d for d in dataset[:2]["text"]]

['MISTER QUILTER IS THE APOSTLE OF THE MIDDLE CLASSES AND WE ARE GLAD TO WELCOME HIS GOSPEL',
 "NOR IS MISTER QUILTER'S MANNER LESS INTERESTING THAN HIS MATTER"]

# WhisperFeatureExtractor

In [74]:
# WhisperProcessor 没有对应的模型
feature_extractor: WhisperFeatureExtractor = WhisperFeatureExtractor.from_pretrained(version, torch_dtype=torch.float16).to(device)
feature_extractor

WhisperFeatureExtractor {
  "chunk_length": 30,
  "feature_extractor_type": "WhisperFeatureExtractor",
  "feature_size": 80,
  "hop_length": 160,
  "n_fft": 400,
  "n_samples": 480000,
  "nb_max_frames": 3000,
  "padding_side": "right",
  "padding_value": 0.0,
  "processor_class": "WhisperProcessor",
  "return_attention_mask": false,
  "sampling_rate": 16000
}

## processor

In [79]:
inputs = feature_extractor(
    [d["array"] for d in dataset[:2]["audio"]],
    sampling_rate=sampling_rate,
    # padding=True,
    return_tensors="pt"
).to(device, torch.float16)
inputs

{'input_features': tensor([[[ 1.1933e-01, -9.4576e-02, -1.0978e-01,  ..., -8.0603e-01,
          -8.0603e-01, -8.0603e-01],
         [ 4.9347e-04, -8.9271e-02, -6.7290e-02,  ..., -8.0603e-01,
          -8.0603e-01, -8.0603e-01],
         [-1.5326e-01, -2.0804e-01, -2.2227e-01,  ..., -8.0603e-01,
          -8.0603e-01, -8.0603e-01],
         ...,
         [-8.0603e-01, -8.0603e-01, -7.9997e-01,  ..., -8.0603e-01,
          -8.0603e-01, -8.0603e-01],
         [-8.0603e-01, -7.7211e-01, -8.0603e-01,  ..., -8.0603e-01,
          -8.0603e-01, -8.0603e-01],
         [-8.0603e-01, -8.0603e-01, -8.0603e-01,  ..., -8.0603e-01,
          -8.0603e-01, -8.0603e-01]],

        [[-4.6956e-01, -7.5109e-02,  2.7610e-02,  ..., -7.0427e-01,
          -7.0427e-01, -7.0427e-01],
         [-1.2772e-01, -2.0680e-02, -3.2390e-02,  ..., -7.0427e-01,
          -7.0427e-01, -7.0427e-01],
         [-3.1414e-01, -9.7058e-02, -1.8364e-01,  ..., -7.0427e-01,
          -7.0427e-01, -7.0427e-01],
         ...,
      

In [80]:
inputs["input_features"].shape

torch.Size([2, 80, 3000])

# WhisperForAudioClassification(分辨语言种类)

Whisper Encoder Model with a sequence classification head on top (a linear layer over the pooled output) for tasks like SUPERB Keyword Spotting.

In [81]:
model: WhisperForAudioClassification = WhisperForAudioClassification.from_pretrained(version, torch_dtype=torch.float16).to(device)
model

WhisperForAudioClassification(
  (encoder): WhisperEncoder(
    (conv1): Conv1d(80, 1024, kernel_size=(3,), stride=(1,), padding=(1,))
    (conv2): Conv1d(1024, 1024, kernel_size=(3,), stride=(2,), padding=(1,))
    (embed_positions): Embedding(1500, 1024)
    (layers): ModuleList(
      (0-23): 24 x WhisperEncoderLayer(
        (self_attn): WhisperAttention(
          (k_proj): Linear(in_features=1024, out_features=1024, bias=False)
          (v_proj): Linear(in_features=1024, out_features=1024, bias=True)
          (q_proj): Linear(in_features=1024, out_features=1024, bias=True)
          (out_proj): Linear(in_features=1024, out_features=1024, bias=True)
        )
        (self_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
        (activation_fn): GELUActivation()
        (fc1): Linear(in_features=1024, out_features=4096, bias=True)
        (fc2): Linear(in_features=4096, out_features=1024, bias=True)
        (final_layer_norm): LayerNorm((1024,), eps=1e-05

In [82]:
model.eval()
with torch.inference_mode():
    outputs = model(**inputs)
outputs

SequenceClassifierOutput(loss=None, logits=tensor([[ 1.8480e+00,  4.7897e-01, -2.3965e+00, -2.4014e+00,  4.5726e-01,
         -1.6911e+00, -1.4168e+00,  4.8689e-01, -1.6355e+00,  1.3438e+00,
          9.4267e-02, -6.3596e-01,  1.6233e+00, -2.5041e-01, -1.5681e+00,
          1.2616e+00,  2.6256e-01,  2.5062e-01,  1.1798e+00,  1.5179e+01,
          1.2782e+00,  1.4662e+00,  7.8219e-01,  1.1033e+00, -3.3376e-01,
         -2.4659e+00,  5.6272e-01,  9.8932e-01, -4.0164e-01, -1.7247e+00,
          8.7482e-01,  2.8559e-01, -1.4991e+00,  1.8257e-01,  6.5450e-01,
          1.4379e+00,  2.2578e+00,  1.2676e+00,  1.2480e+00, -9.0694e-01,
         -9.7850e-01, -1.3821e-01, -4.2959e-01,  8.8256e-01, -1.8457e+00,
          1.1700e+00,  2.8078e-01, -5.6256e-01,  1.4715e+00, -7.5645e-01,
          2.9968e-01,  6.5087e-01,  3.2428e-01, -7.2840e-02, -1.7354e+00,
         -7.1813e-01,  2.5838e-01,  8.1692e-01, -7.9979e-01, -1.9048e+00,
         -4.5969e-01,  9.5149e-02,  8.2318e-01,  8.9126e-01, -3.0588e

In [None]:
logits = outputs.logits
print(logits.shape)

torch.Size([2, 102])


In [None]:
predicted_class_ids = logits.argmax(dim=-1)
predicted_class_ids

tensor([19, 19], device='cuda:0')

In [None]:
model.config.id2label

{0: 'Afrikaans',
 1: 'Amharic',
 2: 'Arabic',
 3: 'Assamese',
 4: 'Asturian',
 5: 'Azerbaijani',
 6: 'Belarusian',
 7: 'Bulgarian',
 8: 'Bengali',
 9: 'Bosnian',
 10: 'Catalan',
 11: 'Cebuano',
 12: 'Sorani-Kurdish',
 13: 'Mandarin Chinese',
 14: 'Czech',
 15: 'Welsh',
 16: 'Danish',
 17: 'German',
 18: 'Greek',
 19: 'English',
 20: 'Spanish',
 21: 'Estonian',
 22: 'Persian',
 23: 'Fula',
 24: 'Finnish',
 25: 'Filipino',
 26: 'French',
 27: 'Irish',
 28: 'Galician',
 29: 'Gujarati',
 30: 'Hausa',
 31: 'Hebrew',
 32: 'Hindi',
 33: 'Croatian',
 34: 'Hungarian',
 35: 'Armenian',
 36: 'Indonesian',
 37: 'Igbo',
 38: 'Icelandic',
 39: 'Italian',
 40: 'Japanese',
 41: 'Javanese',
 42: 'Georgian',
 43: 'Kamba',
 44: 'Kabuverdianu',
 45: 'Kazakh',
 46: 'Khmer',
 47: 'Kannada',
 48: 'Korean',
 49: 'Kyrgyz',
 50: 'Luxembourgish',
 51: 'Ganda',
 52: 'Lingala',
 53: 'Lao',
 54: 'Lithuanian',
 55: 'Luo',
 56: 'Latvian',
 57: 'Maori',
 58: 'Macedonian',
 59: 'Malayalam',
 60: 'Mongolian',
 61: 'Mara

In [None]:
print(model.config.id2label[predicted_class_ids[0].item()])
print(model.config.id2label[predicted_class_ids[1].item()])

English
English
