In [1]:
from transformers import SpeechEncoderDecoderModel, Speech2Text2Processor, Speech2TextProcessor
import soundfile as sf

from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor

from transformers import Trainer, TrainingArguments

from torch.utils.data import Dataset
import librosa
import json
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Set, Union

import torch


In [2]:
wav2vec2_model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-large-xlsr-53")
wav2vec2_processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base")

Some weights of the model checkpoint at facebook/wav2vec2-large-xlsr-53 were not used when initializing Wav2Vec2ForCTC: ['project_q.bias', 'quantizer.weight_proj.weight', 'project_hid.bias', 'quantizer.codevectors', 'project_hid.weight', 'project_q.weight', 'quantizer.weight_proj.bias']
- This IS expected if you are initializing Wav2Vec2ForCTC from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing Wav2Vec2ForCTC from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of Wav2Vec2ForCTC were not initialized from the model checkpoint at facebook/wav2vec2-large-xlsr-53 and are newly initialized: ['lm_head.bias', 'lm_head.weight']
You should probably TRAIN this model on a down-stream task to be able to u

In [4]:
class LibriSpeechDataset(Dataset):
	def __init__(self, json_path):
		self.json_path = json_path
		self.data = self.load_data_from_json(json_path)

	def load_data_from_json(self, json_path):
		with open(json_path, "r") as f:
			data = json.load(f)

		data = data["data"]
		return data

	def __getitem__(self, idx):
		audio, _ = librosa.load(self.data[idx]["file"], 16000)
		input_value = wav2vec2_processor.feature_extractor(audio, sampling_rate=16000)
		# Do some text preprocessing here
		text = self.data[idx]["text"]
		with wav2vec2_processor.as_target_processor():
			label = wav2vec2_processor(text).input_ids

		sample = {
			"input_values": input_value["input_values"][0],
			"labels": label,
			"lang": "en",
		}
		return sample

	def __len__(self):
		return len(self.data)

In [5]:
train_dataset = LibriSpeechDataset("/root/develop/KIWI-module/code/wav2byte-pipeline/data/en-librispeech-test-clean-pure-99.0-local-wav.json")


In [6]:
train_dataset[0]

{'input_values': array([ 0.00410041,  0.00410041,  0.00681573, ..., -0.0325565 ,
        -0.03595066, -0.03323534], dtype=float32),
 'labels': [6,
  8,
  4,
  19,
  5,
  15,
  5,
  24,
  13,
  7,
  6,
  5,
  4,
  6,
  11,
  5,
  4,
  7,
  13,
  13,
  10,
  25,
  7,
  15,
  4,
  8,
  20,
  4,
  11,
  5,
  13,
  4,
  12,
  8,
  9,
  4,
  12,
  10,
  15,
  25,
  10,
  7,
  4,
  21,
  7,
  25,
  5,
  4,
  7,
  4,
  12,
  23,
  15,
  5,
  9,
  14,
  10,
  14,
  4,
  12,
  16,
  23,
  23,
  5,
  13,
  4,
  6,
  8,
  4,
  18,
  11,
  10,
  19,
  11,
  4,
  12,
  11,
  5,
  4,
  11,
  7,
  14,
  4,
  10,
  9,
  25,
  10,
  6,
  5,
  14,
  4,
  7,
  15,
  15,
  4,
  11,
  5,
  13,
  4,
  13,
  5,
  15,
  7,
  6,
  10,
  25,
  5,
  12,
  4,
  7,
  9,
  14,
  4,
  10,
  6,
  4,
  18,
  7,
  12,
  4,
  7,
  4,
  21,
  8,
  8,
  14,
  4,
  8,
  23,
  23,
  8,
  13,
  6,
  16,
  9,
  10,
  6,
  22,
  4,
  20,
  8,
  13,
  4,
  17,
  5,
  4,
  6,
  8,
  4,
  17,
  7,
  26,
  5,
  4,
  6,
  11,
  5,
 

In [7]:
features = [train_dataset[0], train_dataset[1]]



input_features = [
	{"input_values": feature["input_values"]} for feature in features
]
label_features = [{"input_ids": feature["labels"]} for feature in features]

In [9]:
batch = wav2vec2_processor.feature_extractor.pad(
	input_features,
	padding=True,
	max_length=1024,
	return_tensors="pt"
)

In [11]:
batch['input_values'] 

tensor([[ 0.0041,  0.0041,  0.0068,  ...,  0.0000,  0.0000,  0.0000],
        [-0.0098, -0.0153, -0.0098,  ...,  0.0104,  0.0160,  0.0117]])