## Model Creation

Let's create a model first, with some vocab.
The output is 

In [27]:
import json

import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import WavLMModel, AutoFeatureExtractor
from datasets import load_dataset
import numpy as np

# ————————————————————————————————————————————————————————————————————————
# PhonemeRecognizer: WavLM + CTC for phoneme speech recognition
# ————————————————————————————————————————————————————————————————————————

# Load vocab from file
with open("phoneme_tokenizer/vocab.json") as vocab_file:
    vocab = json.load(vocab_file)

# IT + FR phonemes + blank
VOCAB = {
  "0": "ʒ",
  "1": "ɹ",
  "2": "j",
  "3": "d",
  "4": "ɲ",
  "5": "ʌ",
  "6": "[UNK]",
  "7": "ɒ",
  "8": "ɐ",
  "9": "ʃ",
  "10": "ɔ",
  "11": "f",
  "12": "ø",
  "13": "z",
  "14": "ŋ",
  "15": "i",
  "16": "u",
  "17": "̃",
  "18": "o",
  "19": "œ",
  "20": "a",
  "21": "(",
  "22": "ə",
  "23": "ɜ",
  "24": "ɾ",
  "25": "ː",
  "26": "̪",
  "27": "e",
  "28": "b",
  "29": "ʁ",
  "30": "w",
  "31": "n",
  "32": "p",
  "33": "y",
  "34": "ɡ",
  "35": "ɪ",
  "36": "r",
  "37": "v",
  "38": "t",
  "39": ")",
  "40": "m",
  "41": "k",
  "42": "ʊ",
  "43": "ʎ",
  "44": "ɑ",
  "45": "s",
  "46": "l",
  "47": "[PAD]",
  "48": "ɛ",
  "49": '<blank>' # blank token for CTC
}
PHONEME_DICT = {v: int(k) for k, v in VOCAB.items()}

NUM_PHONEMES = len(PHONEME_DICT)

class PhonemeRecognizer(nn.Module):
    def __init__(self, wavlm_model, num_phonemes=NUM_PHONEMES):
        super().__init__()
        self.wavlm = wavlm_model

        # Get the hidden size from the WavLM model
        hidden_size = self.wavlm.config.hidden_size

        # Add a dropout layer for regularization
        self.dropout = nn.Dropout(0.1)

        # Linear layer to map from WavLM hidden states to phoneme classes (including blank)
        self.phoneme_classifier = nn.Linear(hidden_size, num_phonemes)

    def forward(self, inputs):
        # Get WavLM embeddings
        outputs = self.wavlm(**inputs)
        hidden_states = outputs.last_hidden_state

        # Apply dropout
        hidden_states = self.dropout(hidden_states)

        # Apply the linear layer to get logits for each time step
        logits = self.phoneme_classifier(hidden_states)

        # Apply log softmax for CTC loss
        log_probs = F.log_softmax(logits, dim=-1)

        return log_probs

    def recognize(self, inputs, beam_width=100):
        """Perform phoneme recognition with beam search decoding"""
        self.eval()
        with torch.no_grad():
            # Forward pass to get log probabilities
            log_probs = self(inputs)

            # Simple greedy decoding (for demonstration)
            # In a real system, you would use beam search with ctcdecode
            predictions = torch.argmax(log_probs, dim=-1).cpu().numpy()

            # Convert to phoneme sequences with CTC decoding rules (merge repeats, remove blanks)
            phoneme_sequences = []
            for pred_seq in predictions:
                seq = []
                prev = -1
                for p in pred_seq:
                    # Skip blanks (index 0) and repeated phonemes (CTC rules)
                    if p != 0 and p != prev:
                        # Convert index back to phoneme
                        phoneme = list(PHONEME_DICT.keys())[list(PHONEME_DICT.values()).index(p)]
                        seq.append(phoneme)
                    prev = p
                phoneme_sequences.append(seq)

            return phoneme_sequences

# ————————————————————————————————————————————————————————————————————————
# Method A: Using the PhonemeRecognizer for speech-to-phoneme ASR
# ————————————————————————————————————————————————————————————————————————

# 1. Load the feature extractor and model
feature_extractor = AutoFeatureExtractor.from_pretrained("microsoft/wavlm-base-plus")
wavlm_model = WavLMModel.from_pretrained("microsoft/wavlm-base-plus")

# Create the phoneme recognizer with the WavLM model
phoneme_recognizer = PhonemeRecognizer(wavlm_model)
phoneme_recognizer.eval()  # disable dropout, etc.

# 2. Load an example audio file (here using a small demo from `datasets`)
#    The `audio["array"]` is a NumPy array of floats; sampling_rate is an int.
ds = load_dataset("hf-internal-testing/librispeech_asr_demo", "clean", split="validation")
audio_sample = ds[0]["audio"]["array"]
sr = ds[0]["audio"]["sampling_rate"]

def run_inference(data_row):
    audio_sample = data_row["audio"]["array"]
    sr = data_row["audio"]["sampling_rate"]
    
    # 3. Preprocess (pad/truncate + batch‐dim)
    inputs = feature_extractor(
        audio_sample,
        sampling_rate=sr,
        return_tensors="pt",        # => PyTorch tensors
        padding=True,               # pad to longest in batch
    )

    # 4. Inference for phoneme recognition
    with torch.no_grad():
        # Get phoneme log probabilities
        log_probs = phoneme_recognizer(inputs)

        # Recognize phoneme sequence
        phoneme_sequences = phoneme_recognizer.recognize(inputs)

    return log_probs, phoneme_sequences

log_probs, phoneme_sequences = run_inference(ds[0])

# Print output
print("Log probabilities shape:", log_probs.shape)  # (batch_size, seq_len, num_phonemes)
print("Recognized phoneme sequence:", phoneme_sequences[0])
print("Transcript for reference:", ds[0]["text"])

Log probabilities shape: torch.Size([1, 292, 50])
Recognized phoneme sequence: [')', 'w', 'm', 'w', '[UNK]', 'l', 'ɛ', 'n', 'ɪ', 'ʌ', 'ɒ', 'w', 'f', 'œ', 'ŋ', 'œ', 'ʁ', 'w', 'p', 'v', 'ʌ', 'd', 'f', 't', ')', 'w', '[UNK]', 'œ', 'p', '[UNK]', 'ɛ', 'ɔ', 'ʃ', '[UNK]', 'v', 'ʌ', 'e', 'w', 'ɲ', ')', 'i', 'ɾ', 'n', 'l', 'n', 'f', 'l', 'k', 't', '[PAD]', 'ɛ', 'l', 'j', 'f', 'n', 'ɔ', 'm', 'd', 'ɐ', 'p', 'v', 'ʎ', 'v', 'ŋ', 'f', 'k', 'ʌ', 'p', ')', '[PAD]', 'ʌ', 'w', 'ɲ', 'f', 't', 'i', ')', 'p', 'l', 'p', 'ɔ', 'f', 'ə', 'f', 'w', '[PAD]', 'j', 'd', 'n', '[PAD]', 'ɜ', 'w', 'r', 'v', 'ɐ', 'p', 's', 'a', 'j', '[UNK]', 'n', 'ʌ', 'w', 'ɜ', 'ɹ', 'p', 't', ')', 'ɐ', 's', 'ɛ', 'l', 'ɒ', 'w', 'm', 'ɾ', 'n', 'ɑ', 'w', 'ɛ', ')', 'ɐ', 'w', 'ʌ', 'l', 'i', 'j', 'ɹ', 'n', 'ŋ', 'ɜ', 'k', 'ɜ', '(', 'd', 'k', '[PAD]', 'm', 'ʃ', 'ɜ', 'w', 'ŋ', 'ɹ', ')', 'ɔ', 'ə', 'l', 'n', 'i', 'w', ')', 't', '[PAD]', 'w', 'i', 'k', 'p', 'ɑ', 'j', 'ɑ', 'e', 'p', 'l', 'ɡ', 't', ')']
Transcript for reference: MISTER QUILTER IS TH

## Dataset

Let's load our data in a Hugging Face dataset.

In [None]:
from datasets import load_dataset, Audio, Features, Sequence, Value

# 1. Location of your CSV
# csv_file = "train_phonemes_clean.csv"  # replace with your path
csv_file = "ground_truth_it_coder_2.csv"  # replace with your path


# 2. Define initial features: audio paths as plain strings, phonemes as plain strings
features = Features({
    "file_name": Value("string"),
    "phoneme_sequence": Value("string"),
})

# 3. Load the CSV into a DatasetDict (default split is 'train')
ds_dict = load_dataset("csv", data_files=csv_file, features=features)
dataset = ds_dict["train"]

# 4. Rename the audio-path column to 'audio' (required by Audio feature)
dataset = dataset.rename_column("file_name", "audio")

# 5. Cast 'audio' to the Audio type (will load the file when you access it)
dataset = dataset.cast_column("audio", Audio(sampling_rate=16_000))

# 6. Map + split phoneme strings into lists
def split_phonemes(example):
    # assume phonemes are space-separated, e.g. "AH0 T EH1 S T"
    example["phoneme_sequence"] = example["phoneme_sequence"].split()
    return example

dataset = dataset.map(split_phonemes)

# 7. Cast the phoneme_sequence column to a Sequence of strings
dataset = dataset.cast_column(
    "phoneme_sequence",
    Sequence(feature=Value("string"))
)

# Now 'dataset' has:
#   - dataset[i]["audio"] → { "array": np.ndarray, "sampling_rate": 16000 }
#   - dataset[i]["phoneme_sequence"] → list of strings
print(dataset)
print(dataset[0]["audio"])
print(dataset[0]["phoneme_sequence"])


Dataset({
    features: ['audio', 'phoneme_sequence'],
    num_rows: 932
})
{'path': 'Hackathon_ASR/2_Audiofiles/Decoding_IT_T1/1001_edugame2023_59aa8ecf74c44db2adf56d71d1705cf5_1de23ac3deaf4b4d8c7db6d0cc9d6bfe.wav', 'array': array([0.        , 0.        , 0.        , ..., 0.00378418, 0.00424194,
       0.        ], shape=(364544,)), 'sampling_rate': 16000}
['vuzo[PAD]seɡa[PAD]klofɛno[PAD]raviʎo[PAD]da[PAD]pe[PAD]tarse[PAD]doridzːa[PAD]prateʎa[PAD]aː[PAD]ɛrɾe[PAD]lo[PAD]beɲole[PAD]fla[PAD]vɛstro[PAD]kʊɲaripːo']


Just a simple test.

In [12]:
import torchaudio
import warnings

wavfile, sampling_rate = torchaudio.load("Hackathon_ASR/2_Audiofiles/Decoding_IT_T1/1001_edugame2023_59aa8ecf74c44db2adf56d71d1705cf5_1de23ac3deaf4b4d8c7db6d0cc9d6bfe.wav")
if sampling_rate != 16000:
    warnings.warn(f"Sampling rate should be 16000 Hz, is {sampling_rate} Hz")

"same values" if torch.all(dataset[0]["audio"]["array"] == wavfile) else "different data"

'same values'

## Putting stuff together

Now we run the model on our in-house dataset.

In [30]:
for data in dataset:
    run_inference(data)
    break

