## Import Libraries

In [1]:
import os
import sys
import torch
import pandas as pd
from datetime import timedelta
import torchaudio
import pandas as pd
import torch
from transformers import WhisperProcessor, WhisperForConditionalGeneration, WhisperConfig
from jiwer import wer

In [3]:
# Check if MPS is available
# device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
device = torch.device("cpu")
print(device)

cpu


In [5]:
def create_file_mapping(transcript_dir, audio_dir):
    # Create lists of transcript and audio file paths
    transcript_paths = [os.path.join(transcript_dir, filename) for filename in os.listdir(transcript_dir) if filename.endswith('.txt')]
    audio_paths = [os.path.join(audio_dir, filename) for filename in os.listdir(audio_dir) if filename.endswith('.wav')]

    # Extract identifiers from filenames
    transcript_ids = [os.path.splitext(os.path.basename(path))[0] for path in transcript_paths]
    audio_ids = [os.path.splitext(os.path.basename(path))[0] for path in audio_paths]

    # Create a dictionary to map identifiers to file paths
    file_mapping = {
        'Interview': [],
        'Transcript Path': [],
        'Audio Path': []
    }

    # Match transcript files with audio files using their identifiers
    for transcript_id in transcript_ids:
        if transcript_id in audio_ids:
            transcript_path = os.path.join(transcript_dir, transcript_id + '.txt')
            audio_path = os.path.join(audio_dir, transcript_id + '.wav')
            file_mapping['Interview'].append(transcript_id)
            file_mapping['Transcript Path'].append(transcript_path)
            file_mapping['Audio Path'].append(audio_path)

    # Create DataFrame from the file mapping dictionary
    df = pd.DataFrame(file_mapping)
    
    return df


In [5]:
# Example usage
transcript_dir = '../data/coraal/transcript/text/'
audio_dir = '../data/coraal/audio/wav/'

paths_df = create_file_mapping(transcript_dir, audio_dir)
display(paths_df)

Unnamed: 0,Interview,Transcript Path,Audio Path
0,ATL_se0_ag2_f_02_1,../data/coraal/transcript/text/ATL_se0_ag2_f_0...,../data/coraal/audio/wav/ATL_se0_ag2_f_02_1.wav
1,ATL_se0_ag2_m_02_1,../data/coraal/transcript/text/ATL_se0_ag2_m_0...,../data/coraal/audio/wav/ATL_se0_ag2_m_02_1.wav
2,ATL_se0_ag2_f_01_1,../data/coraal/transcript/text/ATL_se0_ag2_f_0...,../data/coraal/audio/wav/ATL_se0_ag2_f_01_1.wav
3,ATL_se0_ag2_m_03_1,../data/coraal/transcript/text/ATL_se0_ag2_m_0...,../data/coraal/audio/wav/ATL_se0_ag2_m_03_1.wav
4,ATL_se0_ag2_m_01_1,../data/coraal/transcript/text/ATL_se0_ag2_m_0...,../data/coraal/audio/wav/ATL_se0_ag2_m_01_1.wav
5,ATL_se0_ag1_m_05_1,../data/coraal/transcript/text/ATL_se0_ag1_m_0...,../data/coraal/audio/wav/ATL_se0_ag1_m_05_1.wav
6,ATL_se0_ag1_m_03_1,../data/coraal/transcript/text/ATL_se0_ag1_m_0...,../data/coraal/audio/wav/ATL_se0_ag1_m_03_1.wav
7,ATL_se0_ag1_f_01_1,../data/coraal/transcript/text/ATL_se0_ag1_f_0...,../data/coraal/audio/wav/ATL_se0_ag1_f_01_1.wav
8,ATL_se0_ag1_f_03_1,../data/coraal/transcript/text/ATL_se0_ag1_f_0...,../data/coraal/audio/wav/ATL_se0_ag1_f_03_1.wav
9,ATL_se0_ag1_m_01_1,../data/coraal/transcript/text/ATL_se0_ag1_m_0...,../data/coraal/audio/wav/ATL_se0_ag1_m_01_1.wav


In [6]:
def process_transcripts(paths_df):
    all_transcript_data = []

    # Iterate over each row in the file_mapping_df
    for _, row in paths_df.iterrows():
        transcript_path = row['Transcript Path']
        
        # Load transcript data
        transcript_df = pd.read_csv(transcript_path, delimiter="\t")
        
        # Extract relevant columns
        transcript_data = transcript_df[['StTime', 'EnTime', 'Content']].copy()
        
        # Convert times from minutes to milliseconds for easier calculations
        transcript_data['StTime'] *= 1000
        transcript_data['EnTime'] *= 1000
        transcript_data['Content'] = transcript_data['Content'].str.upper()
        
        # Add identifier column (optional)
        transcript_data['Interview'] = row['Interview']

        # Append to list
        all_transcript_data.append(transcript_data)
    
    # Combine all transcript data into a single DataFrame
    combined_transcript_df = pd.concat(all_transcript_data, ignore_index=True)
    
    return combined_transcript_df

In [7]:
# Example usage
combined_transcript_df = process_transcripts(paths_df)
display(combined_transcript_df.head(10), combined_transcript_df.tail(10))

Unnamed: 0,StTime,EnTime,Content,Interview
0,752.6,2511.3,HEY WHAT'S GOING ON?,ATL_se0_ag2_f_02_1
1,2511.3,3144.7,(PAUSE 0.63),ATL_se0_ag2_f_02_1
2,3144.7,4165.9,I'M HERE WITH,ATL_se0_ag2_f_02_1
3,4165.9,5083.0,(PAUSE 0.92),ATL_se0_ag2_f_02_1
4,5083.0,5853.6,/RD-NAME-2/.,ATL_se0_ag2_f_02_1
5,5853.6,6312.2,(PAUSE 0.46),ATL_se0_ag2_f_02_1
6,6312.2,7844.7,THIS IS /RD-NAME-2/ BY THE WAY.,ATL_se0_ag2_f_02_1
7,7844.7,8799.6,(PAUSE 0.95),ATL_se0_ag2_f_02_1
8,8799.6,10241.5,"UM, /RD-NAME-2/, IF YOU COULD",ATL_se0_ag2_f_02_1
9,10241.5,10534.6,(PAUSE 0.29),ATL_se0_ag2_f_02_1


Unnamed: 0,StTime,EnTime,Content,Interview
23827,2423303.1,2424396.2,ABOUT YOUR FUTURE.,ATL_se0_ag1_m_02_1
23828,2424396.2,2425081.9,(PAUSE 0.69),ATL_se0_ag1_m_02_1
23829,2425081.9,2425860.4,DEFINITELY GONNA,ATL_se0_ag1_m_02_1
23830,2425937.8,2426912.2,I APPRECIATE [THAT.],ATL_se0_ag1_m_02_1
23831,2426726.6,2428514.3,"[BE IN] TOUCH AFTER THIS INTERVIEW,",ATL_se0_ag1_m_02_1
23832,2429019.6,2429658.9,"AND, UM,",ATL_se0_ag1_m_02_1
23833,2429658.9,2430942.7,(PAUSE 1.28),ATL_se0_ag1_m_02_1
23834,2430942.7,2432412.1,"IT'S /RD-NAME-3/, AND",ATL_se0_ag1_m_02_1
23835,2432412.1,2433113.3,(PAUSE 0.70),ATL_se0_ag1_m_02_1
23836,2433113.3,2434628.8,I'M SIGNING OFF. THIS IS A WRAP.,ATL_se0_ag1_m_02_1


In [8]:
# Define the regex pattern for unwanted characters
pattern = r'[\(\)\[\]/<>]'

# Filter out rows with unwanted characters in the 'Content' column
filtered_transcript_df = combined_transcript_df[~combined_transcript_df['Content'].str.contains(pattern)].reset_index(drop=True)

display(filtered_transcript_df.head(10), filtered_transcript_df.tail(10))

# import re

# # Define the regex pattern for unwanted characters and enclosed content
# pattern = r'\([^)]*\)|\[[^\]]*\]|<[^>]*>|\/[^\/]*\/'

# def remove_enclosed_content(text):
#     return re.sub(pattern, '', text)

# # Apply the function to remove enclosed content
# combined_transcript_df['Content'] = combined_transcript_df['Content'].apply(remove_enclosed_content)

# # Optionally, you can filter out rows with empty 'Content' after removal
# filtered_transcript_df = combined_transcript_df[combined_transcript_df['Content'].str.strip() != ''].reset_index(drop=True)

# display(filtered_transcript_df.head(10), filtered_transcript_df.tail(10))

Unnamed: 0,StTime,EnTime,Content,Interview
0,752.6,2511.3,HEY WHAT'S GOING ON?,ATL_se0_ag2_f_02_1
1,3144.7,4165.9,I'M HERE WITH,ATL_se0_ag2_f_02_1
2,10534.6,11289.9,SAY HELLO.,ATL_se0_ag2_f_02_1
3,12142.4,12846.2,HELLO.,ATL_se0_ag2_f_02_1
4,13458.1,13962.9,"OKAY,",ATL_se0_ag2_f_02_1
5,14227.6,15991.0,MIC SEEMS TO BE PICKING YOU UP WELL.,ATL_se0_ag2_f_02_1
6,16841.9,17230.4,"UM,",ATL_se0_ag2_f_02_1
7,17306.0,17618.0,MM.,ATL_se0_ag2_f_02_1
8,17646.4,18227.9,WE'LL GO,ATL_se0_ag2_f_02_1
9,18932.3,21206.2,"AND, UM, CONTINUE WITH THIS INTERVIEW.",ATL_se0_ag2_f_02_1


Unnamed: 0,StTime,EnTime,Content,Interview
9958,2405803.3,2407020.0,"THAT WAS DOPE AS WELL, MAN.",ATL_se0_ag1_m_02_1
9959,2409658.4,2411158.8,"OH MAN, WELL,",ATL_se0_ag1_m_02_1
9960,2415648.1,2417452.7,AND THE WORDS OF WISDOM,ATL_se0_ag1_m_02_1
9961,2417782.6,2418267.3,AND,ATL_se0_ag1_m_02_1
9962,2419437.7,2420386.3,"JUST, YOU KNOW,",ATL_se0_ag1_m_02_1
9963,2421273.1,2422947.4,EVERYTHING. I'M EXCITED,ATL_se0_ag1_m_02_1
9964,2423303.1,2424396.2,ABOUT YOUR FUTURE.,ATL_se0_ag1_m_02_1
9965,2425081.9,2425860.4,DEFINITELY GONNA,ATL_se0_ag1_m_02_1
9966,2429019.6,2429658.9,"AND, UM,",ATL_se0_ag1_m_02_1
9967,2433113.3,2434628.8,I'M SIGNING OFF. THIS IS A WRAP.,ATL_se0_ag1_m_02_1


In [9]:
data_subset = filtered_transcript_df.sample(50)
data_subset

Unnamed: 0,StTime,EnTime,Content,Interview
9105,118264.1,119268.1,"YOU KNOW, THAT- JUST THAT-",ATL_se0_ag1_m_02_1
1146,642318.4,642748.4,"IF NOT,",ATL_se0_ag2_m_02_1
1513,1643770.0,1644977.7,THE ENTIRE PASSAGE?,ATL_se0_ag2_m_02_1
9845,2082168.7,2084210.4,AND QUENCHED MY THIRST SO MANY TIMES.,ATL_se0_ag1_m_02_1
9385,814607.5,816630.0,"BUT, YOU KNOW, AROUND THAT- THAT CONCEPT.",ATL_se0_ag1_m_02_1
5852,1414333.1,1417254.7,SOMEBODY NEED TO PUT OBAMA ON THE TWENTY-FIVE ...,ATL_se0_ag1_f_01_1
3732,1569677.6,1571222.5,LIKE WHEN THEM S- KIDS USED TO COME,ATL_se0_ag2_m_01_1
6182,612226.9,616651.3,THAT YOU DIDN'T SEE FOR A WHILE COME AROUND. A...,ATL_se0_ag1_f_03_1
7623,358665.4,359614.1,"WHAT ABOUT, UM,",ATL_se0_ag1_m_04_2
6041,223694.1,224934.6,LET'S SEE,ATL_se0_ag1_f_03_1


In [42]:
def pad_audio_segment(segment, target_length, padding_value=0):
    """
    Pads an audio segment to the target length.
    
    Parameters:
    - segment (torch.Tensor): The audio segment to be padded.
    - target_length (int): The desired length in samples.
    - padding_value (float): The value to use for padding. Default is 0.
    
    Returns:
    - torch.Tensor: Padded audio segment.
    """
    current_length = segment.size(1)
    if current_length < target_length:
        padding_size = target_length - current_length
        padding = torch.full((segment.size(0), padding_size), padding_value)
        padded_segment = torch.cat((segment, padding), dim=1)
    else:
        padded_segment = segment[:, :target_length]  # Truncate if longer than target length
    return padded_segment

In [45]:
def extract_audio_segment(audio_file_path, start_time_ms, end_time_ms, padding_value=0):
    """
    Extracts and pads a segment from an audio file based on start and end times.
    
    Parameters:
    - audio_file_path (str): Path to the audio file.
    - start_time_ms (int): Start time in milliseconds.
    - end_time_ms (int): End time in milliseconds.
    - padding_value (float): The value to use for padding. Default is 0.
    
    Returns:
    - torch.Tensor: Extracted and padded audio segment.
    """
    # Load the audio file
    waveform, sr = torchaudio.load(audio_file_path)
    
    # Convert milliseconds to sample indices
    start_sample = int(start_time_ms * sr / 1000)
    end_sample = int(end_time_ms * sr / 1000)
    target_length_ms = 15000  # Target length of 15 seconds (15000 milliseconds)
    
    # Extract the segment
    segment = waveform[:, start_sample:end_sample]
    
    # Convert target length from milliseconds to samples
    target_length_samples = int(target_length_ms * sr / 1000)
    
    # Pad the segment
    padded_segment = pad_audio_segment(segment, target_length_samples, padding_value)
    
    return padded_segment

In [46]:
# Initialize a list to store extracted audio segments
audio_segments = []

for idx, row in data_subset.iterrows():
    start_tm = row['StTime']
    end_tm = row['EnTime']
    content = row['Content']
    interview = row['Interview']

    transcript_path = transcript_dir+interview+".txt"
    audio_path = audio_dir+interview+".wav"

    segment = extract_audio_segment(audio_path, start_tm, end_tm)
    
    # Optionally, save the segment or process it further
    audio_segments.append({
        'StTime': start_tm,
        'EnTime': end_tm,
        'Interview': interview,
        'Content': content,
        'Segment': segment
    })

audio_segments_subset_df = pd.DataFrame(audio_segments)
audio_segments_subset_df.head(10)

Unnamed: 0,StTime,EnTime,Interview,Content,Segment
0,1486802.6,1490400.2,ATL_se0_ag2_m_01_1,"YOU KNOW, I AIN'T KNOW WHAT FREAKNIK WAS. I JU...","[[tensor(-0.0004), tensor(-0.0003), tensor(-0...."
1,411784.0,413218.5,ATL_se0_ag2_f_02_1,DON'T REALLY THINK ABOUT HIM.,"[[tensor(-0.0026), tensor(-0.0025), tensor(-0...."
2,2101098.7,2102099.0,ATL_se0_ag1_m_03_1,"I MET HER,","[[tensor(0.0003), tensor(0.0004), tensor(0.000..."
3,2749747.4,2750289.2,ATL_se0_ag1_m_01_1,FOR SURE.,"[[tensor(-0.0004), tensor(-0.0005), tensor(-0...."
4,2445679.0,2446416.5,ATL_se0_ag1_m_01_1,"WOW,","[[tensor(0.0020), tensor(0.0020), tensor(0.001..."
5,359799.5,360743.1,ATL_se0_ag1_f_02_1,"SHUT UP, SPARKY.","[[tensor(-0.0014), tensor(-0.0014), tensor(-0...."
6,288361.0,289381.8,ATL_se0_ag1_m_04_1,"DEFINITELY, DEFINITELY.","[[tensor(-0.0005), tensor(-0.0005), tensor(-0...."
7,485938.8,486962.3,ATL_se0_ag1_f_01_1,"OH, OKAY.","[[tensor(-3.0518e-05), tensor(9.1553e-05), ten..."
8,274038.7,276081.3,ATL_se0_ag1_f_03_1,MIDDLE SCHOOL AND- AND HIGH SCHOOL HERE.,"[[tensor(0.0005), tensor(0.0005), tensor(0.000..."
9,1806153.9,1806682.6,ATL_se0_ag2_f_02_1,WOW.,"[[tensor(-0.0005), tensor(-0.0006), tensor(-0...."


In [47]:
for idx, row in audio_segments_subset_df.iterrows():
    if idx == 10:
        break
    print(row['Segment'].shape)

torch.Size([1, 661500])
torch.Size([1, 661500])
torch.Size([1, 661500])
torch.Size([1, 661500])
torch.Size([1, 661500])
torch.Size([1, 661500])
torch.Size([1, 661500])
torch.Size([1, 661500])
torch.Size([1, 661500])
torch.Size([1, 661500])


In [48]:
# Initialize processor and model for whisper-tiny
processor = WhisperProcessor.from_pretrained("openai/whisper-tiny.en")
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en")

In [49]:
# Move the model to the appropriate device
model = model.to(device)

In [50]:
def preprocess_audio(audio_tensor, sample_rate, target_sample_rate=16000):
    if sample_rate != target_sample_rate:
        resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=target_sample_rate)
        audio_tensor = resampler(audio_tensor)
    return audio_tensor

def transcribe_audio(audio_tensor, model, processor):
    audio_tensor = audio_tensor.to(device)
    inputs = processor(audio_tensor.squeeze().cpu().numpy(), return_tensors="pt", sampling_rate=16000)
    inputs = {k: v.to(device) for k, v in inputs.items()}
    with torch.no_grad():
        predicted_ids = model.generate(**inputs)
    transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)
    return transcription[0]

In [51]:
# Initialize lists for metrics
true_transcriptions = []
predicted_transcriptions = []

sample_rate = 44100

# Process each segment and compare transcriptions
for idx in range(len(audio_segments_subset_df)):
    audio_tensor = audio_segments_subset_df['Segment'][idx]
    actual_transcript = audio_segments_subset_df['Content'][idx]
    
    # Ensure audio is in the correct format
    audio_tensor = preprocess_audio(audio_tensor, sample_rate)
    
    # Transcribe audio
    model_transcription = transcribe_audio(audio_tensor, model, processor)
    model_transcription = model_transcription.upper()
    
    # Store transcriptions for metric calculation
    true_transcriptions.append(actual_transcript)
    predicted_transcriptions.append(model_transcription)
    
    # Print segment info
    print(f"Segment Start: {audio_segments_subset_df['StTime'][idx]}")
    print(f"Actual Transcript: {actual_transcript}")
    print(f"Model Transcription: {model_transcription}")
    print("\n")


Segment Start: 1486802.6
Actual Transcript: YOU KNOW, I AIN'T KNOW WHAT FREAKNIK WAS. I JUST HAPPENED TO GO DOWNTOWN.
Model Transcription:  YOU KNOW, I KNOW WHAT FREAKIN' MEEK WAS. I JUST HAPPENED TO GO DOWNTOWN.


Segment Start: 411784.0
Actual Transcript: DON'T REALLY THINK ABOUT HIM.
Model Transcription:  DON'T REALLY THINK ABOUT IT.


Segment Start: 2101098.7
Actual Transcript: I MET HER,
Model Transcription:  I MET HER.


Segment Start: 2749747.4000000004
Actual Transcript: FOR SURE.
Model Transcription:  YOU


Segment Start: 2445679.0
Actual Transcript: WOW,
Model Transcription:  WOW.


Segment Start: 359799.5
Actual Transcript: SHUT UP, SPARKY.
Model Transcription:  SHOUT OUT, SWAGGY!


Segment Start: 288361.0
Actual Transcript: DEFINITELY, DEFINITELY.
Model Transcription:  DEFINITELY.


Segment Start: 485938.8
Actual Transcript: OH, OKAY.
Model Transcription:  OH, OKAY.


Segment Start: 274038.7
Actual Transcript: MIDDLE SCHOOL AND- AND HIGH SCHOOL HERE.
Model Transcription:  M

In [52]:
wer_score = wer(true_transcriptions, predicted_transcriptions)
print(f"Word Error Rate (WER): {wer_score}")

Word Error Rate (WER): 0.41649484536082476


## Training `whisper-tiny` on CORAAL:ATL data

In [13]:
import os
import sys
import torch
import pandas as pd
from datetime import timedelta
import torchaudio
import pandas as pd
import torch
from transformers import WhisperProcessor, WhisperForConditionalGeneration, WhisperConfig, WhisperFeatureExtractor
from jiwer import wer
from tqdm import tqdm
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split

In [2]:
# Set device
# device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
device = torch.device("cpu")
print(device)

cpu


In [None]:
def create_file_mapping(transcript_dir, audio_dir):
    # Create lists of transcript and audio file paths
    transcript_paths = [os.path.join(transcript_dir, filename) for filename in os.listdir(transcript_dir) if filename.endswith('.txt')]
    audio_paths = [os.path.join(audio_dir, filename) for filename in os.listdir(audio_dir) if filename.endswith('.wav')]

    # Extract identifiers from filenames
    transcript_ids = [os.path.splitext(os.path.basename(path))[0] for path in transcript_paths]
    audio_ids = [os.path.splitext(os.path.basename(path))[0] for path in audio_paths]

    # Create a dictionary to map identifiers to file paths
    file_mapping = {
        'Interview': [],
        'Transcript Path': [],
        'Audio Path': []
    }

    # Match transcript files with audio files using their identifiers
    for transcript_id in transcript_ids:
        if transcript_id in audio_ids:
            transcript_path = os.path.join(transcript_dir, transcript_id + '.txt')
            audio_path = os.path.join(audio_dir, transcript_id + '.wav')
            file_mapping['Interview'].append(transcript_id)
            file_mapping['Transcript Path'].append(transcript_path)
            file_mapping['Audio Path'].append(audio_path)

    # Create DataFrame from the file mapping dictionary
    df = pd.DataFrame(file_mapping)
    
    return df

def process_transcripts(paths_df):
    all_transcript_data = []

    # Iterate over each row in the file_mapping_df
    for _, row in paths_df.iterrows():
        transcript_path = row['Transcript Path']
        
        # Load transcript data
        transcript_df = pd.read_csv(transcript_path, delimiter="\t")
        
        # Extract relevant columns
        transcript_data = transcript_df[['StTime', 'EnTime', 'Content']].copy()
        
        # Convert times from minutes to milliseconds for easier calculations
        transcript_data['StTime'] *= 1000
        transcript_data['EnTime'] *= 1000
        transcript_data['Content'] = transcript_data['Content'].str.upper()
        
        # Add identifier column (optional)
        transcript_data['Interview'] = row['Interview']

        # Append to list
        all_transcript_data.append(transcript_data)
    
    # Combine all transcript data into a single DataFrame
    combined_transcript_df = pd.concat(all_transcript_data, ignore_index=True)
    
    return combined_transcript_df

In [7]:
# Load your data
transcript_dir = '../data/coraal/transcript/text/'
audio_dir = '../data/coraal/audio/wav/'

paths_df = create_file_mapping(transcript_dir, audio_dir)
combined_transcript_df = process_transcripts(paths_df)

In [8]:
# Filter out rows with unwanted characters in the 'Content' column
pattern = r'[\(\)\[\]/<>]'
filtered_transcript_df = combined_transcript_df[~combined_transcript_df['Content'].str.contains(pattern)].reset_index(drop=True)

In [9]:
# subset_size = int(len(filtered_transcript_df)*.20)
# data_subset = filtered_transcript_df.sample(subset_size)
data_subset = filtered_transcript_df.sample(250)
print(len(data_subset))

250


In [10]:
data_subset.head()

Unnamed: 0,StTime,EnTime,Content,Interview
864,2497451.9,2498142.3,AWESOME.,ATL_se0_ag2_f_02_1
7067,1442971.0,1446815.2,"FREEDOM, YOU KNOW WHAT I'M SAYING, KIND OF ENE...",ATL_se0_ag1_m_01_1
4492,1803180.4,1805892.4,THEY JUST DON'T GET THAT RESPECT BECAUSE,ATL_se0_ag1_m_05_1
2833,980297.7,981786.0,WHAT IT WAS IN THE PAST,ATL_se0_ag2_m_03_1
6160,563491.7,563938.3,OKAY.,ATL_se0_ag1_f_03_1


In [11]:
# Split the data into training and testing sets
train_df, test_df = train_test_split(data_subset, test_size=0.2, random_state=42)

In [12]:
class AudioDataset(Dataset):
    def __init__(self, df, transcript_dir, audio_dir, processor, target_sample_rate=16000, target_length_ms=15000, padding_value=0):
        """
        Audio Dataset with Padding for ASR tasks.

        Parameters:
        - df (pd.DataFrame): DataFrame containing the dataset information.
        - transcript_dir (str): Directory containing transcript files.
        - audio_dir (str): Directory containing audio files.
        - processor: Processor for audio and text processing.
        - target_sample_rate (int): Target sample rate for audio segments.
        - target_length_ms (int): Target length for audio segments in milliseconds.
        - padding_value (float): Value to use for padding audio segments.
        """
        self.df = df
        self.transcript_dir = transcript_dir
        self.audio_dir = audio_dir
        self.processor = processor
        self.target_sample_rate = target_sample_rate
        self.target_length_ms = target_length_ms
        self.padding_value = padding_value

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        start_tm = row['StTime']
        end_tm = row['EnTime']
        content = row['Content']
        interview = row['Interview']
        audio_path = os.path.join(self.audio_dir, interview + '.wav')
        
        # Extract audio segment
        audio_segment = self.extract_audio_segment(audio_path, start_tm, end_tm)
        
        # Preprocess audio segment
        audio_tensor = self.processor(audio_segment.squeeze().numpy(), sampling_rate=self.target_sample_rate, return_tensors="pt").input_features
        
        # Tokenize and pad transcription
        encodings = self.processor.tokenizer(content, return_tensors="pt", padding='max_length', truncation=True, max_length=448)  # Adjust max_length as needed
        labels = encodings.input_ids.squeeze()
        attention_mask = encodings.attention_mask.squeeze()
        
        return {
            'input_features': audio_tensor.squeeze(),
            'labels': labels,
            'attention_mask': attention_mask
        }
    
    def extract_audio_segment(self, audio_file_path, start_time_ms, end_time_ms):
        """
        Extracts and pads a segment from an audio file based on start and end times.
        
        Parameters:
        - audio_file_path (str): Path to the audio file.
        - start_time_ms (int): Start time in milliseconds.
        - end_time_ms (int): End time in milliseconds.
        
        Returns:
        - torch.Tensor: Extracted and padded audio segment.
        """
        # Load the audio file
        waveform, sr = torchaudio.load(audio_file_path)
        
        # Convert milliseconds to sample indices
        start_sample = int(start_time_ms * sr / 1000)
        end_sample = int(end_time_ms * sr / 1000)
        
        # Extract the segment
        segment = waveform[:, start_sample:end_sample]
        
        # Resample audio to target sample rate if necessary
        if sr != self.target_sample_rate:
            resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=self.target_sample_rate)
            segment = resampler(segment)
        
        # Convert target length from milliseconds to samples
        target_length_samples = int(self.target_length_ms * self.target_sample_rate / 1000)
        
        # Pad the segment
        padded_segment = self.pad_audio_segment(segment, target_length_samples)
        
        return padded_segment
    
    def pad_audio_segment(self, segment, target_length, padding_value=0):
        """
        Pads an audio segment to the target length.
        
        Parameters:
        - segment (torch.Tensor): The audio segment to be padded.
        - target_length (int): The desired length in samples.
        - padding_value (float): The value to use for padding. Default is 0.
        
        Returns:
        - torch.Tensor: Padded audio segment.
        """
        current_length = segment.size(1)
        if current_length < target_length:
            padding_size = target_length - current_length
            padding = torch.full((segment.size(0), padding_size), padding_value)
            padded_segment = torch.cat((segment, padding), dim=1)
        else:
            padded_segment = segment[:, :target_length]  # Truncate if longer than target length
        return padded_segment

In [14]:
# Initialize processor and model
feature_extractor = WhisperFeatureExtractor.from_pretrained("openai/whisper-tiny.en")
processor = WhisperProcessor.from_pretrained("openai/whisper-tiny.en")
config = WhisperConfig.from_pretrained("openai/whisper-tiny.en")
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en").to(device)

In [15]:
# Create datasets and dataloaders
train_dataset = AudioDataset(train_df, transcript_dir, audio_dir, processor)
test_dataset = AudioDataset(test_df, transcript_dir, audio_dir, processor)

train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=8, shuffle=False)

In [16]:
# Training function
def train(model, train_loader, processor, optimizer, scheduler, device):
    model.train()
    total_loss = 0
    true_transcriptions = []
    predicted_transcriptions = []
    
    for batch in tqdm(train_loader, total=len(train_loader), desc='Training'):
        inputs = batch['input_features'].to(device)
        labels = batch['labels'].to(device)

        outputs = model(input_features=inputs, labels=labels)
        loss = outputs.loss

        outputs = model.generate(input_features=inputs)
        transcriptions = processor.batch_decode(outputs, skip_special_tokens=True)
        true_transcriptions.extend(processor.batch_decode(labels, skip_special_tokens=True))
        predicted_transcriptions.extend(transcriptions)

        predicted_transcriptions = [x.upper() for x in predicted_transcriptions]
        wer_score = wer(true_transcriptions, predicted_transcriptions)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        scheduler.step()
        
        total_loss += loss.item()
    total_loss = total_loss / len(train_loader)
    
    return total_loss, wer_score

# Evaluation function
def evaluate(model, test_loader, processor, device):
    model.eval()
    total_loss = 0
    true_transcriptions = []
    predicted_transcriptions = []
    
    with torch.no_grad():
        # for batch in tqdm(test_loader, total=len(test_loader), desc='Testing'):
        for batch in test_loader:
            inputs = batch['input_features'].to(device)
            labels = batch['labels'].to(device)

            outputs = model(input_features=inputs, labels=labels)
            loss = outputs.loss
            total_loss += loss.item()
            
            outputs = model.generate(input_features=inputs)
            transcriptions = processor.batch_decode(outputs, skip_special_tokens=True)
            true_transcriptions.extend(processor.batch_decode(labels, skip_special_tokens=True))
            predicted_transcriptions.extend(transcriptions)

    total_loss = total_loss / len(train_loader)
    predicted_transcriptions = [x.upper() for x in predicted_transcriptions]
    wer_score = wer(true_transcriptions, predicted_transcriptions)
    
    return total_loss, wer_score

In [17]:
# Fine-tuning
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-5)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.9)
num_epochs = 5

for epoch in range(num_epochs):
    print(f"Epoch {epoch+1}/{num_epochs}")
    
    train_loss, train_wer = train(model, train_loader, processor, optimizer, scheduler, device)
    print(f"Training Loss: {train_loss:.4f}, WER: {train_wer:.4f}")

Epoch 1/5


Training:   0%|          | 0/25 [00:00<?, ?it/s]The attention mask is not set and cannot be inferred from input because pad token is same as eos token.As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Training: 100%|██████████| 25/25 [03:38<00:00,  8.73s/it]


Training Loss: 0.2713, WER: 0.4554
Epoch 2/5


Training: 100%|██████████| 25/25 [03:39<00:00,  8.79s/it]


Training Loss: 0.1214, WER: 0.4536
Epoch 3/5


Training: 100%|██████████| 25/25 [03:47<00:00,  9.11s/it]


Training Loss: 0.1208, WER: 0.4536
Epoch 4/5


Training: 100%|██████████| 25/25 [03:48<00:00,  9.13s/it]


Training Loss: 0.1208, WER: 0.4536
Epoch 5/5


Training: 100%|██████████| 25/25 [03:47<00:00,  9.08s/it]

Training Loss: 0.1208, WER: 0.4536





In [18]:
test_loss, test_wer = evaluate(model, test_loader, processor, device)
print(f"Test Loss: {test_loss:.4f}, WER: {test_wer:.4f}")

Test Loss: 0.0324, WER: 0.3930


In [None]:
k_test = []
for k in range(10):
    test_loss, test_wer = evaluate(model, test_loader, processor, device)
    k_test.append([k, test_loss, test_wer])

In [None]:
k_test_df = pd.DataFrame(k_test, columns=['k', 'Loss', 'Word Error Rate']).sort_values(by="Word Error Rate", ascending=False).groupby('k')[:5]
k_test_df