In [8]:
import pandas as pd
from sklearn.model_selection import train_test_split
import re
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC
from tqdm.auto import tqdm
import torchaudio
from dataclasses import dataclass, field
import os

In [9]:
ROOT_DIR = '../data'
TRAIN_INDICES = pd.read_csv(f'{ROOT_DIR}/cv-valid-train.csv')
TRAIN_DATA_BASE = f"{ROOT_DIR}/cv-valid-train"

TEST_INDICES = pd.read_csv(f'{ROOT_DIR}/cv-valid-test.csv')
TEST_DATA_BASE = f"{ROOT_DIR}/cv-valid-test"

In [10]:
TRAIN_INDICES

Unnamed: 0,filename,text,up_votes,down_votes,age,gender,accent,duration
0,cv-valid-train/sample-000000.mp3,learn to recognize omens and follow them the o...,1,0,,,,
1,cv-valid-train/sample-000001.mp3,everything in the universe evolved he said,1,0,,,,
2,cv-valid-train/sample-000002.mp3,you came so that you could learn about your dr...,1,0,,,,
3,cv-valid-train/sample-000003.mp3,so now i fear nothing because it was those ome...,1,0,,,,
4,cv-valid-train/sample-000004.mp3,if you start your emails with greetings let me...,3,2,,,,
...,...,...,...,...,...,...,...,...
195771,cv-valid-train/sample-195771.mp3,the englishman said nothing,1,0,thirties,male,england,
195772,cv-valid-train/sample-195772.mp3,the irish man sipped his tea,1,0,,,,
195773,cv-valid-train/sample-195773.mp3,what do you know about that,1,0,,,,
195774,cv-valid-train/sample-195774.mp3,the phone rang while she was awake,2,0,twenties,male,us,


## Dataset and Data Collator

1. Load the audio from corresponding path
2. resample the audio to 16kHz
3. Ensure the audio is mono channel, if not, take the mean of the channels
4. using wav2vec processor to process the audio
5. remove all punctuation from the transcription (match the original training data style)

In [11]:
class AudioDataset(Dataset):
    def __init__(self, dataframe, processor, audio_base_path):
        self.dataframe = dataframe
        self.processor = processor
        self.audio_base_path = audio_base_path
        self.chars_to_ignore_regex = '[\,\?\.\!\-\;\:\"“%”\\\]'

    def __len__(self):
        return len(self.dataframe)

    def __getitem__(self, idx):
        row = self.dataframe.iloc[idx]
        audio_filename = row['filename']
        text = row['text']

        # Construct full audio file path
        audio_path = os.path.join(self.audio_base_path, audio_filename)

        # Load audio file
        try:
            speech_array, sampling_rate = torchaudio.load(audio_path)
        except Exception as e:
            print(f"Error loading audio file {audio_path}: {e}")
            return None # Return None for problematic samples, will be filtered by collator

        # Resample if necessary (Wav2Vec2 expects 16kHz)
        if sampling_rate != 16000:
            resampler = torchaudio.transforms.Resample(orig_freq=sampling_rate, new_freq=16000)
            speech_array = resampler(speech_array)
        
        # Ensure mono audio
        if speech_array.shape[0] > 1:
            speech_array = speech_array.mean(dim=0, keepdim=True)

        # Process audio (normalize, etc.)
        # Squeeze to remove the channel dimension (e.g., from (1, samples) to (samples,))
        input_values = self.processor(speech_array.squeeze(0).numpy(), sampling_rate=16000).input_values[0]

        # Clean and tokenize text
        text = re.sub(self.chars_to_ignore_regex, '', text).lower()
    
        return {"input_values": input_values, "labels": text}

In [12]:
class DataCollatorCTCWithPadding:
    def __init__(self, processor, padding=True):
        self.processor = processor
        self.padding = padding

    def __call__(self, features):
        # Filter out None values from __getitem__
        features = [f for f in features if f is not None]
        if not features: # If all features were None, return empty batch
            return None

        # Separate input_values and labels
        input_features = [feature["input_values"] for feature in features]
        label_features = [feature["labels"][0] for feature in features]
   
        batch = self.processor(
            input_features,
            padding=self.padding,
            sampling_rate=16000,
            return_tensors="pt",
        )
  
        labels_batch = self.processor(
            text = label_features,
            padding=self.padding,
            return_tensors="pt",
        )

        labels = labels_batch["input_ids"].masked_fill(
            labels_batch.attention_mask.ne(1), -100
        )

        batch["labels"] = labels

        return batch

In [13]:
def train():
    MODEL_NAME = "facebook/wav2vec2-large-960h"
    processor = Wav2Vec2Processor.from_pretrained(MODEL_NAME, sampling_rate=16000)
    model = Wav2Vec2ForCTC.from_pretrained(MODEL_NAME)
    RANDOM_STATE = 42 # to ensure reproducibility
    TRAIN_SPLIT_RATIO = 0.7
    LEARNING_RATE = 1e-4
    BATCH_SIZE = 1 # Keep batch size as 1 for simplicity in this example
    NUM_EPOCHS = 1
    LOG_STEP = 200

    # Lists to store losses
    train_losses = []
    val_losses = []
    log_steps_recorded = [] # To keep track of which steps the losses correspond to

    train_indices, val_indices = train_test_split(
        TRAIN_INDICES, test_size=1-TRAIN_SPLIT_RATIO, random_state=RANDOM_STATE
    )

    print(f"Train samples: {len(train_indices)}, Validation samples: {len(val_indices)}")

    # Freeze feature extractor layers
    model.freeze_feature_encoder()

    # Move model to GPU if available
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    print(f"Using device: {device}")

    train_dataset = AudioDataset(train_indices, processor, TRAIN_DATA_BASE)
    val_dataset = AudioDataset(val_indices, processor, TRAIN_DATA_BASE)

    data_collator = DataCollatorCTCWithPadding(processor=processor, padding=True)

    train_dataloader = DataLoader(
        train_dataset,
        batch_size=BATCH_SIZE,
        collate_fn=data_collator,
        num_workers=0,
        pin_memory=True,
        shuffle=True
    )
    val_dataloader = DataLoader(
        val_dataset,
        batch_size=BATCH_SIZE,
        collate_fn=data_collator,
        num_workers=0,
        pin_memory=True,
        shuffle=False
    )

    optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE)

    print("\nStarting training...")
    for epoch in range(NUM_EPOCHS):
        # --- Training Phase ---
        model.train()
        total_train_loss_epoch = 0 # This will still track the total loss for the epoch average

        for i, batch in enumerate(tqdm(train_dataloader, desc=f"Epoch {epoch+1} Training")):
            batch = {k: v.to(device) for k, v in batch.items()}
            outputs = model(input_values=batch["input_values"], labels=batch["labels"])
            loss = outputs.loss
            loss.backward()

            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
            optimizer.zero_grad()

            total_train_loss_epoch += loss.item() # Accumulate for epoch average

            # --- Logging and Validation Phase (Every LOG_STEP) ---
            if (i + 1) % LOG_STEP == 0:
                # Log the loss of the current batch/step directly
                current_batch_train_loss = loss.item()
                train_losses.append(current_batch_train_loss)
                
                model.eval() # Set model to evaluation mode
                total_val_loss = 0
                with torch.no_grad():
                    for val_batch in tqdm(val_dataloader, desc=f"Step {i+1} Validation"):
                        val_batch = {k: v.to(device) for k, v in val_batch.items()}
                        outputs = model(input_values=val_batch["input_values"], labels=val_batch["labels"])
                        val_loss = outputs.loss
                        total_val_loss += val_loss.item()
                avg_val_loss = total_val_loss / len(val_dataloader)
                val_losses.append(avg_val_loss)
                log_steps_recorded.append(i + 1)
                
                print(f"Step {i+1} - Train Loss (current batch): {current_batch_train_loss:.4f} - Validation Loss: {avg_val_loss:.4f}")
                model.train() # Set model back to training mode

        avg_train_loss_epoch = total_train_loss_epoch / len(train_dataloader)
        print(f"Epoch {epoch+1} - Average Epoch Train Loss: {avg_train_loss_epoch:.4f}")

    print("\nTraining complete!")

    return train_losses, val_losses, log_steps_recorded

In [14]:
train()



Ignored unknown kwarg option normalize
Ignored unknown kwarg option normalize
Ignored unknown kwarg option normalize
Ignored unknown kwarg option normalize


KeyboardInterrupt: 