# Training a Speech Tokenizer with WavLM and K-means

## Introduction

This notebook was created by [Jupyter AI](https://github.com/jupyterlab/jupyter-ai) with the following prompt:

> /generate I want to train a speech tokenizer based on WavLM and k-means algorithm. The dataset is the LibriSpeech. 

# Summary

This Jupyter notebook details the step-by-step process of training a speech tokenizer utilizing the WavLM model and the k-means algorithm on the LibriSpeech dataset. It begins with the setup and installation of necessary libraries, followed by the loading and preprocessing of the dataset. The notebook then describes the extraction of audio features, the implementation of the k-means clustering algorithm, and the training of the tokenizer using these features. Subsequent sections evaluate the tokenizer's performance using various metrics and visualization techniques, save the trained model for future use, and demonstrate how to apply the tokenizer to new audio data, complete with examples and visualizations of the tokenization process.

## 2. Data Loading

In [None]:
# Section 2: Data Loading

In [None]:
# Import necessary libraries
import os
import torchaudio
from torchaudio.datasets import LIBRISPEECH

In [None]:
# Define a function to load the LibriSpeech dataset
def load_librispeech_data(root_dir, url='train-clean-100', download=True):
    """
    Load the LibriSpeech dataset from the specified directory.

In [None]:
    Parameters:
    root_dir (str): The root directory where the dataset will be stored.
    url (str): The specific dataset subset to load (e.g., 'train-clean-100').
    download (bool): If True, download the dataset if not already present.

In [None]:
    Returns:
    list: A list of tuples containing audio tensors and their corresponding transcripts.
    """
    # Load the dataset
    dataset = LIBRISPEECH(root=root_dir, url=url, download=download)

In [None]:
    # Prepare a list to hold the audio data and transcripts
    data = []
    
    # Loop through the dataset and extract audio and transcript
    for waveform, sample_rate, utterance, speaker_id, chapter_id, utterance_id in dataset:
        data.append((waveform, sample_rate, utterance))

In [None]:
    return data

In [None]:
# Set the root directory for the dataset
root_directory = './librispeech_data'

In [None]:
# Load the dataset
librispeech_data = load_librispeech_data(root_directory)

In [None]:
# Display the number of samples loaded and an example entry
print(f"Number of audio samples loaded: {len(librispeech_data)}")
print("Example entry (waveform shape, sample rate, transcript):")
print(librispeech_data[0])  # Display the first entry

## 3. Feature Extraction

In [None]:
# Section 3: Feature Extraction

In [None]:
# Import necessary libraries
import torch
import torchaudio
from transformers import WavLMModel, WavLMProcessor

In [None]:
# Load the pre-trained WavLM model and processor
model_name = "microsoft/wavlm-base-960h"
processor = WavLMProcessor.from_pretrained(model_name)
model = WavLMModel.from_pretrained(model_name)

In [None]:
# Set the model to evaluation mode
model.eval()

In [None]:
# Function to extract features from audio files
def extract_features(audio_file):
    # Load the audio file
    audio_input, sample_rate = torchaudio.load(audio_file)

In [None]:
    # Resample if necessary (WavLM expects 16kHz)
    if sample_rate != 16000:
        resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)
        audio_input = resampler(audio_input)

In [None]:
    # Process the audio input for WavLM
    inputs = processor(audio_input.squeeze(0), sampling_rate=16000, return_tensors="pt", padding=True)

In [None]:
    # Move inputs to the same device as the model (CPU/GPU)
    inputs = {k: v.to(model.device) for k, v in inputs.items()}

In [None]:
    # Extract features using WavLM
    with torch.no_grad():  # Disable gradient calculation
        features = model(**inputs).last_hidden_state
    
    # Return the features
    return features.squeeze(0)  # Remove the batch dimension

In [None]:
# Example usage: Extract features from an audio file
audio_file_path = "path/to/your/audio/file.wav"  # Replace with your audio file path
features = extract_features(audio_file_path)

In [None]:
# Print the shape of the extracted features
print(f"Extracted features shape: {features.shape}")  # Should print (sequence_length, feature_dimension)

## 4. K-means Clustering

In [None]:
# Section 4: K-means Clustering

In [None]:
import numpy as np

In [None]:
# Function to initialize centroids randomly from the dataset
def initialize_centroids(X, k):
    """
    Initialize k centroids randomly from the dataset.
    
    Parameters:
    X : np.ndarray
        The input dataset of shape (n_samples, n_features).
    k : int
        The number of clusters (centroids) to initialize.

In [None]:
    Returns:
    centroids : np.ndarray
        Initialized centroids of shape (k, n_features).
    """
    n_samples = X.shape[0]
    random_indices = np.random.choice(n_samples, size=k, replace=False)
    centroids = X[random_indices]
    return centroids

In [None]:
# Function to assign clusters based on the nearest centroid
def assign_clusters(X, centroids):
    """
    Assign each sample to the nearest centroid.

In [None]:
    Parameters:
    X : np.ndarray
        The input dataset of shape (n_samples, n_features).
    centroids : np.ndarray
        Current centroids of shape (k, n_features).

In [None]:
    Returns:
    labels : np.ndarray
        Array of cluster labels for each sample, shape (n_samples,).
    """
    distances = np.linalg.norm(X[:, np.newaxis] - centroids, axis=2)  # Compute distances to centroids
    labels = np.argmin(distances, axis=1)  # Assign clusters based on the nearest centroid
    return labels

In [None]:
# Function to update centroids based on the assigned clusters
def update_centroids(X, labels, k):
    """
    Update centroids to be the mean of the samples assigned to each cluster.

In [None]:
    Parameters:
    X : np.ndarray
        The input dataset of shape (n_samples, n_features).
    labels : np.ndarray
        Array of cluster labels for each sample, shape (n_samples,).
    k : int
        The number of clusters.

In [None]:
    Returns:
    centroids : np.ndarray
        Updated centroids of shape (k, n_features).
    """
    centroids = np.zeros((k, X.shape[1]))  # Initialize centroids array
    for i in range(k):
        if np.any(labels == i):
            centroids[i] = X[labels == i].mean(axis=0)  # Compute mean of the samples in each cluster
        else:
            centroids[i] = X[np.random.choice(X.shape[0])]  # Reinitialize centroid if cluster is empty
    return centroids

In [None]:
# Main K-means algorithm function
def kmeans(X, k, max_iters=100):
    """
    Perform K-means clustering.

In [None]:
    Parameters:
    X : np.ndarray
        The input dataset of shape (n_samples, n_features).
    k : int
        The number of clusters.
    max_iters : int
        The maximum number of iterations to run the algorithm.

In [None]:
    Returns:
    centroids : np.ndarray
        Final centroids of shape (k, n_features).
    labels : np.ndarray
        Array of cluster labels for each sample, shape (n_samples,).
    """
    centroids = initialize_centroids(X, k)  # Step 1: Initialize centroids
    for _ in range(max_iters):
        labels = assign_clusters(X, centroids)  # Step 2: Assign clusters
        new_centroids = update_centroids(X, labels, k)  # Step 3: Update centroids
        # Check for convergence: if centroids do not change, break the loop
        if np.all(centroids == new_centroids):
            break
        centroids = new_centroids  # Update centroids for next iteration
    return centroids, labels

## 5. Tokenizer Training

In [None]:
# Section 5: Tokenizer Training

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
from sklearn.cluster import KMeans

In [None]:
# Assuming features and labels are already prepared from previous steps
# features: a tensor of shape (num_samples, feature_dim)
# labels: the cluster labels obtained from k-means (as long tensor)

In [None]:
# Hyperparameters
learning_rate = 0.001
num_epochs = 100
batch_size = 32

In [None]:
# Create a TensorDataset and DataLoader for training
dataset = TensorDataset(features, labels)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

In [None]:
# Define a simple neural network for the tokenizer training
class TokenizerNN(nn.Module):
    def __init__(self, input_dim, num_clusters):
        super(TokenizerNN, self).__init__()
        self.fc = nn.Linear(input_dim, num_clusters)  # Map features to cluster space

In [None]:
    def forward(self, x):
        return self.fc(x)

In [None]:
# Initialize the model, loss function, and optimizer
num_clusters = labels.max().item() + 1  # Number of clusters from k-means
input_dim = features.shape[1]  # Feature dimension

In [None]:
model = TokenizerNN(input_dim, num_clusters)
criterion = nn.CrossEntropyLoss()  # Loss function for multi-class classification
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

In [None]:
# Training loop
for epoch in range(num_epochs):
    model.train()  # Set the model to training mode
    running_loss = 0.0

In [None]:
    for batch_features, batch_labels in dataloader:
        optimizer.zero_grad()  # Zero the gradients
        outputs = model(batch_features)  # Forward pass
        loss = criterion(outputs, batch_labels)  # Compute loss
        loss.backward()  # Backward pass
        optimizer.step()  # Update weights
        
        running_loss += loss.item() * batch_features.size(0)  # Accumulate loss

In [None]:
    epoch_loss = running_loss / len(dataset)  # Average loss for the epoch
    print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {epoch_loss:.4f}')

In [None]:
# Save the trained model
torch.save(model.state_dict(), 'tokenizer_model.pth')
print("Tokenizer model trained and saved.")

## 6. Evaluation

In [None]:
# Section 6: Evaluation

In [None]:
# Import necessary libraries
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import accuracy_score, silhouette_score
from sklearn.cluster import KMeans
from sklearn.decomposition import PCA
from scipy.stats import mode

In [None]:
# Assuming `X_test` is the test feature set obtained from the tokenizer
# and `y_test` is the true label set for evaluation.

In [None]:
# Load the trained tokenizer's output on the test dataset
# For example, this could be the output of the WavLM model
# Here, we just assume `X_test` and `y_test` are already defined.

In [None]:
# Perform clustering on the test dataset using k-means
n_clusters = 10  # Set the number of clusters based on the training phase
kmeans = KMeans(n_clusters=n_clusters, random_state=42)
y_pred = kmeans.fit_predict(X_test)

In [None]:
# Evaluate the clustering performance using accuracy
# Map the predicted clusters to the true labels
# Create a mapping from cluster labels to true labels
def map_clusters_to_labels(y_true, y_pred):
    mapping = {}
    for label in np.unique(y_pred):
        mask = (y_pred == label)
        mapping[label] = mode(y_true[mask])[0][0]  # Most common true label in the cluster
    return mapping

In [None]:
# Create a mapping from predicted cluster labels to true labels
label_mapping = map_clusters_to_labels(y_test, y_pred)

In [None]:
# Map predicted labels to true labels
y_pred_mapped = np.array([label_mapping[label] for label in y_pred])

In [None]:
# Calculate accuracy
accuracy = accuracy_score(y_test, y_pred_mapped)
print(f"Accuracy of the tokenizer: {accuracy * 100:.2f}%")

In [None]:
# Calculate silhouette score for clustering quality
silhouette_avg = silhouette_score(X_test, y_pred)
print(f"Silhouette Score: {silhouette_avg:.4f}")

In [None]:
# Visualization of clustering results using PCA
pca = PCA(n_components=2)
X_pca = pca.fit_transform(X_test)

In [None]:
# Plotting the clusters
plt.figure(figsize=(10, 7))
plt.scatter(X_pca[:, 0], X_pca[:, 1], c=y_pred, cmap='viridis', alpha=0.5)
plt.title('PCA of Test Set with K-Means Clustering')
plt.xlabel('Principal Component 1')
plt.ylabel('Principal Component 2')
plt.colorbar(label='Cluster Label')
plt.grid()
plt.show()

## 7. Saving the Model

In [None]:
import os
import pickle

In [None]:
model_save_path = 'trained_tokenizer.pkl'

In [None]:
def save_model(tokenizer, save_path):
    """
    Saves the trained tokenizer model to a specified path using pickle.
    
    Parameters:
    tokenizer: The trained tokenizer model to be saved.
    save_path: The file path where the model should be saved.
    """
    with open(save_path, 'wb') as f:
        pickle.dump(tokenizer, f)
    print(f"Model saved to {save_path}")

In [None]:
# Assuming 'tokenizer' is defined and is a trained tokenizer model
if 'tokenizer' in locals():
    save_model(tokenizer, model_save_path)

In [None]:
    if os.path.exists(model_save_path):
        print("Model saved successfully!")
    else:
        print("Model saving failed.")
else:
    print("Tokenizer model is not defined.")

## 8. Inference

In [None]:
# Section 8: Inference
# This section demonstrates how to use the trained tokenizer on new audio data.
# We will visualize the tokenization process using the trained model.

In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
import librosa
import librosa.display

In [None]:
# Load the trained tokenizer model
# Assuming 'tokenizer' is the trained tokenizer object and 'wavlm_model' is the WavLM model
# (These should be defined in previous sections of the notebook)

In [None]:
# Uncomment and define these variables in your actual code
# tokenizer = ...  # Load your trained tokenizer
# wavlm_model = ...  # Load your WavLM model

In [None]:
def load_audio(file_path):
    """Load an audio file and return the waveform and sample rate."""
    waveform, sample_rate = librosa.load(file_path, sr=None)  # Load audio file
    return waveform, sample_rate

In [None]:
def visualize_waveform(waveform, sample_rate):
    """Plot the audio waveform."""
    plt.figure(figsize=(12, 4))
    librosa.display.waveshow(waveform, sr=sample_rate)
    plt.title('Audio Waveform')
    plt.xlabel('Time (s)')
    plt.ylabel('Amplitude')
    plt.show()

In [None]:
def tokenize_audio(waveform):
    """Tokenize the audio waveform using the trained tokenizer."""
    with torch.no_grad():  # Disable gradient calculation
        input_tensor = torch.FloatTensor(waveform).unsqueeze(0)  # Add batch dimension
        tokens = tokenizer(input_tensor)  # Tokenize the audio
    return tokens

In [None]:
def visualize_tokens(tokens):
    """Visualize the tokenized output as a sequence."""
    plt.figure(figsize=(12, 4))
    plt.imshow(tokens.cpu().numpy(), aspect='auto', cmap='viridis')  # Ensure tokens are on CPU
    plt.title('Tokenized Output')
    plt.xlabel('Time Steps')
    plt.ylabel('Token Indices')
    plt.colorbar(label='Token Value')
    plt.show()

In [None]:
# Example usage:
audio_file_path = 'path/to/new/audio/file.wav'  # Replace with the path to your audio file
waveform, sample_rate = load_audio(audio_file_path)  # Load the new audio file

In [None]:
# Visualize the waveform
visualize_waveform(waveform, sample_rate)

In [None]:
# Tokenize the audio waveform
tokens = tokenize_audio(waveform)

In [None]:
# Visualize the tokenization process
visualize_tokens(tokens)