## Import Libraries

In [1]:
import os
import sys
import torch
import pandas as pd
from datetime import timedelta
import torchaudio
import pandas as pd
import torch
from transformers import WhisperProcessor, WhisperForConditionalGeneration, WhisperConfig
from jiwer import wer

In [2]:
# Check if MPS is available
# device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
device = torch.device("cpu")
print(device)

cpu


In [3]:
def create_file_mapping(transcript_dir, audio_dir):
    # Create lists of transcript and audio file paths
    transcript_paths = [os.path.join(transcript_dir, filename) for filename in os.listdir(transcript_dir) if filename.endswith('.txt')]
    audio_paths = [os.path.join(audio_dir, filename) for filename in os.listdir(audio_dir) if filename.endswith('.wav')]

    # Extract identifiers from filenames
    transcript_ids = [os.path.splitext(os.path.basename(path))[0] for path in transcript_paths]
    audio_ids = [os.path.splitext(os.path.basename(path))[0] for path in audio_paths]

    # Create a dictionary to map identifiers to file paths
    file_mapping = {
        'Interview': [],
        'Transcript Path': [],
        'Audio Path': []
    }

    # Match transcript files with audio files using their identifiers
    for transcript_id in transcript_ids:
        if transcript_id in audio_ids:
            transcript_path = os.path.join(transcript_dir, transcript_id + '.txt')
            audio_path = os.path.join(audio_dir, transcript_id + '.wav')
            file_mapping['Interview'].append(transcript_id)
            file_mapping['Transcript Path'].append(transcript_path)
            file_mapping['Audio Path'].append(audio_path)

    # Create DataFrame from the file mapping dictionary
    df = pd.DataFrame(file_mapping)
    
    return df


In [4]:
# Example usage
transcript_dir = '../data/coraal/transcript/text/'
audio_dir = '../data/coraal/audio/wav/'

paths_df = create_file_mapping(transcript_dir, audio_dir)
display(paths_df)

Unnamed: 0,Interview,Transcript Path,Audio Path
0,ATL_se0_ag2_f_02_1,../data/coraal/transcript/text/ATL_se0_ag2_f_0...,../data/coraal/audio/wav/ATL_se0_ag2_f_02_1.wav
1,ATL_se0_ag2_m_02_1,../data/coraal/transcript/text/ATL_se0_ag2_m_0...,../data/coraal/audio/wav/ATL_se0_ag2_m_02_1.wav
2,ATL_se0_ag2_f_01_1,../data/coraal/transcript/text/ATL_se0_ag2_f_0...,../data/coraal/audio/wav/ATL_se0_ag2_f_01_1.wav
3,ATL_se0_ag2_m_03_1,../data/coraal/transcript/text/ATL_se0_ag2_m_0...,../data/coraal/audio/wav/ATL_se0_ag2_m_03_1.wav
4,ATL_se0_ag2_m_01_1,../data/coraal/transcript/text/ATL_se0_ag2_m_0...,../data/coraal/audio/wav/ATL_se0_ag2_m_01_1.wav
5,ATL_se0_ag1_m_05_1,../data/coraal/transcript/text/ATL_se0_ag1_m_0...,../data/coraal/audio/wav/ATL_se0_ag1_m_05_1.wav
6,ATL_se0_ag1_m_03_1,../data/coraal/transcript/text/ATL_se0_ag1_m_0...,../data/coraal/audio/wav/ATL_se0_ag1_m_03_1.wav
7,ATL_se0_ag1_f_01_1,../data/coraal/transcript/text/ATL_se0_ag1_f_0...,../data/coraal/audio/wav/ATL_se0_ag1_f_01_1.wav
8,ATL_se0_ag1_f_03_1,../data/coraal/transcript/text/ATL_se0_ag1_f_0...,../data/coraal/audio/wav/ATL_se0_ag1_f_03_1.wav
9,ATL_se0_ag1_m_01_1,../data/coraal/transcript/text/ATL_se0_ag1_m_0...,../data/coraal/audio/wav/ATL_se0_ag1_m_01_1.wav


In [5]:
def process_transcripts(paths_df):
    all_transcript_data = []

    # Iterate over each row in the file_mapping_df
    for _, row in paths_df.iterrows():
        transcript_path = row['Transcript Path']
        
        # Load transcript data
        transcript_df = pd.read_csv(transcript_path, delimiter="\t")
        
        # Extract relevant columns
        transcript_data = transcript_df[['StTime', 'EnTime', 'Content']].copy()
        
        # Convert times from minutes to milliseconds for easier calculations
        transcript_data['StTime'] *= 1000
        transcript_data['EnTime'] *= 1000
        transcript_data['Content'] = transcript_data['Content'].str.upper()
        
        # Add identifier column (optional)
        transcript_data['Interview'] = row['Interview']

        # Append to list
        all_transcript_data.append(transcript_data)
    
    # Combine all transcript data into a single DataFrame
    combined_transcript_df = pd.concat(all_transcript_data, ignore_index=True)
    
    return combined_transcript_df

In [9]:
# Example usage
combined_transcript_df = process_transcripts(paths_df)
display(combined_transcript_df.head(10), combined_transcript_df.tail(10))

Unnamed: 0,StTime,EnTime,Content,Interview
0,752.6,2511.3,HEY WHAT'S GOING ON?,ATL_se0_ag2_f_02_1
1,2511.3,3144.7,(PAUSE 0.63),ATL_se0_ag2_f_02_1
2,3144.7,4165.9,I'M HERE WITH,ATL_se0_ag2_f_02_1
3,4165.9,5083.0,(PAUSE 0.92),ATL_se0_ag2_f_02_1
4,5083.0,5853.6,/RD-NAME-2/.,ATL_se0_ag2_f_02_1
5,5853.6,6312.2,(PAUSE 0.46),ATL_se0_ag2_f_02_1
6,6312.2,7844.7,THIS IS /RD-NAME-2/ BY THE WAY.,ATL_se0_ag2_f_02_1
7,7844.7,8799.6,(PAUSE 0.95),ATL_se0_ag2_f_02_1
8,8799.6,10241.5,"UM, /RD-NAME-2/, IF YOU COULD",ATL_se0_ag2_f_02_1
9,10241.5,10534.6,(PAUSE 0.29),ATL_se0_ag2_f_02_1


Unnamed: 0,StTime,EnTime,Content,Interview
23827,2423303.1,2424396.2,ABOUT YOUR FUTURE.,ATL_se0_ag1_m_02_1
23828,2424396.2,2425081.9,(PAUSE 0.69),ATL_se0_ag1_m_02_1
23829,2425081.9,2425860.4,DEFINITELY GONNA,ATL_se0_ag1_m_02_1
23830,2425937.8,2426912.2,I APPRECIATE [THAT.],ATL_se0_ag1_m_02_1
23831,2426726.6,2428514.3,"[BE IN] TOUCH AFTER THIS INTERVIEW,",ATL_se0_ag1_m_02_1
23832,2429019.6,2429658.9,"AND, UM,",ATL_se0_ag1_m_02_1
23833,2429658.9,2430942.7,(PAUSE 1.28),ATL_se0_ag1_m_02_1
23834,2430942.7,2432412.1,"IT'S /RD-NAME-3/, AND",ATL_se0_ag1_m_02_1
23835,2432412.1,2433113.3,(PAUSE 0.70),ATL_se0_ag1_m_02_1
23836,2433113.3,2434628.8,I'M SIGNING OFF. THIS IS A WRAP.,ATL_se0_ag1_m_02_1


In [10]:
# Define the regex pattern for unwanted characters
pattern = r'[\(\)\[\]/<>]'

# Filter out rows with unwanted characters in the 'Content' column
filtered_transcript_df = combined_transcript_df[~combined_transcript_df['Content'].str.contains(pattern)].reset_index(drop=True)

display(filtered_transcript_df.head(10), filtered_transcript_df.tail(10))

# import re

# # Define the regex pattern for unwanted characters and enclosed content
# pattern = r'\([^)]*\)|\[[^\]]*\]|<[^>]*>|\/[^\/]*\/'

# def remove_enclosed_content(text):
#     return re.sub(pattern, '', text)

# # Apply the function to remove enclosed content
# combined_transcript_df['Content'] = combined_transcript_df['Content'].apply(remove_enclosed_content)

# # Optionally, you can filter out rows with empty 'Content' after removal
# filtered_transcript_df = combined_transcript_df[combined_transcript_df['Content'].str.strip() != ''].reset_index(drop=True)

# display(filtered_transcript_df.head(10), filtered_transcript_df.tail(10))

Unnamed: 0,StTime,EnTime,Content,Interview
0,752.6,2511.3,HEY WHAT'S GOING ON?,ATL_se0_ag2_f_02_1
1,3144.7,4165.9,I'M HERE WITH,ATL_se0_ag2_f_02_1
2,10534.6,11289.9,SAY HELLO.,ATL_se0_ag2_f_02_1
3,12142.4,12846.2,HELLO.,ATL_se0_ag2_f_02_1
4,13458.1,13962.9,"OKAY,",ATL_se0_ag2_f_02_1
5,14227.6,15991.0,MIC SEEMS TO BE PICKING YOU UP WELL.,ATL_se0_ag2_f_02_1
6,16841.9,17230.4,"UM,",ATL_se0_ag2_f_02_1
7,17306.0,17618.0,MM.,ATL_se0_ag2_f_02_1
8,17646.4,18227.9,WE'LL GO,ATL_se0_ag2_f_02_1
9,18932.3,21206.2,"AND, UM, CONTINUE WITH THIS INTERVIEW.",ATL_se0_ag2_f_02_1


Unnamed: 0,StTime,EnTime,Content,Interview
9958,2405803.3,2407020.0,"THAT WAS DOPE AS WELL, MAN.",ATL_se0_ag1_m_02_1
9959,2409658.4,2411158.8,"OH MAN, WELL,",ATL_se0_ag1_m_02_1
9960,2415648.1,2417452.7,AND THE WORDS OF WISDOM,ATL_se0_ag1_m_02_1
9961,2417782.6,2418267.3,AND,ATL_se0_ag1_m_02_1
9962,2419437.7,2420386.3,"JUST, YOU KNOW,",ATL_se0_ag1_m_02_1
9963,2421273.1,2422947.4,EVERYTHING. I'M EXCITED,ATL_se0_ag1_m_02_1
9964,2423303.1,2424396.2,ABOUT YOUR FUTURE.,ATL_se0_ag1_m_02_1
9965,2425081.9,2425860.4,DEFINITELY GONNA,ATL_se0_ag1_m_02_1
9966,2429019.6,2429658.9,"AND, UM,",ATL_se0_ag1_m_02_1
9967,2433113.3,2434628.8,I'M SIGNING OFF. THIS IS A WRAP.,ATL_se0_ag1_m_02_1


In [11]:
data_subset = filtered_transcript_df.sample(50)
data_subset

Unnamed: 0,StTime,EnTime,Content,Interview
645,1903747.8,1907220.5,YOUR GRIND AND WHAT YOU GET A- OUT OF IT. IT'S...,ATL_se0_ag2_f_02_1
1438,1382617.7,1383209.3,FUCKING J-,ATL_se0_ag2_m_02_1
5791,1187951.0,1188831.8,MM-MM.,ATL_se0_ag1_f_01_1
3012,1839011.7,1840170.1,"I DON'T EVEN KNOW, BRO.",ATL_se0_ag2_m_03_1
9662,1572775.3,1575791.5,"FOOTBALL JOCK DOING THE THEATER SHIT, COOL AS ...",ATL_se0_ag1_m_02_1
9357,744470.2,745670.1,SOME RHYME AND SOME DOESN'T?,ATL_se0_ag1_m_02_1
9564,1297661.0,1298302.0,"YOU KNOW,",ATL_se0_ag1_m_02_1
529,1449021.3,1449545.7,"BUT, UM,",ATL_se0_ag2_f_02_1
3765,1738973.3,1740904.1,SO-CALLED INDIAN HAIR LOOK.,ATL_se0_ag2_m_01_1
7669,531112.7,532345.0,AS OF RIGHT NOW,ATL_se0_ag1_m_04_2


In [27]:
def extract_audio_segment(audio_file_path, start_time_ms, end_time_ms, target_sample_rate=16000, segment_length_ms=30000):
    """
    Extracts a segment from an audio file based on start and end times and resamples it to a target sample rate.
    
    Parameters:
    - audio_file_path (str): Path to the audio file.
    - start_time_ms (int): Start time in milliseconds.
    - end_time_ms (int): End time in milliseconds.
    - target_sample_rate (int): The target sample rate for the audio segment.
    - segment_length_ms (int): The desired length of the audio segment in milliseconds.
    
    Returns:
    - torch.Tensor: Extracted and resampled audio segment.
    """
    # Load the audio file
    waveform, sr = torchaudio.load(audio_file_path)
    
    # Resample if the sample rate is not the target sample rate
    if sr != target_sample_rate:
        resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=target_sample_rate)
        waveform = resampler(waveform)
        sr = target_sample_rate
    
    # Convert milliseconds to sample indices
    start_sample = int(start_time_ms * sr / 1000)
    end_sample = int(end_time_ms * sr / 1000)
    
    # Extract the segment
    segment = waveform[:, start_sample:end_sample]
    
    # # Ensure the segment is of the desired length
    # target_length_samples = int(segment_length_ms * sr / 1000)
    # if segment.size(1) < target_length_samples:
    #     # Pad the segment if it's shorter than the target length
    #     padding = target_length_samples - segment.size(1)
    #     segment = torch.nn.functional.pad(segment, (0, padding))
    # else:
    #     # Trim the segment if it's longer than the target length
    #     segment = segment[:, :target_length_samples]
    
    return segment

In [28]:
# Initialize a list to store extracted audio segments
audio_segments = []

for idx, row in data_subset.iterrows():
    start_tm = row['StTime']
    end_tm = row['EnTime']
    content = row['Content']
    interview = row['Interview']

    transcript_path = transcript_dir+interview+".txt"
    audio_path = audio_dir+interview+".wav"

    segment = extract_audio_segment(audio_path, start_tm, end_tm)
    
    # Optionally, save the segment or process it further
    audio_segments.append({
        'StTime': start_tm,
        'EnTime': end_tm,
        'Interview': interview,
        'Content': content,
        'Segment': segment
    })

audio_segments_subset_df = pd.DataFrame(audio_segments)
audio_segments_subset_df.head(10)

Unnamed: 0,StTime,EnTime,Interview,Content,Segment
0,1903747.8,1907220.5,ATL_se0_ag2_f_02_1,YOUR GRIND AND WHAT YOU GET A- OUT OF IT. IT'S...,"[[tensor(-0.0038), tensor(-0.0038), tensor(-0...."
1,1382617.7,1383209.3,ATL_se0_ag2_m_02_1,FUCKING J-,"[[tensor(-0.0011), tensor(-0.0019), tensor(-0...."
2,1187951.0,1188831.8,ATL_se0_ag1_f_01_1,MM-MM.,"[[tensor(-0.0001), tensor(0.0002), tensor(0.00..."
3,1839011.7,1840170.1,ATL_se0_ag2_m_03_1,"I DON'T EVEN KNOW, BRO.","[[tensor(0.0005), tensor(0.0005), tensor(0.000..."
4,1572775.3,1575791.5,ATL_se0_ag1_m_02_1,"FOOTBALL JOCK DOING THE THEATER SHIT, COOL AS ...","[[tensor(-0.0004), tensor(-1.3209e-06), tensor..."
5,744470.2,745670.1,ATL_se0_ag1_m_02_1,SOME RHYME AND SOME DOESN'T?,"[[tensor(6.8611e-05), tensor(-0.0001), tensor(..."
6,1297661.0,1298302.0,ATL_se0_ag1_m_02_1,"YOU KNOW,","[[tensor(-0.0005), tensor(-0.0004), tensor(-0...."
7,1449021.3,1449545.7,ATL_se0_ag2_f_02_1,"BUT, UM,","[[tensor(-0.0005), tensor(-0.0005), tensor(-0...."
8,1738973.3,1740904.1,ATL_se0_ag2_m_01_1,SO-CALLED INDIAN HAIR LOOK.,"[[tensor(0.0006), tensor(0.0005), tensor(0.000..."
9,531112.7,532345.0,ATL_se0_ag1_m_04_2,AS OF RIGHT NOW,"[[tensor(-0.0007), tensor(-0.0009), tensor(-0...."


In [29]:
for idx, row in audio_segments_subset_df.iterrows():
    if idx == 10:
        break
    print(row['Segment'].shape)

torch.Size([1, 55564])
torch.Size([1, 9465])
torch.Size([1, 14092])
torch.Size([1, 18534])
torch.Size([1, 48260])
torch.Size([1, 19198])
torch.Size([1, 10256])
torch.Size([1, 8391])
torch.Size([1, 30893])
torch.Size([1, 19717])


In [23]:
# Initialize processor and model for whisper-tiny
processor = WhisperProcessor.from_pretrained("openai/whisper-tiny")
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny")

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [30]:
# Move the model to the appropriate device
model = model.to(device)

In [31]:
def preprocess_audio(audio_tensor, sample_rate, target_sample_rate=16000):
    if sample_rate != target_sample_rate:
        resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=target_sample_rate)
        audio_tensor = resampler(audio_tensor)
    return audio_tensor

def transcribe_audio(audio_tensor, model, processor):
    audio_tensor = audio_tensor.to(device)
    inputs = processor(audio_tensor.squeeze().cpu().numpy(), return_tensors="pt", sampling_rate=16000)
    inputs = {k: v.to(device) for k, v in inputs.items()}
    with torch.no_grad():
        predicted_ids = model.generate(**inputs)
    transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)
    return transcription[0]

In [32]:
# Initialize lists for metrics
true_transcriptions = []
predicted_transcriptions = []

sample_rate = 44100

# Process each segment and compare transcriptions
for idx in range(len(audio_segments_subset_df)):
    audio_tensor = audio_segments_subset_df['Segment'][idx]
    actual_transcript = audio_segments_subset_df['Content'][idx]
    
    # Ensure audio is in the correct format
    audio_tensor = preprocess_audio(audio_tensor, sample_rate)
    
    # Transcribe audio
    model_transcription = transcribe_audio(audio_tensor, model, processor)
    model_transcription = model_transcription.upper()
    
    # Store transcriptions for metric calculation
    true_transcriptions.append(actual_transcript)
    predicted_transcriptions.append(model_transcription)
    
    # Print segment info
    print(f"Segment Start: {audio_segments_subset_df['StTime'][idx]}")
    print(f"Actual Transcript: {actual_transcript}")
    print(f"Model Transcription: {model_transcription}")
    print("\n")


Segment Start: 1903747.8
Actual Transcript: YOUR GRIND AND WHAT YOU GET A- OUT OF IT. IT'S JUST WHAT IT IS, RIGHT?
Model Transcription:  ВОН, ВЫ ЖЕ НЕ БУДЕМ.


Segment Start: 1382617.7
Actual Transcript: FUCKING J-
Model Transcription:  YOU


Segment Start: 1187951.0
Actual Transcript: MM-MM.
Model Transcription:  YOU


Segment Start: 1839011.7
Actual Transcript: I DON'T EVEN KNOW, BRO.
Model Transcription:  YEAH.


Segment Start: 1572775.3
Actual Transcript: FOOTBALL JOCK DOING THE THEATER SHIT, COOL AS SHIT.
Model Transcription:  THANK YOU.


Segment Start: 744470.2
Actual Transcript: SOME RHYME AND SOME DOESN'T?
Model Transcription:  BYE.


Segment Start: 1297661.0
Actual Transcript: YOU KNOW,
Model Transcription:  I'M SORRY.


Segment Start: 1449021.3
Actual Transcript: BUT, UM,
Model Transcription:  YOU


Segment Start: 1738973.3
Actual Transcript: SO-CALLED INDIAN HAIR LOOK.
Model Transcription:  YOU


Segment Start: 531112.7000000001
Actual Transcript: AS OF RIGHT NOW
Model Tran

In [None]:
wer_score = wer(true_transcriptions, predicted_transcriptions)
print(f"Word Error Rate (WER): {wer_score}")

Word Error Rate (WER): 2.3282208588957056


## Training `whisper-tiny` on CORAAL:ATL data

In [22]:
import os
import sys
import torch
import pandas as pd
from datetime import timedelta
import torchaudio
import pandas as pd
import torch
from transformers import WhisperProcessor, WhisperForConditionalGeneration, WhisperConfig
from jiwer import wer
from tqdm import tqdm
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split

In [23]:
# Set device
# device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
device = torch.device("cpu")
print(device)

cpu


In [24]:
# Load your data
transcript_dir = '../data/coraal/transcript/text/'
audio_dir = '../data/coraal/audio/wav/'

paths_df = create_file_mapping(transcript_dir, audio_dir)
combined_transcript_df = process_transcripts(paths_df)

In [25]:
# Filter out rows with unwanted characters in the 'Content' column
pattern = r'[\(\)\[\]/<>]'
filtered_transcript_df = combined_transcript_df[~combined_transcript_df['Content'].str.contains(pattern)].reset_index(drop=True)

In [36]:
# subset_size = int(len(filtered_transcript_df)*.20)
# data_subset = filtered_transcript_df.sample(subset_size)
data_subset = filtered_transcript_df.sample(100)
print(len(data_subset))

100


In [37]:
data_subset.head()

Unnamed: 0,StTime,EnTime,Content,Interview
9531,1164719.5,1165790.6,I LOVE GOOD MUSIC.,ATL_se0_ag1_m_02_1
5617,434093.0,434514.8,MM-HM.,ATL_se0_ag1_f_01_1
6650,125283.1,125905.6,"UH,",ATL_se0_ag1_m_01_1
3173,2621873.9,2623842.1,"THAT'S STILL LIT THOUGH TO GET, UM,",ATL_se0_ag2_m_03_1
6157,552074.8,555554.6,YOU WAS TREATED DIFFERENTLY BECAUSE YOU WAS TH...,ATL_se0_ag1_f_03_1


In [38]:
# Split the data into training and testing sets
train_df, test_df = train_test_split(data_subset, test_size=0.2, random_state=42)

In [39]:
class AudioDataset(Dataset):
    def __init__(self, df, transcript_dir, audio_dir, processor):
        self.df = df
        self.transcript_dir = transcript_dir
        self.audio_dir = audio_dir
        self.processor = processor
        self.target_sample_rate = 16000

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        start_tm = row['StTime']
        end_tm = row['EnTime']
        content = row['Content']
        interview = row['Interview']
        audio_path = os.path.join(self.audio_dir, interview + '.wav')
        
        # Extract audio segment
        waveform, sr = torchaudio.load(audio_path)
        start_sample = int(start_tm * sr / 1000)
        end_sample = int(end_tm * sr / 1000)
        audio_segment = waveform[:, start_sample:end_sample]
        
        # Resample audio to target sample rate
        if sr != self.target_sample_rate:
            resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=self.target_sample_rate)
            audio_segment = resampler(audio_segment)
        
        # Preprocess audio segment
        audio_tensor = self.processor(audio_segment.squeeze().numpy(), sampling_rate=self.target_sample_rate, return_tensors="pt").input_features
        
        return {
            'input_features': audio_tensor.squeeze(),
            'labels': self.processor.tokenizer(content, return_tensors="pt").input_ids.squeeze()
        }

In [40]:
# Initialize processor and model
from transformers import WhisperFeatureExtractor

feature_extractor = WhisperFeatureExtractor.from_pretrained("openai/whisper-small")
processor = WhisperProcessor.from_pretrained("openai/whisper-tiny")
config = WhisperConfig.from_pretrained("openai/whisper-tiny")
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny").to(device)

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [44]:
# Create datasets and dataloaders
train_dataset = AudioDataset(train_df, transcript_dir, audio_dir, processor)
test_dataset = AudioDataset(test_df, transcript_dir, audio_dir, processor)

train_loader = DataLoader(train_dataset, batch_size=1, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)

In [45]:
# Training function
def train(model, train_loader, processor, optimizer, scheduler, device):
    model.train()
    total_loss = 0
    true_transcriptions = []
    predicted_transcriptions = []
    
    for batch in tqdm(train_loader, total=len(train_loader), desc='Training'):
        inputs = batch['input_features'].to(device)
        labels = batch['labels'].to(device)

        outputs = model(input_features=inputs, labels=labels)
        loss = outputs.loss

        outputs = model.generate(input_features=inputs)
        transcriptions = processor.batch_decode(outputs, skip_special_tokens=True)
        true_transcriptions.extend(processor.batch_decode(labels, skip_special_tokens=True))
        predicted_transcriptions.extend(transcriptions)

        predicted_transcriptions = [x.upper() for x in predicted_transcriptions]
        wer_score = wer(true_transcriptions, predicted_transcriptions)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        scheduler.step()
        
        total_loss += loss.item()
    total_loss = total_loss / len(train_loader)
    
    return total_loss, wer_score

# Evaluation function
def evaluate(model, test_loader, processor, device):
    model.eval()
    total_loss = 0
    true_transcriptions = []
    predicted_transcriptions = []
    
    with torch.no_grad():
        # for batch in tqdm(test_loader, total=len(test_loader), desc='Testing'):
        for batch in test_loader:
            inputs = batch['input_features'].to(device)
            labels = batch['labels'].to(device)

            outputs = model(input_features=inputs, labels=labels)
            loss = outputs.loss
            total_loss += loss.item()
            
            outputs = model.generate(input_features=inputs)
            transcriptions = processor.batch_decode(outputs, skip_special_tokens=True)
            true_transcriptions.extend(processor.batch_decode(labels, skip_special_tokens=True))
            predicted_transcriptions.extend(transcriptions)

    total_loss = total_loss / len(train_loader)
    predicted_transcriptions = [x.upper() for x in predicted_transcriptions]
    wer_score = wer(true_transcriptions, predicted_transcriptions)
    
    return total_loss, wer_score

In [50]:
# Fine-tuning
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-5)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.9)
num_epochs = 5

for epoch in range(num_epochs):
    print(f"Epoch {epoch+1}/{num_epochs}")
    
    train_loss, train_wer = train(model, train_loader, processor, optimizer, scheduler, device)
    print(f"Training Loss: {train_loss:.4f}, WER: {train_wer:.4f}")

Epoch 1/5


Training: 100%|██████████| 80/80 [01:36<00:00,  1.20s/it]


Training Loss: 2.4433, WER: 0.5142
Epoch 2/5


Training: 100%|██████████| 80/80 [01:49<00:00,  1.37s/it]


Training Loss: 2.2080, WER: 0.4953
Epoch 3/5


Training: 100%|██████████| 80/80 [01:41<00:00,  1.26s/it]


Training Loss: 2.2080, WER: 0.4953
Epoch 4/5


Training:  80%|████████  | 64/80 [01:28<00:23,  1.45s/it]

: 

In [None]:
test_loss, test_wer = evaluate(model, test_loader, processor, device)
print(f"Test Loss: {test_loss:.4f}, WER: {test_wer:.4f}")

Test Loss: 0.7490, WER: 0.4286


In [None]:
k_test = []
for k in range(10):
    test_loss, test_wer = evaluate(model, test_loader, processor, device)
    k_test.append([k, test_loss, test_wer])

In [None]:
k_test_df = pd.DataFrame(k_test, columns=['k', 'Loss', 'Word Error Rate']).sort_values(by="Word Error Rate", ascending=False).groupby('k')[:5]
k_test_df