In [184]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchaudio
import numpy as np
import pandas as pd
from transformers import WavLMForCTC, WavLMConfig, Wav2Vec2Processor, Wav2Vec2CTCTokenizer, Wav2Vec2FeatureExtractor
import json
import os
from dataclasses import dataclass
from typing import Dict, List, Optional, Union, Tuple

# Add import for local WavLM model
import sys
sys.path.append('/home/zohreh/ownCloud/hackathon2025/local_test/unilm/wavlm')
from WavLM import WavLM, WavLMConfig as LocalWavLMConfig

# Install required packages if not already installed
# !pip install transformers torchaudio datasets

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Path to local WavLM model
LOCAL_MODEL_PATH = '/home/zohreh/ownCloud/hackathon2025/local_test/models/WavLM-Base+.pt'

# Function to load local WavLM model
def load_local_wavlm_model():
    """Load WavLM model from local file"""
    checkpoint = torch.load(LOCAL_MODEL_PATH, map_location='cpu')
    cfg = LocalWavLMConfig(checkpoint['cfg'])
    model = WavLM(cfg)
    model.load_state_dict(checkpoint['model'])
    return model, cfg

In [185]:
def create_and_get_processor(phoneme_list_path, output_dir="./phoneme_tokenizer"):
    """Create a phoneme tokenizer and processor or load existing ones."""
    # Create tokenizer
    if os.path.exists(output_dir):
        # Load existing tokenizer
        tokenizer = Wav2Vec2CTCTokenizer.from_pretrained(output_dir)
    else:
        # Create new tokenizer from phoneme list
        with open(phoneme_list_path, "r") as f:
            phoneme_dict = json.load(f)

        vocab_dict = {p: i for i, p in enumerate(phoneme_dict["phonemes"])}
        # Add special tokens
        vocab_dict["<pad>"] = len(vocab_dict)
        vocab_dict["<unk>"] = len(vocab_dict)
        vocab_dict["<s>"] = len(vocab_dict)
        vocab_dict["</s>"] = len(vocab_dict)

        os.makedirs(output_dir, exist_ok=True)
        with open(os.path.join(output_dir, "vocab.json"), "w") as f:
            json.dump(vocab_dict, f)

        tokenizer = Wav2Vec2CTCTokenizer(
            os.path.join(output_dir, "vocab.json"),
            unk_token="<unk>",
            pad_token="<pad>",
            word_delimiter_token="|",
        )

        tokenizer.save_pretrained(output_dir)

    # Get feature extractor
    # Instead of downloading from Microsoft repo, we just create a new feature extractor
    # with the same parameters as WavLM
    feature_extractor = Wav2Vec2FeatureExtractor(
        feature_size=1,
        sampling_rate=16000,
        padding_value=0.0,
        do_normalize=True,
        return_attention_mask=False,
    )

    # Create processor (combines feature extractor and tokenizer)
    processor = Wav2Vec2Processor(feature_extractor=feature_extractor, tokenizer=tokenizer)

    return processor, len(tokenizer.get_vocab())

In [186]:
class PhonemeDataset(Dataset):
    def __init__(self, data_path, processor, sample_rate=16000):
        """
        Args:
            data_path: Path to CSV file with columns 'file_path', 'phoneme_sequence'
            processor: WavLM processor
            sample_rate: Target sample rate
        """
        self.df = pd.read_csv(data_path)
        self.processor = processor
        self.sample_rate = sample_rate

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        # Load audio
        file_path = self.df.iloc[idx]['file_name']
        waveform, sample_rate = torchaudio.load(file_path)

        # Resample if necessary
        if sample_rate != self.sample_rate:
            resampler = torchaudio.transforms.Resample(sample_rate, self.sample_rate)
            waveform = resampler(waveform)

        # Convert to mono if necessary
        if waveform.shape[0] > 1:
            waveform = torch.mean(waveform, dim=0, keepdim=True)

        # Get phoneme sequence
        phoneme_sequence = self.df.iloc[idx]['phoneme_sequence']

        # Process inputs - return raw audio data
        input_values = waveform.squeeze().numpy()

        # Process labels
        with self.processor.as_target_processor():
            labels = self.processor(phoneme_sequence).input_ids

        return {
            "input_values": input_values,  # Return raw audio data
            "labels": labels,
            "phoneme_sequence": phoneme_sequence
        }

# Completely rewritten data collator
@dataclass
class DataCollatorCTCWithPadding:
    processor: Wav2Vec2Processor
    padding: Union[bool, str] = True

    def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor, np.ndarray]]]) -> Dict[str, torch.Tensor]:
        # Extract raw audio waveforms and process them with the feature extractor
        waveforms = [feature["input_values"] for feature in features]
        batch = self.processor(
            waveforms,
            padding=self.padding,
            return_tensors="pt",
            sampling_rate=16000,
            return_attention_mask=False,
        )

        # Process text labels
        label_features = [{"input_ids": feature["labels"]} for feature in features]
        with self.processor.as_target_processor():
            labels_batch = self.processor.pad(
                label_features,
                padding=self.padding,
                return_tensors="pt"
            )

        # Replace padding with -100 for loss calculation
        labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)

        # Add labels to the batch
        batch["labels"] = labels

        # Add original phoneme sequences for later use
        batch["phoneme_sequences"] = [feature["phoneme_sequence"] for feature in features]

        return batch

In [187]:
class WavLMForPhonemeRecognition(nn.Module):
    def __init__(self, num_phonemes):
        super().__init__()
        # Load pre-trained WavLM model from local file
        local_model, local_cfg = load_local_wavlm_model()

        # Create a standard HuggingFace config for the CTC head
        self.config = WavLMConfig()
        self.config.ctc_loss_reduction = "mean"
        self.config.ctc_zero_infinity = True
        self.config.vocab_size = num_phonemes
        self.config.num_labels = num_phonemes
        print(num_phonemes)

        # Initialize WavLM with CTC head for phoneme recognition
        # We use the default WavLMForCTC model initially, but replace its WavLM encoder
        # with our locally loaded model
        self.model = WavLMForCTC(self.config)

        # Replace the wavlm model inside WavLMForCTC with our local model
        # (This assumes the model structure is compatible)
        self.model.wavlm = local_model

        # Initialize the CTC head properly
        self.model.lm_head = nn.Linear(local_cfg.encoder_embed_dim, num_phonemes)

    def forward(self, input_values, labels=None):
        return self.model(
            input_values=input_values,
            labels=labels
        )

In [188]:
def train_model(model, train_loader, val_loader, optimizer, num_epochs, device):
    """Training loop with validation"""
    model.train()
    best_val_loss = float('inf')

    for epoch in range(num_epochs):
        # Training
        train_loss = 0
        model.train()

        for batch in train_loader:
            optimizer.zero_grad()

            input_values = batch["input_values"].to(device)
            print(input_values.shape)
            attention_mask = batch.get("attention_mask", None)
            print(attention_mask)

            if attention_mask is not None:
                attention_mask = attention_mask.to(device)
            labels = batch["labels"].to(device)
            print(labels.shape)

            outputs = model(
                input_values=input_values,
                labels=labels
            )

            loss = outputs.loss
            train_loss += loss.item()

            loss.backward()
            optimizer.step()

        avg_train_loss = train_loss / len(train_loader)
        print(f"Epoch {epoch+1}/{num_epochs}, Train Loss: {avg_train_loss:.4f}")

        # Validation
        val_loss = 0
        model.eval()

        with torch.no_grad():
            for batch in val_loader:
                input_values = batch["input_values"].to(device)
                attention_mask = batch.get("attention_mask", None)
                if attention_mask is not None:
                    attention_mask = attention_mask.to(device)
                labels = batch["labels"].to(device)

                outputs = model(
                    input_values=input_values,

                    labels=labels
                )

                loss = outputs.loss
                val_loss += loss.item()

        avg_val_loss = val_loss / len(val_loader)
        print(f"Epoch {epoch+1}/{num_epochs}, Val Loss: {avg_val_loss:.4f}")

        # Save best model
        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            torch.save(model.state_dict(), "best_phoneme_model.pt")
            print(f"Saved best model with validation loss: {best_val_loss:.4f}")

In [189]:
def predict_with_time_indices(model, processor, audio_path, device, sample_rate=16000):
    """
    Predict phonemes with their time indices from an audio file
    using the custom WavLM model
    """
    model.eval()

    # Load and preprocess audio
    waveform, file_sample_rate = torchaudio.load(audio_path)
    if file_sample_rate != sample_rate:
        resampler = torchaudio.transforms.Resample(file_sample_rate, sample_rate)
        waveform = resampler(waveform)

    if waveform.shape[0] > 1:
        waveform = torch.mean(waveform, dim=0, keepdim=True)

    # Process audio
    input_values = processor(waveform.squeeze().numpy(),
                             sampling_rate=sample_rate,
                             return_tensors="pt").input_values.to(device)

    # Get logits
    with torch.no_grad():
        # Just use input_values without attention_mask
        outputs = model(input_values=input_values)
        logits = outputs.logits

    # Get predicted tokens
    predicted_ids = torch.argmax(logits, dim=-1).squeeze()

    # Decode tokens to phonemes
    predicted_phonemes = processor.tokenizer.convert_ids_to_tokens(predicted_ids.tolist())

    # Get time indices
    # First, remove repeated tokens (CTC decoding)
    decoded_phonemes = []
    time_indices = []

    prev_phoneme = None
    current_start = 0

    for i, phoneme in enumerate(predicted_phonemes):
        # Skip blank tokens (represented by special tokens) and repeated tokens
        if phoneme in processor.tokenizer.all_special_tokens or phoneme == prev_phoneme:
            continue

        if prev_phoneme is not None:
            # Convert frame index to time in seconds
            # (20ms per frame at 16kHz with stride of 320 samples)
            start_time = current_start * 0.02
            end_time = i * 0.02
            decoded_phonemes.append(prev_phoneme)
            time_indices.append([start_time, end_time])

        prev_phoneme = phoneme
        current_start = i

    # Don't forget the last phoneme
    if prev_phoneme is not None and prev_phoneme not in processor.tokenizer.all_special_tokens:
        start_time = current_start * 0.02
        end_time = len(predicted_phonemes) * 0.02
        decoded_phonemes.append(prev_phoneme)
        time_indices.append([start_time, end_time])

    return decoded_phonemes, time_indices

training csv should look like this
file_path,phoneme_sequence
/path/to/audio1.wav,"a b c"
/path/to/audio2.wav,"d e f"

In [190]:

# Example paths - replace with your own
train_data_path = "train_phonemes_clean.csv"  # Should contain file_path and phoneme_sequence
val_data_path = "val_phonemes_clean.csv"      # Should contain file_path and phoneme_sequence
phoneme_list_path = "phoneme_list.json"

# Create processor from the phoneme list using the function defined above
processor, vocab_size = create_and_get_processor(phoneme_list_path)

# Create datasets (without time indices)
train_dataset = PhonemeDataset(train_data_path, processor)
val_dataset = PhonemeDataset(val_data_path, processor)

# Create data loaders
data_collator = DataCollatorCTCWithPadding(processor=processor)
train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True, collate_fn=data_collator)
val_loader = DataLoader(val_dataset, batch_size=4, collate_fn=data_collator)
print(vocab_size)

# Initialize model
model = WavLMForPhonemeRecognition(vocab_size).to(device)

# Setup optimizer
optimizer = optim.AdamW(model.parameters(), lr=1e-4)


# Train model
train_model(model, train_loader, val_loader, optimizer, num_epochs=10, device=device)

# Load best model for inference
model.load_state_dict(torch.load("best_phoneme_model.pt"))

# Example inference with time indices
audio_path = "test_audio.wav"
predicted_phonemes, time_indices = predict_with_time_indices(model, processor, audio_path, device)

# Print results
print("Predicted phonemes with time indices:")
for phoneme, (start, end) in zip(predicted_phonemes, time_indices):
    print(f"Phoneme: {phoneme}, Start: {start:.2f}s, End: {end:.2f}s")

# Save model
torch.save(model.state_dict(), "./phoneme_recognition_model.pt")
processor.save_pretrained("./phoneme_recognition_processor")

36


  checkpoint = torch.load(LOCAL_MODEL_PATH, map_location='cpu')


  WeightNorm.apply(module, name, dim)


36
torch.Size([4, 479232])
None
torch.Size([4, 202])
torch.Size([4, 479232])
None
torch.Size([4, 202])




TypeError: _forward_unimplemented() got an unexpected keyword argument 'attention_mask'

In [None]:
model

WavLMForPhonemeRecognition(
  (model): WavLMForCTC(
    (wavlm): WavLM(
      (feature_extractor): ConvFeatureExtractionModel(
        (conv_layers): ModuleList(
          (0): Sequential(
            (0): Conv1d(1, 512, kernel_size=(10,), stride=(5,), bias=False)
            (1): Dropout(p=0.0, inplace=False)
            (2): Fp32GroupNorm(512, 512, eps=1e-05, affine=True)
            (3): GELU(approximate='none')
          )
          (1-4): 4 x Sequential(
            (0): Conv1d(512, 512, kernel_size=(3,), stride=(2,), bias=False)
            (1): Dropout(p=0.0, inplace=False)
            (2): GELU(approximate='none')
          )
          (5-6): 2 x Sequential(
            (0): Conv1d(512, 512, kernel_size=(2,), stride=(2,), bias=False)
            (1): Dropout(p=0.0, inplace=False)
            (2): GELU(approximate='none')
          )
        )
      )
      (post_extract_proj): Linear(in_features=512, out_features=768, bias=True)
      (dropout_input): Dropout(p=0.1, inplace=Fal

In [None]:
vocab_size

36

In [None]:
waveform, sample_rate = torchaudio.load("Hackathon_ASR/2_Audiofiles/Decoding_IT_T1/1001_edugame2023_59aa8ecf74c44db2adf56d71d1705cf5_1de23ac3deaf4b4d8c7db6d0cc9d6bfe.wav")

In [None]:
phoneme_list_path

'phoneme_list.json'