## Model Creation

Let's create a model first, with some vocab.
The output is 

In [2]:
import json

import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import WavLMModel, AutoFeatureExtractor
from datasets import load_dataset
import numpy as np

# ————————————————————————————————————————————————————————————————————————
# PhonemeRecognizer: WavLM + CTC for phoneme speech recognition
# ————————————————————————————————————————————————————————————————————————

# Load vocab from file
with open("phoneme_tokenizer/vocab.json") as vocab_file:
    vocab = json.load(vocab_file)

# IT + FR phonemes + blank
VOCAB = {
  "0": "ʒ",
  "1": "ɹ",
  "2": "j",
  "3": "d",
  "4": "ɲ",
  "5": "ʌ",
  "6": "[UNK]",
  "7": "ɒ",
  "8": "ɐ",
  "9": "ʃ",
  "10": "ɔ",
  "11": "f",
  "12": "ø",
  "13": "z",
  "14": "ŋ",
  "15": "i",
  "16": "u",
  "17": "̃",
  "18": "o",
  "19": "œ",
  "20": "a",
  "21": "(",
  "22": "ə",
  "23": "ɜ",
  "24": "ɾ",
  "25": "ː",
  "26": "̪",
  "27": "e",
  "28": "b",
  "29": "ʁ",
  "30": "w",
  "31": "n",
  "32": "p",
  "33": "y",
  "34": "ɡ",
  "35": "ɪ",
  "36": "r",
  "37": "v",
  "38": "t",
  "39": ")",
  "40": "m",
  "41": "k",
  "42": "ʊ",
  "43": "ʎ",
  "44": "ɑ",
  "45": "s",
  "46": "l",
  "47": "[PAD]",
  "48": "ɛ",
  "49": '<blank>' # blank token for CTC
}
PHONEME_DICT = {v: int(k) for k, v in VOCAB.items()}

NUM_PHONEMES = len(PHONEME_DICT)

class PhonemeRecognizer(nn.Module):
    def __init__(self, wavlm_model, num_phonemes=NUM_PHONEMES):
        super().__init__()
        self.wavlm = wavlm_model

        # Get the hidden size from the WavLM model
        hidden_size = self.wavlm.config.hidden_size

        # Add a dropout layer for regularization
        self.dropout = nn.Dropout(0.1)

        # Linear layer to map from WavLM hidden states to phoneme classes (including blank)
        self.phoneme_classifier = nn.Linear(hidden_size, num_phonemes)

    def forward(self, inputs):
        # Get WavLM embeddings
        outputs = self.wavlm(**inputs)
        hidden_states = outputs.last_hidden_state

        # Apply dropout
        hidden_states = self.dropout(hidden_states)

        # Apply the linear layer to get logits for each time step
        logits = self.phoneme_classifier(hidden_states)

        # Apply log softmax for CTC loss
        log_probs = F.log_softmax(logits, dim=-1)

        return log_probs
    
    def classify_to_phonemes(self, log_probs):
        # Simple greedy decoding (for demonstration)
        # In a real system, you would use beam search with ctcdecode
        predictions = torch.argmax(log_probs, dim=-1).cpu().numpy()

        # Convert to phoneme sequences with CTC decoding rules (merge repeats, remove blanks)
        phoneme_sequences = []
        for pred_seq in predictions:
            seq = []
            prev = -1
            for p in pred_seq:
                # Skip blanks (index 0) and repeated phonemes (CTC rules)
                if p != 0 and p != prev:
                    # Convert index back to phoneme
                    phoneme = list(PHONEME_DICT.keys())[list(PHONEME_DICT.values()).index(p)]
                    seq.append(phoneme)
                prev = p
            phoneme_sequences.append(seq)

        return phoneme_sequences


    def recognize(self, inputs, beam_width=100):
        """Perform phoneme recognition with beam search decoding"""
        self.eval()
        with torch.no_grad():
            # Forward pass to get log probabilities
            log_probs = self(inputs)

            return self.classify_to_phonemes(log_probs)

    def tokenize(self, char_list):
        """Go from a list of characters to a list of indices."""
        return torch.tensor([PHONEME_DICT[x] for x in char_list])
    
    def get_embedding(self, char_list):
        tokens = self.tokenize(char_list)
        out_tensor = torch.zeros((len(tokens), len(PHONEME_DICT)))
        for i, token_id in enumerate(tokens):
            out_tensor[i, token_id] = 1
        return out_tensor

# ————————————————————————————————————————————————————————————————————————
# Method A: Using the PhonemeRecognizer for speech-to-phoneme ASR
# ————————————————————————————————————————————————————————————————————————

# 1. Load the feature extractor and model
feature_extractor = AutoFeatureExtractor.from_pretrained("microsoft/wavlm-base-plus")
wavlm_model = WavLMModel.from_pretrained("microsoft/wavlm-base-plus")

# Create the phoneme recognizer with the WavLM model
phoneme_recognizer = PhonemeRecognizer(wavlm_model)
phoneme_recognizer.eval()  # disable dropout, etc.

# 2. Load an example audio file (here using a small demo from `datasets`)
#    The `audio["array"]` is a NumPy array of floats; sampling_rate is an int.
ds = load_dataset("hf-internal-testing/librispeech_asr_demo", "clean", split="validation")
audio_sample = ds[0]["audio"]["array"]
sr = ds[0]["audio"]["sampling_rate"]

def get_features(data_row):
    audio_sample = data_row["audio"]["array"]
    sr = data_row["audio"]["sampling_rate"]

    # 3. Preprocess (pad/truncate + batch‐dim)
    inputs = feature_extractor(
        audio_sample,
        sampling_rate=sr,
        return_tensors="pt",        # => PyTorch tensors
        padding=True,               # pad to longest in batch
    )
    return inputs

def run_inference(data_row, model):
    """Return log probs and most likely phonemes."""
    inputs = get_features(data_row)

    # 4. Inference for phoneme recognition
    with torch.no_grad():
        # Get phoneme log probabilities
        log_probs = model(inputs)

        # Recognize phoneme sequence
        phoneme_sequences = model.recognize(inputs)

    return log_probs, phoneme_sequences

log_probs, phoneme_sequences = run_inference(ds[0], phoneme_recognizer)

# Print output
print("Log probabilities shape:", log_probs.shape)  # (batch_size, seq_len, num_phonemes)
print("Recognized phoneme sequence:", phoneme_sequences[0])
print("Transcript for reference:", ds[0]["text"])



Log probabilities shape: torch.Size([1, 292, 50])
Recognized phoneme sequence: ['ø', 'p', 'ø', 'p', '<blank>', 'n', '<blank>', 'j', '[UNK]', ')', 'u', 'ɹ', 'ɲ', 'ɪ', 'd', 'n', 'ɲ', 'n', 'ŋ', '[UNK]', 'k', 'j', 'ʊ', '[UNK]', 't', 'ɐ', 'j', 'œ', 'ɲ', ')', 'r', 'u', 'y', 'u', 'z', '<blank>', 'e', 'p', 'e', 'ɲ', 'r', 'n', 'ʌ', 'n', 't', 'v', 'w', 'v', 'n', 'ɔ', 'ɒ', 'ɜ', 'f', 'œ', 'ɲ', 'e', '[PAD]', 'ɡ', 'e', 'r', 'p', 'u', 'ʊ', 'ɡ', 'e', 'ː', 'ɲ', 't', 'y', '[UNK]', 'i', 'ɒ', 'j', 'ɐ', 'v', 'ɪ', 'n', 'ɲ', 'r', 'ŋ', '[UNK]', 'ʊ', 'ɔ', 'e', 'n', '[PAD]', 'ɜ', 'r', 'ɡ', 'f', 'ɜ', 'ː', '(', 'ɲ', '[PAD]', 'ɒ', 'm', 'ɡ', 'ɾ', 'y', 'ɾ', 'p', 'u', ')', 'ɾ', 'n', 'j', '[UNK]', 'e', 'ɔ', '[PAD]', 'ʊ', ')', 'ʊ', 'y', 'ɲ', 'ʊ', '[UNK]', 'ɾ', '(', 'p', 'v', 'œ', 'k', 'ʊ', 'ɾ', 'ʊ', 'ɾ', 'n', 'œ', 'ː', 'r', '[UNK]', 'ɡ', 'y', 'f', 'v', 'œ', 'u', 'e', 'f', 'ɲ', 'n', 'œ', 'ː', 'ɜ', 'e', 'p', 'k', 'n', 'ɜ', 'ɾ', 'ɲ', 'r', 'ɐ', 'œ', 'e', 'v', 'ʌ', 'v', 'ʊ', 'ə', '<blank>', 'j', 'y', 'n', 'ɜ', 'w', 'ɜ', 'w'

## Dataset

Let's load our data in a Hugging Face dataset.

In [3]:
from datasets import load_dataset, Audio, Features, Sequence, Value

# 1. Location of your CSV
# csv_file = "train_phonemes_clean.csv"  # replace with your path
csv_file = "ground_truth_it_coder_2.csv"  # replace with your path


# 2. Define initial features: audio paths as plain strings, phonemes as plain strings
features = Features({
    "file_name": Value("string"),
    "phoneme_sequence": Value("string"),
})

# 3. Load the CSV into a DatasetDict (default split is 'train')
ds_dict = load_dataset("csv", data_files=csv_file, features=features)
dataset = ds_dict["train"]

# 4. Rename the audio-path column to 'audio' (required by Audio feature)
dataset = dataset.rename_column("file_name", "audio")

# 5. Cast 'audio' to the Audio type (will load the file when you access it)
dataset = dataset.cast_column("audio", Audio(sampling_rate=16_000))

# 6. Map + split phoneme strings into lists
def split_phonemes(example):
    # assume phonemes are space-separated, e.g. "AH0 T EH1 S T"
    example["phoneme_sequence"] = example["phoneme_sequence"].split()
    return example

dataset = dataset.map(split_phonemes)

# 7. Cast the phoneme_sequence column to a Sequence of strings
dataset = dataset.cast_column(
    "phoneme_sequence",
    Sequence(feature=Value("string"))
)

# Now 'dataset' has:
#   - dataset[i]["audio"] → { "array": np.ndarray, "sampling_rate": 16000 }
#   - dataset[i]["phoneme_sequence"] → list of strings
print(dataset)
print(dataset[0]["audio"])
print(dataset[0]["phoneme_sequence"])


Dataset({
    features: ['audio', 'phoneme_sequence'],
    num_rows: 932
})
{'path': 'Hackathon_ASR/2_Audiofiles/Decoding_IT_T1/1001_edugame2023_59aa8ecf74c44db2adf56d71d1705cf5_1de23ac3deaf4b4d8c7db6d0cc9d6bfe.wav', 'array': array([0.        , 0.        , 0.        , ..., 0.00378418, 0.00424194,
       0.        ], shape=(364544,)), 'sampling_rate': 16000}
['vuzo[PAD]seɡa[PAD]klofɛno[PAD]raviʎo[PAD]da[PAD]pe[PAD]tarse[PAD]doridzːa[PAD]prateʎa[PAD]aː[PAD]ɛrɾe[PAD]lo[PAD]beɲole[PAD]fla[PAD]vɛstro[PAD]kʊɲaripːo']


Just a simple test.

In [4]:
import torchaudio
import warnings

wavfile, sampling_rate = torchaudio.load("Hackathon_ASR/2_Audiofiles/Decoding_IT_T1/1001_edugame2023_59aa8ecf74c44db2adf56d71d1705cf5_1de23ac3deaf4b4d8c7db6d0cc9d6bfe.wav")
if sampling_rate != 16000:
    warnings.warn(f"Sampling rate should be 16000 Hz, is {sampling_rate} Hz")

"same values" if torch.all(dataset[0]["audio"]["array"] == wavfile) else "different data"

'same values'

## Putting stuff together

Now we run the model on our in-house dataset. We also define a scoring.

In [5]:
transcriptions = []
losses = []
scored_dataset = dataset.select(range(3))


def smart_split_coder(sentence):
    output = []
    in_brackets = False
    for char in sentence:
        if in_brackets:
            output[-1] += char
        else:
            output.append(char)

        if char == '[':
            in_brackets = True
        elif char == ']':
            in_brackets = False
    return output


for data in scored_dataset:
    log_probs, phonemes = run_inference(data, phoneme_recognizer)
    coding = smart_split_coder(data["phoneme_sequence"][0])
    loss = len(phonemes[0]) - np.sum([p in coding for p in phonemes[0]])
    transcriptions.append(phonemes[0])
    losses.append(loss)
    
scored_dataset = scored_dataset.add_column("transcription", transcriptions).add_column("loss", losses)


## Training

Now, we train the model.

In [6]:
for data in scored_dataset.select_columns(["loss", "transcription"]):
    print(data)


{'loss': 224, 'transcription': ['<blank>', 'ʃ', '[UNK]', ')', 'ɲ', ')', 'ɲ', ')', 'n', ')', 'n', 'j', 'n', 'ø', 'ɛ', 'd', 'ɛ', 'ʌ', 'ɛ', 'ʌ', 'ɾ', 't', 'f', 'ɾ', 'k', 'f', 'ɜ', ')', 'n', 'r', 'ɾ', 'ŋ', 'n', 'œ', 'ː', 'v', ')', 'a', 'ɹ', 'ɜ', 'p', 'ŋ', 'ɡ', 'ɪ', 'v', 'i', 'ɜ', 'n', 'f', 'ɹ', 'ɜ', 'ŋ', 'ɪ', 'ɔ', '[PAD]', 'ɔ', '<blank>', 'ɲ', 'ɐ', 'j', 'ɪ', 'ʃ', 'u', 'e', 'n', 'ɔ', 'y', 'œ', 'ɜ', 'r', 'y', '[UNK]', 'œ', 'j', 'e', 'u', 'v', 'ɾ', 'ɜ', 'ŋ', 'œ', '<blank>', 'u', '[PAD]', 'v', '[PAD]', 'ɒ', 'ɔ', 'o', 'ɐ', 'œ', '<blank>', 'v', 't', 'n', 'k', '(', 'ɜ', ')', 'k', 'ʃ', 'd', 'f', 's', 'k', ')', 'j', 'd', 'ɔ', 'ø', 'ɔ', 'ɾ', 'l', 'p', 'ɪ', 'v', 'ʊ', 'œ', 'ɾ', 'œ', 'l', 'œ', 'j', 'ɾ', 'u', 'v', 'n', '(', 'ɜ', 'k', 'ɾ', 'f', 'ɲ', ')', 'œ', 'ː', 'ʊ', 'ɪ', 'r', 'a', 'e', 'n', 'ɲ', 'f', 'ɐ', 'j', 'ɪ', ')', 'n', 'ɾ', 'ɲ', 's', 'n', 'k', '(', 'k', 'ɪ', 'ɹ', 'ɜ', '<blank>', 'ɜ', 'ɲ', 'ɐ', 'j', 'ɪ', 'ʁ', 'r', 'z', 'e', ')', 'ɡ', 'ʌ', 'n', 'ɲ', 'b', 'ɲ', 'u', 'ɹ', 'ɜ', 'n', 'ɾ', 'ɹ', 'ə', '<b

## Model training

Now we want a phoneme recognition.
It means to train the last layer of the model to the ground truth.

In [None]:
import csv
import os
import re

phoneme_recognizer.train()
linear_optimizer = torch.optim.Adam(phoneme_recognizer.phoneme_classifier.parameters(), lr=1e-3)


def calculate_ctc_loss(log_probs, target_sequence):
    """Calculates CTC loss."""
    # Create input_lengths and target_lengths tensors
    input_lengths = torch.tensor([1])  # Batch size of 1
    target_lengths = torch.tensor([1])  # Batch size of 1

    # Calculate CTC loss
    loss = F.ctc_loss(
        log_probs,
        target_sequence,
        input_lengths=input_lengths,
        target_lengths=target_lengths
    )
    return loss

MODEL_DIR = "models"
OUTPUTS_DIR = "outputs"


def prepare_folders():
    if not os.path.exists(MODEL_DIR):
        os.makedirs(MODEL_DIR)
    if not os.path.exists(OUTPUTS_DIR):
        os.makedirs(OUTPUTS_DIR)
    

def load_last_checkpoint(model_dir):
    increment = -1
    # Load the latest version
    pth_files = [f for f in os.listdir(model_dir) if f.endswith(".pth")]
    increment = len(pth_files)

    if not pth_files:
        warnings.warn("No .pth files found in the model directory! Starting from scratch!")
    else:
        # Sort the files by their index (last number)
        pth_files.sort(key=lambda x: int(re.search(r"(\d+)\.pth$", x)[1]))

        # Load the latest version
        checkpoint = pth_files[-1]  # Load the last element (highest index)
        match = re.search(r"(\d+)\.pth$", checkpoint)
        if match:
            increment = int(match[1])
            # Load the linear layer's parameters
            phoneme_recognizer.phoneme_classifier.load_state_dict(
                torch.load(f"{model_dir}/{checkpoint}")
            )
        else:
            warnings.warn("Couldn't find a model! Starting from scratch!")
    return increment

prepare_folders()
increment = load_last_checkpoint(MODEL_DIR)

# Freeze the wavlm model
for param in phoneme_recognizer.wavlm.parameters():
    param.requires_grad = False


def write_to_csv(row):
    with open(f'{OUTPUTS_DIR}/phonemes_training.csv', 'a') as file:
        writer = csv.writer(file)
        writer.writerow(row)


# Training loop
for epoch in range(10):
    for i, data in enumerate(dataset.shuffle().select(range(50))):
        inputs = get_features(data)
        log_probs = phoneme_recognizer(inputs)
        split_phonemes = smart_split_coder(data["phoneme_sequence"][0])
        target = phoneme_recognizer.tokenize(split_phonemes)
        loss = calculate_ctc_loss(log_probs[0], target.reshape([1, -1]))
        linear_optimizer.zero_grad()
        loss.backward()
        linear_optimizer.step()
        write_to_csv(
            [
                increment, epoch, i, loss.item(),
                "".join(phoneme_recognizer.classify_to_phonemes(log_probs)[0]),
                "".join(split_phonemes)
            ]
        )
        print(f"Epoch {epoch}, Loss: {loss.item()}")
        increment += 1
        torch.save(
            phoneme_recognizer.phoneme_classifier.state_dict(),
            f"{MODEL_DIR}/phoneme_classifier_epoch_{epoch}_step_{i}_{increment}.pth"
        )
    

## Binary classification

We have a model roughly trained for phonemes.
We want a binary classification though.
We won't do that for now as it would be an end-to-end pipeline, defeating the purpose of the created pipeline.