In [1]:
import torch
import torchaudio
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
import soundfile as sf
import numpy as np
from ctc_segmentation import ctc_segmentation, prepare_text, CtcSegmentationParameters
import librosa
import pandas as pd
import os

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
df = pd.read_csv("/blue/ufdatastudios/ahmed.waseem/ctc/meta_speaker.csv")

In [3]:
df = df[df['duration'].apply(lambda x: x >= 5)]
audio_folder = "/blue/ufdatastudios/ahmed.waseem/processed_audio"
df["audio_filepath"] = df["audio_filepath"].apply(lambda x: os.path.join(audio_folder, x))



In [21]:
df.head()

Unnamed: 0,audio_filepath,duration,text,gender,age-group,primary_language,native_place_state,native_place_district,highest_qualification,job_category,occupation_domain
3852,/blue/ufdatastudios/ahmed.waseem/processed_aud...,5.0,and they went off into the space when earth wa...,Female,45-60,Telugu,Odisha,Sambalpur,Graduate,Full Time,Education and Research
3853,/blue/ufdatastudios/ahmed.waseem/processed_aud...,5.0,I should say addictive because once I started ...,Female,45-60,Telugu,Odisha,Sambalpur,Graduate,Full Time,Education and Research
3854,/blue/ufdatastudios/ahmed.waseem/processed_aud...,5.0,and with the kind of dishes that you have in I...,Female,18-30,Marathi,Maharashtra,Ratnagiri,Graduate,Part Time,Social service
3855,/blue/ufdatastudios/ahmed.waseem/processed_aud...,5.0,"So in order to bring all these back, in order ...",Female,18-30,Malayalam,Kerala,Pathanamthitta,Post Graduate,Full Time,Education and Research
3856,/blue/ufdatastudios/ahmed.waseem/processed_aud...,5.0,It's a well-known fact that India is a country...,Female,30-45,Hindi,Goa,South Goa,Post Graduate,Full Time,Technology and Services


In [8]:
df.columns

Index(['audio_filepath', 'duration', 'text', 'gender', 'age-group',
       'primary_language', 'native_place_state', 'native_place_district',
       'highest_qualification', 'job_category', 'occupation_domain'],
      dtype='object')

In [5]:
audio_filepaths = df['audio_filepath'].tolist()
transcripts = df['text'].tolist()
text = df['text'].tolist()

In [45]:
model_name = "facebook/wav2vec2-large-960h"
processor = Wav2Vec2Processor.from_pretrained(model_name)
model = Wav2Vec2ForCTC.from_pretrained(model_name).to("cuda" if torch.cuda.is_available() else "cpu")
model.eval()


Some weights of Wav2Vec2ForCTC were not initialized from the model checkpoint at facebook/wav2vec2-large-960h and are newly initialized: ['wav2vec2.masked_spec_embed']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Wav2Vec2ForCTC(
  (wav2vec2): Wav2Vec2Model(
    (feature_extractor): Wav2Vec2FeatureEncoder(
      (conv_layers): ModuleList(
        (0): Wav2Vec2GroupNormConvLayer(
          (conv): Conv1d(1, 512, kernel_size=(10,), stride=(5,), bias=False)
          (activation): GELUActivation()
          (layer_norm): GroupNorm(512, 512, eps=1e-05, affine=True)
        )
        (1-4): 4 x Wav2Vec2NoLayerNormConvLayer(
          (conv): Conv1d(512, 512, kernel_size=(3,), stride=(2,), bias=False)
          (activation): GELUActivation()
        )
        (5-6): 2 x Wav2Vec2NoLayerNormConvLayer(
          (conv): Conv1d(512, 512, kernel_size=(2,), stride=(2,), bias=False)
          (activation): GELUActivation()
        )
      )
    )
    (feature_projection): Wav2Vec2FeatureProjection(
      (layer_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
      (projection): Linear(in_features=512, out_features=1024, bias=True)
      (dropout): Dropout(p=0.0, inplace=False)
    )
    (encoder

In [46]:
# Initialize CTC segmentation parameters with skip_prob adjustment
config = CtcSegmentationParameters()
config.blank = processor.tokenizer.pad_token_id
config.skip_prob = 1.0  # Allows skipping parts of the text if audio is shorter
config.max_prob = 0.95  # Adjust as needed based on segmentation tolerance

: 

In [43]:
# Function to get logits from the audio file
def get_logits(audio_path):
    waveform, sample_rate = torchaudio.load(audio_path)
    
    # Resample if necessary
    if sample_rate != 16000:
        waveform = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(waveform)
    
    # Ensure waveform is 2D with shape [1, sequence_length]
    if waveform.ndim == 1:
        waveform = waveform.unsqueeze(0)  # Add a channel dimension if missing
    
    # Pass waveform as a list to maintain the batch dimension (1, sequence_length)
    inputs = processor(waveform.tolist(), sampling_rate=16000, return_tensors="pt", padding=True)
    
    # Get logits from model
    with torch.no_grad():
        logits = model(inputs.input_values.to(model.device)).logits.cpu()
    
    # Confirm logits are 2D with shape [time_steps, vocab_size]
    if logits.ndim == 3:
        logits = logits.squeeze(0)  # Remove the batch dimension if it exists and only if it’s a 3D tensor
    
    return logits

In [44]:

# Loop over DataFrame to process each audio file and align with CTC segmentation
aligned_data = []
for _, row in df.iterrows():
    audio_path = row["audio_filepath"]
    text = row["text"]
    
    # Generate logits
    logits = get_logits(audio_path)
    
    # Convert logits to log probabilities (required for CTC segmentation)
    lpz = torch.nn.functional.log_softmax(logits, dim=-1).numpy()
    
    # Tokenize the text and prepare ground truth
    ground_truth = processor.tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=512).input_ids[0].tolist()
    
    # Run CTC segmentation
    timing, char_list = ctc_segmentation(config, lpz, ground_truth)
    
    # Aggregate subword alignments to approximate word-level alignments
    word_alignments = []
    current_word = ""
    word_start_time = None
    
    for idx, (start, end) in enumerate(zip(timing["begin"], timing["end"])):
        subword = processor.tokenizer.decode([ground_truth[idx]])
        
        # Start a new word if the previous one ended
        if subword.strip() != "":
            if current_word == "":
                word_start_time = start
            
            current_word += subword
            
            # If the subword completes a word (based on space in text)
            if " " in text[len(" ".join([w["word"] for w in word_alignments])) :].strip().split(" ", 1)[0]:
                word_alignments.append({
                    "word": current_word.strip(),
                    "start_time": word_start_time,
                    "end_time": end
                })
                current_word = ""
                word_start_time = None
    
    # Append to aligned_data with filename reference
    aligned_data.append({"audio_filepath": audio_path, "text": text, "word_alignments": word_alignments})

# Save the aligned data in a new DataFrame
aligned_df = pd.DataFrame(aligned_data)

# Display the aligned DataFrame
aligned_df

ValueError: Buffer has wrong number of dimensions (expected 2, got 1)

["'", '<', '>', '[', ']', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z']


In [33]:
print("Processor vocabulary size:", len(processor.tokenizer))

Processor vocabulary size: 32


Some weights of Wav2Vec2ForCTC were not initialized from the model checkpoint at facebook/wav2vec2-large-960h and are newly initialized: ['wav2vec2.masked_spec_embed']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


ValueError: not enough values to unpack (expected 3, got 2)

In [30]:
print("Text length:", len(prepared_text))
print("Logits length:", logits_list[2].shape)

Text length: 222198
Logits length: torch.Size([1, 249, 32])


tensor([[[  2.6421, -17.2083, -16.9143,  ...,  -3.9510,  -4.9777,  -3.8921],
         [  8.2383, -24.9574, -24.4365,  ...,  -7.0936,  -7.7027,  -3.9959],
         [  9.4998, -26.5420, -26.2310,  ...,  -7.4192,  -6.5949,  -5.1694],
         ...,
         [  6.4110, -21.9786, -21.2738,  ...,  -5.9054,  -3.6497,  -5.7050],
         [  7.6090, -22.5000, -21.7588,  ...,  -7.2212,  -4.0831,  -6.0867],
         [  6.8192, -22.1714, -21.4689,  ...,  -6.6844,  -3.8365,  -5.9234]]])