## Import Libraries

In [29]:
import os
import torch
import pandas as pd
import torchaudio
import pandas as pd
import torch
from transformers import WhisperProcessor, WhisperForConditionalGeneration, WhisperConfig, WhisperFeatureExtractor
from jiwer import wer
from tqdm import tqdm
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split

In [30]:
# Set device
# device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
device = torch.device("cpu")
print(device)

cpu


## Utils

In [32]:
def create_file_mapping(transcript_dir, audio_dir):
    """
    Creates a DataFrame mapping transcript files to their corresponding audio files based on filenames.
    
    Args:
        transcript_dir (str): Directory containing transcript text files.
        audio_dir (str): Directory containing audio files.
    
    Returns:
        pd.DataFrame: DataFrame with columns 'Interview', 'Transcript Path', and 'Audio Path'.
    """
    # Get lists of file paths
    transcript_files = [f for f in os.listdir(transcript_dir) if f.endswith('.txt')]
    audio_files = [f for f in os.listdir(audio_dir) if f.endswith('.wav')]
    
    # Extract file identifiers from filenames
    transcript_ids = {os.path.splitext(f)[0] for f in transcript_files}
    audio_ids = {os.path.splitext(f)[0] for f in audio_files}
    
    # Determine the intersection of transcript and audio IDs
    common_ids = transcript_ids.intersection(audio_ids)
    
    # Create the mapping dictionary
    file_mapping = {
        'Interview': [],
        'Transcript Path': [],
        'Audio Path': []
    }
    
    for file_id in common_ids:
        transcript_path = os.path.join(transcript_dir, file_id + '.txt')
        audio_path = os.path.join(audio_dir, file_id + '.wav')
        file_mapping['Interview'].append(file_id)
        file_mapping['Transcript Path'].append(transcript_path)
        file_mapping['Audio Path'].append(audio_path)
    
    # Convert the dictionary to a DataFrame
    return pd.DataFrame(file_mapping)


def process_transcripts(paths_df):
    """
    Processes transcript files to extract and format relevant data, and combines them into a single DataFrame.
    
    Args:
        paths_df (pd.DataFrame): DataFrame with columns 'Interview', 'Transcript Path', and 'Audio Path'.
    
    Returns:
        pd.DataFrame: DataFrame containing concatenated and processed transcript data.
    """
    all_transcript_data = []
    
    for _, row in paths_df.iterrows():
        transcript_path = row['Transcript Path']
        
        # Load transcript data with appropriate delimiter and columns
        transcript_df = pd.read_csv(transcript_path, delimiter="\t", usecols=['StTime', 'EnTime', 'Content'])
        
        # Convert times from minutes to milliseconds
        transcript_df['StTime'] *= 1000
        transcript_df['EnTime'] *= 1000
        
        # Standardize content to uppercase
        transcript_df['Content'] = transcript_df['Content'].str.upper()
        
        # Add the identifier column
        transcript_df['Interview'] = row['Interview']
        
        # Append to the list
        all_transcript_data.append(transcript_df)
    
    # Combine all transcript data into a single DataFrame
    return pd.concat(all_transcript_data, ignore_index=True)


In [33]:
class AudioDataset(Dataset):
    def __init__(self, df, transcript_dir, audio_dir, processor, target_sample_rate=16000, target_length_ms=15000, padding_value=0):
        """
        Audio Dataset with Padding for ASR tasks.

        Parameters:
        - df (pd.DataFrame): DataFrame containing the dataset information.
        - transcript_dir (str): Directory containing transcript files.
        - audio_dir (str): Directory containing audio files.
        - processor: Processor for audio and text processing.
        - target_sample_rate (int): Target sample rate for audio segments.
        - target_length_ms (int): Target length for audio segments in milliseconds.
        - padding_value (float): Value to use for padding audio segments.
        """
        self.df = df
        self.transcript_dir = transcript_dir
        self.audio_dir = audio_dir
        self.processor = processor
        self.target_sample_rate = target_sample_rate
        self.target_length_samples = int(target_length_ms * target_sample_rate / 1000)
        self.padding_value = padding_value

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        start_tm = row['StTime']
        end_tm = row['EnTime']
        content = row['Content']
        interview = row['Interview']
        audio_path = os.path.join(self.audio_dir, interview + '.wav')
        
        # Extract and process audio segment
        audio_segment = self.extract_audio_segment(audio_path, start_tm, end_tm)
        audio_tensor = self.processor(audio_segment.squeeze().numpy(), sampling_rate=self.target_sample_rate, return_tensors="pt").input_features
        
        # Tokenize and pad transcription
        encodings = self.processor.tokenizer(
            content,
            return_tensors="pt",
            padding='max_length',
            truncation=True,
            max_length=448
        )
        
        labels = encodings.input_ids.squeeze()
        attention_mask = encodings.attention_mask.squeeze()
        
        return {
            'input_features': audio_tensor.squeeze(),
            'labels': labels,
            'attention_mask': attention_mask
        }
    
    def extract_audio_segment(self, audio_file_path, start_time_ms, end_time_ms):
        """
        Extracts and pads a segment from an audio file based on start and end times.
        
        Parameters:
        - audio_file_path (str): Path to the audio file.
        - start_time_ms (int): Start time in milliseconds.
        - end_time_ms (int): End time in milliseconds.
        
        Returns:
        - torch.Tensor: Extracted and padded audio segment.
        """
        waveform, sr = torchaudio.load(audio_file_path)
        
        # Convert milliseconds to sample indices
        start_sample = int(start_time_ms * sr / 1000)
        end_sample = int(end_time_ms * sr / 1000)
        
        # Extract segment and resample if necessary
        segment = waveform[:, start_sample:end_sample]
        if sr != self.target_sample_rate:
            resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=self.target_sample_rate)
            segment = resampler(segment)
        
        # Pad or truncate segment
        return self.pad_audio_segment(segment)
    
    def pad_audio_segment(self, segment):
        """
        Pads or truncates an audio segment to the target length.
        
        Parameters:
        - segment (torch.Tensor): The audio segment to be padded.
        
        Returns:
        - torch.Tensor: Padded or truncated audio segment.
        """
        current_length = segment.size(1)
        if current_length < self.target_length_samples:
            padding_size = self.target_length_samples - current_length
            padding = torch.full((segment.size(0), padding_size), self.padding_value)
            padded_segment = torch.cat((segment, padding), dim=1)
        else:
            padded_segment = segment[:, :self.target_length_samples]
        return padded_segment

In [34]:
# Load your data
transcript_dir = '../data/coraal/transcript/text/'
audio_dir = '../data/coraal/audio/wav/'

paths_df = create_file_mapping(transcript_dir, audio_dir)
combined_transcript_df = process_transcripts(paths_df)
display(combined_transcript_df.sample(10))

Unnamed: 0,StTime,Content,EnTime,Interview
11640,2718918.0,(PAUSE 0.14),2719059.1,ATL_se0_ag1_m_03_1
20990,1489578.4,(PAUSE 1.50),1491083.3,ATL_se0_ag1_f_03_1
11291,2338974.0,(PAUSE 0.10),2339077.6,ATL_se0_ag1_m_03_1
1659,1837608.8,AND A BIG TOY FROG FOR THE KIDS.,1839870.8,ATL_se0_ag2_m_02_1
10048,907384.2,JORDAN WOULD WHOOP LEBRON ASS,908894.9,ATL_se0_ag1_m_03_1
2486,339232.3,"YOU WOULD KNOW THAT, SINCE I GOT PREGNANT.",340745.1,ATL_se0_ag2_f_02_1
18839,873271.2,YOU KIND OF GET WISDOM FROM THESE TYPE OF PEOP...,877731.1,ATL_se0_ag2_f_01_1
2383,193721.5,MAKING THAT ADJUSTMENT TO A DIFFERENT AREA?,196303.7,ATL_se0_ag2_f_02_1
7346,437000.2,WHAT THAT BOY NAME IS?,438250.0,ATL_se0_ag2_m_03_1
8114,1407805.9,[YOU MIGHT DROWN] IN THAT BITCH. [YOU MIGHT GE...,1411648.5,ATL_se0_ag2_m_03_1


In [35]:
# Filter out rows with unwanted characters in the 'Content' column
pattern = r'[\(\)\[\]/<>]'
filtered_transcript_df = combined_transcript_df[~combined_transcript_df['Content'].str.contains(pattern)].reset_index(drop=True)
display(filtered_transcript_df.sample(10))

Unnamed: 0,StTime,Content,EnTime,Interview
8541,1233479.3,THAT'S IT?,1234205.0,ATL_se0_ag1_f_03_1
3640,2471853.3,"I TRY TO DROP MY SHIT. EVEN IF IT JUST ONE TRACK,",2474208.1,ATL_se0_ag2_m_03_1
2627,1149491.3,THAT'S WHAT- THAT'S WHAT-,1150584.3,ATL_se0_ag1_m_02_1
1275,822792.2,WRESTLING?,823547.6,ATL_se0_ag2_f_02_1
8374,744923.1,THE LEAST IS LIKE TWO FIFTY.,746560.6,ATL_se0_ag1_f_03_1
2399,563400.2,IT WAS,564145.4,ATL_se0_ag1_m_02_1
6692,470958.9,NEEDLESS TO SAY I GOT KICKED OUT OF THE CHARTE...,472974.9,ATL_se0_ag1_f_02_1
1474,1444149.2,OF MY OWN AND STUFF LIKE THAT. LIKE,1446193.1,ATL_se0_ag2_f_02_1
1170,506861.4,"OKAY,",507333.4,ATL_se0_ag2_f_02_1
4847,1048828.3,"JJ WANTED TO, UM, WHERE ONE OF THEM DRESSES LI...",1051452.1,ATL_se0_ag1_f_01_1


In [36]:
# subset_size = int(len(filtered_transcript_df)*.25)
subset_size = 150
data_subset = filtered_transcript_df.sample(subset_size)
print(len(data_subset))

150


In [37]:
# Split the data into training and testing sets
# train_df, test_df = train_test_split(filtered_transcript_df, test_size=0.2, random_state=42)
train_df, test_df = train_test_split(data_subset, test_size=0.2, random_state=42)

In [38]:
print(f"Train Dataset: {len(train_df)}\t Test Dataset: {len(test_df)}")

Train Dataset: 120	 Test Dataset: 30


In [39]:
# Initialize processor and model
processor = WhisperProcessor.from_pretrained("openai/whisper-tiny.en")
config = WhisperConfig.from_pretrained("openai/whisper-tiny.en")
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en").to(device)

In [40]:
# Create datasets and dataloaders
train_dataset = AudioDataset(train_df, transcript_dir, audio_dir, processor)
test_dataset = AudioDataset(test_df, transcript_dir, audio_dir, processor)

train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=16, shuffle=False)

In [41]:
# Training function
def train(model, train_loader, processor, optimizer, scheduler, device):
    model.train()
    total_loss = 0
    true_transcriptions = []
    predicted_transcriptions = []
    
    for batch in tqdm(train_loader, total=len(train_loader), desc='Training'):
        inputs = batch['input_features'].to(device)
        labels = batch['labels'].to(device)

        outputs = model(input_features=inputs, labels=labels)
        loss = outputs.loss

        outputs = model.generate(input_features=inputs)
        transcriptions = processor.batch_decode(outputs, skip_special_tokens=True)
        true_transcriptions.extend(processor.batch_decode(labels, skip_special_tokens=True))
        predicted_transcriptions.extend(transcriptions)

        predicted_transcriptions = [x.upper() for x in predicted_transcriptions]
        wer_score = wer(true_transcriptions, predicted_transcriptions)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        scheduler.step()
        
        total_loss += loss.item()
    total_loss = total_loss / len(train_loader)
    
    return total_loss, wer_score

# Evaluation function
def test(model, test_loader, processor, device):
    model.eval()
    total_loss = 0
    true_transcriptions = []
    predicted_transcriptions = []
    
    with torch.no_grad():
        # for batch in tqdm(test_loader, total=len(test_loader), desc='Testing'):
        for batch in test_loader:
            inputs = batch['input_features'].to(device)
            labels = batch['labels'].to(device)

            outputs = model(input_features=inputs, labels=labels)
            loss = outputs.loss
            total_loss += loss.item()
            
            outputs = model.generate(input_features=inputs)
            transcriptions = processor.batch_decode(outputs, skip_special_tokens=True)
            true_transcriptions.extend(processor.batch_decode(labels, skip_special_tokens=True))
            predicted_transcriptions.extend(transcriptions)

    total_loss = total_loss / len(train_loader)
    predicted_transcriptions = [x.upper() for x in predicted_transcriptions]
    wer_score = wer(true_transcriptions, predicted_transcriptions)
    
    return total_loss, wer_score

In [42]:
# Fine-tuning
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-5)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.9)
num_epochs = 10
train_loss_per_epoch = []
train_wer_per_epoch = []

for epoch in range(num_epochs):
    train_loss, train_wer = train(model, train_loader, processor, optimizer, scheduler, device)
    print(f"Epoch {epoch+1}/{num_epochs}\t Training Loss: {train_loss:.4f}, Word Error Rate (WER): {train_wer:.4f}")
    train_loss_per_epoch.append(train_loss)
    train_wer_per_epoch.append(train_wer)

Epoch 1/10


Training: 100%|██████████| 8/8 [03:36<00:00, 27.02s/it]


Training Loss: 0.5724, Word Error Rate (WER): 0.3935
Epoch 2/10


Training: 100%|██████████| 8/8 [03:27<00:00, 25.90s/it]


Training Loss: 0.1297, Word Error Rate (WER): 0.3949
Epoch 3/10


Training: 100%|██████████| 8/8 [03:23<00:00, 25.47s/it]


Training Loss: 0.1216, Word Error Rate (WER): 0.4147
Epoch 4/10


Training: 100%|██████████| 8/8 [03:23<00:00, 25.44s/it]


Training Loss: 0.1208, Word Error Rate (WER): 0.4147
Epoch 5/10


Training: 100%|██████████| 8/8 [03:20<00:00, 25.09s/it]


Training Loss: 0.1191, Word Error Rate (WER): 0.4147
Epoch 6/10


Training: 100%|██████████| 8/8 [03:50<00:00, 28.76s/it]


Training Loss: 0.1192, Word Error Rate (WER): 0.4133
Epoch 7/10


Training: 100%|██████████| 8/8 [03:30<00:00, 26.27s/it]


Training Loss: 0.1196, Word Error Rate (WER): 0.4133
Epoch 8/10


Training: 100%|██████████| 8/8 [03:47<00:00, 28.50s/it]


Training Loss: 0.1178, Word Error Rate (WER): 0.4133
Epoch 9/10


Training: 100%|██████████| 8/8 [03:24<00:00, 25.52s/it]


Training Loss: 0.1187, Word Error Rate (WER): 0.4133
Epoch 10/10


Training: 100%|██████████| 8/8 [03:18<00:00, 24.81s/it]

Training Loss: 0.1209, Word Error Rate (WER): 0.4133





In [43]:
test_loss, test_wer = test(model, test_loader, processor, device)
print(f"Running inference...\t Test Loss: {test_loss:.4f}, Word Error Rate (WER): {test_wer:.4f}")

Running inference...	 Test Loss: 0.0305, Word Error Rate (WER): 0.4022


In [44]:
from datetime import datetime

timestamp = datetime.now().strftime('%Y-%m-%d_%H:%M:%S') # Format: YYYY-MM-DD_HH:MM:SS
weights_filepath = "../weights/whisper-weights_"

torch.save(model.state_dict(), weights_filepath+timestamp)
print(f"Model weights saved to {weights_filepath+timestamp}")

Model weights saved to ../weights/whisper-weights_2024-07-23_16:16:07
