In [17]:
import torch
import torchaudio
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
import soundfile as sf
import numpy as np
from ctc_segmentation import ctc_segmentation, prepare_text, CtcSegmentationParameters
import librosa
import pandas as pd
import os

In [19]:
df = pd.read_csv("/blue/ufdatastudios/ahmed.waseem/ctc/meta_speaker.csv")

In [20]:
df = df[df['duration'].apply(lambda x: x >= 5)]
audio_folder = "/blue/ufdatastudios/ahmed.waseem/processed_audio"
df["audio_filepath"] = df["audio_filepath"].apply(lambda x: os.path.join(audio_folder, x))



In [21]:
df.head()

Unnamed: 0,audio_filepath,duration,text,gender,age-group,primary_language,native_place_state,native_place_district,highest_qualification,job_category,occupation_domain
3852,/blue/ufdatastudios/ahmed.waseem/processed_aud...,5.0,and they went off into the space when earth wa...,Female,45-60,Telugu,Odisha,Sambalpur,Graduate,Full Time,Education and Research
3853,/blue/ufdatastudios/ahmed.waseem/processed_aud...,5.0,I should say addictive because once I started ...,Female,45-60,Telugu,Odisha,Sambalpur,Graduate,Full Time,Education and Research
3854,/blue/ufdatastudios/ahmed.waseem/processed_aud...,5.0,and with the kind of dishes that you have in I...,Female,18-30,Marathi,Maharashtra,Ratnagiri,Graduate,Part Time,Social service
3855,/blue/ufdatastudios/ahmed.waseem/processed_aud...,5.0,"So in order to bring all these back, in order ...",Female,18-30,Malayalam,Kerala,Pathanamthitta,Post Graduate,Full Time,Education and Research
3856,/blue/ufdatastudios/ahmed.waseem/processed_aud...,5.0,It's a well-known fact that India is a country...,Female,30-45,Hindi,Goa,South Goa,Post Graduate,Full Time,Technology and Services


In [8]:
df.columns

Index(['audio_filepath', 'duration', 'text', 'gender', 'age-group',
       'primary_language', 'native_place_state', 'native_place_district',
       'highest_qualification', 'job_category', 'occupation_domain'],
      dtype='object')

In [5]:
audio_filepaths = df['audio_filepath'].tolist()
transcripts = df['text'].tolist()
text = df['text'].tolist()

In [31]:
from transformers import AutoProcessor, AutoModelForCTC

model_name = "facebook/wav2vec2-large-960h"
processor = AutoProcessor.from_pretrained(model_name)
model = AutoModelForCTC.from_pretrained(model_name).to("cuda" if torch.cuda.is_available() else "cpu")

Some weights of Wav2Vec2ForCTC were not initialized from the model checkpoint at facebook/wav2vec2-large-960h and are newly initialized: ['wav2vec2.masked_spec_embed']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [14]:

# Initialize CTC segmentation parameters
config = CtcSegmentationParameters()
config.blank = processor.tokenizer.pad_token_id

In [36]:
# Function to get logits from the audio file
def get_logits(audio_path):
    waveform, sample_rate = torchaudio.load(audio_path)
    
    # Ensure the waveform is resampled if needed
    if sample_rate != 16000:
        waveform = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(waveform)
    
    # Remove any unnecessary dimensions and process the audio
    waveform = waveform.squeeze()  # Make sure waveform is 1D before feeding it to the processor
    
    # Pass waveform as a list to keep the batch dimension (1, sequence_length)
    inputs = processor(waveform.tolist(), sampling_rate=16000, return_tensors="pt", padding=True)
    
    # Get logits from model
    with torch.no_grad():
        logits = model(inputs.input_values.to(model.device)).logits.cpu()
    
    return logits

In [37]:

# Loop over DataFrame to process each audio file and align with CTC segmentation
aligned_data = []
for _, row in df.iterrows():
    audio_path = row["audio_filepath"]
    text = row["text"]
    
    # Generate logits
    logits = get_logits(audio_path)
    
    # Convert logits to log probabilities (required for CTC segmentation)
    lpz = torch.nn.functional.log_softmax(logits, dim=-1).numpy()
    
    # Tokenize the text and prepare ground truth
    ground_truth = processor.tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=512).input_ids[0].tolist()
    
    # Run CTC segmentation
    timing, char_list = ctc_segmentation(config, lpz, ground_truth)
    
    # Aggregate subword alignments to approximate word-level alignments
    word_alignments = []
    current_word = ""
    word_start_time = None
    
    for idx, (start, end) in enumerate(zip(timing["begin"], timing["end"])):
        subword = processor.tokenizer.decode([ground_truth[idx]])
        
        # Start a new word if the previous one ended
        if subword.strip() != "":
            if current_word == "":
                word_start_time = start
            
            current_word += subword
            
            # If the subword completes a word (based on space in text)
            if " " in text[len(" ".join([w["word"] for w in word_alignments])) :].strip().split(" ", 1)[0]:
                word_alignments.append({
                    "word": current_word.strip(),
                    "start_time": word_start_time,
                    "end_time": end
                })
                current_word = ""
                word_start_time = None
    
    # Append to aligned_data with filename reference
    aligned_data.append({"audio_filepath": audio_path, "text": text, "word_alignments": word_alignments})

# Save the aligned data in a new DataFrame
aligned_df = pd.DataFrame(aligned_data)

# Display the aligned DataFrame
aligned_df

AssertionError: Audio is shorter than text!

In [24]:
# Assuming your MFA dictionary is in a file called `mfa_dict.txt`
mfa_dict_path = "/home/ahmed.waseem/Documents/MFA/pretrained_models/dictionary/english_india_mfa.dict"

unique_chars = set()

# Read the MFA dictionary file
with open(mfa_dict_path, 'r', encoding='utf-8') as f:
    for line in f:
        # Split each line by whitespace to extract the word (first part)
        if line.strip():  # Make sure the line is not empty
            word = line.split()[0]  # The word is the first element before the phonemes
            # Add characters of the word to the set
            unique_chars.update(list(word))

# Convert the set to a sorted list to create a consistent char_list
char_list = sorted(unique_chars)

# Print the character list for verification
print(char_list)

["'", '<', '>', '[', ']', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z']


In [33]:
print("Processor vocabulary size:", len(processor.tokenizer))

Processor vocabulary size: 32


In [34]:
from ctc_segmentation import ctc_segmentation, CtcSegmentationParameters, prepare_text, determine_utterance_segments

# Load pre-trained Wav2Vec2 model and processor
model_name = "facebook/wav2vec2-large-960h"
processor = Wav2Vec2Processor.from_pretrained(model_name)
model = Wav2Vec2ForCTC.from_pretrained(model_name)

# Set the device to GPU if available, otherwise use CPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Move the model to the appropriate device
model.to(device)
model.eval()

# Directory to save resampled audio files
resampled_audio_folder = 'processed_audio'
os.makedirs(resampled_audio_folder, exist_ok=True)

# Function to resample and save audio files
def resample_and_save_audio(audio_path, target_sample_rate=16000):
    resampled_audio_path = os.path.join(resampled_audio_folder, os.path.basename(audio_path))
    if os.path.exists(resampled_audio_path):
        return resampled_audio_path
    
    waveform, sample_rate = torchaudio.load(audio_path)
    if sample_rate != target_sample_rate:
        resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=target_sample_rate)
        waveform = resampler(waveform)
    
    torchaudio.save(resampled_audio_path, waveform, target_sample_rate)
    return resampled_audio_path

# CTC segmentation setup
params = CtcSegmentationParameters()
char_list = processor.tokenizer.convert_ids_to_tokens(range(processor.tokenizer.vocab_size))

# Iterate over each row in the DataFrame
for index, row in df.iterrows():
    audio_path = row['audio_filepath']
    transcript = row['text']

    # Resample and save the audio file if necessary
    resampled_audio_path = resample_and_save_audio(audio_path)

    # Load the resampled audio
    waveform, _ = torchaudio.load(resampled_audio_path)
    waveform = waveform.squeeze()

    if waveform.ndim > 1:
        waveform = waveform[0]  # Take the first channel if stereo

    # Prepare input features for the Wav2Vec2 model
    inputs = processor(waveform, sampling_rate=16000, return_tensors="pt", padding=True)
    input_values = inputs.input_values.to(device)

    # Extract logits from the model without computing gradients
    with torch.no_grad():
        logits = model(input_values).logits

    # Prepare text for CTC segmentation
    prepared_text, char_list = prepare_text(params, [transcript], char_list)

    # Perform CTC segmentation
    params.char_list = char_list  # Update character list in segmentation parameters
    segmentation = ctc_segmentation(params, logits.squeeze(0).cpu().numpy(), prepared_text)

    # Print alignment details
    for start, end, word in segmentation:
        print(f"Word: '{word}', Start: {start:.2f}s, End: {end:.2f}s")

    print(f"CTC segmentation completed for file {index + 1}/{len(df)}: {audio_path}")

print("CTC segmentation completed for all audio files.")


Some weights of Wav2Vec2ForCTC were not initialized from the model checkpoint at facebook/wav2vec2-large-960h and are newly initialized: ['wav2vec2.masked_spec_embed']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


ValueError: not enough values to unpack (expected 3, got 2)

In [30]:
print("Text length:", len(prepared_text))
print("Logits length:", logits_list[2].shape)

Text length: 222198
Logits length: torch.Size([1, 249, 32])


tensor([[[  2.6421, -17.2083, -16.9143,  ...,  -3.9510,  -4.9777,  -3.8921],
         [  8.2383, -24.9574, -24.4365,  ...,  -7.0936,  -7.7027,  -3.9959],
         [  9.4998, -26.5420, -26.2310,  ...,  -7.4192,  -6.5949,  -5.1694],
         ...,
         [  6.4110, -21.9786, -21.2738,  ...,  -5.9054,  -3.6497,  -5.7050],
         [  7.6090, -22.5000, -21.7588,  ...,  -7.2212,  -4.0831,  -6.0867],
         [  6.8192, -22.1714, -21.4689,  ...,  -6.6844,  -3.8365,  -5.9234]]])