# Whisper Inference for Punjabi

In [None]:
# Check GPU availability
gpu_info = !nvidia-smi
gpu_info = '\n'.join(gpu_info)
if gpu_info.find('failed') >= 0:
  print('Not connected to a GPU')
else:
  print(gpu_info)

In [2]:
# All the imports
from pathlib import Path
import pandas as pd
import numpy as np
from copy import deepcopy

from datasets import DatasetDict, Dataset, Audio
import soundfile as sf
import pysrt

In [3]:
data_path = Path("../data/Punjabi")
audio_path = data_path / "Audio"
transcript_path = data_path / "Text"
srt_path = data_path / "Ground Truth SRT"

In [None]:
benchmark_df = pd.read_csv(data_path / "benchmark_list.csv", index_col=0).reset_index(drop=True)
benchmark_df.head()

In [None]:
benchmark_paths = benchmark_df["Story Name"].unique().tolist()
benchmark_paths

In [None]:
# Create lists to store decoded audio and transcripts
base_lists = {"audio_path": [], "transcript_path": [], "srt_path": []}
dataset = {"train": deepcopy(base_lists), "val": deepcopy(base_lists)}
dataset

In [None]:
for dir_path in audio_path.glob("*"):
    for file_path in dir_path.glob("*.wav"):
        print(file_path.name)
        if any(benchmark_path in file_path.name for benchmark_path in benchmark_paths):
            dataset_type = "val"
        else:
            dataset_type = "train"
        print(dataset_type)

        # Append audio path and transcript path to dataset
        file_audio_path = str(file_path)
        file_transcript_path = transcript_path / dir_path.name.replace("Videos", "Text") / str(file_path.name.replace(".wav", ".txt"))
        file_srt_path = srt_path / dir_path.name.replace("Videos", "SRT") / str(file_path.name.replace(".wav", ".srt"))
        dataset[dataset_type]["audio_path"].append(file_audio_path)
        dataset[dataset_type]["transcript_path"].append(file_transcript_path)
        dataset[dataset_type]["srt_path"].append(file_srt_path)  
        
        # # Load and decode audio
        # decoded_audio, sampling_rate = sf.read(file_audio_path)
        # dataset[dataset_type]["array"].append(decoded_audio)
        # dataset[dataset_type]["sampling_rate"].append(sampling_rate)

        # # Read transcript and append to dataset
        # with open(file_transcript_path, 'r', encoding='utf-8') as file:
        #     transcript = file.read()
        #     dataset[dataset_type]["transcript"].append(transcript)

## Create Whisper Feature Extractor, Tokenizer and Processor

In [None]:
from transformers import WhisperFeatureExtractor

feature_extractor = WhisperFeatureExtractor.from_pretrained("openai/whisper-large-v2")

In [9]:
from transformers import WhisperTokenizer

tokenizer = WhisperTokenizer.from_pretrained("openai/whisper-large-v2", language="Punjabi", task="transcribe")

In [10]:
from transformers import WhisperProcessor

processor = WhisperProcessor.from_pretrained("openai/whisper-large-v2", language="Punjabi", task="transcribe")

In [None]:
from transformers import WhisperForConditionalGeneration

model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-large-v2")

In [10]:
from transformers import WhisperForConditionalGeneration
import torch
import librosa


# Load Whisper model and processor
model_name = "openai/whisper-large-v2"
model = WhisperForConditionalGeneration.from_pretrained(model_name).eval().to("cuda" if torch.cuda.is_available() else "cpu")

In [12]:
import torchaudio

def load_audio(file_path):
    speech_array, sampling_rate = torchaudio.load(file_path)
    if sampling_rate != 16000:
        resampler = torchaudio.transforms.Resample(orig_freq=sampling_rate, new_freq=16000)
        speech_array = resampler(speech_array)
    return speech_array.squeeze(), 16000

In [13]:
# Load and preprocess audio
audio_path = "/home/arjun/naren/alignment/data/Punjabi/Audio/AniBook Videos/Abdul_Kalam,_Missile_Man_Punjabi.wav"
import os
import math
from pydub import AudioSegment

def chunk_audio(file_path, chunk_length_ms=30000):
    audio = AudioSegment.from_wav(file_path)
    audio_length_ms = len(audio)
    chunks = []
    for i in range(0, audio_length_ms, chunk_length_ms):
        chunk = audio[i:i + chunk_length_ms]
        chunks.append(chunk)
    return chunks

def save_chunks(chunks, base_path):
    for i, chunk in enumerate(chunks):
        chunk_name = f"{base_path}/chunk{i}.wav"
        chunk.export(chunk_name, format="wav")

# Chunk the audio file into 30-second segments
chunks = chunk_audio(audio_path, chunk_length_ms=30000)
save_chunks(chunks, "/home/arjun/naren/alignment/data/Punjabi/Audio/AniBook Videos/Abdul_Kalam,_Missile_Man_Punjabi_chunks")

## Transcription

In [None]:
audio_chunk_path = "/home/arjun/naren/alignment/data/Punjabi/Audio/AniBook Videos/Abdul_Kalam,_Missile_Man_Punjabi_chunks/chunk0.wav"

speech, _ = load_audio(audio_chunk_path)
input_features = processor(speech, sampling_rate=16000, return_tensors="pt").input_features.to(model.device)

In [None]:
# forced_decoder_ids = processor.get_decoder_prompt_ids(language="punjabi", task="transcribe")

# Generate transcription with timestamps
with torch.no_grad():
    generated_ids = model.generate(input_features, return_timestamps=False)

# Decode and clean special tokens
transcription = processor.batch_decode(generated_ids, output_word_offsets=True)
# cleaned_transcription = transcription[0].replace("<|startoftranscript|>", "").replace("<|endoftranscript|>", "").strip()

# print("Transcription:", cleaned_transcription)
print(transcription)

In [None]:
# forced_decoder_ids = processor.get_decoder_prompt_ids(language="punjabi", task="transcribe")

# Generate transcription with timestamps
with torch.no_grad():
    generated_ids = model.generate(input_features, return_timestamps=True)

# Decode and clean special tokens
transcription = processor.batch_decode(generated_ids, output_word_offsets=False, max_length=448)
# cleaned_transcription = transcription[0].replace("<|startoftranscript|>", "").replace("<|endoftranscript|>", "").strip()

# print("Transcription:", cleaned_transcription)
print(transcription)

In [None]:
# Define your initial prompt
initial_prompt = "ਸਤ ਸ੍ਰੀ ਅਕਾਲ, ਇਹ ਇੱਕ ਪੰਜਾਬੀ ਆਡੀਓ ਫਾਇਲ ਹੈ, ਇਸ ਦਾ ਸਮੱਗਰੀ ਹੇਠ ਲਿਖੀ ਹੈ۔"
forced_decoder_ids = processor.get_decoder_prompt_ids(text=initial_prompt, language="pa")

In [None]:
from transformers import pipeline

# Load whisper-large-v3
model_name = "openai/whisper-large-v2"
whisper_pipeline = pipeline(
    task="automatic-speech-recognition", 
    model=model_name, 
    model_kwargs={"forced_decoder_ids": forced_decoder_ids},  
    device=0  # Specify GPU device
)

# Transcribe audio with timestamps
output = whisper_pipeline(
    "/home/arjun/naren/alignment/data/Punjabi/Audio/AniBook Videos/Abdul_Kalam,_Missile_Man_Punjabi_chunks/chunk0.wav",
    return_timestamps=True,
    chunk_length_s=10,  # Chunk size in seconds
    stride_length_s=1,  # Overlapping stride to prevent cutoff
)

# Extract timestamps and transcriptions
for segment in output["chunks"]:
    print(f"Text: {segment['text']} | Start: {segment['timestamp'][0]}s | End: {segment['timestamp'][1]}s")


In [None]:
from transformers import pipeline, AutoProcessor, AutoModelForSpeechSeq2Seq

# Load the Whisper model and processor
model_name = "openai/whisper-large-v2"
processor = AutoProcessor.from_pretrained(model_name)
model = AutoModelForSpeechSeq2Seq.from_pretrained(model_name)

# Define the initial prompt
initial_prompt = "ਸਤ ਸ੍ਰੀ ਅਕਾਲ, ਇਹ ਇੱਕ ਪੰਜਾਬੀ ਆਡੀਓ ਫਾਇਲ ਹੈ, ਇਸ ਦਾ ਸਮੱਗਰੀ ਹੇਠ ਲਿਖੀ ਹੈ۔"

# Encode the initial prompt manually using the tokenizer
tokenizer = processor.tokenizer
forced_decoder_ids = tokenizer.encode(initial_prompt, add_special_tokens=False, return_tensors="pt").squeeze(0).tolist()

# Create the pipeline with the model and forced_decoder_ids
whisper_pipeline = pipeline(
    task="automatic-speech-recognition",
    model=model,
    tokenizer=tokenizer,
    feature_extractor=processor.feature_extractor,
    model_kwargs={"forced_decoder_ids": forced_decoder_ids},  # Pass the encoded initial prompt
    device=0,  # Use GPU if available
)

# Transcribe audio with the pipeline
output = whisper_pipeline(
    "/home/arjun/naren/alignment/data/Punjabi/Audio/AniBook Videos/Abdul_Kalam,_Missile_Man_Punjabi_chunks/chunk0.wav",
    return_timestamps=True,
    chunk_length_s=5,  # Specify chunk size in seconds
    stride_length_s=1,  # Specify overlapping stride
)

# Extract and print timestamps and transcriptions
for segment in output["chunks"]:
    print(f"Text: {segment['text']} | Start: {segment['timestamp'][0]}s | End: {segment['timestamp'][1]}s")


In [None]:
output["chunks"]

In [None]:
# Load audio
audio_path = "/home/arjun/naren/alignment/data/Punjabi/Audio/AniBook Videos/Abdul_Kalam,_Missile_Man_Punjabi_chunks/chunk0.wav"
audio, sr = librosa.load(audio_path, sr=16000)

# Encode input
inputs = processor(audio, sampling_rate=sr, return_tensors="pt", padding=True).to("cuda")

# Enable timestamps
forced_decoder_ids = processor.get_decoder_prompt_ids(language="pa", task="translate")
print(forced_decoder_ids)
# Generate output
with torch.no_grad():
    generated_ids = model.generate(
        inputs.input_features,
        forced_decoder_ids=forced_decoder_ids,  # Ensure no `<|notimestamps|>` token
        return_timestamps=True,
        #output_word_offsets=True
    )

# Decode output
transcription = processor.batch_decode(generated_ids, skip_special_tokens=True)
print(transcription)

In [None]:

# Split the string wherever there is `><`
segments = transcription[0].split('><')

# Print each segment
for i in range(len(segments)):
    if  i < len(segments) - 1:
        segments[i] = segments[i] + '>'

    if i > 0:
        segments[i] = '<' + segments[i]
print(segments)

In [None]:
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer

# Load the IndicTrans model and tokenizer
# model_name = "ai4bharat/indic-bert"
# tokenizer = AutoTokenizer.from_pretrained(model_name)
# model = AutoModelForSeq2SeqLM.from_pretrained(model_name)

# Load model directly
from transformers import AutoModel
model = AutoModel.from_pretrained("ai4bharat/indic-bert")


# Function to extract text and timestamps
def extract_text_and_timestamps(segment):
    parts = segment.split('>')
    timestamp_start = parts[0] + '>'
    text = parts[1].rsplit('<', 1)[0]
    timestamp_end = '<' + parts[1].rsplit('<', 1)[1]
    return timestamp_start, text, timestamp_end

# Generate Punjabi transcript for the English text and preserve timestamps
punjabi_transcripts = []
for segment in segments:
    timestamp_start, text, timestamp_end = extract_text_and_timestamps(segment)
    # Tokenize the input text
    inputs = tokenizer(text, return_tensors="pt")
    # Generate translation
    translated_tokens = model.generate(**inputs)
    # Decode the translated tokens
    punjabi_transcript = tokenizer.batch_decode(translated_tokens, skip_special_tokens=True)[0]

    punjabi_transcripts.append(f"{timestamp_start}{punjabi_transcript}{timestamp_end}")

# Print the Punjabi transcripts with timestamps
for transcript in punjabi_transcripts:
    print(transcript)


In [None]:
# Load audio
audio_path = "/home/arjun/naren/alignment/data/Punjabi/Audio/AniBook Videos/Abdul_Kalam,_Missile_Man_Punjabi_chunks/chunk0.wav"
audio, sr = librosa.load(audio_path, sr=16000)

# Encode input
inputs = processor(audio, sampling_rate=sr, return_tensors="pt", padding=True).to("cuda")

# Enable timestamps
forced_decoder_ids = processor.get_decoder_prompt_ids(language="punjabi", task="transcribe")


# Generate output
with torch.no_grad():
    generated = model.generate(
        inputs.input_features,
        forced_decoder_ids=forced_decoder_ids,
        return_timestamps=True,  # Ensuring timestamps are returned
    )

# Decode output with timestamps
transcription = processor.batch_decode(generated, skip_special_tokens=False)
  
# Print transcription with timestamps
for i, segment in enumerate(transcription):
    print(f"Segment {i}: {segment}")

# Now let's extract the timestamps from the output
# `generated` will contain a dictionary with 'sequences' and 'timestamps'
timestamps = generated['timestamps']
print(f"Timestamps: {timestamps}")


## Pipeline code

In [None]:
from transformers import pipeline

# Replace with the path to your Punjabi audio file
AUDIO_FILE = "/home/arjun/naren/alignment/data/Punjabi/Audio/AniBook Videos/Abdul_Kalam,_Missile_Man_Punjabi_chunks/chunk0.wav"

# Load a Whisper large model (v2) from Hugging Face
# Note: There's no "v3" on the Hub, so we use the latest large-v2 model.
# If you have a custom "v3" model, just replace with its Hugging Face repo name.
# whisper_asr = pipeline(
#     task="automatic-speech-recognition",
#     model="openai/whisper-large-v2",
#     chunk_length_s=30,          # Process audio in 30 second chunks
#     return_timestamps=True,     # Return timestamps for each chunk
#     generate_kwargs={
#         "language": "<|pa|>",   # Force the model to use Punjabi
#         "task": "transcribe"    # Instruct the model to perform transcription
#     }
# )

whisper_asr = pipeline(
    task="automatic-speech-recognition",
    model="openai/whisper-large-v2",
    chunk_length_s=30,
    return_timestamps=True,
)

# Overwrite forced_decoder_ids at the model level
forced_decoder_ids = whisper_asr.tokenizer.get_decoder_prompt_ids(
    language="pa",
    task="transcribe",
    no_timestamps=False
)
whisper_asr.model.generation_config.forced_decoder_ids = forced_decoder_ids

# Run transcription
result = whisper_asr(AUDIO_FILE)

# 'result' will have a 'chunks' key containing segments with timestamps
print("Transcription with timestamps:\n")
for chunk in result["chunks"]:
    start_time, end_time = chunk["timestamp"]
    text = chunk["text"]
    print(f"[{start_time:.2f}s - {end_time:.2f}s] {text}")

## New try

In [65]:
import torch
import torchaudio
import torchaudio.transforms as T
import re
from transformers import WhisperProcessor, WhisperForConditionalGeneration

# ------------------------------------------------
# 1. Configuration
# ------------------------------------------------
AUDIO_FILE = "/home/arjun/naren/alignment/data/Punjabi/Audio/AniBook Videos/Abdul_Kalam,_Missile_Man_Punjabi.wav"  # Replace with your audio path
MODEL_ID = "openai/whisper-large-v2"
CHUNK_LENGTH_S = 30       # 30-second chunks
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# ------------------------------------------------
# 2. Load Model & Processor
# ------------------------------------------------
processor = WhisperProcessor.from_pretrained(MODEL_ID)
model = WhisperForConditionalGeneration.from_pretrained(MODEL_ID).to(DEVICE)

# Force the model to produce timestamps in Punjabi transcription:
# - no_timestamps=False ensures the model does NOT insert <|notimestamps|>.
forced_decoder_ids = processor.get_decoder_prompt_ids(
    language="pa",
    task="transcribe",
    no_timestamps=False      # crucial!
)


In [79]:
# ------------------------------------------------
# 3. Load & Resample Audio to 16 kHz
# ------------------------------------------------
audio_waveform, original_sr = torchaudio.load(AUDIO_FILE)
if original_sr != 16000:
    resampler = T.Resample(orig_freq=original_sr, new_freq=16000)
    audio_waveform = resampler(audio_waveform)
sr = 16000

# Convert to mono if needed
if audio_waveform.shape[0] > 1:
    audio_waveform = torch.mean(audio_waveform, dim=0, keepdim=True)

num_samples = audio_waveform.shape[1]
chunk_size = CHUNK_LENGTH_S * sr

In [80]:
# ------------------------------------------------
# 4. Chunking Helper
# ------------------------------------------------
def audio_chunks_generator(waveform, chunk_samples):
    """
    Yields (chunk_waveform, chunk_start_sec).
    """
    start = 0
    while start < num_samples:
        end = min(start + chunk_samples, num_samples)
        chunk = waveform[:, start:end]
        chunk_start_sec = start / sr
        yield chunk, chunk_start_sec
        start = end

# ------------------------------------------------
# 5. Parsing Timestamp Tokens
# ------------------------------------------------
# We want to keep <|xx.xx|> but remove other special tokens
# like <|startoftranscript|>, <|pa|>, <|transcribe|>, etc.

def clean_special_tokens(text: str) -> str:
    """
    Remove known special tokens except those used for timestamps (<|xx.xx|>).
    """
    # Remove <|startoftranscript|>, <|pa|>, <|transcribe|>, etc.
    # but do NOT remove <|0.00|> or similar tokens.
    # We'll do a simple approach that specifically targets known tokens.
    tokens_to_remove = [
        r"<\|startoftranscript\|>",
        r"<\|endoftext\|>",
        r"<\|pa\|>",
        r"<\|transcribe\|>",
        r"<\|translate\|>",
        r"<\|notimestamps\|>"
    ]
    for pattern in tokens_to_remove:
        text = re.sub(pattern, "", text)
    return text.strip()

def parse_timestamped_text(text_with_ts):
    """
    Given a text that may contain <|0.00|> tokens, split it into segments:
        [(time_in_s, segment_text), ...]
    We assume the text has already been cleaned of other tokens,
    so only <|xx.xx|> remains.
    """
    # This regex will capture the numeric part inside <|xx.xx|>
    ts_pattern = re.compile(r"<\|(\d+\.\d+)\|>")

    segments = []
    current_text = ""
    current_time = 0.0

    # Split on <|xx.xx|> but keep the time in the result
    parts = ts_pattern.split(text_with_ts)
    # Example: ["Some text ", "0.00", " more text ", "3.24", " more text", ...]

    for i, part in enumerate(parts):
        if i % 2 == 0:
            # This is text content
            # Accumulate or finalize a segment
            text_segment = part.strip()
            if text_segment:
                segments.append((current_time, text_segment))
        else:
            # This is the timestamp (odd index)
            current_time = float(part)

    return segments

In [69]:
all_segments = []

# Process each chunk independently
for chunk_waveform, chunk_start_sec in get_audio_chunks(audio_waveform, chunk_size):
    # 1) Prepare input features
    input_features = processor(
        chunk_waveform.squeeze(0),
        sampling_rate=sr,
        return_tensors="pt"
    ).input_features.to(DEVICE)

    # 2) Inference
    with torch.no_grad():
        generated_tokens = model.generate(
            input_features,
            forced_decoder_ids=forced_decoder_ids,
            max_length=448,   # you can adjust
            # You can set other generation parameters here
        )
   
    # 3) Decode *with* special tokens so we can see <|xx.xx|>
    raw_transcription = processor.batch_decode(generated_tokens, skip_special_tokens=False)[0]
    
    # 4) Clean out all special tokens EXCEPT <|xx.xx|>
    cleaned_text = clean_special_tokens(raw_transcription)

    # 5) Parse the cleaned text to extract (timestamp, text)
    segments_in_chunk = parse_timestamped_text(cleaned_text)

    # 6) Offset each timestamp by the chunk start
    for rel_t, seg_text in segments_in_chunk:
        abs_t = chunk_start_sec + rel_t
        all_segments.append((abs_t, seg_text))

# ------------------------------------------------
# 7. Print Final Results
# ------------------------------------------------
print("Transcription with approximate timestamps:\n")
for t, txt in all_segments:
    print(f"[{t:.2f}s] {txt}")

In [81]:
# ------------------------------------------------
# 6. Perform Transcription with Timestamps
# ------------------------------------------------
all_segments = []

for chunk_waveform, chunk_start_sec in audio_chunks_generator(audio_waveform, chunk_size):
    # 1) Convert waveform to input features
    inputs = processor(
        chunk_waveform.squeeze(0),
        sampling_rate=sr,
        return_tensors="pt"
    )

    input_features = inputs.input_features.to(DEVICE)

    # 2) Generate tokens (no timestamps skipping)
    with torch.no_grad():
        generated_tokens = model.generate(
            input_features,
            forced_decoder_ids=forced_decoder_ids,
            max_length=448,  # can adjust
            # You can set other decoding params: temperature, top_k, etc.
        )

    # 3) Decode *with* special tokens so we can see <|xx.xx|>
    raw_transcription = processor.batch_decode(generated_tokens, skip_special_tokens=False)[0]
    
    # 4) Clean out all special tokens EXCEPT <|xx.xx|>
    cleaned_text = clean_special_tokens(raw_transcription)

    # 5) Parse the cleaned text to extract (timestamp, text)
    segments_in_chunk = parse_timestamped_text(cleaned_text)

    # 6) Offset each timestamp by the chunk start
    for rel_t, seg_text in segments_in_chunk:
        abs_t = chunk_start_sec + rel_t
        all_segments.append((abs_t, seg_text))

# ------------------------------------------------
# 7. Print Final Results
# ------------------------------------------------
print("Transcription with approximate timestamps:\n")
for t, txt in all_segments:
    print(f"[{t:.2f}s] {txt}")

In [77]:
import transformers
transformers.__version__

## Post processing

In [14]:
def format_timestamp(seconds):
    """Convert seconds to HH:MM:SS format"""
    hours = int(seconds // 3600)
    minutes = int((seconds % 3600) // 60)
    seconds = seconds % 60
    return f"{hours:02d}:{minutes:02d}:{seconds:05.2f}"

# Process the output to get clean timestamps and text
chunks = []
current_chunk = {"text": "", "start": 0, "end": 0}

# Split the transcription into chunks with timestamps
for chunk in transcription[0].split("<|"):
    if ">" in chunk:
        timestamp_info = chunk.split(">")[0]
        if timestamp_info.startswith("0.00"):
            current_chunk["start"] = 0.00
        elif "timestamp" in timestamp_info:
            time = float(timestamp_info.split("]")[0].split("[")[-1])
            if current_chunk["text"]:
                current_chunk["end"] = time
                chunks.append(current_chunk.copy())
            current_chunk = {"text": "", "start": time, "end": 0}
    else:
        current_chunk["text"] += chunk.strip()

# Format the output
formatted_transcript = []
for chunk in chunks:
    if chunk["text"].strip():  # Only include non-empty chunks
        formatted_transcript.append({
            "text": chunk["text"].strip(),
            "start": format_timestamp(chunk["start"]),
            "end": format_timestamp(chunk["end"]),
            "start_seconds": chunk["start"],
            "end_seconds": chunk["end"]
        })


# Print formatted output
for segment in formatted_transcript:
    print(f"[{segment['start']} --> {segment['end']}] {segment['text']}")

# Whisper library

In [None]:
!pip install openai-whisper
!sudo apt-get install ffmpeg

In [None]:
import whisper
# Load the Whisper model
model = whisper.load_model("medium")

# Transcribe and translate the chunk audio of Punjabi
result = model.transcribe("/home/arjun/naren/alignment/data/Punjabi/Audio/AniBook Videos/Abdul_Kalam,_Missile_Man_Punjabi_chunks/chunk0.wav", task="translate")

# Print the transcription and translation result
print(result["text"])
# !pip install scipy --no-cache-dir

In [None]:
# Load the Whisper model
model = whisper.load_model("turbo")

# Transcribe and translate the chunk audio of Punjabi
result = model.transcribe("/home/arjun/naren/alignment/data/Punjabi/Audio/AniBook Videos/Abdul_Kalam,_Missile_Man_Punjabi_chunks/chunk0.wav", task="translate")

# Print the transcription and translation result
print(result["text"])


# !pip install --upgrade scipy
# !pip install --upgrade transformers

In [None]:
model.generation_config.language = "punjabi"
model.generation_config.task = "transcribe"

model.generation_config.forced_decoder_ids = None

In [None]:
import torch

from dataclasses import dataclass
from typing import Any, Dict, List, Union

@dataclass
class DataCollatorSpeechSeq2SeqWithPadding:
    processor: Any
    decoder_start_token_id: int

    def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
        # split inputs and labels since they have to be of different lengths and need different padding methods
        # first treat the audio inputs by simply returning torch tensors
        input_features = [{"input_features": feature["input_features"]} for feature in features]
        batch = self.processor.feature_extractor.pad(input_features, return_tensors="pt")

        # get the tokenized label sequences
        label_features = [{"input_ids": feature["labels"]} for feature in features]
        # pad the labels to max length
        labels_batch = self.processor.tokenizer.pad(label_features, return_tensors="pt", max_length=self.processor.tokenizer.model_max_length)

        # replace padding with -100 to ignore loss correctly
        labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)

        # if bos token is appended in previous tokenization step,
        # cut bos token here as it's append later anyways
        if (labels[:, 0] == self.decoder_start_token_id).all().cpu().item():
            labels = labels[:, 1:]

        batch["labels"] = labels

        return batch

In [None]:
data_collator = DataCollatorSpeechSeq2SeqWithPadding(
    processor=processor,
    decoder_start_token_id=model.config.decoder_start_token_id,
)

In [None]:
import evaluate

metric = evaluate.load("wer")

In [None]:
def compute_metrics(pred):
    pred_ids = pred.predictions
    label_ids = pred.label_ids

    # replace -100 with the pad_token_id
    label_ids[label_ids == -100] = tokenizer.pad_token_id

    # we do not want to group tokens when computing the metrics
    pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
    label_str = tokenizer.batch_decode(label_ids, skip_special_tokens=True)

    print("Predictions: \n", pred_str)
    print("Labels: \n", label_str)

    wer = 100 * metric.compute(predictions=pred_str, references=label_str)

    return {"wer": wer}

In [None]:
from transformers import Seq2SeqTrainingArguments

training_args = Seq2SeqTrainingArguments(
    output_dir="whisper-large-v3-pa",  # change to a repo name of your choice
    per_device_train_batch_size=16,
    gradient_accumulation_steps=1,  # increase by 2x for every 2x decrease in batch size
    learning_rate=1e-4,
    warmup_steps=50,
    max_steps=100,
    gradient_checkpointing=True,
    fp16=True,
    evaluation_strategy="steps",
    per_device_eval_batch_size=8,
    predict_with_generate=True,
    generation_max_length=225,
    save_steps=50,
    eval_steps=10,
    logging_steps=10,
    report_to=["tensorboard"],
    load_best_model_at_end=True,
    metric_for_best_model="wer",
    greater_is_better=False,
    push_to_hub=False,
)

In [None]:
from transformers import Seq2SeqTrainer

trainer = Seq2SeqTrainer(
    args=training_args,
    model=model,
    train_dataset=dataset["train"],
    eval_dataset=dataset["train"],
    data_collator=data_collator,
    compute_metrics=compute_metrics,
    tokenizer=processor.feature_extractor,
)

In [None]:
trainer.train()

In [None]:
import torch
import gc
torch.cuda.empty_cache()
gc.collect()

In [None]:
!nvidia-smi