In [8]:
import gradio as gr
import torchaudio
from transformers import AutoModelForAudioClassification, AutoFeatureExtractor
import librosa
import torch

In [9]:
def dump_pickle(file_path: str, file, mode: str = "wb"):
    import pickle

    with open(file_path, mode=mode) as f:
        pickle.dump(file, f)


def load_pickle(file_path: str, mode: str = "rb", encoding=""):
    import pickle

    with open(file_path, mode=mode) as f:
        return pickle.load(f, encoding=encoding)

In [10]:
label2id = load_pickle('/data/audio-classification-pytorch/wav2vec2/results/best/label2id.pkl')
id2label = load_pickle('/data/audio-classification-pytorch/wav2vec2/results/best/id2label.pkl')

In [16]:
model = AutoModelForAudioClassification.from_pretrained(
    "facebook/wav2vec2-base", num_labels=len(label2id), label2id=label2id, id2label=id2label
)

Some weights of the model checkpoint at facebook/wav2vec2-base were not used when initializing Wav2Vec2ForSequenceClassification: ['quantizer.weight_proj.weight', 'project_hid.weight', 'project_q.bias', 'project_q.weight', 'quantizer.weight_proj.bias', 'project_hid.bias', 'quantizer.codevectors']
- This IS expected if you are initializing Wav2Vec2ForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing Wav2Vec2ForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of Wav2Vec2ForSequenceClassification were not initialized from the model checkpoint at facebook/wav2vec2-base and are newly initialized: ['classifier.bias', 'projector.weight', 'classifier

In [17]:
feature_extractor = AutoFeatureExtractor.from_pretrained("facebook/wav2vec2-base")

In [18]:
checkpoint = torch.load('/data/audio-classification-pytorch/wav2vec2/results/best/pytorch_model.bin')

In [19]:
model.load_state_dict(checkpoint)

<All keys matched successfully>

In [21]:
def predict(input):
    waveform, sr = librosa.load(input)
    waveform = torch.from_numpy(waveform).unsqueeze(0)
    waveform = torchaudio.transforms.Resample(sr, 16_000)(waveform)
    inputs = feature_extractor(waveform, sampling_rate=feature_extractor.sampling_rate,
                            max_length=16000, truncation=True)
    tensor = torch.tensor(inputs['input_values'][0])
    with torch.no_grad():
        output = model(tensor)
        logits = output['logits'][0]
        label_id = torch.argmax(logits).item()
    label_name = id2label[str(label_id)]

    return label_name

In [5]:
demo = gr.Interface(
    fn=predict,
    inputs=gr.Audio(source="microphone", type="filepath", label="Speak to classify your voice!"), # record audio, save in temp file to feed to inference func
    outputs="text"
)

In [6]:
demo.launch()

Running on local URL:  http://127.0.0.1:7860

To create a public link, set `share=True` in `launch()`.


